chromaprint_matcher.py 9.86 KB
"""
Simplified Chromaprint-style fingerprint matcher.

Implements landmark-based audio fingerprinting:
1. Extract spectral peaks from spectrogram
2. Build hash table from peak pairs
3. Match queries via hash lookup + time offset histogram voting
"""

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import pickle
import json
from pathlib import Path
import time
import wave

try:
    import librosa  # type: ignore
except ImportError:  # pragma: no cover - optional dependency
    librosa = None


def _resample_linear(y: np.ndarray, src_sr: int, target_sr: int) -> np.ndarray:
    if src_sr == target_sr or y.size == 0:
        return y.astype(np.float32, copy=False)
    duration = y.shape[0] / float(src_sr)
    target_len = max(int(round(duration * target_sr)), 1)
    src_x = np.linspace(0.0, duration, num=y.shape[0], endpoint=False)
    dst_x = np.linspace(0.0, duration, num=target_len, endpoint=False)
    return np.interp(dst_x, src_x, y).astype(np.float32, copy=False)


def load_audio_mono(path: str, sr: int) -> tuple[np.ndarray, int]:
    if librosa is not None:
        y, _ = librosa.load(path, sr=sr, mono=True)
        return y.astype(np.float32, copy=False), sr

    with wave.open(path, 'rb') as wav_file:
        src_sr = wav_file.getframerate()
        channels = wav_file.getnchannels()
        sample_width = wav_file.getsampwidth()
        frame_count = wav_file.getnframes()
        raw = wav_file.readframes(frame_count)

    if sample_width == 1:
        y = np.frombuffer(raw, dtype=np.uint8).astype(np.float32)
        y = (y - 128.0) / 128.0
    elif sample_width == 2:
        y = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    elif sample_width == 4:
        y = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
    else:
        raise ValueError(f'unsupported wav sample width: {sample_width}')

    if channels > 1:
        y = y.reshape(-1, channels).mean(axis=1)
    y = _resample_linear(y, src_sr, sr)
    return y, sr


class Fingerprint:
    def __init__(self, song_id: str, offset: int, hash_val: int):
        self.song_id = song_id
        self.offset = offset
        self.hash = hash_val


