evaluate_selected20_songid_retrieval.py 11.7 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import librosa
import numpy as np

ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.engines.chromaprint_matcher import ChromaprintMatcher, load_audio_mono
from scripts.enrich_songcentric_manifest_with_local_features import load_mert_runtime


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument('--downloads-dir', required=True)
    ap.add_argument('--reference-type', type=int, default=11)
    ap.add_argument('--query-types', nargs='+', type=int, default=[1, 7, 12, 16])
    ap.add_argument('--duration', type=float, default=8.0)
    ap.add_argument('--topk', type=int, default=3)
    ap.add_argument('--exact-weight', type=float, default=0.6)
    ap.add_argument('--semantic-weight', type=float, default=0.4)
    ap.add_argument('--output-json', required=True)
    ap.add_argument('--output-md', required=True)
    return ap.parse_args()


def audio_files(path: Path) -> List[Path]:
    if not path.exists():
        return []
    return sorted([p for p in path.iterdir() if p.is_file()])


def collect_dataset(downloads_dir: Path, reference_type: int, query_types: List[int]) -> tuple[list[dict], list[dict]]:
    references: list[dict] = []
    queries: list[dict] = []
    for song_dir in sorted([p for p in downloads_dir.iterdir() if p.is_dir()]):
        song_id = song_dir.name
        for ref in audio_files(song_dir / f'type_{reference_type}'):
            references.append({'song_id': song_id, 'type': reference_type, 'path': str(ref)})
        for qtype in query_types:
            for q in audio_files(song_dir / f'type_{qtype}'):
                queries.append({'song_id': song_id, 'type': qtype, 'path': str(q)})
    return references, queries


def load_semantic_embedding(path: str, duration: float) -> np.ndarray:
    rt = load_mert_runtime()
    torch = rt['torch']
    sr = int(rt['sample_rate'])
    y, _ = librosa.load(path, sr=sr, mono=True, duration=duration)
    if y.size == 0:
        raise ValueError(f'empty audio: {path}')
    inputs = rt['feature_extractor'](y.astype(np.float32), sampling_rate=sr, return_tensors='pt')
    with torch.no_grad():
        outputs = rt['model'](**inputs)
    emb = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().astype(np.float32)
    norm = np.linalg.norm(emb)
    if norm > 0:
        emb = emb / norm
    return emb


def normalize_score_pairs(score_pairs: List[Tuple[str, float]]) -> Dict[str, float]:
    if not score_pairs:
        return {}
    vals = np.asarray([float(x[1]) for x in score_pairs], dtype=np.float32)
    ids = [x[0] for x in score_pairs]
    if vals.size == 1:
        return {ids[0]: 1.0}
    vmin = float(vals.min())
    vmax = float(vals.max())
    if abs(vmax - vmin) < 1e-12:
        return {sid: 1.0 for sid in ids}
    norm = (vals - vmin) / (vmax - vmin)
    return {sid: float(v) for sid, v in zip(ids, norm)}


def topk_from_scores(score_map: Dict[str, float], topk: int) -> List[Dict]:
    items = sorted(score_map.items(), key=lambda kv: kv[1], reverse=True)[:topk]
    return [{'rank': idx + 1, 'song_id': sid, 'score': float(score)} for idx, (sid, score) in enumerate(items)]


def rank_of(song_id: str, ranked: List[Dict], default_rank: int) -> int:
    for item in ranked:
        if item['song_id'] == song_id:
            return int(item['rank'])
    return default_rank


def build_reference_assets(references: list[dict], duration: float):
    matcher = ChromaprintMatcher(sr=16000)
    ref_song_to_embeddings: dict[str, list[np.ndarray]] = defaultdict(list)
    for ref in references:
        y, _ = load_audio_mono(ref['path'], sr=matcher.sr)
        matcher.index_song(ref['song_id'], y)
        emb = load_semantic_embedding(ref['path'], duration=duration)
        ref_song_to_embeddings[ref['song_id']].append(emb)
    ref_song_embeddings: dict[str, np.ndarray] = {}
    for song_id, embs in ref_song_to_embeddings.items():
        mat = np.vstack(embs)
        mean_emb = mat.mean(axis=0)
        norm = np.linalg.norm(mean_emb)
        if norm > 0:
            mean_emb = mean_emb / norm
        ref_song_embeddings[song_id] = mean_emb.astype(np.float32)
    return matcher, ref_song_embeddings


