ecapa_embedder.py
3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn.functional as F
import numpy as np
import librosa
from pathlib import Path
from typing import List, Optional, Tuple
import json
class ECAPAEmbedder:
def __init__(
self,
model_path: str,
device: str = "cpu",
sr: int = 16000,
n_mels: int = 80,
n_fft: int = 512,
hop_length: int = 160,
):
self.device = torch.device(device)
self.sr = sr
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
from src.models.ecapa_tdnn import ECAPA_ACR
self.model = ECAPA_ACR(n_mels=n_mels, embed_dim=192)
state = torch.load(model_path, map_location="cpu", weights_only=True)
if "model_state_dict" in state:
state = state["model_state_dict"]
self.model.load_state_dict(state, strict=False)
self.model.to(self.device)
self.model.eval()
def _load_audio(self, path: str) -> np.ndarray:
y, _ = librosa.load(path, sr=self.sr, mono=True)
return y
def _to_mel(self, y: np.ndarray) -> torch.Tensor:
mel = librosa.feature.melspectrogram(
y=y, sr=self.sr, n_mels=self.n_mels,
n_fft=self.n_fft, hop_length=self.hop_length
)
mel = librosa.power_to_db(mel, ref=np.max)
return torch.FloatTensor(mel).unsqueeze(0)
def extract_embedding(self, audio_path: str) -> np.ndarray:
y = self._load_audio(audio_path)
mel = self._to_mel(y).to(self.device)
with torch.no_grad():
emb, _ = self.model(mel)
return emb.cpu().numpy().flatten()
def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
if len(y) < self.sr:
y = np.pad(y, (0, self.sr - len(y)))
mel = self._to_mel(y[:self.sr * 5]).to(self.device)
with torch.no_grad():
emb, _ = self.model(mel)
return emb.cpu().numpy().flatten()
def build_reference_index(
self,
songs_dir: str,
metadata_path: str,
output_path: str,
window_sec: float = 5.0,
stride_sec: float = 2.5,
) -> Tuple[np.ndarray, List[str]]:
with open(metadata_path) as f:
meta = json.load(f)
all_embs = []
all_ids = []
songs_dir = Path(songs_dir)
for item in meta:
if "songs/" not in item.get("audio_path", ""):
continue
audio_path = songs_dir.parent / item["audio_path"]
if not audio_path.exists():
continue
song_id = item["song_id"]
y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
win_len = int(window_sec * self.sr)
stride = int(stride_sec * self.sr)
window_embs = []
for start in range(0, len(y) - win_len + 1, stride):
seg = y[start:start + win_len]
mel = self._to_mel(seg).to(self.device)
with torch.no_grad():
emb, _ = self.model(mel)
window_embs.append(emb.cpu().numpy().flatten())
if window_embs:
song_emb = np.mean(window_embs, axis=0)
all_embs.append(song_emb)
all_ids.append(song_id)
all_embs = np.vstack(all_embs)
np.save(f"{output_path}_embs.npy", all_embs)
np.save(f"{output_path}_ids.npy", np.array(all_ids))
print(f"Built reference index: {len(all_ids)} songs, embeddings shape {all_embs.shape}")
return all_embs, all_ids
def search(
self,
query_emb: np.ndarray,
ref_embs: np.ndarray,
ref_ids: List[str],
top_k: int = 10,
) -> List[Tuple[str, float]]:
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
scores = query_norm @ ref_norm.T
top_indices = np.argsort(-scores)[:top_k]
return [(ref_ids[i], float(scores[i])) for i in top_indices]