chromaprint_matcher.py 4.89 KB
"""
Simplified Chromaprint-style fingerprint matcher.

Implements landmark-based audio fingerprinting:
1. Extract spectral peaks from spectrogram
2. Build hash table from peak pairs
3. Match queries via hash lookup + time offset histogram voting
"""

import numpy as np
import librosa
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import pickle
import json
from pathlib import Path


class Fingerprint:
    def __init__(self, song_id: str, offset: int, hash_val: int):
        self.song_id = song_id
        self.offset = offset
        self.hash = hash_val


class ChromaprintMatcher:
    def __init__(
        self,
        sr: int = 16000,
        n_fft: int = 1024,
        hop_length: int = 256,
        peak_neighborhood: int = 20,
        target_zone_width: int = 50,
        min_peak_energy: float = 0.01,
    ):
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.peak_neighborhood = peak_neighborhood
        self.target_zone_width = target_zone_width
        self.min_peak_energy = min_peak_energy
        self.hash_db: Dict[int, List[Fingerprint]] = defaultdict(list)

    def _resolve_audio_path(self, songs_dir: Path, rel_path: str) -> Path:
        candidate = songs_dir / rel_path
        if candidate.exists():
            return candidate
        candidate = songs_dir.parent / rel_path
        return candidate

    def _spectrogram(self, y: np.ndarray) -> np.ndarray:
        S = np.abs(librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length))
        return S

    def _find_peaks(self, S: np.ndarray) -> List[Tuple[int, int, float]]:
        peaks = []
        for t in range(0, S.shape[1] - self.peak_neighborhood):
            for f in range(0, S.shape[0] - self.peak_neighborhood):
                region = S[f:f + self.peak_neighborhood, t:t + self.peak_neighborhood]
                center = S[f, t]
                if center == np.max(region) and center > self.min_peak_energy:
                    peaks.append((t, f, center))
        peaks.sort(key=lambda x: x[2], reverse=True)
        return peaks[:200]

    def _hash_peaks(self, peaks: List[Tuple[int, int, float]]) -> List[Tuple[int, int, int]]:
        hashes = []
        for i in range(len(peaks)):
            for j in range(i + 1, len(peaks)):
                t1, f1, _ = peaks[i]
                t2, f2, _ = peaks[j]
                if 0 < t2 - t1 < self.target_zone_width:
                    h = (f1 << 16) | (f2 << 8) | (t2 - t1)
                    hashes.append((h, t1))
        return hashes

    def index_song(self, song_id: str, y: np.ndarray):
        S = self._spectrogram(y)
        peaks = self._find_peaks(S)
        hashes = self._hash_peaks(peaks)
        for h, offset in hashes:
            self.hash_db[h].append(Fingerprint(song_id, offset, h))

    def index_songs_from_dir(
        self, songs_dir: str, metadata_path: str, cache_path: Optional[str] = None
    ):
        with open(metadata_path) as f:
            meta = json.load(f)

        songs_dir = Path(songs_dir)
        for item in meta:
            if item.get("type") != "reference":
                continue
            audio_path = self._resolve_audio_path(songs_dir, item["audio_path"])
            if not audio_path.exists():
                continue
            song_id = item["song_id"]
            y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
            self.index_song(song_id, y)

        if cache_path:
            self.save(cache_path)

    def match(self, y: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
        S = self._spectrogram(y)
        peaks = self._find_peaks(S)
        hashes = self._hash_peaks(peaks)

        song_votes: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int))
        for h, q_offset in hashes:
            for fp in self.hash_db.get(h, []):
                delta = fp.offset - q_offset
                song_votes[fp.song_id][delta] += 1

        results = []
        for song_id, deltas in song_votes.items():
            peak_score = max(deltas.values())
            total_score = sum(deltas.values())
            combined = peak_score * 1.0 + total_score * 0.1
            results.append((song_id, combined))

        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def save(self, path: str):
        data = {}
        for h, fps in self.hash_db.items():
            data[h] = [(fp.song_id, fp.offset) for fp in fps]
        with open(path, "wb") as f:
            pickle.dump(data, f)

    def load(self, path: str):
        with open(path, "rb") as f:
            data = pickle.load(f)
        self.hash_db.clear()
        for h, items in data.items():
            self.hash_db[h] = [Fingerprint(sid, off, h) for sid, off in items]

    @property
    def index_size(self) -> int:
        return sum(len(v) for v in self.hash_db.values())

    @property
    def num_hashes(self) -> int:
        return len(self.hash_db)