run_chromaprint_job.py 11.1 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.engines.chromaprint_matcher import ChromaprintMatcher, load_audio_mono

from _job_common import connect, emit_payload, fetch_job_context, resolve_scope_summary, update_job_status


def fetch_scope_assets(conn, target_scope: str) -> list[dict[str, object]]:
    if not target_scope.startswith('reference_set:'):
        raise SystemExit(f'unsupported target_scope for chromaprint worker: {target_scope}')
    set_name = target_scope.split(':', 1)[1]
    rows = conn.execute(
        """
        SELECT
            ra.asset_id,
            ra.storage_uri,
            ra.ingest_status,
            r.recording_id,
            r.work_id,
            r.canonical_song_id
        FROM reference_set_registry rs
        JOIN reference_set_member rsm ON rsm.reference_set_id = rs.reference_set_id
        JOIN recording_asset ra ON ra.recording_id = rsm.recording_id
        JOIN recording r ON r.recording_id = ra.recording_id
        WHERE rs.set_name = %s
          AND ra.ingest_status = 'ready'
        ORDER BY ra.asset_id;
        """,
        (set_name,),
    ).fetchall()
    return [
        {
            'asset_id': int(row[0]),
            'storage_uri': row[1],
            'ingest_status': row[2],
            'recording_id': int(row[3]),
            'work_id': int(row[4]),
            'canonical_song_id': int(row[5]),
        }
        for row in rows
    ]


