hybrid_engine.py 4.99 KB
"""
Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking.
"""

import json
import time
from typing import Dict, List, Optional

import librosa
import numpy as np


class Candidate:
    def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0):
        self.song_id = song_id
        self.chroma_score = chroma_score
        self.ecapa_score = ecapa_score
        self.metadata: Dict = {}

    def combined_score(self, chroma_weight: float, ecapa_weight: float) -> float:
        return chroma_weight * self.chroma_score + ecapa_weight * self.ecapa_score

    def __repr__(self):
        return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})"


class HybridEngine:
    def __init__(
        self,
        chroma_matcher=None,
        ecapa_embedder=None,
        ref_embs: Optional[np.ndarray] = None,
        ref_ids: Optional[List[str]] = None,
        sr: int = 16000,
        chroma_weight: float = 0.35,
        ecapa_weight: float = 0.65,
        reject_threshold: float = 0.35,
    ):
        self.chroma = chroma_matcher
        self.ecapa = ecapa_embedder
        self.ref_embs = ref_embs
        self.ref_ids = ref_ids
        self.sr = sr
        self.chroma_weight = chroma_weight
        self.ecapa_weight = ecapa_weight
        self.reject_threshold = reject_threshold
        self.song_metadata: Dict[str, Dict] = {}

    def load_metadata(self, metadata_path: str):
        with open(metadata_path) as f:
            items = json.load(f)
        for item in items:
            sid = item["song_id"]
            existing = self.song_metadata.get(sid, {})
            if item.get("type") == "reference" or not existing:
                self.song_metadata[sid] = {
                    "song_id": sid,
                    "base_freq": item.get("base_freq", existing.get("base_freq", 0)),
                    "audio_path": item.get("audio_path", existing.get("audio_path", "")),
                    "type": item.get("type", existing.get("type", "unknown")),
                }

    @staticmethod
    def _normalize_scores(score_pairs: List[tuple], invert: bool = False) -> Dict[str, float]:
        if not score_pairs:
            return {}
        ids = [sid for sid, _ in score_pairs]
        values = np.array([float(score) for _, score in score_pairs], dtype=np.float32)
        if invert:
            values = -values
        if len(values) == 1:
            return {ids[0]: 1.0}
        vmin = float(values.min())
        vmax = float(values.max())
        if abs(vmax - vmin) < 1e-8:
            return {sid: 1.0 for sid in ids}
        norm = (values - vmin) / (vmax - vmin)
        return {sid: float(score) for sid, score in zip(ids, norm)}

    def recognize(
        self,
        audio_path: str,
        top_n: int = 5,
        mode: str = "auto",
    ) -> Dict:
        del mode
        start = time.time()
        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)

        chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else []
        chroma_norm = self._normalize_scores(chroma_matches)

        ecapa_matches = []
        if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None:
            query_emb = self.ecapa.extract_embedding_from_wave(y)
            ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12)
            query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
            scores = query_norm @ ref_norm.T
            top_indices = np.argsort(-scores)[: max(top_n * 5, 20)]
            ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices]
        ecapa_norm = self._normalize_scores(ecapa_matches)

        all_song_ids = set(chroma_norm) | set(ecapa_norm)
        combined: List[Candidate] = []
        for song_id in all_song_ids:
            candidate = Candidate(
                song_id=song_id,
                chroma_score=chroma_norm.get(song_id, 0.0),
                ecapa_score=ecapa_norm.get(song_id, 0.0),
            )
            candidate.metadata = self.song_metadata.get(song_id, {})
            combined.append(candidate)

        combined.sort(key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight), reverse=True)
        results = combined[:top_n]
        elapsed = (time.time() - start) * 1000

        output = []
        for c in results:
            fused = c.combined_score(self.chroma_weight, self.ecapa_weight)
            output.append(
                {
                    "song_id": c.song_id,
                    "confidence": round(fused, 4),
                    "chromaprint_score": round(c.chroma_score, 4),
                    "ecapa_score": round(c.ecapa_score, 4),
                    "accepted": fused >= self.reject_threshold,
                    "metadata": c.metadata,
                }
            )

        return {
            "candidates": output,
            "processing_time_ms": round(elapsed, 1),
            "num_candidates": len(results),
        }