hybrid_engine.py
4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking.
"""
import json
import time
from typing import Dict, List, Optional
import librosa
import numpy as np
class Candidate:
def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0):
self.song_id = song_id
self.chroma_score = chroma_score
self.ecapa_score = ecapa_score
self.metadata: Dict = {}
def combined_score(self, chroma_weight: float, ecapa_weight: float) -> float:
return chroma_weight * self.chroma_score + ecapa_weight * self.ecapa_score
def __repr__(self):
return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})"
class HybridEngine:
def __init__(
self,
chroma_matcher=None,
ecapa_embedder=None,
ref_embs: Optional[np.ndarray] = None,
ref_ids: Optional[List[str]] = None,
sr: int = 16000,
chroma_weight: float = 0.35,
ecapa_weight: float = 0.65,
reject_threshold: float = 0.35,
):
self.chroma = chroma_matcher
self.ecapa = ecapa_embedder
self.ref_embs = ref_embs
self.ref_ids = ref_ids
self.sr = sr
self.chroma_weight = chroma_weight
self.ecapa_weight = ecapa_weight
self.reject_threshold = reject_threshold
self.song_metadata: Dict[str, Dict] = {}
def load_metadata(self, metadata_path: str):
with open(metadata_path) as f:
items = json.load(f)
for item in items:
sid = item["song_id"]
existing = self.song_metadata.get(sid, {})
if item.get("type") == "reference" or not existing:
self.song_metadata[sid] = {
"song_id": sid,
"base_freq": item.get("base_freq", existing.get("base_freq", 0)),
"audio_path": item.get("audio_path", existing.get("audio_path", "")),
"type": item.get("type", existing.get("type", "unknown")),
}
@staticmethod
def _normalize_scores(score_pairs: List[tuple], invert: bool = False) -> Dict[str, float]:
if not score_pairs:
return {}
ids = [sid for sid, _ in score_pairs]
values = np.array([float(score) for _, score in score_pairs], dtype=np.float32)
if invert:
values = -values
if len(values) == 1:
return {ids[0]: 1.0}
vmin = float(values.min())
vmax = float(values.max())
if abs(vmax - vmin) < 1e-8:
return {sid: 1.0 for sid in ids}
norm = (values - vmin) / (vmax - vmin)
return {sid: float(score) for sid, score in zip(ids, norm)}
def recognize(
self,
audio_path: str,
top_n: int = 5,
mode: str = "auto",
) -> Dict:
del mode
start = time.time()
y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else []
chroma_norm = self._normalize_scores(chroma_matches)
ecapa_matches = []
if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None:
query_emb = self.ecapa.extract_embedding_from_wave(y)
ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12)
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
scores = query_norm @ ref_norm.T
top_indices = np.argsort(-scores)[: max(top_n * 5, 20)]
ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices]
ecapa_norm = self._normalize_scores(ecapa_matches)
all_song_ids = set(chroma_norm) | set(ecapa_norm)
combined: List[Candidate] = []
for song_id in all_song_ids:
candidate = Candidate(
song_id=song_id,
chroma_score=chroma_norm.get(song_id, 0.0),
ecapa_score=ecapa_norm.get(song_id, 0.0),
)
candidate.metadata = self.song_metadata.get(song_id, {})
combined.append(candidate)
combined.sort(key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight), reverse=True)
results = combined[:top_n]
elapsed = (time.time() - start) * 1000
output = []
for c in results:
fused = c.combined_score(self.chroma_weight, self.ecapa_weight)
output.append(
{
"song_id": c.song_id,
"confidence": round(fused, 4),
"chromaprint_score": round(c.chroma_score, 4),
"ecapa_score": round(c.ecapa_score, 4),
"accepted": fused >= self.reject_threshold,
"metadata": c.metadata,
}
)
return {
"candidates": output,
"processing_time_ms": round(elapsed, 1),
"num_candidates": len(results),
}