import_songcentric_manifest_live.py 9.02 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path

import psycopg
from psycopg.rows import dict_row


def quote_ident(name: str) -> str:
    return '"' + name.replace('"', '""') + '"'


def load_jsonl(path: Path):
    for line in path.read_text().splitlines():
        line = line.strip()
        if line:
            yield json.loads(line)


def ensure_song(cur, song: dict) -> int:
    row = cur.execute(
        "select entity_id from media_entity where entity_type='song' and biz_key=%s",
        (song['biz_key'],),
    ).fetchone()
    if row:
        return row['entity_id']
    return cur.execute(
        "insert into media_entity (entity_type,biz_key,title,artist_name) values ('song',%s,%s,%s) returning entity_id",
        (song['biz_key'], song['title'], song.get('artist_name')),
    ).fetchone()['entity_id']


def ensure_asset(cur, song_id: int, asset: dict) -> int:
    row = cur.execute(
        "select object_id from audio_object where object_type='asset' and song_id=%s and checksum=%s",
        (song_id, asset['checksum']),
    ).fetchone()
    if row:
        return row['object_id']
    return cur.execute(
        """
        insert into audio_object (
            object_type,song_id,source_type,storage_uri,storage_scheme,checksum,codec,sample_rate,channels,duration_ms
        ) values ('asset',%s,%s,%s,%s,%s,%s,%s,%s,%s) returning object_id
        """,
        (
            song_id,
            asset.get('source_type'),
            asset.get('storage_uri'),
            asset.get('storage_scheme'),
            asset.get('checksum'),
            asset.get('codec'),
            asset.get('sample_rate'),
            asset.get('channels'),
            asset.get('duration_ms'),
        ),
    ).fetchone()['object_id']


def ensure_window(cur, song_id: int, asset_id: int, win: dict) -> int:
    row = cur.execute(
        "select object_id from audio_object where object_type='window' and parent_object_id=%s and start_ms=%s and end_ms=%s",
        (asset_id, win['start_ms'], win['end_ms']),
    ).fetchone()
    if row:
        return row['object_id']
    return cur.execute(
        "insert into audio_object (object_type,song_id,parent_object_id,start_ms,end_ms,duration_ms) values ('window',%s,%s,%s,%s,%s) returning object_id",
        (song_id, asset_id, win['start_ms'], win['end_ms'], win['end_ms'] - win['start_ms']),
    ).fetchone()['object_id']


def ensure_feature(cur, feature: dict, object_id: int, song_id: int) -> int:
    row = cur.execute(
        "select feature_id from feature_fact where object_id=%s and model_name=%s and model_version=%s and feature_set_name=%s and feature_type=%s",
        (object_id, feature['model_name'], feature['model_version'], feature['feature_set_name'], feature['feature_type']),
    ).fetchone()
    if row:
        return row['feature_id']

    if feature['feature_type'] == 'fingerprint':
        return cur.execute(
            """
            insert into feature_fact (
                feature_type, object_id, song_id, model_name, model_version,
                feature_set_name, fingerprint_value, checksum, metadata_json
            ) values (%s,%s,%s,%s,%s,%s,%s,%s,%s::jsonb)
            returning feature_id
            """,
            (
                feature['feature_type'],
                object_id,
                song_id,
                feature['model_name'],
                feature['model_version'],
                feature['feature_set_name'],
                feature['fingerprint_value'],
                feature.get('checksum'),
                json.dumps(feature.get('metadata_json', {})),
            ),
        ).fetchone()['feature_id']

    return cur.execute(
        """
        insert into feature_fact (
            feature_type, object_id, song_id, model_name, model_version,
            feature_set_name, feature_schema_ver, embedding_dim, embedding_uri, vector_table_name, checksum, metadata_json
        ) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s::jsonb)
        returning feature_id
        """,
        (
            feature['feature_type'],
            object_id,
            song_id,
            feature['model_name'],
            feature['model_version'],
            feature['feature_set_name'],
            feature.get('feature_schema_ver', 'v1'),
            feature.get('embedding_dim'),
            feature.get('embedding_uri'),
            feature.get('vector_table_name'),
            feature.get('checksum'),
            json.dumps(feature.get('metadata_json', {})),
        ),
    ).fetchone()['feature_id']


