export_workspace_music20_embeddings_jsonl.py 2.6 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from scripts.local_music20_acr import REFERENCE_TYPE, SUPPORTED_QUERY_TYPES, embed_chroma, first_file


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument('--downloads-dir', default='/workspace/downloads')
    ap.add_argument('--song-limit', type=int, default=20)
    ap.add_argument('--duration', type=float, default=8.0)
    ap.add_argument('--sr', type=int, default=22050)
    ap.add_argument('--out-dir', default='data/pgvector_eval/music20')
    return ap.parse_args()


def main():
    args = parse_args()
    downloads_dir = Path(args.downloads_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    ref_path = out_dir / 'reference_embeddings.jsonl'
    qry_path = out_dir / 'query_embeddings.jsonl'
    ref_f = ref_path.open('w', encoding='utf-8')
    qry_f = qry_path.open('w', encoding='utf-8')
    ref_count = qry_count = 0
    refs_seen = set()
    for song_dir in sorted(p for p in downloads_dir.iterdir() if p.is_dir()):
        ref = first_file(song_dir / f'type_{REFERENCE_TYPE}')
        if ref and song_dir.name not in refs_seen and len(refs_seen) < args.song_limit:
            row = {
                'song_id': song_dir.name,
                'audio_path': str(ref),
                'type': 'reference',
                'embedding': embed_chroma(str(ref), args.sr, args.duration).tolist(),
            }
            ref_f.write(json.dumps(row, ensure_ascii=False) + '\n')
            ref_count += 1
            refs_seen.add(song_dir.name)
    for query_type in SUPPORTED_QUERY_TYPES:
        kept = 0
        for song_id in sorted(refs_seen):
            song_dir = downloads_dir / song_id
            qry = first_file(song_dir / f'type_{query_type}')
            if not qry:
                continue
            row = {
                'song_id': song_id,
                'audio_path': str(qry),
                'query_type': query_type,
                'embedding': embed_chroma(str(qry), args.sr, args.duration).tolist(),
            }
            qry_f.write(json.dumps(row, ensure_ascii=False) + '\n')
            qry_count += 1
            kept += 1
        print(f'query_type={query_type} rows={kept}')
    ref_f.close()
    qry_f.close()
    print(json.dumps({'reference_rows': ref_count, 'query_rows': qry_count, 'out_dir': str(out_dir.resolve())}, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()