plan_phase1_extraction_jobs_live.py 7.97 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path
import sys
from typing import Any

import psycopg

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from workers._job_common import validate_schema

DEFAULT_OUTPUT = ROOT / 'data' / 'pgvector_eval' / 'music20' / 'phase1_extraction_plan_report.json'

LANE_PRIORITY = {
    'exact': 0,
    'semantic': 1,
    'cover': 2,
}
PYTHON_BIN = '/usr/local/miniconda3/bin/python'


def parse_target_scope(target_scope: str) -> dict[str, Any]:
    if ':' in target_scope:
        scope_type, scope_value = target_scope.split(':', 1)
        return {'scope_type': scope_type, 'scope_value': scope_value}
    return {'scope_type': 'unknown', 'scope_value': target_scope}


def build_command_suggestions(job: dict[str, Any], schema: str) -> list[str]:
    command_prefix = 'cd /workspace/acr-engine && '
    base_env = (
        command_prefix
        + 'PG_DSN="${PG_DSN:?set PG_DSN}" '
        f"EXTRACTION_JOB_ID={job['extraction_job_id']} "
        f"FEATURE_SET_ID={job['feature_set_id']} "
        f"TARGET_SCOPE='{job['target_scope']}' "
        f"PG_SCHEMA={schema}"
    )
    commands = []
    if job['lane'] == 'exact':
        commands.append(
            base_env
            + f" OUTPUT_TARGET=audio_fingerprint \\\n{PYTHON_BIN} workers/run_chromaprint_job.py --complete-dry-run"
        )
    else:
        commands.append(
            base_env
            + f" MODEL_NAME={job['model_name']} MODEL_VERSION={job['model_version']} VECTOR_TABLE={job['vector_table']} OUTPUT_TARGET={job['physical_target']} \\\n{PYTHON_BIN} workers/run_embedding_job.py --complete-dry-run"
        )
    commands.append(
        base_env
        + f" \\\n{PYTHON_BIN} workers/mark_job_status.py --status running --expected-status pending"
    )
    return commands


