evaluate_songid_pgvector_path.py 3.46 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from statistics import median

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import faiss
import numpy as np


def load_jsonl(path: Path):
    return [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines() if line.strip()]


def aggregate_song_scores(song_ids, sims, idxs):
    aggregated = defaultdict(list)
    for score, idx in zip(sims, idxs):
        aggregated[song_ids[int(idx)]].append(float(score))
    ranked = []
    for song_id, vals in aggregated.items():
        vals.sort(reverse=True)
        max_sim = vals[0]
        top3_avg = sum(vals[:3]) / min(3, len(vals))
        vote = len(vals)
        combined = 0.6 * max_sim + 0.3 * top3_avg + 0.1 * min(vote / 10.0, 1.0)
        ranked.append((song_id, combined, max_sim, top3_avg, vote))
    ranked.sort(key=lambda x: x[1], reverse=True)
    return ranked


def compute_metrics(ranks, topk):
    if not ranks:
        return {'count': 0}
    return {
        'count': len(ranks),
        'top1': round(sum(1 for r in ranks if r == 1) / len(ranks), 6),
        'top3': round(sum(1 for r in ranks if r <= 3) / len(ranks), 6),
        f'top{topk}': round(sum(1 for r in ranks if r <= topk) / len(ranks), 6),
        'mrr': round(sum(1.0 / r for r in ranks) / len(ranks), 6),
        'mean_rank': round(sum(ranks) / len(ranks), 4),
        'median_rank': median(ranks),
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--reference-embeddings-jsonl', required=True)
    ap.add_argument('--query-embeddings-jsonl', required=True)
    ap.add_argument('--topn', type=int, default=20)
    ap.add_argument('--topk', type=int, default=10)
    ap.add_argument('--output', required=True)
    args = ap.parse_args()

    refs = load_jsonl(Path(args.reference_embeddings_jsonl))
    queries = load_jsonl(Path(args.query_embeddings_jsonl))
    ref_matrix = np.asarray([r['embedding'] for r in refs], dtype=np.float32)
    song_ids = [r['song_id'] for r in refs]
    index = faiss.IndexFlatIP(ref_matrix.shape[1])
    index.add(ref_matrix)

    by_type = defaultdict(list)
    examples = defaultdict(list)
    for q in queries:
        qvec = np.asarray(q['embedding'], dtype=np.float32).reshape(1, -1)
        sims, idxs = index.search(qvec, args.topn)
        ranked = aggregate_song_scores(song_ids, sims[0], idxs[0])
        gold = q['song_id']
        rank = next((i + 1 for i, item in enumerate(ranked) if item[0] == gold), len(ranked) + 1)
        qtype = str(q['query_type'])
        by_type[qtype].append(rank)
        if len(examples[qtype]) < 5:
            examples[qtype].append({'song_id': gold, 'rank': rank, 'top3': ranked[:3]})

    report = {
        'backend': 'faiss-as-pgvector-standin',
        'note': 'Uses song-level aggregation compatible with a future pgvector online path.',
        'overall': compute_metrics([r for ranks in by_type.values() for r in ranks], args.topk),
        'by_query_type': {qtype: compute_metrics(ranks, args.topk) for qtype, ranks in by_type.items()},
        'examples': examples,
    }
    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps(report, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()