smoke_songcentric_schema_live.py 5.52 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path

import psycopg
from psycopg.rows import dict_row


def quote_ident(name: str) -> str:
    return '"' + name.replace('"', '""') + '"'


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--dsn', required=True)
    parser.add_argument('--schema', default='acr_songcentric_test')
    parser.add_argument('--sql', default='acr-engine/sql/acr_pg_schema_songcentric_v1.sql')
    parser.add_argument('--output', default='acr-engine/data/pgvector_eval/music20/songcentric_schema_smoke_report.json')
    args = parser.parse_args()

    sql_path = Path('/workspace') / args.sql
    output_path = Path('/workspace') / args.output
    output_path.parent.mkdir(parents=True, exist_ok=True)
    schema = args.schema
    qschema = quote_ident(schema)

    report: dict = {'schema': schema, 'sql_path': str(sql_path.relative_to('/workspace'))}

    with psycopg.connect(args.dsn, row_factory=dict_row) as conn:
        conn.execute(f'drop schema if exists {qschema} cascade')
        conn.execute(f'create schema {qschema}')
        conn.execute(f'set search_path to {qschema}, public')
        conn.execute(sql_path.read_text())

        song_id = conn.execute(
            """
            insert into media_entity (entity_type, biz_key, title, artist_name)
            values ('song', 'song-9001', 'Smoke Song', 'Smoke Artist')
            returning entity_id
            """
        ).fetchone()['entity_id']

        asset_id = conn.execute(
            """
            insert into audio_object (
                object_type, song_id, source_type, storage_uri, storage_scheme,
                checksum, codec, sample_rate, channels, duration_ms
            ) values (
                'asset', %s, 'official', 's3://bucket/smoke-song.wav', 's3',
                'sha256:smoke-asset', 'wav', 44100, 2, 180000
            ) returning object_id
            """,
            (song_id,),
        ).fetchone()['object_id']

        window_id = conn.execute(
            """
            insert into audio_object (
                object_type, song_id, parent_object_id, start_ms, end_ms, duration_ms
            ) values ('window', %s, %s, 30000, 35000, 5000)
            returning object_id
            """,
            (song_id, asset_id),
        ).fetchone()['object_id']

        fingerprint_id = conn.execute(
            """
            insert into feature_fact (
                feature_type, object_id, song_id, model_name, model_version,
                feature_set_name, fingerprint_value
            ) values (
                'fingerprint', %s, %s, 'chromaprint', 'phase1', 'chromaprint_5s', 'fp-smoke'
            ) returning feature_id
            """,
            (window_id, song_id),
        ).fetchone()['feature_id']

        embedding_id = conn.execute(
            """
            insert into feature_fact (
                feature_type, object_id, song_id, model_name, model_version,
                feature_set_name, embedding_dim, embedding_uri, vector_table_name
            ) values (
                'embedding', %s, %s, 'mert', 'v1-95m',
                'mert_5s_hop2.5_meanpool', 768, 's3://bucket/smoke-song-win.npy', 'audio_embedding_vector_768'
            ) returning feature_id
            """,
            (window_id, song_id),
        ).fetchone()['feature_id']

        membership_id = conn.execute(
            """
            insert into set_membership (
                set_type, set_name, member_type, member_id, song_id, priority
            ) values (
                'reference_set', 'phase1_hot_reference_v1', 'asset', %s, %s, 100
            ) returning membership_id
            """,
            (asset_id, song_id),
        ).fetchone()['membership_id']

        lineage = conn.execute(
            """
            select ff.feature_id,
                   ff.feature_type,
                   ff.model_name,
                   ff.model_version,
                   ff.feature_set_name,
                   win.object_id as window_id,
                   ast.object_id as asset_id,
                   song.entity_id as song_id,
                   song.title,
                   song.artist_name
            from feature_fact ff
            join audio_object win
              on win.object_id = ff.object_id
             and win.object_type = 'window'
            join audio_object ast
              on ast.object_id = win.parent_object_id
             and ast.object_type = 'asset'
            join media_entity song
              on song.entity_id = ff.song_id
             and song.entity_type = 'song'
            where ff.feature_id = %s
            """,
            (embedding_id,),
        ).fetchone()

        counts = {}
        for table in ['media_entity', 'audio_object', 'feature_fact', 'set_membership']:
            counts[table] = conn.execute(f'select count(*) as c from {table}').fetchone()['c']

        report.update(
            inserted={
                'song_id': song_id,
                'asset_id': asset_id,
                'window_id': window_id,
                'fingerprint_feature_id': fingerprint_id,
                'embedding_feature_id': embedding_id,
                'membership_id': membership_id,
            },
            counts=counts,
            embedding_lineage=lineage,
        )
        conn.commit()

    output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2))
    print(json.dumps(report, ensure_ascii=False, indent=2))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())