def evaluate(references: list[dict], queries: list[dict], duration: float, topk: int, exact_weight: float, semantic_weight: float) -> dict:
    matcher, ref_song_embeddings = build_reference_assets(references, duration=duration)
    ref_song_ids = sorted(ref_song_embeddings.keys())
    ref_matrix = np.vstack([ref_song_embeddings[sid] for sid in ref_song_ids]).astype(np.float32)

    results = []
    by_type: dict[int, list[dict]] = defaultdict(list)
    lane_rank_lists = {
        'exact': [],
        'semantic': [],
        'fused': [],
    }
    for q in queries:
        qy, _ = load_audio_mono(q['path'], sr=matcher.sr)
        exact_pairs = matcher.match(qy, top_k=max(topk * 5, 20))
        exact_norm = normalize_score_pairs(exact_pairs)

        qemb = load_semantic_embedding(q['path'], duration=duration)
        sims = ref_matrix @ qemb
        semantic_pairs = [(sid, float(score)) for sid, score in zip(ref_song_ids, sims.tolist())]
        semantic_norm = normalize_score_pairs(semantic_pairs)

        fused_scores = {}
        for sid in set(list(exact_norm.keys()) + list(semantic_norm.keys()) + ref_song_ids):
            fused_scores[sid] = exact_weight * exact_norm.get(sid, 0.0) + semantic_weight * semantic_norm.get(sid, 0.0)

        exact_ranked = topk_from_scores(exact_norm, topk)
        semantic_ranked = topk_from_scores(semantic_norm, topk)
        fused_ranked = topk_from_scores(fused_scores, topk)

        default_rank = topk + 1
        exact_rank = rank_of(q['song_id'], exact_ranked, default_rank)
        semantic_rank = rank_of(q['song_id'], semantic_ranked, default_rank)
        fused_rank = rank_of(q['song_id'], fused_ranked, default_rank)
        lane_rank_lists['exact'].append(exact_rank)
        lane_rank_lists['semantic'].append(semantic_rank)
        lane_rank_lists['fused'].append(fused_rank)

        item = {
            'song_id': q['song_id'],
            'query_type': q['type'],
            'query_path': q['path'],
            'exact_rank': exact_rank,
            'semantic_rank': semantic_rank,
            'fused_rank': fused_rank,
            'exact_topk': exact_ranked,
            'semantic_topk': semantic_ranked,
            'fused_topk': fused_ranked,
        }
        results.append(item)
        by_type[q['type']].append(item)

    def metric_block(rank_list: list[int]) -> dict:
        n = len(rank_list)
        if n == 0:
            return {'count': 0, 'top1': 0.0, 'top3': 0.0}
        return {
            'count': n,
            'top1': sum(1 for r in rank_list if r == 1) / n,
            'top3': sum(1 for r in rank_list if r <= 3) / n,
        }

    overall = {lane: metric_block(ranks) for lane, ranks in lane_rank_lists.items()}
    per_type = {}
    for qtype, items in sorted(by_type.items()):
        per_type[qtype] = {
            'exact': metric_block([x['exact_rank'] for x in items]),
            'semantic': metric_block([x['semantic_rank'] for x in items]),
            'fused': metric_block([x['fused_rank'] for x in items]),
        }

    failed_fused = [x for x in results if x['fused_rank'] != 1]
    failed_fused.sort(key=lambda x: (x['query_type'], x['fused_rank'], x['song_id'], x['query_path']))

    return {
        'reference_count': len(references),
        'query_count': len(queries),
        'reference_song_count': len(ref_song_ids),
        'query_type_counts': {str(k): len(v) for k, v in sorted(by_type.items())},
        'weights': {'exact': exact_weight, 'semantic': semantic_weight},
        'overall': overall,
        'per_type': per_type,
        'failed_fused_examples': failed_fused[:20],
        'results': results,
    }


