bootstrap_songcentric_phase1_live.py 8.8 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path

import psycopg
from psycopg.rows import dict_row


def quote_ident(name: str) -> str:
    return '"' + name.replace('"', '""') + '"'


def ensure_song(cur, biz_key: str, title: str, artist_name: str) -> int:
    row = cur.execute(
        """
        select entity_id from media_entity
        where entity_type = 'song' and biz_key = %s
        """,
        (biz_key,),
    ).fetchone()
    if row:
        return row['entity_id']
    return cur.execute(
        """
        insert into media_entity (entity_type, biz_key, title, artist_name)
        values ('song', %s, %s, %s)
        returning entity_id
        """,
        (biz_key, title, artist_name),
    ).fetchone()['entity_id']


def ensure_asset(cur, song_id: int, source_type: str, storage_uri: str, checksum: str, duration_ms: int) -> int:
    row = cur.execute(
        """
        select object_id from audio_object
        where object_type = 'asset' and song_id = %s and checksum = %s
        """,
        (song_id, checksum),
    ).fetchone()
    if row:
        return row['object_id']
    return cur.execute(
        """
        insert into audio_object (
            object_type, song_id, source_type, storage_uri, storage_scheme,
            checksum, codec, sample_rate, channels, duration_ms
        ) values (
            'asset', %s, %s, %s, 'file', %s, 'wav', 16000, 1, %s
        ) returning object_id
        """,
        (song_id, source_type, storage_uri, checksum, duration_ms),
    ).fetchone()['object_id']


def ensure_window(cur, song_id: int, asset_id: int, start_ms: int, end_ms: int) -> int:
    row = cur.execute(
        """
        select object_id from audio_object
        where object_type = 'window' and parent_object_id = %s and start_ms = %s and end_ms = %s
        """,
        (asset_id, start_ms, end_ms),
    ).fetchone()
    if row:
        return row['object_id']
    return cur.execute(
        """
        insert into audio_object (
            object_type, song_id, parent_object_id, start_ms, end_ms, duration_ms
        ) values ('window', %s, %s, %s, %s, %s)
        returning object_id
        """,
        (song_id, asset_id, start_ms, end_ms, end_ms - start_ms),
    ).fetchone()['object_id']


def ensure_feature(cur, feature_type: str, object_id: int, song_id: int, model_name: str, model_version: str,
                   feature_set_name: str, payload: dict) -> int:
    row = cur.execute(
        """
        select feature_id from feature_fact
        where object_id = %s and model_name = %s and model_version = %s
          and feature_set_name = %s and feature_type = %s
        """,
        (object_id, model_name, model_version, feature_set_name, feature_type),
    ).fetchone()
    if row:
        return row['feature_id']
    if feature_type == 'fingerprint':
        return cur.execute(
            """
            insert into feature_fact (
                feature_type, object_id, song_id, model_name, model_version,
                feature_set_name, fingerprint_value, checksum, metadata_json
            ) values (%s, %s, %s, %s, %s, %s, %s, %s, %s::jsonb)
            returning feature_id
            """,
            (
                feature_type, object_id, song_id, model_name, model_version,
                feature_set_name, payload['fingerprint_value'], payload['checksum'], json.dumps(payload.get('metadata_json', {})),
            ),
        ).fetchone()['feature_id']
    return cur.execute(
        """
        insert into feature_fact (
            feature_type, object_id, song_id, model_name, model_version,
            feature_set_name, embedding_dim, embedding_uri, vector_table_name, checksum, metadata_json
        ) values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s::jsonb)
        returning feature_id
        """,
        (
            feature_type, object_id, song_id, model_name, model_version,
            feature_set_name, payload['embedding_dim'], payload['embedding_uri'], payload['vector_table_name'],
            payload['checksum'], json.dumps(payload.get('metadata_json', {})),
        ),
    ).fetchone()['feature_id']


