hybrid_engine.py 6.17 KB
"""Hybrid ACR Engine: Chromaprint + ECAPA + melody-aware re-ranking."""

import json
import time
from pathlib import Path
from typing import Dict, List, Optional

import librosa
import numpy as np

from src.utils.audio import AudioProcessor


class Candidate:
    def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0, melody_score: float = 0.0):
        self.song_id = song_id
        self.chroma_score = chroma_score
        self.ecapa_score = ecapa_score
        self.melody_score = melody_score
        self.metadata: Dict = {}

    def combined_score(self, chroma_weight: float, ecapa_weight: float, melody_weight: float) -> float:
        return (
            chroma_weight * self.chroma_score
            + ecapa_weight * self.ecapa_score
            + melody_weight * self.melody_score
        )


class HybridEngine:
    def __init__(
        self,
        chroma_matcher=None,
        ecapa_embedder=None,
        ref_embs: Optional[np.ndarray] = None,
        ref_ids: Optional[List[str]] = None,
        sr: int = 16000,
        chroma_weight: float = 0.25,
        ecapa_weight: float = 0.5,
        melody_weight: float = 0.25,
        reject_threshold: float = 0.35,
    ):
        self.chroma = chroma_matcher
        self.ecapa = ecapa_embedder
        self.ref_embs = ref_embs
        self.ref_ids = ref_ids
        self.sr = sr
        self.chroma_weight = chroma_weight
        self.ecapa_weight = ecapa_weight
        self.melody_weight = melody_weight
        self.reject_threshold = reject_threshold
        self.song_metadata: Dict[str, Dict] = {}
        self.song_audio_paths: Dict[str, str] = {}
        self.audio = AudioProcessor(sr=sr)

    def load_metadata(self, metadata_path: str):
        with open(metadata_path) as f:
            items = json.load(f)
        base_dir = str(Path(metadata_path).parent)
        for item in items:
            sid = item["song_id"]
            existing = self.song_metadata.get(sid, {})
            if item.get("type") == "reference" or not existing:
                self.song_metadata[sid] = {
                    "song_id": sid,
                    "base_freq": item.get("base_freq", existing.get("base_freq", 0)),
                    "audio_path": item.get("audio_path", existing.get("audio_path", "")),
                    "type": item.get("type", existing.get("type", "unknown")),
                }
            if item.get("type") == "reference":
                self.song_audio_paths[sid] = str(Path(base_dir) / item["audio_path"])

    @staticmethod
    def _normalize_scores(score_pairs: List[tuple]) -> Dict[str, float]:
        if not score_pairs:
            return {}
        ids = [sid for sid, _ in score_pairs]
        values = np.array([float(score) for _, score in score_pairs], dtype=np.float32)
        if len(values) == 1:
            return {ids[0]: 1.0}
        vmin = float(values.min())
        vmax = float(values.max())
        if abs(vmax - vmin) < 1e-8:
            return {sid: 1.0 for sid in ids}
        norm = (values - vmin) / (vmax - vmin)
        return {sid: float(score) for sid, score in zip(ids, norm)}

    def _melody_scores(self, query_y: np.ndarray, candidate_ids: List[str]) -> Dict[str, float]:
        scores = []
        for song_id in candidate_ids:
            ref_path = self.song_audio_paths.get(song_id)
            if not ref_path or not Path(ref_path).exists():
                continue
            ref_y, _ = librosa.load(ref_path, sr=self.sr, mono=True, duration=8.0)
            score = self.audio.melody_similarity(query_y, ref_y)
            scores.append((song_id, score))
        return self._normalize_scores(scores)

    def recognize(self, audio_path: str, top_n: int = 5, mode: str = "auto") -> Dict:
        del mode
        start = time.time()
        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)

        chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else []
        chroma_norm = self._normalize_scores(chroma_matches)

        ecapa_matches = []
        if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None:
            query_emb = self.ecapa.extract_embedding_from_wave(y)
            ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12)
            query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
            scores = query_norm @ ref_norm.T
            top_indices = np.argsort(-scores)[: max(top_n * 10, 30)]
            ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices]
        ecapa_norm = self._normalize_scores(ecapa_matches)

        candidate_pool = list(set(list(chroma_norm.keys())[: top_n * 8] + list(ecapa_norm.keys())[: top_n * 8]))
        melody_norm = self._melody_scores(y, candidate_pool)

        all_song_ids = set(candidate_pool) | set(melody_norm)
        combined: List[Candidate] = []
        for song_id in all_song_ids:
            candidate = Candidate(
                song_id=song_id,
                chroma_score=chroma_norm.get(song_id, 0.0),
                ecapa_score=ecapa_norm.get(song_id, 0.0),
                melody_score=melody_norm.get(song_id, 0.0),
            )
            candidate.metadata = self.song_metadata.get(song_id, {})
            combined.append(candidate)

        combined.sort(
            key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight, self.melody_weight),
            reverse=True,
        )
        results = combined[:top_n]
        elapsed = (time.time() - start) * 1000

        output = []
        for c in results:
            fused = c.combined_score(self.chroma_weight, self.ecapa_weight, self.melody_weight)
            output.append(
                {
                    "song_id": c.song_id,
                    "confidence": round(fused, 4),
                    "chromaprint_score": round(c.chroma_score, 4),
                    "ecapa_score": round(c.ecapa_score, 4),
                    "melody_score": round(c.melody_score, 4),
                    "accepted": fused >= self.reject_threshold,
                    "metadata": c.metadata,
                }
            )

        return {"candidates": output, "processing_time_ms": round(elapsed, 1), "num_candidates": len(results)}