hybrid_engine.py
6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Hybrid ACR Engine: Chromaprint + ECAPA + melody-aware re-ranking."""
import json
import time
from pathlib import Path
from typing import Dict, List, Optional
import librosa
import numpy as np
from src.utils.audio import AudioProcessor
class Candidate:
def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0, melody_score: float = 0.0):
self.song_id = song_id
self.chroma_score = chroma_score
self.ecapa_score = ecapa_score
self.melody_score = melody_score
self.metadata: Dict = {}
def combined_score(self, chroma_weight: float, ecapa_weight: float, melody_weight: float) -> float:
return (
chroma_weight * self.chroma_score
+ ecapa_weight * self.ecapa_score
+ melody_weight * self.melody_score
)
class HybridEngine:
def __init__(
self,
chroma_matcher=None,
ecapa_embedder=None,
ref_embs: Optional[np.ndarray] = None,
ref_ids: Optional[List[str]] = None,
sr: int = 16000,
chroma_weight: float = 0.25,
ecapa_weight: float = 0.5,
melody_weight: float = 0.25,
reject_threshold: float = 0.35,
disable_melody: bool = False,
):
self.chroma = chroma_matcher
self.ecapa = ecapa_embedder
self.ref_embs = ref_embs
self.ref_ids = ref_ids
self.sr = sr
self.chroma_weight = chroma_weight
self.ecapa_weight = ecapa_weight
self.melody_weight = melody_weight
self.reject_threshold = reject_threshold
self.disable_melody = disable_melody
self.song_metadata: Dict[str, Dict] = {}
self.song_audio_paths: Dict[str, str] = {}
self.audio = AudioProcessor(sr=sr)
def load_metadata(self, metadata_path: str):
with open(metadata_path) as f:
items = json.load(f)
base_dir = str(Path(metadata_path).parent)
for item in items:
sid = item["song_id"]
existing = self.song_metadata.get(sid, {})
if item.get("type") == "reference" or not existing:
self.song_metadata[sid] = {
"song_id": sid,
"base_freq": item.get("base_freq", existing.get("base_freq", 0)),
"audio_path": item.get("audio_path", existing.get("audio_path", "")),
"type": item.get("type", existing.get("type", "unknown")),
}
if item.get("type") == "reference":
self.song_audio_paths[sid] = str(Path(base_dir) / item["audio_path"])
@staticmethod
def _normalize_scores(score_pairs: List[tuple]) -> Dict[str, float]:
if not score_pairs:
return {}
ids = [sid for sid, _ in score_pairs]
values = np.array([float(score) for _, score in score_pairs], dtype=np.float32)
if len(values) == 1:
return {ids[0]: 1.0}
vmin = float(values.min())
vmax = float(values.max())
if abs(vmax - vmin) < 1e-8:
return {sid: 1.0 for sid in ids}
norm = (values - vmin) / (vmax - vmin)
return {sid: float(score) for sid, score in zip(ids, norm)}
def _melody_scores(self, query_y: np.ndarray, candidate_ids: List[str]) -> Dict[str, float]:
scores = []
for song_id in candidate_ids:
ref_path = self.song_audio_paths.get(song_id)
if not ref_path or not Path(ref_path).exists():
continue
ref_y, _ = librosa.load(ref_path, sr=self.sr, mono=True, duration=8.0)
score = self.audio.melody_similarity(query_y, ref_y)
scores.append((song_id, score))
return self._normalize_scores(scores)
def recognize(self, audio_path: str, top_n: int = 5, mode: str = "auto") -> Dict:
del mode
start = time.time()
y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else []
chroma_norm = self._normalize_scores(chroma_matches)
ecapa_matches = []
if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None:
query_emb = self.ecapa.extract_embedding_from_wave(y)
ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12)
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
scores = query_norm @ ref_norm.T
top_indices = np.argsort(-scores)[: max(top_n * 10, 30)]
ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices]
ecapa_norm = self._normalize_scores(ecapa_matches)
candidate_pool = list(set(list(chroma_norm.keys())[: top_n * 4] + list(ecapa_norm.keys())[: top_n * 4]))
melody_norm = {} if self.disable_melody else self._melody_scores(y, candidate_pool)
all_song_ids = set(candidate_pool) | set(melody_norm)
combined: List[Candidate] = []
for song_id in all_song_ids:
candidate = Candidate(
song_id=song_id,
chroma_score=chroma_norm.get(song_id, 0.0),
ecapa_score=ecapa_norm.get(song_id, 0.0),
melody_score=melody_norm.get(song_id, 0.0),
)
candidate.metadata = self.song_metadata.get(song_id, {})
combined.append(candidate)
combined.sort(
key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight, self.melody_weight),
reverse=True,
)
results = combined[:top_n]
elapsed = (time.time() - start) * 1000
output = []
for c in results:
fused = c.combined_score(self.chroma_weight, self.ecapa_weight, self.melody_weight)
output.append(
{
"song_id": c.song_id,
"confidence": round(fused, 4),
"chromaprint_score": round(c.chroma_score, 4),
"ecapa_score": round(c.ecapa_score, 4),
"melody_score": round(c.melody_score, 4),
"accepted": fused >= self.reject_threshold,
"metadata": c.metadata,
}
)
return {"candidates": output, "processing_time_ms": round(elapsed, 1), "num_candidates": len(results)}