def ensure_membership(cur, set_type: str, set_name: str, member_type: str, member_id: int, song_id: int, priority: int) -> int:
    row = cur.execute(
        """
        select membership_id from set_membership
        where set_type = %s and set_name = %s and member_type = %s and member_id = %s
        """,
        (set_type, set_name, member_type, member_id),
    ).fetchone()
    if row:
        return row['membership_id']
    return cur.execute(
        """
        insert into set_membership (set_type, set_name, member_type, member_id, song_id, priority)
        values (%s, %s, %s, %s, %s, %s)
        returning membership_id
        """,
        (set_type, set_name, member_type, member_id, song_id, priority),
    ).fetchone()['membership_id']


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--dsn', required=True)
    parser.add_argument('--schema', default='acr_songcentric_test')
    parser.add_argument('--output', default='acr-engine/data/pgvector_eval/music20/songcentric_phase1_bootstrap_report.json')
    args = parser.parse_args()

    output_path = Path('/workspace') / args.output
    output_path.parent.mkdir(parents=True, exist_ok=True)
    qschema = quote_ident(args.schema)

    report = {'schema': args.schema, 'songs': []}
    songs = [
        {'biz_key': 'song-10001', 'title': 'Song 10001', 'artist_name': 'Artist A'},
        {'biz_key': 'song-10002', 'title': 'Song 10002', 'artist_name': 'Artist B'},
    ]

    with psycopg.connect(args.dsn, row_factory=dict_row) as conn:
        with conn.cursor() as cur:
            cur.execute(f'set search_path to {qschema}, public')
            for idx, song in enumerate(songs, start=1):
                song_id = ensure_song(cur, **song)
                asset_id = ensure_asset(
                    cur, song_id, 'official', f'/workspace/downloads/{song["biz_key"]}/master.wav',
                    f'sha256:{song["biz_key"]}', 180000 + idx * 1000,
                )
                window_id = ensure_window(cur, song_id, asset_id, 30000, 35000)
                fingerprint_id = ensure_feature(
                    cur, 'fingerprint', window_id, song_id,
                    'chromaprint', 'phase1', 'chromaprint_5s',
                    {'fingerprint_value': f'fp-{song["biz_key"]}', 'checksum': f'fpchk-{song["biz_key"]}', 'metadata_json': {'lane': 'exact'}},
                )
                embedding_id = ensure_feature(
                    cur, 'embedding', window_id, song_id,
                    'mert', 'v1-95m', 'mert_5s_hop2.5_meanpool',
                    {
                        'embedding_dim': 768,
                        'embedding_uri': f's3://bucket/{song["biz_key"]}/win0001.npy',
                        'vector_table_name': 'audio_embedding_vector_768',
                        'checksum': f'embchk-{song["biz_key"]}',
                        'metadata_json': {'lane': 'semantic'},
                    },
                )
                membership_id = ensure_membership(cur, 'reference_set', 'phase1_hot_reference_v1', 'asset', asset_id, song_id, 100)
                report['songs'].append({
                    'song_id': song_id,
                    'asset_id': asset_id,
                    'window_id': window_id,
                    'fingerprint_feature_id': fingerprint_id,
                    'embedding_feature_id': embedding_id,
                    'membership_id': membership_id,
                })

            counts = {}
            for table in ['media_entity', 'audio_object', 'feature_fact', 'set_membership']:
                counts[table] = cur.execute(f'select count(*) as c from {table}').fetchone()['c']
            report['counts'] = counts

            report['lineage_sample'] = cur.execute(
                """
                select ff.feature_type,
                       ff.model_name,
                       win.object_id as window_id,
                       ast.object_id as asset_id,
                       song.entity_id as song_id,
                       song.title
                from feature_fact ff
                join audio_object win on win.object_id = ff.object_id and win.object_type = 'window'
                join audio_object ast on ast.object_id = win.parent_object_id and ast.object_type = 'asset'
                join media_entity song on song.entity_id = ff.song_id and song.entity_type = 'song'
                where ff.feature_type = 'embedding'
                order by ff.feature_id asc
                limit 1
                """
            ).fetchone()
        conn.commit()

    output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2))
    print(json.dumps(report, ensure_ascii=False, indent=2))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())