validate_audio_embedding_asset_upsert_live.py 11.4 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path
import sys
from typing import Any

import psycopg

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from workers._job_common import validate_schema

DEFAULT_SCHEMA_SQL = ROOT / 'sql' / 'acr_pg_schema_v2.sql'
DEFAULT_OUTPUT = ROOT / 'data' / 'pgvector_eval' / 'music20' / 'audio_embedding_asset_upsert_live_report.json'


def vec_literal(vec: list[float]) -> str:
    return '[' + ','.join(f'{x:.10f}' for x in vec) + ']'


def reset_schema(conn: psycopg.Connection, schema: str) -> None:
    schema = validate_schema(schema)
    conn.execute(f'DROP SCHEMA IF EXISTS {schema} CASCADE;')
    conn.execute(f'CREATE SCHEMA {schema};')
    conn.execute(f'SET search_path TO {schema}, public;')


def apply_schema(conn: psycopg.Connection, schema_sql: Path) -> None:
    conn.execute(schema_sql.read_text(encoding='utf-8'))


def seed_minimal_graph(conn: psycopg.Connection) -> dict[str, int]:
    model_id = conn.execute(
        """
        INSERT INTO model_registry (
            model_name, model_family, model_version, model_source, model_uri,
            license_name, input_sample_rate, default_window_sec, default_hop_sec,
            output_embedding_dim, pooling_supported, metadata_json
        ) VALUES (
            'asset_level_probe', 'probe', 'v1', 'live-test',
            'scripts/validate_audio_embedding_asset_upsert_live.py', 'internal-eval',
            16000, 5.0, 2.5, 192, ARRAY['none'], '{}'::jsonb
        )
        RETURNING model_id;
        """
    ).fetchone()[0]
    feature_set_id = conn.execute(
        """
        INSERT INTO feature_set_registry (
            model_id, feature_name, feature_level, extraction_granularity,
            window_sec, hop_sec, embedding_dim, pooling_strategy, layer_selection,
            normalize_l2, distance_metric, quantization_type, feature_schema_version,
            config_json, status
        ) VALUES (
            %s, 'semantic_embedding', 'asset', 'whole_asset',
            5.0, 2.5, 192, 'none', 'na', TRUE, 'cosine', NULL, 'v1',
            '{"probe":"asset_level_upsert"}'::jsonb, 'active'
        )
        RETURNING feature_set_id;
        """,
        (model_id,),
    ).fetchone()[0]
    canonical_song_id = conn.execute(
        """
        INSERT INTO canonical_song (biz_song_code, title, rights_status, metadata_json)
        VALUES ('asset-probe-song', 'Asset Probe Song', 'protected', '{}'::jsonb)
        RETURNING canonical_song_id;
        """
    ).fetchone()[0]
    work_id = conn.execute(
        """
        INSERT INTO work (canonical_song_id, work_code, work_title, metadata_json)
        VALUES (%s, 'asset-probe-work', 'Asset Probe Work', '{}'::jsonb)
        RETURNING work_id;
        """,
        (canonical_song_id,),
    ).fetchone()[0]
    recording_id = conn.execute(
        """
        INSERT INTO recording (
            work_id, canonical_song_id, recording_code, recording_title,
            version_type, is_reference, duration_sec, metadata_json
        ) VALUES (%s, %s, 'asset-probe-rec', 'Asset Probe Recording', 'master_reference', TRUE, 5.0, '{}'::jsonb)
        RETURNING recording_id;
        """,
        (work_id, canonical_song_id),
    ).fetchone()[0]
    asset_id = conn.execute(
        """
        INSERT INTO recording_asset (
            recording_id, asset_role, storage_uri, storage_scheme, file_ext,
            mime_type, sample_rate, channels, codec_name, duration_sec,
            normalized_storage_uri, ingest_status, metadata_json
        ) VALUES (
            %s, 'reference_audio', '/tmp/asset-probe.wav', 'file', 'wav',
            'audio/wav', 16000, 1, 'pcm_s16le', 5.0,
            '/tmp/asset-probe.wav', 'ready', '{}'::jsonb
        )
        RETURNING asset_id;
        """,
        (recording_id,),
    ).fetchone()[0]
    return {
        'model_id': int(model_id),
        'feature_set_id': int(feature_set_id),
        'canonical_song_id': int(canonical_song_id),
        'work_id': int(work_id),
        'recording_id': int(recording_id),
        'asset_id': int(asset_id),
    }


