local_music20_acr.py 6.61 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
"""Run a FAISS-first local ACR eval on up to 20 songs from /workspace/downloads.

Purpose:
- keep small-sample validation inside acr-engine
- default to FAISS for local dev
- optionally allow ChromaDB when installed
- preserve pgvector as the production path (not used here)
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List

import faiss
import librosa
import numpy as np


DEFAULT_DOWNLOADS = Path('/workspace/downloads')
DEFAULT_OUTPUT = Path('/root/vprecog/acr-engine/data/local_eval/music20_summary.json')
SUPPORTED_QUERY_TYPES = (1, 7, 8, 16)
REFERENCE_TYPE = 11


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument('--downloads-dir', default=str(DEFAULT_DOWNLOADS))
    ap.add_argument('--song-limit', type=int, default=20)
    ap.add_argument('--duration', type=float, default=8.0)
    ap.add_argument('--sr', type=int, default=22050)
    ap.add_argument('--topk', type=int, default=3)
    ap.add_argument('--backend', choices=['faiss', 'chromadb'], default='faiss')
    ap.add_argument('--output', default=str(DEFAULT_OUTPUT))
    return ap.parse_args()


def first_file(path: Path) -> Path | None:
    files = sorted(p for p in path.iterdir() if p.is_file()) if path.exists() else []
    return files[0] if files else None


def collect_pairs(downloads_dir: Path, song_limit: int, query_type: int) -> List[Dict[str, str]]:
    pairs = []
    for song_dir in sorted(p for p in downloads_dir.iterdir() if p.is_dir()):
        ref = first_file(song_dir / f'type_{REFERENCE_TYPE}')
        qry = first_file(song_dir / f'type_{query_type}')
        if ref and qry:
            pairs.append({
                'song_id': song_dir.name,
                'reference_path': str(ref),
                'query_path': str(qry),
            })
        if len(pairs) >= song_limit:
            break
    return pairs


def load_audio(path: str, sr: int, duration: float) -> np.ndarray:
    y, _ = librosa.load(path, sr=sr, mono=True, duration=duration)
    target_len = int(sr * duration)
    if len(y) < target_len:
        y = np.pad(y, (0, target_len - len(y)))
    else:
        y = y[:target_len]
    return y.astype(np.float32)


def embed_chroma(path: str, sr: int, duration: float) -> np.ndarray:
    y = load_audio(path, sr=sr, duration=duration)
    chroma = librosa.feature.chroma_stft(y=y, sr=sr, n_chroma=12)
    feat = np.concatenate([chroma.mean(axis=1), chroma.std(axis=1)], axis=0).astype(np.float32)
    norm = np.linalg.norm(feat)
    if norm > 0:
        feat = feat / norm
    return feat


def run_faiss(ref_matrix: np.ndarray, qry_matrix: np.ndarray, topk: int):
    index = faiss.IndexFlatIP(ref_matrix.shape[1])
    index.add(ref_matrix)
    return index.search(qry_matrix, topk)


def run_chromadb(ref_matrix: np.ndarray, qry_matrix: np.ndarray, topk: int):
    try:
        import chromadb  # type: ignore
    except Exception as exc:  # pragma: no cover - env-dependent
        raise SystemExit(f'ChromaDB backend requested but unavailable: {exc}')

    client = chromadb.EphemeralClient()
    collection = client.create_collection('music20_local_eval')
    ref_ids = [str(i) for i in range(len(ref_matrix))]
    collection.add(ids=ref_ids, embeddings=ref_matrix.tolist())
    result = collection.query(query_embeddings=qry_matrix.tolist(), n_results=topk)
    distances = np.asarray(result['distances'], dtype=np.float32)
    idxs = np.asarray([[int(x) for x in row] for row in result['ids']], dtype=np.int32)
    sims = 1.0 - distances
    return sims, idxs


def evaluate_query_type(downloads_dir: Path, song_limit: int, query_type: int, sr: int, duration: float, topk: int, backend: str):
    pairs = collect_pairs(downloads_dir, song_limit, query_type=query_type)
    if not pairs:
        return {
            'query_type': query_type,
            'reference_type': REFERENCE_TYPE,
            'song_count': 0,
            'file_count': 0,
            'topk': topk,
            'metrics': {'top1': 0.0, 'top3': 0.0},
            'results': [],
            'note': 'No matching query/reference pairs found.',
        }

    ref_vecs = [embed_chroma(item['reference_path'], sr, duration) for item in pairs]
    qry_vecs = [embed_chroma(item['query_path'], sr, duration) for item in pairs]
    ref_ids = [item['song_id'] for item in pairs]

    ref_matrix = np.vstack(ref_vecs).astype(np.float32)
    qry_matrix = np.vstack(qry_vecs).astype(np.float32)

    if backend == 'faiss':
        sims, idxs = run_faiss(ref_matrix, qry_matrix, topk)
    else:
        sims, idxs = run_chromadb(ref_matrix, qry_matrix, topk)

    ranks = []
    results = []
    for i, item in enumerate(pairs):
        candidates = []
        rank = None
        for j in range(topk):
            ref_idx = int(idxs[i, j])
            cand_song_id = ref_ids[ref_idx]
            score = float(sims[i, j])
            candidates.append({'rank': j + 1, 'song_id': cand_song_id, 'score': score})
            if cand_song_id == item['song_id'] and rank is None:
                rank = j + 1
        if rank is None:
            rank = topk + 1
        ranks.append(rank)
        results.append({
            'song_id': item['song_id'],
            'query_path': item['query_path'],
            'reference_path': item['reference_path'],
            'rank': rank,
            'candidates': candidates,
        })

    top1 = sum(1 for r in ranks if r == 1) / len(ranks)
    top3 = sum(1 for r in ranks if r <= min(3, topk)) / len(ranks)
    return {
        'query_type': query_type,
        'reference_type': REFERENCE_TYPE,
        'song_count': len(pairs),
        'file_count': len(pairs) * 2,
        'topk': topk,
        'metrics': {'top1': top1, 'top3': top3},
        'results': results,
    }


def main() -> None:
    args = parse_args()
    downloads_dir = Path(args.downloads_dir)
    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)

    summary = {
        'backend': args.backend,
        'purpose': 'Local 20-song ACR sanity flow for development; production remains pgvector.',
        'downloads_dir': str(downloads_dir),
        'song_limit': args.song_limit,
        'duration_sec': args.duration,
        'sr': args.sr,
        'evaluations': [],
    }
    for query_type in SUPPORTED_QUERY_TYPES:
        summary['evaluations'].append(
            evaluate_query_type(downloads_dir, args.song_limit, query_type, args.sr, args.duration, args.topk, args.backend)
        )

    out.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(summary, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()