chromaprint_matcher.py
4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Simplified Chromaprint-style fingerprint matcher.
Implements landmark-based audio fingerprinting:
1. Extract spectral peaks from spectrogram
2. Build hash table from peak pairs
3. Match queries via hash lookup + time offset histogram voting
"""
import numpy as np
import librosa
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import pickle
import json
from pathlib import Path
class Fingerprint:
def __init__(self, song_id: str, offset: int, hash_val: int):
self.song_id = song_id
self.offset = offset
self.hash = hash_val
class ChromaprintMatcher:
def __init__(
self,
sr: int = 16000,
n_fft: int = 1024,
hop_length: int = 256,
peak_neighborhood: int = 20,
target_zone_width: int = 50,
min_peak_energy: float = 0.01,
):
self.sr = sr
self.n_fft = n_fft
self.hop_length = hop_length
self.peak_neighborhood = peak_neighborhood
self.target_zone_width = target_zone_width
self.min_peak_energy = min_peak_energy
self.hash_db: Dict[int, List[Fingerprint]] = defaultdict(list)
def _spectrogram(self, y: np.ndarray) -> np.ndarray:
S = np.abs(librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length))
return S
def _find_peaks(self, S: np.ndarray) -> List[Tuple[int, int, float]]:
peaks = []
for t in range(0, S.shape[1] - self.peak_neighborhood):
for f in range(0, S.shape[0] - self.peak_neighborhood):
region = S[f:f + self.peak_neighborhood, t:t + self.peak_neighborhood]
center = S[f, t]
if center == np.max(region) and center > self.min_peak_energy:
peaks.append((t, f, center))
peaks.sort(key=lambda x: x[2], reverse=True)
return peaks[:200]
def _hash_peaks(self, peaks: List[Tuple[int, int, float]]) -> List[Tuple[int, int, int]]:
hashes = []
for i in range(len(peaks)):
for j in range(i + 1, len(peaks)):
t1, f1, _ = peaks[i]
t2, f2, _ = peaks[j]
if 0 < t2 - t1 < self.target_zone_width:
h = (f1 << 16) | (f2 << 8) | (t2 - t1)
hashes.append((h, t1))
return hashes
def index_song(self, song_id: str, y: np.ndarray):
S = self._spectrogram(y)
peaks = self._find_peaks(S)
hashes = self._hash_peaks(peaks)
for h, offset in hashes:
self.hash_db[h].append(Fingerprint(song_id, offset, h))
def index_songs_from_dir(
self, songs_dir: str, metadata_path: str, cache_path: Optional[str] = None
):
with open(metadata_path) as f:
meta = json.load(f)
songs_dir = Path(songs_dir)
for item in meta:
if "songs" not in item.get("audio_path", ""):
continue
audio_path = songs_dir.parent / item["audio_path"]
if not audio_path.exists():
continue
song_id = item["song_id"]
y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
self.index_song(song_id, y)
if cache_path:
self.save(cache_path)
def match(self, y: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
S = self._spectrogram(y)
peaks = self._find_peaks(S)
hashes = self._hash_peaks(peaks)
song_votes: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int))
for h, q_offset in hashes:
for fp in self.hash_db.get(h, []):
delta = fp.offset - q_offset
song_votes[fp.song_id][delta] += 1
results = []
for song_id, deltas in song_votes.items():
peak_score = max(deltas.values())
total_score = sum(deltas.values())
combined = peak_score * 1.0 + total_score * 0.1
results.append((song_id, combined))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def save(self, path: str):
data = {}
for h, fps in self.hash_db.items():
data[h] = [(fp.song_id, fp.offset) for fp in fps]
with open(path, "wb") as f:
pickle.dump(data, f)
def load(self, path: str):
with open(path, "rb") as f:
data = pickle.load(f)
self.hash_db.clear()
for h, items in data.items():
self.hash_db[h] = [Fingerprint(sid, off, h) for sid, off in items]
@property
def index_size(self) -> int:
return sum(len(v) for v in self.hash_db.values())
@property
def num_hashes(self) -> int:
return len(self.hash_db)