def render_md(report: dict, downloads_dir: Path, reference_type: int, query_types: List[int], duration: float) -> str:
    lines = []
    lines.append('# Selected 20 Songs 实战检索评测')
    lines.append('')
    lines.append(f'- 数据目录:`{downloads_dir}`')
    lines.append(f'- reference:`type_{reference_type}`')
    lines.append(f'- queries:`{", ".join(f"type_{x}" for x in query_types)}`')
    lines.append('- 当前方案:`chromaprint_matcher + mert-v1-95m`')
    lines.append(f'- 语义截断时长:`{duration:.1f}s`')
    lines.append(f"- reference 文件数:`{report['reference_count']}`")
    lines.append(f"- query 文件数:`{report['query_count']}`")
    lines.append('')
    lines.append('## 1. 总体结果')
    lines.append('')
    lines.append('| lane | count | top1 | top3 |')
    lines.append('|---|---:|---:|---:|')
    for lane in ['exact', 'semantic', 'fused']:
        m = report['overall'][lane]
        lines.append(f"| {lane} | {m['count']} | {m['top1']:.4f} | {m['top3']:.4f} |")
    lines.append('')
    lines.append('## 2. 分 query type 结果')
    lines.append('')
    lines.append('| query_type | lane | count | top1 | top3 |')
    lines.append('|---|---|---:|---:|---:|')
    for qtype, block in report['per_type'].items():
        for lane in ['exact', 'semantic', 'fused']:
            m = block[lane]
            lines.append(f"| type_{qtype} | {lane} | {m['count']} | {m['top1']:.4f} | {m['top3']:.4f} |")
    lines.append('')
    lines.append('## 3. 失败样例(fused rank != 1)')
    lines.append('')
    if not report['failed_fused_examples']:
        lines.append('- 无,fused 全部 top1 正确。')
    else:
        for item in report['failed_fused_examples'][:10]:
            lines.append(
                f"- `type_{item['query_type']}` / true=`{item['song_id']}` / fused_rank=`{item['fused_rank']}` / file=`{item['query_path']}` / top1=`{item['fused_topk'][0]['song_id'] if item['fused_topk'] else 'NA'}`"
            )
    lines.append('')
    lines.append('## 4. 结论')
    lines.append('')
    lines.append('- `exact` 代表当前指纹链路单独表现。')
    lines.append('- `semantic` 代表当前 MERT 单独表现。')
    lines.append('- `fused` 代表当前版权保护场景里更接近实战的合并结果。')
    return '\n'.join(lines) + '\n'


def main() -> int:
    args = parse_args()
    downloads_dir = Path(args.downloads_dir).resolve()
    refs, queries = collect_dataset(downloads_dir, args.reference_type, args.query_types)
    report = evaluate(
        references=refs,
        queries=queries,
        duration=args.duration,
        topk=args.topk,
        exact_weight=args.exact_weight,
        semantic_weight=args.semantic_weight,
    )
    out_json = Path(args.output_json)
    out_md = Path(args.output_md)
    out_json.parent.mkdir(parents=True, exist_ok=True)
    out_md.parent.mkdir(parents=True, exist_ok=True)
    report.update({
        'downloads_dir': str(downloads_dir),
        'reference_type': args.reference_type,
        'query_types': args.query_types,
        'duration_sec': args.duration,
        'topk': args.topk,
        'solution': 'chromaprint_matcher + mert-v1-95m',
    })
    out_json.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding='utf-8')
    out_md.write_text(render_md(report, downloads_dir, args.reference_type, args.query_types, args.duration), encoding='utf-8')
    print(json.dumps({
        'output_json': str(out_json),
        'output_md': str(out_md),
        'overall': report['overall'],
        'per_type': report['per_type'],
        'query_count': report['query_count'],
        'reference_count': report['reference_count'],
    }, ensure_ascii=False, indent=2))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())