ecapa_embedder.py 3.98 KB
import torch
import torch.nn.functional as F
import numpy as np
import librosa
from pathlib import Path
from typing import List, Optional, Tuple
import json


class ECAPAEmbedder:
    def __init__(
        self,
        model_path: str,
        device: str = "cpu",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
    ):
        self.device = torch.device(device)
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length

        from src.models.ecapa_tdnn import ECAPA_ACR
        self.model = ECAPA_ACR(n_mels=n_mels, embed_dim=192)
        state = torch.load(model_path, map_location="cpu", weights_only=True)
        if "model_state_dict" in state:
            state = state["model_state_dict"]
        self.model.load_state_dict(state, strict=False)
        self.model.to(self.device)
        self.model.eval()

    def _load_audio(self, path: str) -> np.ndarray:
        y, _ = librosa.load(path, sr=self.sr, mono=True)
        return y

    def _to_mel(self, y: np.ndarray) -> torch.Tensor:
        mel = librosa.feature.melspectrogram(
            y=y, sr=self.sr, n_mels=self.n_mels,
            n_fft=self.n_fft, hop_length=self.hop_length
        )
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.FloatTensor(mel).unsqueeze(0)

    def extract_embedding(self, audio_path: str) -> np.ndarray:
        y = self._load_audio(audio_path)
        mel = self._to_mel(y).to(self.device)
        with torch.no_grad():
            emb, _ = self.model(mel)
        return emb.cpu().numpy().flatten()

    def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
        if len(y) < self.sr:
            y = np.pad(y, (0, self.sr - len(y)))
        mel = self._to_mel(y[:self.sr * 5]).to(self.device)
        with torch.no_grad():
            emb, _ = self.model(mel)
        return emb.cpu().numpy().flatten()

    def build_reference_index(
        self,
        songs_dir: str,
        metadata_path: str,
        output_path: str,
        window_sec: float = 5.0,
        stride_sec: float = 2.5,
    ) -> Tuple[np.ndarray, List[str]]:
        with open(metadata_path) as f:
            meta = json.load(f)

        all_embs = []
        all_ids = []
        songs_dir = Path(songs_dir)

        for item in meta:
            if "songs/" not in item.get("audio_path", ""):
                continue
            audio_path = songs_dir.parent / item["audio_path"]
            if not audio_path.exists():
                continue
            song_id = item["song_id"]
            y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)

            win_len = int(window_sec * self.sr)
            stride = int(stride_sec * self.sr)

            window_embs = []
            for start in range(0, len(y) - win_len + 1, stride):
                seg = y[start:start + win_len]
                mel = self._to_mel(seg).to(self.device)
                with torch.no_grad():
                    emb, _ = self.model(mel)
                window_embs.append(emb.cpu().numpy().flatten())

            if window_embs:
                song_emb = np.mean(window_embs, axis=0)
                all_embs.append(song_emb)
                all_ids.append(song_id)

        all_embs = np.vstack(all_embs)
        np.save(f"{output_path}_embs.npy", all_embs)
        np.save(f"{output_path}_ids.npy", np.array(all_ids))
        print(f"Built reference index: {len(all_ids)} songs, embeddings shape {all_embs.shape}")
        return all_embs, all_ids

    def search(
        self,
        query_emb: np.ndarray,
        ref_embs: np.ndarray,
        ref_ids: List[str],
        top_k: int = 10,
    ) -> List[Tuple[str, float]]:
        query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
        ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
        scores = query_norm @ ref_norm.T
        top_indices = np.argsort(-scores)[:top_k]
        return [(ref_ids[i], float(scores[i])) for i in top_indices]