def insert_asset_embedding(conn: psycopg.Connection, ids: dict[str, int], *, checksum: str, metadata: dict[str, Any], vec: list[float]) -> int:
    embedding_id = conn.execute(
        """
        INSERT INTO audio_embedding (
            feature_set_id, extraction_job_id, asset_id, window_id, recording_id, work_id,
            canonical_song_id, embedding_storage_mode, embedding_uri, vector_norm, checksum,
            is_indexed, metadata_json
        ) VALUES (
            %s, NULL, %s, NULL, %s, %s,
            %s, 'pgvector_inline_192', 'inline://asset-probe', 1.0, %s,
            TRUE, %s::jsonb
        )
        RETURNING embedding_id;
        """,
        (
            ids['feature_set_id'],
            ids['asset_id'],
            ids['recording_id'],
            ids['work_id'],
            ids['canonical_song_id'],
            checksum,
            json.dumps(metadata, ensure_ascii=False),
        ),
    ).fetchone()[0]
    conn.execute(
        'INSERT INTO audio_embedding_vector_192 (embedding_id, embedding) VALUES (%s, %s::vector);',
        (embedding_id, vec_literal(vec)),
    )
    return int(embedding_id)


def expect_duplicate_insert_failure(conn: psycopg.Connection, ids: dict[str, int]) -> dict[str, Any]:
    try:
        with conn.transaction():
            conn.execute(
                """
                INSERT INTO audio_embedding (
                    feature_set_id, extraction_job_id, asset_id, window_id, recording_id, work_id,
                    canonical_song_id, embedding_storage_mode, embedding_uri, vector_norm, checksum,
                    is_indexed, metadata_json
                ) VALUES (
                    %s, NULL, %s, NULL, %s, %s,
                    %s, 'pgvector_inline_192', 'inline://asset-probe-duplicate', 1.0, 'dup-checksum',
                    TRUE, '{"probe":"duplicate_insert"}'::jsonb
                );
                """,
                (
                    ids['feature_set_id'],
                    ids['asset_id'],
                    ids['recording_id'],
                    ids['work_id'],
                    ids['canonical_song_id'],
                ),
            )
        return {'passed': False, 'note': 'duplicate asset-level insert unexpectedly succeeded'}
    except Exception as exc:  # noqa: BLE001
        return {
            'passed': 'uq_audio_embedding_feature_asset' in str(exc),
            'error_type': type(exc).__name__,
            'message': str(exc).splitlines()[0],
        }


def upsert_asset_embedding(conn: psycopg.Connection, ids: dict[str, int], *, checksum: str, metadata: dict[str, Any], vec: list[float]) -> int:
    embedding_id = conn.execute(
        """
        INSERT INTO audio_embedding (
            feature_set_id, extraction_job_id, asset_id, window_id, recording_id, work_id,
            canonical_song_id, embedding_storage_mode, embedding_uri, vector_norm, checksum,
            is_indexed, metadata_json
        ) VALUES (
            %s, NULL, %s, NULL, %s, %s,
            %s, 'pgvector_inline_192', 'inline://asset-probe-upsert', 1.0, %s,
            TRUE, %s::jsonb
        )
        ON CONFLICT (feature_set_id, asset_id)
        WHERE window_id IS NULL AND asset_id IS NOT NULL
        DO UPDATE SET
            checksum = EXCLUDED.checksum,
            embedding_uri = EXCLUDED.embedding_uri,
            metadata_json = EXCLUDED.metadata_json,
            is_indexed = EXCLUDED.is_indexed,
            vector_norm = EXCLUDED.vector_norm
        RETURNING embedding_id;
        """,
        (
            ids['feature_set_id'],
            ids['asset_id'],
            ids['recording_id'],
            ids['work_id'],
            ids['canonical_song_id'],
            checksum,
            json.dumps(metadata, ensure_ascii=False),
        ),
    ).fetchone()[0]
    conn.execute(
        """
        INSERT INTO audio_embedding_vector_192 (embedding_id, embedding)
        VALUES (%s, %s::vector)
        ON CONFLICT (embedding_id)
        DO UPDATE SET embedding = EXCLUDED.embedding;
        """,
        (embedding_id, vec_literal(vec)),
    )
    return int(embedding_id)


