bootstrap_phase1_extraction_jobs_live.py 6.87 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path
import sys
from typing import Any

import psycopg

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from workers._job_common import validate_schema

DEFAULT_OUTPUT = ROOT / 'data' / 'pgvector_eval' / 'music20' / 'phase1_extraction_jobs_report.json'

JOB_SPECS = [
    {
        'model_name': 'chromaprint',
        'model_version': 'v1',
        'feature_name': 'fingerprint_asset',
        'window_sec': 5.0,
        'hop_sec': 2.5,
        'target_scope': 'reference_set:phase1_hot_reference_v1',
        'job_status': 'pending',
        'shard_key': 'phase1/reference/chromaprint/v1',
        'metadata_json': {'lane': 'exact', 'phase': 'phase1', 'priority': 'p0'},
    },
    {
        'model_name': 'mert',
        'model_version': 'v1-95m',
        'feature_name': 'semantic_embedding',
        'window_sec': 5.0,
        'hop_sec': 2.5,
        'target_scope': 'reference_set:phase1_hot_reference_v1',
        'job_status': 'pending',
        'shard_key': 'phase1/reference/mert/v1-95m/5s_2.5s',
        'metadata_json': {'lane': 'semantic', 'role': 'primary_baseline', 'phase': 'phase1'},
    },
    {
        'model_name': 'mert',
        'model_version': 'v1-95m',
        'feature_name': 'semantic_embedding',
        'window_sec': 10.0,
        'hop_sec': 5.0,
        'target_scope': 'reference_set:phase1_hot_reference_v1',
        'job_status': 'pending',
        'shard_key': 'phase1/reference/mert/v1-95m/10s_5s',
        'metadata_json': {'lane': 'semantic', 'role': 'long_context_validation', 'phase': 'phase1'},
    },
    {
        'model_name': 'muq',
        'model_version': 'large-msd-iter',
        'feature_name': 'semantic_embedding',
        'window_sec': 5.0,
        'hop_sec': 2.5,
        'target_scope': 'reference_set:phase1_hot_reference_v1',
        'job_status': 'pending',
        'shard_key': 'phase1/reference/muq/large-msd-iter/5s_2.5s',
        'metadata_json': {'lane': 'semantic', 'role': 'challenger', 'phase': 'phase1'},
    },
    {
        'model_name': 'ecapa',
        'model_version': 'acr-baseline-v1',
        'feature_name': 'semantic_embedding',
        'window_sec': 5.0,
        'hop_sec': 2.5,
        'target_scope': 'reference_set:phase1_hot_reference_v1',
        'job_status': 'pending',
        'shard_key': 'phase1/reference/ecapa/acr-baseline-v1/5s_2.5s',
        'metadata_json': {'lane': 'semantic', 'role': 'historical_baseline', 'phase': 'phase1'},
    },
]


def resolve_feature_set_id(conn: psycopg.Connection, job: dict[str, Any]) -> int:
    row = conn.execute(
        """
        SELECT fs.feature_set_id
        FROM feature_set_registry fs
        JOIN model_registry mr ON mr.model_id = fs.model_id
        WHERE mr.model_name = %s
          AND mr.model_version = %s
          AND fs.feature_name = %s
          AND coalesce(fs.window_sec, -1) = coalesce(%s, -1)
          AND coalesce(fs.hop_sec, -1) = coalesce(%s, -1)
        ORDER BY fs.feature_set_id
        LIMIT 1;
        """,
        (
            job['model_name'],
            job['model_version'],
            job['feature_name'],
            job['window_sec'],
            job['hop_sec'],
        ),
    ).fetchone()
    if not row:
        raise RuntimeError(
            f"Feature set not found for {job['model_name']} {job['model_version']} {job['feature_name']} {job['window_sec']}/{job['hop_sec']}"
        )
    return int(row[0])


def ensure_job(conn: psycopg.Connection, feature_set_id: int, job: dict[str, Any]) -> tuple[int, str]:
    existing = conn.execute(
        """
        SELECT extraction_job_id
        FROM feature_extraction_job
        WHERE feature_set_id = %s
          AND target_scope = %s
          AND coalesce(shard_key, '') = coalesce(%s, '')
        ORDER BY extraction_job_id
        LIMIT 1;
        """,
        (feature_set_id, job['target_scope'], job['shard_key']),
    ).fetchone()
    if existing:
        conn.execute(
            """
            UPDATE feature_extraction_job
            SET job_status = %s,
                input_count = NULL,
                output_count = NULL,
                started_at = NULL,
                finished_at = NULL,
                log_uri = NULL,
                metadata_json = %s::jsonb
            WHERE extraction_job_id = %s;
            """,
            (job['job_status'], json.dumps(job['metadata_json']), existing[0]),
        )
        return int(existing[0]), 'reused'

    row = conn.execute(
        """
        INSERT INTO feature_extraction_job (
            feature_set_id, target_scope, job_status, shard_key,
            input_count, output_count, started_at, finished_at,
            log_uri, metadata_json
        ) VALUES (
            %s, %s, %s, %s,
            NULL, NULL, NULL, NULL,
            NULL, %s::jsonb
        )
        RETURNING extraction_job_id;
        """,
        (feature_set_id, job['target_scope'], job['job_status'], job['shard_key'], json.dumps(job['metadata_json'])),
    ).fetchone()
    return int(row[0]), 'inserted'


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--dsn', required=True)
    ap.add_argument('--schema', default='acr_test')
    ap.add_argument('--output', default=str(DEFAULT_OUTPUT))
    args = ap.parse_args()
    schema = validate_schema(args.schema)

    summary: dict[str, Any] = {
        'schema': schema,
        'dsn_redacted': 'postgres://d2:***@127.0.0.1:5432/d2',
        'jobs': [],
    }
    with psycopg.connect(args.dsn, autocommit=True) as conn:
        conn.execute(f'SET search_path TO {schema}, public;')
        for job in JOB_SPECS:
            feature_set_id = resolve_feature_set_id(conn, job)
            extraction_job_id, operation = ensure_job(conn, feature_set_id, job)
            summary['jobs'].append({
                'extraction_job_id': extraction_job_id,
                'feature_set_id': feature_set_id,
                'model_name': job['model_name'],
                'model_version': job['model_version'],
                'feature_name': job['feature_name'],
                'window_sec': job['window_sec'],
                'hop_sec': job['hop_sec'],
                'target_scope': job['target_scope'],
                'job_status': job['job_status'],
                'operation': operation,
            })
        summary['counts'] = {
            'feature_extraction_job': int(conn.execute('SELECT count(*) FROM feature_extraction_job;').fetchone()[0]),
            'pending_jobs': int(conn.execute("SELECT count(*) FROM feature_extraction_job WHERE job_status = 'pending';").fetchone()[0]),
        }

    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(summary, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()