ecapa_embedder.py 4.56 KB
import json
from pathlib import Path
from typing import List, Optional, Tuple

import librosa
import numpy as np
import torch


class ECAPAEmbedder:
    def __init__(
        self,
        model_path: str,
        device: str = "cpu",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
    ):
        self.device = torch.device(device)
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length

        from src.models.ecapa_tdnn import ECAPA_ACR

        state = torch.load(model_path, map_location="cpu", weights_only=True)
        cfg = state.get("config", {})
        model_cfg = cfg.get("model", {})
        self.model = ECAPA_ACR(
            n_mels=model_cfg.get("n_mels", n_mels),
            embed_dim=model_cfg.get("embed_dim", 192),
            channels=model_cfg.get("channels", 512),
            se_channels=model_cfg.get("se_channels", 128),
            res2net_scale=model_cfg.get("res2net_scale", 8),
            num_blocks=model_cfg.get("num_blocks", 3),
            num_classes=None,
        )
        missing = self.model.load_state_dict(state["model_state_dict"], strict=False)
        if missing.unexpected_keys:
            print(f"[warn] unexpected keys while loading model: {missing.unexpected_keys}")
        self.model.to(self.device)
        self.model.eval()

    def _load_audio(self, path: str) -> np.ndarray:
        y, _ = librosa.load(path, sr=self.sr, mono=True)
        return y

    def _to_mel(self, y: np.ndarray) -> torch.Tensor:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.FloatTensor(mel).unsqueeze(0)

    def _windows(self, y: np.ndarray, window_sec: float = 5.0, stride_sec: float = 2.5) -> List[np.ndarray]:
        win_len = int(window_sec * self.sr)
        stride = int(stride_sec * self.sr)
        if len(y) < win_len:
            y = np.pad(y, (0, win_len - len(y)))
        windows = []
        for start in range(0, max(len(y) - win_len + 1, 1), stride):
            windows.append(y[start : start + win_len])
        return windows or [y[:win_len]]

    def extract_embedding(self, audio_path: str) -> np.ndarray:
        y = self._load_audio(audio_path)
        return self.extract_embedding_from_wave(y)

    def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
        window_embs = []
        for seg in self._windows(y):
            mel = self._to_mel(seg).to(self.device)
            with torch.no_grad():
                emb, _ = self.model(mel)
            window_embs.append(emb.cpu().numpy().flatten())
        return np.mean(window_embs, axis=0)

    def build_reference_index(
        self,
        songs_dir: str,
        metadata_path: str,
        output_path: str,
        window_sec: float = 5.0,
        stride_sec: float = 2.5,
    ) -> Tuple[np.ndarray, List[str]]:
        with open(metadata_path) as f:
            meta = json.load(f)

        all_embs = []
        all_ids = []
        songs_dir = Path(songs_dir)

        for item in meta:
            if item.get("type") != "reference" and "songs/" not in item.get("audio_path", ""):
                continue
            audio_path = songs_dir.parent / item["audio_path"]
            if not audio_path.exists():
                continue
            song_id = item["song_id"]
            y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)

            for seg in self._windows(y, window_sec=window_sec, stride_sec=stride_sec):
                mel = self._to_mel(seg).to(self.device)
                with torch.no_grad():
                    emb, _ = self.model(mel)
                all_embs.append(emb.cpu().numpy().flatten())
                all_ids.append(song_id)

        all_embs = np.vstack(all_embs)
        np.save(f"{output_path}_embs.npy", all_embs)
        np.save(f"{output_path}_ids.npy", np.array(all_ids))
        print(f"Built reference index: {len(all_ids)} windows, embeddings shape {all_embs.shape}")
        return all_embs, all_ids

    def search(self, query_emb: np.ndarray, ref_embs: np.ndarray, ref_ids: List[str], top_k: int = 10):
        query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
        ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
        scores = query_norm @ ref_norm.T
        top_indices = np.argsort(-scores)[:top_k]
        return [(ref_ids[i], float(scores[i])) for i in top_indices]