def upsert_audio_fingerprint(
    conn,
    *,
    feature_set_id: int,
    asset: dict[str, object],
    fingerprint_uri: str,
    hash_count: int,
    metadata_json: dict[str, object],
) -> tuple[int, str]:
    row = conn.execute(
        """
        INSERT INTO audio_fingerprint (
            feature_set_id, asset_id, window_id, recording_id, work_id, canonical_song_id,
            fingerprint_uri, hash_count, is_indexed, metadata_json
        ) VALUES (
            %s, %s, NULL, %s, %s, %s,
            %s, %s, TRUE, %s::jsonb
        )
        ON CONFLICT (feature_set_id, asset_id)
        DO UPDATE SET
            fingerprint_uri = EXCLUDED.fingerprint_uri,
            hash_count = EXCLUDED.hash_count,
            is_indexed = EXCLUDED.is_indexed,
            metadata_json = EXCLUDED.metadata_json
        RETURNING fingerprint_id;
        """,
        (
            feature_set_id,
            asset['asset_id'],
            asset['recording_id'],
            asset['work_id'],
            asset['canonical_song_id'],
            fingerprint_uri,
            hash_count,
            json.dumps(metadata_json, ensure_ascii=False),
        ),
    ).fetchone()
    return int(row[0]), 'upserted'


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--dsn', default=os.environ.get('PG_DSN'))
    ap.add_argument('--schema', default=os.environ.get('PG_SCHEMA', 'acr_test'))
    ap.add_argument('--job-id', type=int, default=int(os.environ.get('EXTRACTION_JOB_ID', '0')))
    ap.add_argument('--output-target', default=os.environ.get('OUTPUT_TARGET', 'audio_fingerprint'))
    ap.add_argument('--complete-dry-run', action='store_true')
    ap.add_argument('--artifact-dir', default=str(ROOT / 'data' / 'pgvector_eval' / 'music20' / 'phase1_fingerprints'))
    ap.add_argument('--output')
    args = ap.parse_args()

    if not args.dsn:
        raise SystemExit('missing --dsn or PG_DSN')
    if not args.job_id:
        raise SystemExit('missing --job-id or EXTRACTION_JOB_ID')

    with connect(args.dsn, args.schema) as conn:
        job = fetch_job_context(conn, args.job_id)
        if job.model_name != 'chromaprint':
            raise SystemExit(f'feature_extraction_job={args.job_id} is not a chromaprint job')
        scope = resolve_scope_summary(conn, job.target_scope)
        scope_assets = fetch_scope_assets(conn, job.target_scope)
        running = update_job_status(
            conn,
            job.extraction_job_id,
            status='running',
            expected_status='pending',
            input_count=scope['ready_asset_count'],
            metadata_patch={
                'worker': 'run_chromaprint_job',
                'output_target': args.output_target,
                'dry_run': bool(args.complete_dry_run),
                'target_scope_summary': scope,
                'execution_mode': 'dry_run' if args.complete_dry_run else 'write_attempt',
            },
            set_started_at=True,
        )
        completed = None
        failed = None
        processed_assets: list[dict[str, object]] = []
        missing_assets: list[dict[str, object]] = []
        artifact_dir = Path(args.artifact_dir)
        artifact_dir.mkdir(parents=True, exist_ok=True)

        if args.complete_dry_run:
            completed = update_job_status(
                conn,
                job.extraction_job_id,
                status='completed',
                expected_status='running',
                output_count=0,
                metadata_patch={
                    'worker': 'run_chromaprint_job',
                    'output_target': args.output_target,
                    'dry_run': True,
                    'dry_run_result': 'completed_without_feature_write',
                    'write_target_table': 'audio_fingerprint',
                },
                set_finished_at=True,
            )
        else:
            matcher = ChromaprintMatcher(sr=job.input_sample_rate or 16000)
            extracted_assets: list[dict[str, object]] = []
            for asset in scope_assets:
                asset_path = Path(str(asset['storage_uri']))
                if not asset_path.exists():
                    missing_assets.append({
                        'asset_id': asset['asset_id'],
                        'storage_uri': str(asset_path),
                        'reason': 'missing_audio',
                    })
                    continue
                try:
                    y, _ = load_audio_mono(str(asset_path), sr=matcher.sr)
                    hashes = matcher.extract_hashes(y)
                    extracted_assets.append({
                        'asset': asset,
                        'hashes': hashes,
                    })
                except Exception as exc:  # noqa: BLE001
                    missing_assets.append({
                        'asset_id': asset['asset_id'],
                        'storage_uri': str(asset_path),
                        'reason': 'decode_or_extract_failure',
                        'error': str(exc),
                    })

            if missing_assets:
                failed = update_job_status(
                    conn,
                    job.extraction_job_id,
                    status='failed',
                    expected_status='running',
                    output_count=0,
                    metadata_patch={
                        'worker': 'run_chromaprint_job',
                        'output_target': args.output_target,
                        'dry_run': False,
                        'write_target_table': 'audio_fingerprint',
                        'artifact_dir': str(artifact_dir),
                        'failure_reason': 'unreadable_audio_assets',
                        'missing_asset_count': len(missing_assets),
                        'missing_asset_samples': missing_assets[:5],
                    },
                    set_finished_at=True,
                )
            else:
                for extracted in extracted_assets:
                    asset = extracted['asset']
                    hashes = extracted['hashes']
                    artifact_path = artifact_dir / f"job{job.extraction_job_id}_asset{asset['asset_id']}.json"
                    artifact_payload = {
                        'feature_set_id': job.feature_set_id,
                        'extraction_job_id': job.extraction_job_id,
                        'asset_id': asset['asset_id'],
                        'recording_id': asset['recording_id'],
                        'hash_count': len(hashes),
                        'hashes': [[int(h), int(offset)] for h, offset in hashes],
                    }
                    artifact_path.write_text(json.dumps(artifact_payload, ensure_ascii=False, indent=2), encoding='utf-8')
                    fingerprint_id, operation = upsert_audio_fingerprint(
                        conn,
                        feature_set_id=job.feature_set_id,
                        asset=asset,
                        fingerprint_uri=str(artifact_path),
                        hash_count=len(hashes),
                        metadata_json={
                            'worker': 'run_chromaprint_job',
                            'model_name': job.model_name,
                            'model_version': job.model_version,
                            'extraction_job_id': job.extraction_job_id,
                            'hash_encoding': 'repo-local-chromaprint-matcher',
                            'artifact_format': 'json_hash_pairs_v1',
                        },
                    )
                    processed_assets.append({
                        'asset_id': asset['asset_id'],
                        'recording_id': asset['recording_id'],
                        'fingerprint_id': fingerprint_id,
                        'hash_count': len(hashes),
                        'fingerprint_uri': str(artifact_path),
                        'operation': operation,
                    })
                completed = update_job_status(
                    conn,
                    job.extraction_job_id,
                    status='completed',
                    expected_status='running',
                    output_count=len(processed_assets),
                    metadata_patch={
                        'worker': 'run_chromaprint_job',
                        'output_target': args.output_target,
                        'dry_run': False,
                        'write_target_table': 'audio_fingerprint',
                        'artifact_dir': str(artifact_dir),
                        'processed_asset_count': len(processed_assets),
                        'missing_asset_count': len(missing_assets),
                    },
                    set_finished_at=True,
                )

    emit_payload(
        {
            'worker': 'run_chromaprint_job',
            'schema': args.schema,
            'job': job.__dict__,
            'target_scope_summary': scope,
            'scope_asset_count': len(scope_assets),
            'processed_assets': processed_assets,
            'missing_assets': missing_assets,
            'status_after_start': running,
            'status_after_complete': completed,
            'status_after_failed': failed,
            'next_write_target': 'audio_fingerprint',
            'notes': [
                'dry-run preserves the verified planner -> job -> PostgreSQL state flow',
                'non-dry-run now writes repo-local chromaprint-style hash artifacts plus audio_fingerprint rows when source audio is readable',
            ],
        },
        args.output,
    )


if __name__ == '__main__':
    main()