run_embedding_vector_table_negative_matrix_live.py 5.01 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
import subprocess
from pathlib import Path
import sys
from typing import Any

import psycopg

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from workers._job_common import validate_schema

PYTHON_BIN = '/usr/local/miniconda3/bin/python'
DEFAULT_OUTPUT = ROOT / 'data' / 'pgvector_eval' / 'music20' / 'embedding_vector_table_negative_matrix_report.json'
SOURCE_SCHEMA = 'acr_test'
MINIMAL_TABLES = [
    'canonical_song',
    'work',
    'recording',
    'recording_asset',
    'audio_window',
    'model_registry',
    'feature_set_registry',
    'feature_extraction_job',
    'reference_set_registry',
    'reference_set_member',
]


def run_cmd(cmd: list[str]) -> subprocess.CompletedProcess[str]:
    return subprocess.run(cmd, cwd=ROOT, capture_output=True, text=True)


def reset_source_jobs(dsn: str) -> None:
    proc = run_cmd([
        PYTHON_BIN,
        'scripts/bootstrap_phase1_extraction_jobs_live.py',
        '--dsn', dsn,
        '--schema', SOURCE_SCHEMA,
    ])
    if proc.returncode != 0:
        raise SystemExit(proc.stderr or proc.stdout)


def clone_minimal_schema_without_vectors(dsn: str, target_schema: str) -> None:
    target_schema = validate_schema(target_schema)
    with psycopg.connect(dsn, autocommit=True) as conn:
        conn.execute(f'DROP SCHEMA IF EXISTS {target_schema} CASCADE;')
        conn.execute(f'CREATE SCHEMA {target_schema};')
        for table_name in MINIMAL_TABLES:
            conn.execute(f'CREATE TABLE {target_schema}.{table_name} AS TABLE {SOURCE_SCHEMA}.{table_name} WITH DATA;')


def run_worker_case(*, dsn: str, schema: str, vector_table: str, output_name: str) -> dict[str, Any]:
    out = ROOT / 'data' / 'pgvector_eval' / 'music20' / output_name
    proc = run_cmd([
        PYTHON_BIN,
        'workers/run_embedding_job.py',
        '--dsn', dsn,
        '--schema', schema,
        '--job-id', '2',
        '--model-name', 'mert',
        '--model-version', 'v1-95m',
        '--vector-table', vector_table,
        '--output', str(out),
    ])
    if proc.returncode != 0:
        raise SystemExit(proc.stderr or proc.stdout)
    payload = json.loads(out.read_text(encoding='utf-8'))
    failed = payload.get('status_after_failed') or {}
    metadata = failed.get('metadata_json') or {}
    return {
        'schema': schema,
        'vector_table': vector_table,
        'job_status': failed.get('job_status'),
        'failure_reason': metadata.get('failure_reason'),
        'preflight_blockers': metadata.get('preflight_blockers'),
        'vector_table_report': metadata.get('vector_table_report'),
        'artifact': str(out.relative_to(ROOT)),
    }


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--dsn', required=True)
    ap.add_argument('--output', default=str(DEFAULT_OUTPUT))
    ap.add_argument('--missing-table-schema', default='acr_vector_table_missing_test')
    args = ap.parse_args()

    reset_source_jobs(args.dsn)
    dim_mismatch = run_worker_case(
        dsn=args.dsn,
        schema=SOURCE_SCHEMA,
        vector_table='audio_embedding_vector_192',
        output_name='embedding_vector_table_dim_mismatch_attempt.json',
    )

    reset_source_jobs(args.dsn)
    not_allowlisted = run_worker_case(
        dsn=args.dsn,
        schema=SOURCE_SCHEMA,
        vector_table='audio_embedding_vector_1024',
        output_name='embedding_vector_table_not_allowlisted_attempt.json',
    )

    reset_source_jobs(args.dsn)
    clone_minimal_schema_without_vectors(args.dsn, args.missing_table_schema)
    missing_table = run_worker_case(
        dsn=args.dsn,
        schema=args.missing_table_schema,
        vector_table='audio_embedding_vector_768',
        output_name='embedding_vector_table_missing_in_schema_attempt.json',
    )

    payload = {
        'source_schema': SOURCE_SCHEMA,
        'missing_table_schema': args.missing_table_schema,
        'dsn_redacted': 'postgres://d2:***@127.0.0.1:5432/d2',
        'cases': [
            {'case': 'vector_table_dim_mismatch', **dim_mismatch},
            {'case': 'vector_table_not_allowlisted', **not_allowlisted},
            {'case': 'vector_table_missing_in_schema', **missing_table},
        ],
        'summary': {
            'expected_reasons': {
                'vector_table_dim_mismatch': dim_mismatch['vector_table_report'].get('reason'),
                'vector_table_not_allowlisted': not_allowlisted['vector_table_report'].get('reason'),
                'vector_table_missing_in_schema': missing_table['vector_table_report'].get('reason'),
            },
            'all_failed': all(item['job_status'] == 'failed' for item in [dim_mismatch, not_allowlisted, missing_table]),
        },
    }
    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(payload, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()