class ChromaprintMatcher:
    def __init__(
        self,
        sr: int = 16000,
        n_fft: int = 1024,
        hop_length: int = 256,
        peak_neighborhood: int = 20,
        target_zone_width: int = 50,
        min_peak_energy: float = 0.01,
    ):
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.peak_neighborhood = peak_neighborhood
        self.target_zone_width = target_zone_width
        self.min_peak_energy = min_peak_energy
        self.hash_db: Dict[int, List[Fingerprint]] = defaultdict(list)

    def _resolve_audio_path(self, songs_dir: Path, rel_path: str) -> Path:
        candidate = songs_dir / rel_path
        if candidate.exists():
            return candidate
        candidate = songs_dir.parent / rel_path
        return candidate

    def _spectrogram(self, y: np.ndarray) -> np.ndarray:
        if librosa is not None:
            return np.abs(librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length))

        if y.shape[0] < self.n_fft:
            y = np.pad(y, (0, self.n_fft - y.shape[0]))
        frame_count = 1 + max((y.shape[0] - self.n_fft) // self.hop_length, 0)
        frames = np.stack(
            [y[i * self.hop_length:i * self.hop_length + self.n_fft] for i in range(frame_count)],
            axis=1,
        )
        window = np.hanning(self.n_fft).astype(np.float32)
        frames = frames * window[:, None]
        return np.abs(np.fft.rfft(frames, axis=0))

    def _find_peaks(self, S: np.ndarray) -> List[Tuple[int, int, float]]:
        if S.shape[0] <= self.peak_neighborhood or S.shape[1] <= self.peak_neighborhood:
            return []

        windows = sliding_window_view(S, (self.peak_neighborhood, self.peak_neighborhood))[:-1, :-1]
        region_max = windows.max(axis=(-1, -2))
        centers = S[: S.shape[0] - self.peak_neighborhood, : S.shape[1] - self.peak_neighborhood]
        mask = (centers == region_max) & (centers > self.min_peak_energy)

        peaks = [
            (int(t), int(f), float(centers[f, t]))
            for f, t in np.argwhere(mask)
        ]
        peaks.sort(key=lambda x: x[2], reverse=True)
        return peaks[:200]

    def _hash_peaks(self, peaks: List[Tuple[int, int, float]]) -> List[Tuple[int, int, int]]:
        hashes = []
        for i in range(len(peaks)):
            for j in range(i + 1, len(peaks)):
                t1, f1, _ = peaks[i]
                t2, f2, _ = peaks[j]
                if 0 < t2 - t1 < self.target_zone_width:
                    h = (f1 << 16) | (f2 << 8) | (t2 - t1)
                    hashes.append((h, t1))
        return hashes

    def index_song(self, song_id: str, y: np.ndarray):
        hashes = self.extract_hashes(y)
        for h, offset in hashes:
            self.hash_db[h].append(Fingerprint(song_id, offset, h))

    def extract_hashes(self, y: np.ndarray) -> List[Tuple[int, int]]:
        S = self._spectrogram(y)
        peaks = self._find_peaks(S)
        return self._hash_peaks(peaks)

    def index_songs_from_dir(
        self,
        songs_dir: str,
        metadata_path: str,
        cache_path: Optional[str] = None,
        checkpoint_every_refs: int = 0,
        progress_path: Optional[str] = None,
    ):
        with open(metadata_path) as f:
            meta = json.load(f)

        songs_dir = Path(songs_dir)
        refs = [item for item in meta if item.get("type") == "reference"]
        total_refs = len(refs)
        start_time = time.time()
        skipped_refs = 0

        progress_file = Path(progress_path) if progress_path else None
        cache_file = Path(cache_path) if cache_path else None

        def write_progress(refs_done: int, status: str):
            if progress_file is None:
                return
            elapsed = max(time.time() - start_time, 1e-6)
            refs_per_sec = refs_done / elapsed if refs_done > 0 else 0.0
            eta_sec = (total_refs - refs_done) / refs_per_sec if refs_per_sec > 0 else 0.0
            progress_file.write_text(json.dumps({
                "status": status,
                "refs_done": refs_done,
                "refs_total": total_refs,
                "elapsed_sec": round(elapsed, 3),
                "eta_sec": round(eta_sec, 3),
                "hashes": self.num_hashes,
                "postings": self.index_size,
                "skipped_refs": skipped_refs,
                "cache_path": str(cache_file) if cache_file else None,
            }, indent=2))

        for ref_idx, item in enumerate(refs, start=1):
            audio_path = self._resolve_audio_path(songs_dir, item["audio_path"])
            if not audio_path.exists():
                skipped_refs += 1
                print(
                    f"[chromaprint-index] skip missing audio: song_id={item.get('song_id')} path={audio_path}",
                    flush=True,
                )
                continue
            song_id = item["song_id"]
            try:
                y, _ = load_audio_mono(str(audio_path), sr=self.sr)
            except Exception as exc:
                skipped_refs += 1
                print(
                    f"[chromaprint-index] skip decode failure: song_id={song_id} path={audio_path} error={exc}",
                    flush=True,
                )
                continue
            self.index_song(song_id, y)
            if ref_idx == 1 or ref_idx == total_refs or (checkpoint_every_refs > 0 and ref_idx % checkpoint_every_refs == 0):
                elapsed = max(time.time() - start_time, 1e-6)
                refs_per_sec = ref_idx / elapsed
                eta_sec = (total_refs - ref_idx) / refs_per_sec if refs_per_sec > 0 else 0.0
                print(
                    f"[chromaprint-index] progress: refs={ref_idx}/{total_refs} "
                    f"hashes={self.num_hashes} postings={self.index_size} "
                    f"elapsed_sec={elapsed:.1f} eta_sec={eta_sec:.1f} skipped_refs={skipped_refs}"
                , flush=True)
            if checkpoint_every_refs > 0 and ref_idx % checkpoint_every_refs == 0:
                if cache_file is not None:
                    self.save(str(cache_file))
                write_progress(ref_idx, "building")

        if cache_file is not None:
            self.save(str(cache_file))
        write_progress(total_refs, "complete")

    def match(self, y: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
        S = self._spectrogram(y)
        peaks = self._find_peaks(S)
        hashes = self._hash_peaks(peaks)

        song_votes: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int))
        for h, q_offset in hashes:
            for fp in self.hash_db.get(h, []):
                delta = fp.offset - q_offset
                song_votes[fp.song_id][delta] += 1

        results = []
        for song_id, deltas in song_votes.items():
            peak_score = max(deltas.values())
            total_score = sum(deltas.values())
            combined = peak_score * 1.0 + total_score * 0.1
            results.append((song_id, combined))

        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def save(self, path: str):
        data = {}
        for h, fps in self.hash_db.items():
            data[h] = [(fp.song_id, fp.offset) for fp in fps]
        with open(path, "wb") as f:
            pickle.dump(data, f)

    def load(self, path: str):
        with open(path, "rb") as f:
            data = pickle.load(f)
        self.hash_db.clear()
        for h, items in data.items():
            self.hash_db[h] = [Fingerprint(sid, off, h) for sid, off in items]

    @property
    def index_size(self) -> int:
        return sum(len(v) for v in self.hash_db.values())

    @property
    def num_hashes(self) -> int:
        return len(self.hash_db)