def build_validation_commands(schema: str) -> dict[str, str]:
    command_prefix = 'cd /workspace/acr-engine && '
    base = command_prefix + 'PG_DSN="${PG_DSN:?set PG_DSN}" '
    return {
        'prereq_audit': (
            base
            + f"{PYTHON_BIN} scripts/run_phase1_prereq_audit_live.py --dsn \"$PG_DSN\" --schema {schema} --output data/pgvector_eval/music20/phase1_prereq_audit_report.json"
        ),
        'worker_contract_smoke': (
            base
            + f"{PYTHON_BIN} scripts/run_phase1_worker_contract_smoke_live.py --dsn \"$PG_DSN\" --schema {schema} --output data/pgvector_eval/music20/phase1_worker_contract_smoke_report.json"
        ),
        'semantic_vector_negative_matrix': (
            base
            + f"{PYTHON_BIN} scripts/run_embedding_vector_table_negative_matrix_live.py --dsn \"$PG_DSN\" --output data/pgvector_eval/music20/embedding_vector_table_negative_matrix_report.json"
        ),
        'asset_level_upsert_validation': (
            base
            + f"{PYTHON_BIN} scripts/validate_audio_embedding_asset_upsert_live.py --dsn \"$PG_DSN\" --schema acr_asset_upsert_test --output data/pgvector_eval/music20/audio_embedding_asset_upsert_live_report.json"
        ),
    }


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--dsn', required=True)
    ap.add_argument('--schema', default='acr_test')
    ap.add_argument('--job-status', default='pending')
    ap.add_argument('--output', default=str(DEFAULT_OUTPUT))
    args = ap.parse_args()

    schema = validate_schema(args.schema)
    with psycopg.connect(args.dsn) as conn:
        conn.execute(f'SET search_path TO {schema}, public;')
        rows = conn.execute(
            """
            SELECT
                fej.extraction_job_id,
                fej.feature_set_id,
                fej.target_scope,
                fej.job_status,
                fej.shard_key,
                fej.metadata_json,
                fs.feature_name,
                fs.feature_level,
                fs.extraction_granularity,
                fs.window_sec,
                fs.hop_sec,
                fs.embedding_dim,
                fs.distance_metric,
                mr.model_name,
                mr.model_version,
                mr.model_family,
                mr.output_embedding_dim,
                mr.input_sample_rate,
                mr.default_window_sec,
                mr.default_hop_sec,
                mr.metadata_json
            FROM feature_extraction_job fej
            JOIN feature_set_registry fs ON fs.feature_set_id = fej.feature_set_id
            JOIN model_registry mr ON mr.model_id = fs.model_id
            WHERE fej.job_status = %s
            ORDER BY fej.extraction_job_id;
            """,
            (args.job_status,),
        ).fetchall()

    jobs = []
    by_lane: dict[str, list[dict[str, Any]]] = {}
    for row in rows:
        job_meta = row[5] or {}
        model_meta = row[20] or {}
        lane = job_meta.get('lane') or model_meta.get('lane') or 'unknown'
        scope = parse_target_scope(row[2])
        physical_target = 'audio_fingerprint' if row[6] == 'fingerprint_asset' else 'audio_embedding'
        vector_table = None
        if row[11] == 192:
            vector_table = 'audio_embedding_vector_192'
        elif row[11] == 768:
            vector_table = 'audio_embedding_vector_768'

        item = {
            'priority_rank': LANE_PRIORITY.get(lane, 99),
            'lane': lane,
            'extraction_job_id': row[0],
            'feature_set_id': row[1],
            'target_scope': row[2],
            'scope': scope,
            'job_status': row[3],
            'shard_key': row[4],
            'model_name': row[13],
            'model_version': row[14],
            'model_family': row[15],
            'input_sample_rate': row[17],
            'feature_name': row[6],
            'feature_level': row[7],
            'extraction_granularity': row[8],
            'window_sec': float(row[9]) if row[9] is not None else None,
            'hop_sec': float(row[10]) if row[10] is not None else None,
            'embedding_dim': row[11],
            'distance_metric': row[12],
            'physical_target': physical_target,
            'vector_table': vector_table,
            'job_metadata': job_meta,
            'model_metadata': model_meta,
            'execution_notes': [
                f"run feature extraction for {row[13]} {row[14]}",
                f"write to {physical_target}" + (f" + {vector_table}" if vector_table else ''),
                f"target scope: {row[2]}",
            ],
        }
        item['command_suggestions'] = build_command_suggestions(item, schema)
        jobs.append(item)
        by_lane.setdefault(lane, []).append(item)

    jobs.sort(key=lambda x: (x['priority_rank'], x['extraction_job_id']))
    for lane_jobs in by_lane.values():
        lane_jobs.sort(key=lambda x: x['extraction_job_id'])

    payload = {
        'schema': schema,
        'dsn_redacted': 'postgres://d2:***@127.0.0.1:5432/d2',
        'job_status_filter': args.job_status,
        'counts': {
            'jobs': len(jobs),
            'lanes': {lane: len(items) for lane, items in sorted(by_lane.items())},
        },
        'ordered_jobs': jobs,
        'by_lane': by_lane,
        'validation_commands': build_validation_commands(schema),
        'execution_order_summary': [
            {
                'order': idx + 1,
                'extraction_job_id': job['extraction_job_id'],
                'lane': job['lane'],
                'model_name': job['model_name'],
                'feature_name': job['feature_name'],
                'window_sec': job['window_sec'],
                'hop_sec': job['hop_sec'],
                'physical_target': job['physical_target'],
                'primary_command': job['command_suggestions'][0],
            }
            for idx, job in enumerate(jobs)
        ],
    }

    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(payload, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()