evaluate_songid_pgvector_path.py
3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from statistics import median
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import faiss
import numpy as np
def load_jsonl(path: Path):
return [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines() if line.strip()]
def aggregate_song_scores(song_ids, sims, idxs):
aggregated = defaultdict(list)
for score, idx in zip(sims, idxs):
aggregated[song_ids[int(idx)]].append(float(score))
ranked = []
for song_id, vals in aggregated.items():
vals.sort(reverse=True)
max_sim = vals[0]
top3_avg = sum(vals[:3]) / min(3, len(vals))
vote = len(vals)
combined = 0.6 * max_sim + 0.3 * top3_avg + 0.1 * min(vote / 10.0, 1.0)
ranked.append((song_id, combined, max_sim, top3_avg, vote))
ranked.sort(key=lambda x: x[1], reverse=True)
return ranked
def compute_metrics(ranks, topk):
if not ranks:
return {'count': 0}
return {
'count': len(ranks),
'top1': round(sum(1 for r in ranks if r == 1) / len(ranks), 6),
'top3': round(sum(1 for r in ranks if r <= 3) / len(ranks), 6),
f'top{topk}': round(sum(1 for r in ranks if r <= topk) / len(ranks), 6),
'mrr': round(sum(1.0 / r for r in ranks) / len(ranks), 6),
'mean_rank': round(sum(ranks) / len(ranks), 4),
'median_rank': median(ranks),
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--reference-embeddings-jsonl', required=True)
ap.add_argument('--query-embeddings-jsonl', required=True)
ap.add_argument('--topn', type=int, default=20)
ap.add_argument('--topk', type=int, default=10)
ap.add_argument('--output', required=True)
args = ap.parse_args()
refs = load_jsonl(Path(args.reference_embeddings_jsonl))
queries = load_jsonl(Path(args.query_embeddings_jsonl))
ref_matrix = np.asarray([r['embedding'] for r in refs], dtype=np.float32)
song_ids = [r['song_id'] for r in refs]
index = faiss.IndexFlatIP(ref_matrix.shape[1])
index.add(ref_matrix)
by_type = defaultdict(list)
examples = defaultdict(list)
for q in queries:
qvec = np.asarray(q['embedding'], dtype=np.float32).reshape(1, -1)
sims, idxs = index.search(qvec, args.topn)
ranked = aggregate_song_scores(song_ids, sims[0], idxs[0])
gold = q['song_id']
rank = next((i + 1 for i, item in enumerate(ranked) if item[0] == gold), len(ranked) + 1)
qtype = str(q['query_type'])
by_type[qtype].append(rank)
if len(examples[qtype]) < 5:
examples[qtype].append({'song_id': gold, 'rank': rank, 'top3': ranked[:3]})
report = {
'backend': 'faiss-as-pgvector-standin',
'note': 'Uses song-level aggregation compatible with a future pgvector online path.',
'overall': compute_metrics([r for ranks in by_type.values() for r in ranks], args.topk),
'by_query_type': {qtype: compute_metrics(ranks, args.topk) for qtype, ranks in by_type.items()},
'examples': examples,
}
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding='utf-8')
print(json.dumps(report, ensure_ascii=False, indent=2))
if __name__ == '__main__':
main()