run_demo.py 6.51 KB
#!/usr/bin/env python3
import argparse
import json
import sys
from pathlib import Path

import numpy as np

ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))

from src.data.synthetic import generate_dataset
from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine


def cmd_generate_data(args):
    generate_dataset(
        output_dir=args.output,
        num_songs=args.num_songs,
        song_duration=args.song_duration,
        num_segments_per_song=args.num_segments,
        segment_duration=args.segment_duration,
        seed=args.seed,
    )
    print(f"[done] dataset generated at {args.output}")


def build_chroma_index(data_dir: Path, output_dir: Path):
    matcher = ChromaprintMatcher()
    matcher.index_songs_from_dir(
        songs_dir=str(data_dir / 'songs'),
        metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'),
        cache_path=str(output_dir / 'chromaprint.pkl'),
    )
    print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}")
    return matcher


def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path, device: str):
    embedder = ECAPAEmbedder(model_path=str(model_path), device=device)
    ref_embs, ref_ids = embedder.build_reference_index(
        songs_dir=str(data_dir / 'songs'),
        metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'),
        output_path=str(output_prefix),
    )
    print(f"[done] embedding index built: {len(ref_ids)} refs")
    return embedder, ref_embs, ref_ids


def cmd_build_index(args):
    data_dir = Path(args.data)
    out_dir = Path(args.output)
    out_dir.mkdir(parents=True, exist_ok=True)

    build_chroma_index(data_dir, out_dir)
    build_embedding_index(data_dir, Path(args.model), out_dir / 'reference', args.device)


def load_index(prefix: Path):
    ref_embs = np.load(f"{prefix}_embs.npy")
    ref_ids = np.load(f"{prefix}_ids.npy", allow_pickle=True).tolist()
    return ref_embs, ref_ids


def cmd_recognize(args):
    data_dir = Path(args.data)
    matcher = ChromaprintMatcher()
    matcher.load(str(Path(args.index_prefix).parent / 'chromaprint.pkl'))
    embedder = ECAPAEmbedder(model_path=args.model, device=args.device)
    ref_embs, ref_ids = load_index(Path(args.index_prefix))

    engine = HybridEngine(
        chroma_matcher=matcher,
        ecapa_embedder=embedder,
        ref_embs=ref_embs,
        ref_ids=ref_ids,
    )
    for split in ['train.json', 'val.json', 'test.json']:
        p = data_dir / split
        if p.exists():
            engine.load_metadata(str(p))

    result = engine.recognize(args.query, top_n=args.top_n)
    print(json.dumps(result, ensure_ascii=False, indent=2))


def cmd_full_demo(args):
    data_dir = Path(args.data)
    model_dir = Path(args.model_dir)
    index_dir = Path(args.index_dir)

    if not data_dir.exists() or not (data_dir / 'train.json').exists():
        generate_dataset(
            output_dir=str(data_dir),
            num_songs=args.num_songs,
            song_duration=args.song_duration,
            num_segments_per_song=args.num_segments,
            segment_duration=args.segment_duration,
            seed=args.seed,
        )
        print(f"[done] dataset generated at {data_dir}")

    model_path = model_dir / 'best_model.pt'
    if not model_path.exists():
        import subprocess
        model_dir.mkdir(parents=True, exist_ok=True)
        cmd = [
            '/usr/local/miniconda3/bin/python', 'train.py',
            '--data', str(data_dir), '--output', str(model_dir),
            '--device', args.device, '--epochs', '3', '--batch-size', '8'
        ]
        print('[full-demo] training model:', ' '.join(cmd))
        subprocess.run(cmd, check=True)

    index_dir.mkdir(parents=True, exist_ok=True)
    matcher = build_chroma_index(data_dir, index_dir)
    embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device)

    with open(data_dir / 'test.json') as f:
        test_meta = json.load(f)
    query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0])
    query_path = data_dir / query_item['audio_path']

    engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
    for split in ['train.json', 'val.json', 'test.json']:
        engine.load_metadata(str(data_dir / split))
    result = engine.recognize(str(query_path), top_n=5)
    print('[demo-query]', query_item['song_id'], query_item['audio_path'])
    print(json.dumps(result, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='ACR demo utilities')
    sub = parser.add_subparsers(dest='cmd', required=True)

    p = sub.add_parser('generate-data')
    p.add_argument('--output', default='data/synthetic')
    p.add_argument('--num-songs', type=int, default=24)
    p.add_argument('--song-duration', type=float, default=20.0)
    p.add_argument('--num-segments', type=int, default=4)
    p.add_argument('--segment-duration', type=float, default=5.0)
    p.add_argument('--seed', type=int, default=42)
    p.set_defaults(func=cmd_generate_data)

    p = sub.add_parser('build-index')
    p.add_argument('--data', default='data/synthetic')
    p.add_argument('--model', required=True)
    p.add_argument('--output', default='data/index')
    p.add_argument('--device', default='cpu')
    p.set_defaults(func=cmd_build_index)

    p = sub.add_parser('recognize')
    p.add_argument('--query', required=True)
    p.add_argument('--data', default='data/synthetic')
    p.add_argument('--model', required=True)
    p.add_argument('--index-prefix', default='data/index/reference')
    p.add_argument('--top-n', type=int, default=5)
    p.add_argument('--device', default='cpu')
    p.set_defaults(func=cmd_recognize)

    p = sub.add_parser('full-demo')
    p.add_argument('--data', default='data/synthetic')
    p.add_argument('--model-dir', default='data/models')
    p.add_argument('--index-dir', default='data/index')
    p.add_argument('--num-songs', type=int, default=24)
    p.add_argument('--song-duration', type=float, default=20.0)
    p.add_argument('--num-segments', type=int, default=4)
    p.add_argument('--segment-duration', type=float, default=5.0)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', default='cpu')
    p.set_defaults(func=cmd_full_demo)

    args = parser.parse_args()
    args.func(args)