ecapa_embedder.py 11.1 KB
import json
from pathlib import Path
from typing import List, Optional, Tuple
import time

import librosa
import numpy as np
import torch


class ECAPAEmbedder:
    def __init__(
        self,
        model_path: str,
        device: str = "cpu",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
    ):
        self.device = torch.device(device)
        self.model_path = Path(model_path)
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.model_signature = self._build_model_signature(self.model_path)

        from src.models.ecapa_tdnn import ECAPA_ACR

        state = torch.load(model_path, map_location="cpu", weights_only=True)
        cfg = state.get("config", {})
        model_cfg = cfg.get("model", {})
        data_cfg = cfg.get("data", {})
        self.n_mels = model_cfg.get("n_mels", n_mels)
        self.n_fft = data_cfg.get("n_fft", n_fft)
        self.hop_length = data_cfg.get("hop_length", hop_length)
        self.model = ECAPA_ACR(
            n_mels=self.n_mels,
            embed_dim=model_cfg.get("embed_dim", 192),
            channels=model_cfg.get("channels", 512),
            se_channels=model_cfg.get("se_channels", 128),
            res2net_scale=model_cfg.get("res2net_scale", 8),
            num_blocks=model_cfg.get("num_blocks", 3),
            num_classes=None,
            use_band_split=model_cfg.get("use_band_split", True),
            band_split_channels=model_cfg.get("band_split_channels", 128),
        )
        missing = self.model.load_state_dict(state["model_state_dict"], strict=False)
        if missing.unexpected_keys:
            print(f"[warn] unexpected keys while loading model: {missing.unexpected_keys}", flush=True)
        self.model.to(self.device)
        self.model.eval()

    def _load_audio(self, path: str) -> np.ndarray:
        y, _ = librosa.load(path, sr=self.sr, mono=True)
        return y

    def _build_model_signature(self, model_path: Path) -> dict:
        stat = model_path.stat()
        return {
            "path": str(model_path),
            "size_bytes": int(stat.st_size),
            "mtime_ns": int(stat.st_mtime_ns),
        }

    def _resolve_audio_path(self, songs_dir: Path, rel_path: str) -> Path:
        candidate = songs_dir / rel_path
        if candidate.exists():
            return candidate
        candidate = songs_dir.parent / rel_path
        return candidate

    def _to_mel(self, y: np.ndarray) -> torch.Tensor:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.FloatTensor(mel).unsqueeze(0)

    def _windows(self, y: np.ndarray, window_sec: float = 5.0, stride_sec: float = 2.5) -> List[np.ndarray]:
        win_len = int(window_sec * self.sr)
        stride = int(stride_sec * self.sr)
        if len(y) < win_len:
            y = np.pad(y, (0, win_len - len(y)))
        windows = []
        for start in range(0, max(len(y) - win_len + 1, 1), stride):
            windows.append(y[start : start + win_len])
        return windows or [y[:win_len]]

    def extract_embedding(self, audio_path: str) -> np.ndarray:
        y = self._load_audio(audio_path)
        return self.extract_embedding_from_wave(y)

    def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
        window_embs = []
        for seg in self._windows(y):
            mel = self._to_mel(seg).to(self.device)
            with torch.no_grad():
                emb, _ = self.model(mel)
            window_embs.append(emb.cpu().numpy().flatten())
        return np.mean(window_embs, axis=0)

    def build_reference_index(
        self,
        songs_dir: str,
        metadata_path: str,
        output_path: str,
        window_sec: float = 5.0,
        stride_sec: float = 2.5,
        checkpoint_every_refs: int = 250,
        resume: bool = False,
    ) -> Tuple[np.ndarray, List[str]]:
        with open(metadata_path) as f:
            meta = json.load(f)

        all_embs = []
        all_ids = []
        songs_dir = Path(songs_dir)
        refs = [item for item in meta if item.get("type") == "reference"]
        total_refs = len(refs)
        start_time = time.time()
        output_prefix = Path(output_path)
        progress_path = output_prefix.parent / f"{output_prefix.name}_progress.json"
        partial_embs_path = Path(f"{output_path}_embs.partial.npy")
        partial_ids_path = Path(f"{output_path}_ids.partial.npy")
        final_embs_path = Path(f"{output_path}_embs.npy")
        final_ids_path = Path(f"{output_path}_ids.npy")
        refs_done = 0

        if resume and final_embs_path.exists() and final_ids_path.exists():
            print(f"[build-reference-index] resume hit complete index: {final_embs_path} / {final_ids_path}", flush=True)
            final_embs = np.load(final_embs_path)
            final_ids = np.load(final_ids_path, allow_pickle=True).tolist()
            return final_embs, final_ids

        if resume and progress_path.exists() and partial_embs_path.exists() and partial_ids_path.exists():
            try:
                progress = json.loads(progress_path.read_text())
                progress_sig = progress.get("model_signature")
                if progress_sig and progress_sig != self.model_signature:
                    raise ValueError(
                        f"model signature mismatch: checkpoint={progress_sig} current={self.model_signature}"
                    )
                refs_done = int(progress.get("refs_done", 0) or 0)
                partial_embs = np.load(partial_embs_path)
                partial_ids = np.load(partial_ids_path, allow_pickle=True).tolist()
                all_embs = [row for row in partial_embs]
                all_ids = partial_ids
                print(
                    f"[build-reference-index] resuming from checkpoint: refs_done={refs_done}/{total_refs} "
                    f"windows_done={len(all_ids)}"
                , flush=True)
            except Exception as exc:
                print(f"[build-reference-index] resume checkpoint ignored due to load failure: {exc}", flush=True)
                refs_done = 0
                all_embs = []
                all_ids = []
                for stale_path in (partial_embs_path, partial_ids_path):
                    try:
                        if stale_path.exists():
                            stale_path.unlink()
                    except OSError:
                        pass

        print(
            f"[build-reference-index] start: refs={total_refs} device={self.device.type} "
            f"window_sec={window_sec} stride_sec={stride_sec} resume={resume} refs_done={refs_done}"
        , flush=True)

        def write_checkpoint(ref_idx: int):
            if not all_embs:
                return
            elapsed = max(time.time() - start_time, 1e-6)
            refs_per_sec = ref_idx / elapsed
            eta_sec = (total_refs - ref_idx) / refs_per_sec if refs_per_sec > 0 else 0.0
            emb_array = np.vstack(all_embs)
            np.save(partial_embs_path, emb_array)
            np.save(partial_ids_path, np.array(all_ids))
            progress_path.write_text(json.dumps({
                "status": "building",
                "refs_done": ref_idx,
                "refs_total": total_refs,
                "windows_done": len(all_ids),
                "elapsed_sec": round(elapsed, 3),
                "eta_sec": round(eta_sec, 3),
                "device": self.device.type,
                "window_sec": window_sec,
                "stride_sec": stride_sec,
                "model_signature": self.model_signature,
                "partial_embs_path": str(partial_embs_path),
                "partial_ids_path": str(partial_ids_path),
            }, indent=2))

        def write_complete(total_windows: int, emb_shape: tuple[int, ...]):
            elapsed = max(time.time() - start_time, 1e-6)
            progress_path.write_text(json.dumps({
                "status": "complete",
                "refs_done": total_refs,
                "refs_total": total_refs,
                "windows_done": total_windows,
                "elapsed_sec": round(elapsed, 3),
                "device": self.device.type,
                "window_sec": window_sec,
                "stride_sec": stride_sec,
                "model_signature": self.model_signature,
                "final_embs_path": str(final_embs_path),
                "final_ids_path": str(final_ids_path),
                "embedding_shape": list(emb_shape),
            }, indent=2))

        if refs_done > total_refs:
            print(f"[build-reference-index] resume refs_done={refs_done} exceeds refs_total={total_refs}; restarting", flush=True)
            refs_done = 0
            all_embs = []
            all_ids = []

        for ref_idx, item in enumerate(refs[refs_done:], start=refs_done + 1):
            audio_path = self._resolve_audio_path(songs_dir, item["audio_path"])
            if not audio_path.exists():
                continue
            song_id = item["song_id"]
            y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
            windows = self._windows(y, window_sec=window_sec, stride_sec=stride_sec)
            for seg in windows:
                mel = self._to_mel(seg).to(self.device)
                with torch.no_grad():
                    emb, _ = self.model(mel)
                all_embs.append(emb.cpu().numpy().flatten())
                all_ids.append(song_id)
            if ref_idx == 1 or ref_idx % 250 == 0 or ref_idx == total_refs:
                elapsed = max(time.time() - start_time, 1e-6)
                refs_per_sec = ref_idx / elapsed
                eta_sec = (total_refs - ref_idx) / refs_per_sec if refs_per_sec > 0 else 0.0
                print(
                    f"[build-reference-index] progress: refs={ref_idx}/{total_refs} "
                    f"windows={len(all_ids)} elapsed_sec={elapsed:.1f} eta_sec={eta_sec:.1f}"
                , flush=True)
            if checkpoint_every_refs > 0 and (ref_idx % checkpoint_every_refs == 0 or ref_idx == total_refs):
                write_checkpoint(ref_idx)

        if not all_embs:
            raise ValueError(
                f"No reference embeddings were produced from metadata={metadata_path} songs_dir={songs_dir}"
            )

        all_embs = np.vstack(all_embs)
        np.save(final_embs_path, all_embs)
        np.save(final_ids_path, np.array(all_ids))
        write_complete(len(all_ids), all_embs.shape)
        print(f"Built reference index: {len(all_ids)} windows, embeddings shape {all_embs.shape}", flush=True)
        return all_embs, all_ids

    def search(self, query_emb: np.ndarray, ref_embs: np.ndarray, ref_ids: List[str], top_k: int = 10):
        query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
        ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
        scores = query_norm @ ref_norm.T
        top_indices = np.argsort(-scores)[:top_k]
        return [(ref_ids[i], float(scores[i])) for i in top_indices]