hybrid_engine.py 4.18 KB
"""
Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking.
"""

import numpy as np
import librosa
from typing import List, Tuple, Optional, Dict
from pathlib import Path
import json
import time


class Candidate:
    def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0):
        self.song_id = song_id
        self.chroma_score = chroma_score
        self.ecapa_score = ecapa_score
        self.metadata: Dict = {}

    @property
    def combined_score(self) -> float:
        return 0.3 * self.chroma_score + 0.7 * self.ecapa_score

    def __repr__(self):
        return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})"


class HybridEngine:
    def __init__(
        self,
        chroma_matcher=None,
        ecapa_embedder=None,
        ref_embs: Optional[np.ndarray] = None,
        ref_ids: Optional[List[str]] = None,
        sr: int = 16000,
        chroma_weight: float = 0.3,
        ecapa_weight: float = 0.7,
        reject_threshold: float = 0.4,
    ):
        self.chroma = chroma_matcher
        self.ecapa = ecapa_embedder
        self.ref_embs = ref_embs
        self.ref_ids = ref_ids
        self.sr = sr
        self.chroma_weight = chroma_weight
        self.ecapa_weight = ecapa_weight
        self.reject_threshold = reject_threshold

        self.song_metadata: Dict[str, Dict] = {}

    def load_metadata(self, metadata_path: str):
        with open(metadata_path) as f:
            items = json.load(f)
        for item in items:
            sid = item["song_id"]
            if sid not in self.song_metadata:
                base = item.get("base_freq", 0)
                self.song_metadata[sid] = {
                    "song_id": sid,
                    "base_freq": base,
                    "audio_path": item.get("audio_path", ""),
                }

    def recognize(
        self,
        audio_path: str,
        top_n: int = 5,
        mode: str = "auto",
    ) -> List[Dict]:
        start = time.time()
        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)

        chroma_candidates: List[Candidate] = []
        if self.chroma is not None:
            chroma_matches = self.chroma.match(y, top_k=50)
            seen = set()
            for song_id, score in chroma_matches:
                if song_id not in seen:
                    seen.add(song_id)
                    c = Candidate(song_id, chroma_score=score)
                    chroma_candidates.append(c)

        ecapa_candidates: List[Candidate] = []
        if self.ecapa is not None and self.ref_embs is not None:
            query_emb = self.ecapa.extract_embedding_from_wave(y)
            ref_norm = self.ref_embs / (
                np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12
            )
            query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
            scores = query_norm @ ref_norm.T
            top_indices = np.argsort(-scores)[:top_n]
            for idx in top_indices:
                c = Candidate(self.ref_ids[idx], ecapa_score=float(scores[idx]))
                ecapa_candidates.append(c)

        combined: Dict[str, Candidate] = {}
        for c in chroma_candidates:
            combined[c.song_id] = c
        for c in ecapa_candidates:
            if c.song_id in combined:
                combined[c.song_id].ecapa_score = c.ecapa_score
            else:
                combined[c.song_id] = c

        for sid in list(combined.keys()):
            combined[sid].metadata = self.song_metadata.get(sid, {})

        results = sorted(
            combined.values(),
            key=lambda c: c.combined_score,
            reverse=True,
        )[:top_n]

        elapsed = (time.time() - start) * 1000

        output = []
        for c in results:
            output.append({
                "song_id": c.song_id,
                "confidence": round(c.combined_score, 4),
                "chromaprint_score": round(c.chroma_score, 4),
                "ecapa_score": round(c.ecapa_score, 4),
                "metadata": c.metadata,
            })

        return {
            "candidates": output,
            "processing_time_ms": round(elapsed, 1),
            "num_candidates": len(results),
        }