def fetch_final_state(conn: psycopg.Connection, embedding_id: int) -> dict[str, Any]:
    row = conn.execute(
        """
        SELECT ae.embedding_id, ae.asset_id, ae.window_id, ae.checksum, ae.embedding_uri, ae.metadata_json,
               aev.embedding::text
        FROM audio_embedding ae
        JOIN audio_embedding_vector_192 aev ON aev.embedding_id = ae.embedding_id
        WHERE ae.embedding_id = %s;
        """,
        (embedding_id,),
    ).fetchone()
    return {
        'embedding_id': int(row[0]),
        'asset_id': int(row[1]),
        'window_id': row[2],
        'checksum': row[3],
        'embedding_uri': row[4],
        'metadata_json': row[5] or {},
        'vector_literal': row[6],
    }


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--dsn', required=True)
    ap.add_argument('--schema', default='acr_asset_upsert_test')
    ap.add_argument('--schema-sql', default=str(DEFAULT_SCHEMA_SQL))
    ap.add_argument('--output', default=str(DEFAULT_OUTPUT))
    args = ap.parse_args()

    initial_vec = [0.1] * 192
    updated_vec = [0.2] * 192

    payload: dict[str, Any] = {
        'schema': args.schema,
        'dsn_redacted': 'postgres://d2:***@127.0.0.1:5432/d2',
    }
    with psycopg.connect(args.dsn, autocommit=True) as conn:
        reset_schema(conn, args.schema)
        apply_schema(conn, Path(args.schema_sql))
        ids = seed_minimal_graph(conn)
        payload['seed_ids'] = ids

        first_embedding_id = insert_asset_embedding(
            conn,
            ids,
            checksum='checksum-v1',
            metadata={'probe': 'asset_level_insert_v1'},
            vec=initial_vec,
        )
        payload['first_insert_embedding_id'] = first_embedding_id
        payload['duplicate_insert_guard'] = expect_duplicate_insert_failure(conn, ids)

        upsert_embedding_id = upsert_asset_embedding(
            conn,
            ids,
            checksum='checksum-v2',
            metadata={'probe': 'asset_level_upsert_v2'},
            vec=updated_vec,
        )
        payload['upsert_embedding_id'] = upsert_embedding_id
        payload['same_embedding_id_reused'] = first_embedding_id == upsert_embedding_id
        payload['counts'] = {
            'audio_embedding': int(conn.execute('SELECT count(*) FROM audio_embedding;').fetchone()[0]),
            'audio_embedding_vector_192': int(conn.execute('SELECT count(*) FROM audio_embedding_vector_192;').fetchone()[0]),
        }
        payload['final_state'] = fetch_final_state(conn, upsert_embedding_id)
        payload['passed'] = (
            payload['duplicate_insert_guard'].get('passed')
            and payload['same_embedding_id_reused']
            and payload['counts']['audio_embedding'] == 1
            and payload['counts']['audio_embedding_vector_192'] == 1
            and payload['final_state']['checksum'] == 'checksum-v2'
            and payload['final_state']['metadata_json'].get('probe') == 'asset_level_upsert_v2'
        )

    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(payload, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()