def ensure_membership(cur, membership: dict, member_id: int, song_id: int) -> int:
    row = cur.execute(
        "select membership_id from set_membership where set_type=%s and set_name=%s and member_type=%s and member_id=%s",
        (membership['set_type'], membership['set_name'], membership['member_type'], member_id),
    ).fetchone()
    if row:
        return row['membership_id']
    return cur.execute(
        "insert into set_membership (set_type,set_name,member_type,member_id,song_id,priority) values (%s,%s,%s,%s,%s,%s) returning membership_id",
        (membership['set_type'], membership['set_name'], membership['member_type'], member_id, song_id, membership.get('priority', 100)),
    ).fetchone()['membership_id']


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--dsn', required=True)
    parser.add_argument('--schema', default='acr_songcentric_test')
    parser.add_argument('--manifest', required=True)
    parser.add_argument('--output', required=True)
    args = parser.parse_args()

    manifest_path = Path(args.manifest)
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    qschema = quote_ident(args.schema)

    report = {'schema': args.schema, 'manifest': str(manifest_path), 'imported': []}

    with psycopg.connect(args.dsn, row_factory=dict_row) as conn:
        with conn.cursor() as cur:
            cur.execute(f'set search_path to {qschema}, public')
            for row in load_jsonl(manifest_path):
                song_id = ensure_song(cur, row['song'])
                asset_id = ensure_asset(cur, song_id, row['asset'])
                window_ids = []
                feature_ids = []
                for w in row.get('windows', []):
                    window_id = ensure_window(cur, song_id, asset_id, w)
                    window_ids.append(window_id)
                    for feature in w.get('features', []):
                        feature_ids.append(ensure_feature(cur, feature, window_id, song_id))
                membership_ids = []
                for m in row.get('memberships', []):
                    member_id = asset_id if m['member_type'] == 'asset' else song_id
                    membership_ids.append(ensure_membership(cur, m, member_id, song_id))
                report['imported'].append(
                    {
                        'song_id': song_id,
                        'asset_id': asset_id,
                        'window_ids': window_ids,
                        'feature_ids': feature_ids,
                        'membership_ids': membership_ids,
                    }
                )

            counts = {}
            for table in ['media_entity', 'audio_object', 'feature_fact', 'set_membership']:
                counts[table] = cur.execute(f'select count(*) as c from {table}').fetchone()['c']
            report['counts'] = counts
            report['window_lineage_sample'] = cur.execute(
                """
                select win.object_id as window_id,
                       ast.object_id as asset_id,
                       song.entity_id as song_id,
                       song.title,
                       win.start_ms,
                       win.end_ms
                from audio_object win
                join audio_object ast on ast.object_id = win.parent_object_id and ast.object_type='asset'
                join media_entity song on song.entity_id = win.song_id and song.entity_type='song'
                where win.object_type='window'
                order by win.object_id desc
                limit 1
                """
            ).fetchone()
            report['feature_lineage_sample'] = cur.execute(
                """
                select ff.feature_type,
                       ff.model_name,
                       ff.model_version,
                       ff.feature_set_name,
                       win.object_id as window_id,
                       song.entity_id as song_id,
                       song.title
                from feature_fact ff
                join audio_object win on win.object_id = ff.object_id and win.object_type='window'
                join media_entity song on song.entity_id = ff.song_id and song.entity_type='song'
                order by ff.feature_id desc
                limit 1
                """
            ).fetchone()
        conn.commit()

    output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2))
    print(json.dumps(report, ensure_ascii=False, indent=2))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())