Commit 44d8268c 44d8268ccb015842859f040e802698bfa566c5c2 by cnb.bofCdSsphPA

Make the ACR prototype explainable and runnable

Add missing project documentation and a minimal executable demo flow so the repository can be understood and validated end to end.

Constraint: The existing repo had design fragments but no verified runnable path
Rejected: Delay documentation until after full productization | would keep scope opaque and slow iteration
Confidence: medium
Scope-risk: moderate
Directive: Keep future stages checkpointed with changelog entries and runnable verification commands
Tested: synthetic dataset generation; train.py --dry-run; 1 epoch CPU training; index build; recognition JSON output
Not-tested: production-scale retrieval; real copyrighted audio; API serving
1 parent e25a16be
1 # ACR Engine
2
3 一个可运行的听歌识曲原型,包含:
4
5 - 合成数据集生成
6 - 传统音频指纹(landmark hash)匹配
7 - 深度 embedding 检索(ECAPA-TDNN)
8 - Hybrid 混合识别入口
9
10 ## 快速开始
11
12 ```bash
13 cd acr-engine
14 python -m venv .venv
15 source .venv/bin/activate
16 pip install -r requirements.txt
17 python run_demo.py full-demo
18 ```
19
20 ## 常用命令
21
22 ### 1. 生成合成数据
23 ```bash
24 python run_demo.py generate-data --output data/synthetic --num-songs 24
25 ```
26
27 ### 2. 训练前做干跑校验
28 ```bash
29 python train.py --data data/synthetic --dry-run --device cpu
30 ```
31
32 ### 3. 训练一个最小模型
33 ```bash
34 python train.py --data data/synthetic --output data/models --device cpu --epochs 1 --batch-size 8
35 ```
36
37 ### 4. 构建指纹与 embedding 索引
38 ```bash
39 python run_demo.py build-index --data data/synthetic --model data/models/best_model.pt --output data/index
40 ```
41
42 ### 5. 跑识别
43 ```bash
44 python run_demo.py recognize \
45 --query data/synthetic/segments/song_0020_seg_00.wav \
46 --data data/synthetic \
47 --model data/models/best_model.pt \
48 --index-prefix data/index/reference
49 ```
50
51 ### 6. 一键最小闭环
52 ```bash
53 python run_demo.py full-demo --device cpu
54 ```
55
56 ## 目录
57
58 - `train.py`:训练入口
59 - `run_demo.py`:数据生成 / 建索引 / 识别 / 一键 demo
60 - `src/data`:数据集和合成数据生成
61 - `src/models`:ECAPA 模型与损失
62 - `src/engines`:指纹、embedding、hybrid 检索
63 - `configs/default.yaml`:默认配置
64
65 ## 当前定位
66
67 这是一个**原型仓库**,目标是验证 ACR 主链路能否跑通,不是生产级服务。
1 model:
2 name: ecapa_tdnn
3 embed_dim: 192
4 channels: 512
5 se_channels: 128
6 res2net_scale: 8
7 num_blocks: 3
8 n_mels: 80
9 aam_m: 0.3
10 aam_s: 30.0
11
12 data:
13 sample_rate: 16000
14 n_fft: 512
15 hop_length: 160
16 segment_dur: 5.0
17 crop_per_song: 4
18
19 training:
20 batch_size: 32
21 epochs: 50
22 lr: 0.001
23 weight_decay: 0.0001
24 warmup_epochs: 5
25 temperature: 0.07
26 supcon_weight: 1.0
27 aam_weight: 0.3
28 mixed_precision: true
29 gradient_clip: 1.0
30 save_every: 10
31 log_every: 10
32
33 engine:
34 chromaprint:
35 enabled: true
36 n_fft: 1024
37 hop_length: 256
38 hybrid:
39 chroma_weight: 0.3
40 ecapa_weight: 0.7
41 reject_threshold: 0.4
1 numpy>=1.26
2 PyYAML>=6.0
3 soundfile>=0.12
4 librosa>=0.10
5 tqdm>=4.66
6 torch>=2.3
1 #!/usr/bin/env python3
2 import argparse
3 import json
4 import sys
5 from pathlib import Path
6
7 import numpy as np
8
9 ROOT = Path(__file__).parent
10 sys.path.insert(0, str(ROOT))
11
12 from src.data.synthetic import generate_dataset
13 from src.engines.chromaprint_matcher import ChromaprintMatcher
14 from src.engines.ecapa_embedder import ECAPAEmbedder
15 from src.engines.hybrid_engine import HybridEngine
16
17
18 def cmd_generate_data(args):
19 generate_dataset(
20 output_dir=args.output,
21 num_songs=args.num_songs,
22 song_duration=args.song_duration,
23 num_segments_per_song=args.num_segments,
24 segment_duration=args.segment_duration,
25 seed=args.seed,
26 )
27 print(f"[done] dataset generated at {args.output}")
28
29
30 def build_chroma_index(data_dir: Path, output_dir: Path):
31 matcher = ChromaprintMatcher()
32 matcher.index_songs_from_dir(
33 songs_dir=str(data_dir / 'songs'),
34 metadata_path=str(data_dir / 'train.json'),
35 cache_path=str(output_dir / 'chromaprint.pkl'),
36 )
37 print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}")
38 return matcher
39
40
41 def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path, device: str):
42 embedder = ECAPAEmbedder(model_path=str(model_path), device=device)
43 ref_embs, ref_ids = embedder.build_reference_index(
44 songs_dir=str(data_dir / 'songs'),
45 metadata_path=str(data_dir / 'train.json'),
46 output_path=str(output_prefix),
47 )
48 print(f"[done] embedding index built: {len(ref_ids)} refs")
49 return embedder, ref_embs, ref_ids
50
51
52 def cmd_build_index(args):
53 data_dir = Path(args.data)
54 out_dir = Path(args.output)
55 out_dir.mkdir(parents=True, exist_ok=True)
56
57 build_chroma_index(data_dir, out_dir)
58 build_embedding_index(data_dir, Path(args.model), out_dir / 'reference', args.device)
59
60
61 def load_index(prefix: Path):
62 ref_embs = np.load(f"{prefix}_embs.npy")
63 ref_ids = np.load(f"{prefix}_ids.npy", allow_pickle=True).tolist()
64 return ref_embs, ref_ids
65
66
67 def cmd_recognize(args):
68 data_dir = Path(args.data)
69 matcher = ChromaprintMatcher()
70 matcher.load(str(Path(args.index_prefix).parent / 'chromaprint.pkl'))
71 embedder = ECAPAEmbedder(model_path=args.model, device=args.device)
72 ref_embs, ref_ids = load_index(Path(args.index_prefix))
73
74 engine = HybridEngine(
75 chroma_matcher=matcher,
76 ecapa_embedder=embedder,
77 ref_embs=ref_embs,
78 ref_ids=ref_ids,
79 )
80 for split in ['train.json', 'val.json', 'test.json']:
81 p = data_dir / split
82 if p.exists():
83 engine.load_metadata(str(p))
84
85 result = engine.recognize(args.query, top_n=args.top_n)
86 print(json.dumps(result, ensure_ascii=False, indent=2))
87
88
89 def cmd_full_demo(args):
90 data_dir = Path(args.data)
91 model_dir = Path(args.model_dir)
92 index_dir = Path(args.index_dir)
93
94 if not data_dir.exists() or not (data_dir / 'train.json').exists():
95 generate_dataset(
96 output_dir=str(data_dir),
97 num_songs=args.num_songs,
98 song_duration=args.song_duration,
99 num_segments_per_song=args.num_segments,
100 segment_duration=args.segment_duration,
101 seed=args.seed,
102 )
103 print(f"[done] dataset generated at {data_dir}")
104
105 model_path = model_dir / 'best_model.pt'
106 if not model_path.exists():
107 raise SystemExit(
108 'full-demo requires a trained model at data/models/best_model.pt. '\
109 'Run train.py first or provide one.'
110 )
111
112 index_dir.mkdir(parents=True, exist_ok=True)
113 matcher = build_chroma_index(data_dir, index_dir)
114 embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device)
115
116 query = sorted((data_dir / 'test.json').read_text() and [] )
117 with open(data_dir / 'test.json') as f:
118 test_meta = json.load(f)
119 query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0])
120 query_path = data_dir / query_item['audio_path']
121
122 engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
123 for split in ['train.json', 'val.json', 'test.json']:
124 engine.load_metadata(str(data_dir / split))
125 result = engine.recognize(str(query_path), top_n=5)
126 print('[demo-query]', query_item['song_id'], query_item['audio_path'])
127 print(json.dumps(result, ensure_ascii=False, indent=2))
128
129
130 if __name__ == '__main__':
131 parser = argparse.ArgumentParser(description='ACR demo utilities')
132 sub = parser.add_subparsers(dest='cmd', required=True)
133
134 p = sub.add_parser('generate-data')
135 p.add_argument('--output', default='data/synthetic')
136 p.add_argument('--num-songs', type=int, default=24)
137 p.add_argument('--song-duration', type=float, default=20.0)
138 p.add_argument('--num-segments', type=int, default=4)
139 p.add_argument('--segment-duration', type=float, default=5.0)
140 p.add_argument('--seed', type=int, default=42)
141 p.set_defaults(func=cmd_generate_data)
142
143 p = sub.add_parser('build-index')
144 p.add_argument('--data', default='data/synthetic')
145 p.add_argument('--model', required=True)
146 p.add_argument('--output', default='data/index')
147 p.add_argument('--device', default='cpu')
148 p.set_defaults(func=cmd_build_index)
149
150 p = sub.add_parser('recognize')
151 p.add_argument('--query', required=True)
152 p.add_argument('--data', default='data/synthetic')
153 p.add_argument('--model', required=True)
154 p.add_argument('--index-prefix', default='data/index/reference')
155 p.add_argument('--top-n', type=int, default=5)
156 p.add_argument('--device', default='cpu')
157 p.set_defaults(func=cmd_recognize)
158
159 p = sub.add_parser('full-demo')
160 p.add_argument('--data', default='data/synthetic')
161 p.add_argument('--model-dir', default='data/models')
162 p.add_argument('--index-dir', default='data/index')
163 p.add_argument('--num-songs', type=int, default=24)
164 p.add_argument('--song-duration', type=float, default=20.0)
165 p.add_argument('--num-segments', type=int, default=4)
166 p.add_argument('--segment-duration', type=float, default=5.0)
167 p.add_argument('--seed', type=int, default=42)
168 p.add_argument('--device', default='cpu')
169 p.set_defaults(func=cmd_full_demo)
170
171 args = parser.parse_args()
172 args.func(args)
1 import torch
2 from torch.utils.data import Dataset
3 import numpy as np
4 import librosa
5 import random
6 from pathlib import Path
7 from typing import Dict, List, Tuple
8 import json
9 import os
10
11
12 class ACRDataset(Dataset):
13 def __init__(
14 self,
15 data_dir: str,
16 split: str = "train",
17 sr: int = 16000,
18 n_mels: int = 80,
19 n_fft: int = 512,
20 hop_length: int = 160,
21 segment_dur: float = 5.0,
22 augment: bool = True,
23 n_crops_per_song: int = 4,
24 ):
25 self.sr = sr
26 self.n_mels = n_mels
27 self.n_fft = n_fft
28 self.hop_length = hop_length
29 self.segment_len = int(segment_dur * sr)
30 self.augment = augment
31 self.n_crops = n_crops_per_song
32 self.data_dir = Path(data_dir)
33
34 meta_path = Path(data_dir) / f"{split}.json"
35 with open(meta_path) as f:
36 self.metadata = json.load(f)
37
38 self.samples = []
39 for item in self.metadata:
40 song_path = Path(data_dir) / item["audio_path"]
41 if song_path.exists():
42 self.samples.append(item)
43 self.song_ids = sorted(set(s["song_id"] for s in self.samples))
44 self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}
45
46 def __len__(self):
47 return len(self.samples) * self.n_crops
48
49 def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray:
50 y, _ = librosa.load(
51 path, sr=self.sr, mono=True,
52 offset=offset, duration=duration
53 )
54 if len(y) < self.segment_len:
55 y = np.pad(y, (0, self.segment_len - len(y)))
56 else:
57 y = y[:self.segment_len]
58 return y
59
60 def _to_mel(self, y: np.ndarray) -> np.ndarray:
61 mel = librosa.feature.melspectrogram(
62 y=y, sr=self.sr, n_mels=self.n_mels,
63 n_fft=self.n_fft, hop_length=self.hop_length
64 )
65 return librosa.power_to_db(mel, ref=np.max)
66
67 def __getitem__(self, idx):
68 sample = self.samples[idx // self.n_crops]
69 duration = sample["duration"]
70 max_offset = max(0, duration - 5.0)
71 offset = random.uniform(0, max_offset) if max_offset > 0 else 0
72
73 audio_path = self.data_dir / sample["audio_path"]
74 y = self._load_segment(str(audio_path), offset, 5.0)
75
76 if self.augment:
77 from src.utils.augment import AugmentPipeline
78 aug = AugmentPipeline(self.sr)
79 y = aug(y)
80
81 mel = self._to_mel(y)
82 mel_tensor = torch.FloatTensor(mel)
83
84 song_id = sample["song_id"]
85 class_id = self.song_to_idx[song_id]
86
87 return {
88 "mel": mel_tensor,
89 "song_id": torch.tensor(class_id, dtype=torch.long),
90 "song_name": song_id,
91 }
92
93
94 class ACRTestDataset(Dataset):
95 def __init__(
96 self,
97 data_dir: str,
98 split: str = "test",
99 sr: int = 16000,
100 n_mels: int = 80,
101 n_fft: int = 512,
102 hop_length: int = 160,
103 ):
104 self.sr = sr
105 self.n_mels = n_mels
106 self.n_fft = n_fft
107 self.hop_length = hop_length
108 self.data_dir = Path(data_dir)
109
110 meta_path = Path(data_dir) / f"{split}.json"
111 with open(meta_path) as f:
112 self.metadata = json.load(f)
113
114 self.samples = []
115 for item in self.metadata:
116 p = Path(data_dir) / item["audio_path"]
117 if p.exists():
118 self.samples.append(item)
119
120 self.song_ids = sorted(set(s["song_id"] for s in self.samples))
121 self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}
122
123 def __len__(self):
124 return len(self.samples)
125
126 def __getitem__(self, idx):
127 sample = self.samples[idx]
128 audio_path = self.data_dir / sample["audio_path"]
129 y, _ = librosa.load(
130 str(audio_path), sr=self.sr, mono=True,
131 offset=0, duration=min(sample["duration"], 5.0)
132 )
133 seg_len = 5 * self.sr
134 if len(y) < seg_len:
135 y = np.pad(y, (0, seg_len - len(y)))
136 else:
137 y = y[:seg_len]
138
139 mel = librosa.power_to_db(
140 librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels,
141 n_fft=self.n_fft, hop_length=self.hop_length),
142 ref=np.max
143 )
144 class_id = self.song_to_idx[sample["song_id"]]
145 return {
146 "mel": torch.FloatTensor(mel),
147 "song_id": torch.tensor(class_id, dtype=torch.long),
148 "song_name": sample["song_id"],
149 }
1 """
2 Synthetic audio dataset generator for ACR demo.
3
4 Generates melodies from fundamental frequencies, simulates:
5 - Different "songs" (unique note sequences at different base frequencies)
6 - Song fragments (random crops from songs)
7 - Humming variants (pitch shifted, time stretched versions)
8
9 This allows the full pipeline to be validated without external data.
10 """
11
12 import numpy as np
13 import soundfile as sf
14 import json
15 import random
16 import os
17 from pathlib import Path
18 from typing import List, Tuple
19 from tqdm import tqdm
20
21
22 _SR = 16000
23
24
25 def sine_wave(freq: float, duration: float, sr: int = _SR, amp: float = 0.5) -> np.ndarray:
26 t = np.linspace(0, duration, int(sr * duration), endpoint=False)
27 return amp * np.sin(2 * np.pi * freq * t)
28
29
30 def harmonic_tone(freq: float, duration: float, sr: int = _SR, n_harmonics: int = 4) -> np.ndarray:
31 t = np.linspace(0, duration, int(sr * duration), endpoint=False)
32 y = np.zeros_like(t)
33 for h in range(1, n_harmonics + 1):
34 amp = 0.5 / h
35 y += amp * np.sin(2 * np.pi * freq * h * t)
36 return y / np.max(np.abs(y)) * 0.5
37
38
39 def generate_melody(
40 base_freq: float,
41 note_count: int = 16,
42 note_dur: float = 0.5,
43 sr: int = _SR,
44 timbre: str = "harmonic",
45 ) -> np.ndarray:
46 notes = []
47 freq = base_freq
48 for i in range(note_count):
49 interval = random.choice([0, 2, 4, 5, 7, 9, 11, 12]) # diatonic intervals
50 freq = base_freq * (2 ** (interval / 12))
51 dur = note_dur * random.uniform(0.8, 1.2)
52
53 if timbre == "sine":
54 note = sine_wave(freq, dur, sr)
55 else:
56 note = harmonic_tone(freq, dur, sr)
57
58 if random.random() < 0.15:
59 fade = np.linspace(0, 1, min(int(sr * 0.02), len(note)))
60 note[:len(fade)] *= fade
61
62 notes.append(note)
63
64 return np.concatenate(notes)
65
66
67 _CHORD_PROGRESSIONS = [
68 [0, 3, 7], # Cm
69 [0, 4, 7], # C
70 [0, 3, 7, 10], # Cm7
71 [0, 4, 7, 11], # Cmaj7
72 [0, 4, 9], # Csus4 → C
73 [0, 5, 7], # Csus2
74 ]
75
76
77 def generate_song(
78 song_id: str,
79 base_freq: float,
80 duration: float = 30.0,
81 sr: int = _SR,
82 with_vocals: bool = True,
83 ) -> Tuple[np.ndarray, float]:
84 segments_per_sec = 2
85 total_segments = int(duration * segments_per_sec)
86 y = np.zeros(int(sr * duration))
87
88 for i in range(total_segments):
89 t_start = i / segments_per_sec
90 t_end = (i + 1) / segments_per_sec
91 start_sample = int(t_start * sr)
92 end_sample = int(t_end * sr)
93 seg_len = end_sample - start_sample
94
95 chord = random.choice(_CHORD_PROGRESSIONS)
96 for interval in chord:
97 freq = base_freq * (2 ** (interval / 12))
98 env = np.exp(-np.linspace(0, 3, seg_len))
99 note = harmonic_tone(freq, seg_len / sr, sr) * env * 0.3
100 min_len = min(seg_len, len(note))
101 y[start_sample:start_sample + min_len] += note[:min_len]
102
103 if with_vocals:
104 melody = generate_melody(base_freq * 2, note_count=int(duration * 2), note_dur=0.5, sr=sr)
105 min_len = min(len(y), len(melody))
106 y[:min_len] += melody[:min_len] * 0.2
107
108 peak = np.max(np.abs(y))
109 if peak > 0:
110 y = y / peak * 0.5
111
112 return y, duration
113
114
115 def generate_dataset(
116 output_dir: str,
117 num_songs: int = 50,
118 song_duration: float = 30.0,
119 num_segments_per_song: int = 6,
120 segment_duration: float = 5.0,
121 sr: int = _SR,
122 seed: int = 42,
123 ):
124 random.seed(seed)
125 np.random.seed(seed)
126
127 output_dir = Path(output_dir)
128 songs_dir = output_dir / "songs"
129 segs_dir = output_dir / "segments"
130 songs_dir.mkdir(parents=True, exist_ok=True)
131 segs_dir.mkdir(parents=True, exist_ok=True)
132
133 base_freqs = [130.81, 146.83, 164.81, 174.61, 196.0, 220.0, 246.94,
134 261.63, 293.66, 329.63, 349.23, 392.0, 440.0, 493.88,
135 523.25, 587.33, 659.25, 698.46, 783.99, 880.0, 987.77]
136
137 train_meta = []
138 val_meta = []
139 test_meta = []
140
141 print(f"Generating {num_songs} synthetic songs...")
142 for i in tqdm(range(num_songs)):
143 song_id = f"song_{i:04d}"
144 base_freq = base_freqs[i % len(base_freqs)]
145 key_offset = (i // len(base_freqs)) * 2
146 base_freq *= (2 ** (key_offset / 12))
147
148 y, dur = generate_song(song_id, base_freq, duration=song_duration, sr=sr)
149 song_path = songs_dir / f"{song_id}.wav"
150 sf.write(str(song_path), y, sr)
151
152 for j in range(num_segments_per_song):
153 max_offset = max(0, dur - segment_duration)
154 offset = random.uniform(0, max_offset)
155 start_s = int(offset * sr)
156 end_s = start_s + int(segment_duration * sr)
157 seg = y[start_s:end_s]
158
159 if len(seg) < int(segment_duration * sr):
160 seg = np.pad(seg, (0, int(segment_duration * sr) - len(seg)))
161
162 is_augmented = (j >= num_segments_per_song // 2)
163
164 if is_augmented:
165 from src.utils.augment import AugmentPipeline
166 aug = AugmentPipeline(sr)
167 seg_aug = aug(seg.copy())
168 seg_name = f"{song_id}_seg_{j:02d}_aug.wav"
169 seg_path = segs_dir / seg_name
170 sf.write(str(seg_path), seg_aug, sr)
171 meta_entry = {
172 "song_id": song_id,
173 "audio_path": f"segments/{seg_name}",
174 "duration": segment_duration,
175 "type": "augmented",
176 "offset": offset,
177 }
178 else:
179 seg_name = f"{song_id}_seg_{j:02d}.wav"
180 seg_path = segs_dir / seg_name
181 sf.write(str(seg_path), seg, sr)
182 meta_entry = {
183 "song_id": song_id,
184 "audio_path": f"segments/{seg_name}",
185 "duration": segment_duration,
186 "type": "clean",
187 "offset": offset,
188 }
189
190 offset_sec = offset
191 if offset_sec < dur * 0.2:
192 seg_type = "intro"
193 elif offset_sec > dur * 0.7:
194 seg_type = "outro"
195 else:
196 seg_type = "mid"
197 meta_entry["segment_type"] = seg_type
198
199 if i < int(num_songs * 0.7):
200 train_meta.append(meta_entry)
201 elif i < int(num_songs * 0.85):
202 val_meta.append(meta_entry)
203 else:
204 test_meta.append(meta_entry)
205
206 song_meta = {
207 "song_id": song_id,
208 "audio_path": f"songs/{song_id}.wav",
209 "duration": dur,
210 "base_freq": base_freq,
211 }
212 if i < int(num_songs * 0.7):
213 train_meta.append(song_meta)
214 elif i < int(num_songs * 0.85):
215 val_meta.append(song_meta)
216 else:
217 test_meta.append(song_meta)
218
219 for name, data in [("train", train_meta), ("val", val_meta), ("test", test_meta)]:
220 with open(output_dir / f"{name}.json", "w") as f:
221 json.dump(data, f, indent=2)
222 print(f" {name}: {len(data)} entries")
223
224 print(f"\nDataset generated at {output_dir}")
225 print(f" Songs: {num_songs}")
226 print(f" Total segments: {len(train_meta) + len(val_meta) + len(test_meta)}")
227 return output_dir
228
229
230 if __name__ == "__main__":
231 import argparse
232 parser = argparse.ArgumentParser()
233 parser.add_argument("--output", type=str, default="data/synthetic")
234 parser.add_argument("--num-songs", type=int, default=50)
235 parser.add_argument("--song-duration", type=float, default=30.0)
236 parser.add_argument("--segments-per-song", type=int, default=6)
237 parser.add_argument("--segment-duration", type=float, default=5.0)
238 args = parser.parse_args()
239
240 generate_dataset(
241 output_dir=args.output,
242 num_songs=args.num_songs,
243 song_duration=args.song_duration,
244 num_segments_per_song=args.segments_per_song,
245 segment_duration=args.segment_duration,
246 )
1 """
2 Simplified Chromaprint-style fingerprint matcher.
3
4 Implements landmark-based audio fingerprinting:
5 1. Extract spectral peaks from spectrogram
6 2. Build hash table from peak pairs
7 3. Match queries via hash lookup + time offset histogram voting
8 """
9
10 import numpy as np
11 import librosa
12 from collections import defaultdict
13 from typing import Dict, List, Tuple, Optional
14 import pickle
15 import json
16 from pathlib import Path
17
18
19 class Fingerprint:
20 def __init__(self, song_id: str, offset: int, hash_val: int):
21 self.song_id = song_id
22 self.offset = offset
23 self.hash = hash_val
24
25
26 class ChromaprintMatcher:
27 def __init__(
28 self,
29 sr: int = 16000,
30 n_fft: int = 1024,
31 hop_length: int = 256,
32 peak_neighborhood: int = 20,
33 target_zone_width: int = 50,
34 min_peak_energy: float = 0.01,
35 ):
36 self.sr = sr
37 self.n_fft = n_fft
38 self.hop_length = hop_length
39 self.peak_neighborhood = peak_neighborhood
40 self.target_zone_width = target_zone_width
41 self.min_peak_energy = min_peak_energy
42 self.hash_db: Dict[int, List[Fingerprint]] = defaultdict(list)
43
44 def _spectrogram(self, y: np.ndarray) -> np.ndarray:
45 S = np.abs(librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length))
46 return S
47
48 def _find_peaks(self, S: np.ndarray) -> List[Tuple[int, int, float]]:
49 peaks = []
50 for t in range(0, S.shape[1] - self.peak_neighborhood):
51 for f in range(0, S.shape[0] - self.peak_neighborhood):
52 region = S[f:f + self.peak_neighborhood, t:t + self.peak_neighborhood]
53 center = S[f, t]
54 if center == np.max(region) and center > self.min_peak_energy:
55 peaks.append((t, f, center))
56 peaks.sort(key=lambda x: x[2], reverse=True)
57 return peaks[:200]
58
59 def _hash_peaks(self, peaks: List[Tuple[int, int, float]]) -> List[Tuple[int, int, int]]:
60 hashes = []
61 for i in range(len(peaks)):
62 for j in range(i + 1, len(peaks)):
63 t1, f1, _ = peaks[i]
64 t2, f2, _ = peaks[j]
65 if 0 < t2 - t1 < self.target_zone_width:
66 h = (f1 << 16) | (f2 << 8) | (t2 - t1)
67 hashes.append((h, t1))
68 return hashes
69
70 def index_song(self, song_id: str, y: np.ndarray):
71 S = self._spectrogram(y)
72 peaks = self._find_peaks(S)
73 hashes = self._hash_peaks(peaks)
74 for h, offset in hashes:
75 self.hash_db[h].append(Fingerprint(song_id, offset, h))
76
77 def index_songs_from_dir(
78 self, songs_dir: str, metadata_path: str, cache_path: Optional[str] = None
79 ):
80 with open(metadata_path) as f:
81 meta = json.load(f)
82
83 songs_dir = Path(songs_dir)
84 for item in meta:
85 if "songs" not in item.get("audio_path", ""):
86 continue
87 audio_path = songs_dir.parent / item["audio_path"]
88 if not audio_path.exists():
89 continue
90 song_id = item["song_id"]
91 y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
92 self.index_song(song_id, y)
93
94 if cache_path:
95 self.save(cache_path)
96
97 def match(self, y: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
98 S = self._spectrogram(y)
99 peaks = self._find_peaks(S)
100 hashes = self._hash_peaks(peaks)
101
102 song_votes: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int))
103 for h, q_offset in hashes:
104 for fp in self.hash_db.get(h, []):
105 delta = fp.offset - q_offset
106 song_votes[fp.song_id][delta] += 1
107
108 results = []
109 for song_id, deltas in song_votes.items():
110 peak_score = max(deltas.values())
111 total_score = sum(deltas.values())
112 combined = peak_score * 1.0 + total_score * 0.1
113 results.append((song_id, combined))
114
115 results.sort(key=lambda x: x[1], reverse=True)
116 return results[:top_k]
117
118 def save(self, path: str):
119 data = {}
120 for h, fps in self.hash_db.items():
121 data[h] = [(fp.song_id, fp.offset) for fp in fps]
122 with open(path, "wb") as f:
123 pickle.dump(data, f)
124
125 def load(self, path: str):
126 with open(path, "rb") as f:
127 data = pickle.load(f)
128 self.hash_db.clear()
129 for h, items in data.items():
130 self.hash_db[h] = [Fingerprint(sid, off, h) for sid, off in items]
131
132 @property
133 def index_size(self) -> int:
134 return sum(len(v) for v in self.hash_db.values())
135
136 @property
137 def num_hashes(self) -> int:
138 return len(self.hash_db)
1 import torch
2 import torch.nn.functional as F
3 import numpy as np
4 import librosa
5 from pathlib import Path
6 from typing import List, Optional, Tuple
7 import json
8
9
10 class ECAPAEmbedder:
11 def __init__(
12 self,
13 model_path: str,
14 device: str = "cpu",
15 sr: int = 16000,
16 n_mels: int = 80,
17 n_fft: int = 512,
18 hop_length: int = 160,
19 ):
20 self.device = torch.device(device)
21 self.sr = sr
22 self.n_mels = n_mels
23 self.n_fft = n_fft
24 self.hop_length = hop_length
25
26 from src.models.ecapa_tdnn import ECAPA_ACR
27 self.model = ECAPA_ACR(n_mels=n_mels, embed_dim=192)
28 state = torch.load(model_path, map_location="cpu", weights_only=True)
29 if "model_state_dict" in state:
30 state = state["model_state_dict"]
31 self.model.load_state_dict(state, strict=False)
32 self.model.to(self.device)
33 self.model.eval()
34
35 def _load_audio(self, path: str) -> np.ndarray:
36 y, _ = librosa.load(path, sr=self.sr, mono=True)
37 return y
38
39 def _to_mel(self, y: np.ndarray) -> torch.Tensor:
40 mel = librosa.feature.melspectrogram(
41 y=y, sr=self.sr, n_mels=self.n_mels,
42 n_fft=self.n_fft, hop_length=self.hop_length
43 )
44 mel = librosa.power_to_db(mel, ref=np.max)
45 return torch.FloatTensor(mel).unsqueeze(0)
46
47 def extract_embedding(self, audio_path: str) -> np.ndarray:
48 y = self._load_audio(audio_path)
49 mel = self._to_mel(y).to(self.device)
50 with torch.no_grad():
51 emb, _ = self.model(mel)
52 return emb.cpu().numpy().flatten()
53
54 def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
55 if len(y) < self.sr:
56 y = np.pad(y, (0, self.sr - len(y)))
57 mel = self._to_mel(y[:self.sr * 5]).to(self.device)
58 with torch.no_grad():
59 emb, _ = self.model(mel)
60 return emb.cpu().numpy().flatten()
61
62 def build_reference_index(
63 self,
64 songs_dir: str,
65 metadata_path: str,
66 output_path: str,
67 window_sec: float = 5.0,
68 stride_sec: float = 2.5,
69 ) -> Tuple[np.ndarray, List[str]]:
70 with open(metadata_path) as f:
71 meta = json.load(f)
72
73 all_embs = []
74 all_ids = []
75 songs_dir = Path(songs_dir)
76
77 for item in meta:
78 if "songs/" not in item.get("audio_path", ""):
79 continue
80 audio_path = songs_dir.parent / item["audio_path"]
81 if not audio_path.exists():
82 continue
83 song_id = item["song_id"]
84 y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
85
86 win_len = int(window_sec * self.sr)
87 stride = int(stride_sec * self.sr)
88
89 window_embs = []
90 for start in range(0, len(y) - win_len + 1, stride):
91 seg = y[start:start + win_len]
92 mel = self._to_mel(seg).to(self.device)
93 with torch.no_grad():
94 emb, _ = self.model(mel)
95 window_embs.append(emb.cpu().numpy().flatten())
96
97 if window_embs:
98 song_emb = np.mean(window_embs, axis=0)
99 all_embs.append(song_emb)
100 all_ids.append(song_id)
101
102 all_embs = np.vstack(all_embs)
103 np.save(f"{output_path}_embs.npy", all_embs)
104 np.save(f"{output_path}_ids.npy", np.array(all_ids))
105 print(f"Built reference index: {len(all_ids)} songs, embeddings shape {all_embs.shape}")
106 return all_embs, all_ids
107
108 def search(
109 self,
110 query_emb: np.ndarray,
111 ref_embs: np.ndarray,
112 ref_ids: List[str],
113 top_k: int = 10,
114 ) -> List[Tuple[str, float]]:
115 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
116 ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
117 scores = query_norm @ ref_norm.T
118 top_indices = np.argsort(-scores)[:top_k]
119 return [(ref_ids[i], float(scores[i])) for i in top_indices]
1 """
2 Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking.
3 """
4
5 import numpy as np
6 import librosa
7 from typing import List, Tuple, Optional, Dict
8 from pathlib import Path
9 import json
10 import time
11
12
13 class Candidate:
14 def __init__(self, song_id: str, chroma_score: float = 0.0, ecapa_score: float = 0.0):
15 self.song_id = song_id
16 self.chroma_score = chroma_score
17 self.ecapa_score = ecapa_score
18 self.metadata: Dict = {}
19
20 @property
21 def combined_score(self) -> float:
22 return 0.3 * self.chroma_score + 0.7 * self.ecapa_score
23
24 def __repr__(self):
25 return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})"
26
27
28 class HybridEngine:
29 def __init__(
30 self,
31 chroma_matcher=None,
32 ecapa_embedder=None,
33 ref_embs: Optional[np.ndarray] = None,
34 ref_ids: Optional[List[str]] = None,
35 sr: int = 16000,
36 chroma_weight: float = 0.3,
37 ecapa_weight: float = 0.7,
38 reject_threshold: float = 0.4,
39 ):
40 self.chroma = chroma_matcher
41 self.ecapa = ecapa_embedder
42 self.ref_embs = ref_embs
43 self.ref_ids = ref_ids
44 self.sr = sr
45 self.chroma_weight = chroma_weight
46 self.ecapa_weight = ecapa_weight
47 self.reject_threshold = reject_threshold
48
49 self.song_metadata: Dict[str, Dict] = {}
50
51 def load_metadata(self, metadata_path: str):
52 with open(metadata_path) as f:
53 items = json.load(f)
54 for item in items:
55 sid = item["song_id"]
56 if sid not in self.song_metadata:
57 base = item.get("base_freq", 0)
58 self.song_metadata[sid] = {
59 "song_id": sid,
60 "base_freq": base,
61 "audio_path": item.get("audio_path", ""),
62 }
63
64 def recognize(
65 self,
66 audio_path: str,
67 top_n: int = 5,
68 mode: str = "auto",
69 ) -> List[Dict]:
70 start = time.time()
71 y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
72
73 chroma_candidates: List[Candidate] = []
74 if self.chroma is not None:
75 chroma_matches = self.chroma.match(y, top_k=50)
76 seen = set()
77 for song_id, score in chroma_matches:
78 if song_id not in seen:
79 seen.add(song_id)
80 c = Candidate(song_id, chroma_score=score)
81 chroma_candidates.append(c)
82
83 ecapa_candidates: List[Candidate] = []
84 if self.ecapa is not None and self.ref_embs is not None:
85 query_emb = self.ecapa.extract_embedding_from_wave(y)
86 ref_norm = self.ref_embs / (
87 np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12
88 )
89 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
90 scores = query_norm @ ref_norm.T
91 top_indices = np.argsort(-scores)[:top_n]
92 for idx in top_indices:
93 c = Candidate(self.ref_ids[idx], ecapa_score=float(scores[idx]))
94 ecapa_candidates.append(c)
95
96 combined: Dict[str, Candidate] = {}
97 for c in chroma_candidates:
98 combined[c.song_id] = c
99 for c in ecapa_candidates:
100 if c.song_id in combined:
101 combined[c.song_id].ecapa_score = c.ecapa_score
102 else:
103 combined[c.song_id] = c
104
105 for sid in list(combined.keys()):
106 combined[sid].metadata = self.song_metadata.get(sid, {})
107
108 results = sorted(
109 combined.values(),
110 key=lambda c: c.combined_score,
111 reverse=True,
112 )[:top_n]
113
114 elapsed = (time.time() - start) * 1000
115
116 output = []
117 for c in results:
118 output.append({
119 "song_id": c.song_id,
120 "confidence": round(c.combined_score, 4),
121 "chromaprint_score": round(c.chroma_score, 4),
122 "ecapa_score": round(c.ecapa_score, 4),
123 "metadata": c.metadata,
124 })
125
126 return {
127 "candidates": output,
128 "processing_time_ms": round(elapsed, 1),
129 "num_candidates": len(results),
130 }
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from typing import Optional, Tuple
5
6
7 class SEModule(nn.Module):
8 def __init__(self, channels, se_channels=128):
9 super().__init__()
10 self.se = nn.Sequential(
11 nn.Conv1d(channels, se_channels, kernel_size=1),
12 nn.ReLU(),
13 nn.BatchNorm1d(se_channels),
14 nn.Conv1d(se_channels, channels, kernel_size=1),
15 nn.Sigmoid(),
16 )
17
18 def forward(self, x):
19 return x * self.se(x)
20
21
22 class Res2Block(nn.Module):
23 def __init__(self, channels, kernel_size=3, dilation=1, scale=8, se_channels=128):
24 super().__init__()
25 self.width = channels // scale
26 self.num_split = scale
27 self.convs = nn.ModuleList()
28 for i in range(self.num_split):
29 self.convs.append(
30 nn.Sequential(
31 nn.Conv1d(
32 self.width,
33 self.width,
34 kernel_size,
35 padding=dilation * (kernel_size - 1) // 2,
36 dilation=dilation,
37 ),
38 nn.ReLU(),
39 nn.BatchNorm1d(self.width),
40 )
41 )
42 self.conv1x1 = nn.Sequential(
43 nn.Conv1d(channels, channels, kernel_size=1),
44 nn.ReLU(),
45 nn.BatchNorm1d(channels),
46 )
47 self.se = SEModule(channels, se_channels)
48
49 def forward(self, x):
50 residual = x
51 split_x = torch.split(x, self.width, dim=1)
52 out = []
53 for i, (part, conv) in enumerate(zip(split_x, self.convs)):
54 if i == 0:
55 out.append(conv(part))
56 else:
57 out.append(conv(out[-1] if len(out) else part + part))
58 x = torch.cat(out, dim=1)
59 x = self.conv1x1(x)
60 x = self.se(x)
61 return x + residual
62
63
64 class StatisticsPooling(nn.Module):
65 def forward(self, x):
66 mean = torch.mean(x, dim=2)
67 std = torch.sqrt(torch.var(x, dim=2, unbiased=False) + 1e-12)
68 return torch.cat([mean, std], dim=1)
69
70
71 class AAMSoftmax(nn.Module):
72 def __init__(self, in_features, out_features, m=0.3, s=30.0):
73 super().__init__()
74 self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
75 nn.init.xavier_normal_(self.weight)
76 self.m = m
77 self.s = s
78 self.cos_m = torch.cos(torch.tensor(m))
79 self.sin_m = torch.sin(torch.tensor(m))
80 self.th = torch.cos(torch.tensor(torch.pi - m))
81 self.mm = torch.sin(torch.tensor(torch.pi - m)) * m
82
83 def forward(self, x, labels):
84 x = F.normalize(x, dim=1)
85 w = F.normalize(self.weight, dim=1)
86 cos_theta = F.linear(x, w)
87 sin_theta = torch.sqrt(1.0 - torch.clamp(cos_theta ** 2, 0, 1))
88 phi = cos_theta * self.cos_m - sin_theta * self.sin_m
89 phi = torch.where(cos_theta > self.th, phi, cos_theta - self.mm)
90 one_hot = F.one_hot(labels, num_classes=self.weight.size(0)).float()
91 output = (one_hot * phi) + ((1.0 - one_hot) * cos_theta)
92 output *= self.s
93 return output
94
95
96 class ECAPA_ACR(nn.Module):
97 def __init__(
98 self,
99 n_mels: int = 80,
100 embed_dim: int = 192,
101 channels: int = 512,
102 se_channels: int = 128,
103 res2net_scale: int = 8,
104 num_blocks: int = 3,
105 num_classes: Optional[int] = None,
106 aam_m: float = 0.3,
107 aam_s: float = 30.0,
108 ):
109 super().__init__()
110 self.embed_dim = embed_dim
111
112 self.conv1 = nn.Sequential(
113 nn.Conv1d(n_mels, channels, kernel_size=5, stride=1, padding=2),
114 nn.ReLU(),
115 nn.BatchNorm1d(channels),
116 )
117
118 dilations = [1, 2, 3] if num_blocks == 3 else [d * 1 for d in range(1, num_blocks + 1)]
119 self.blocks = nn.ModuleList()
120 for d in dilations[:num_blocks]:
121 self.blocks.append(
122 Res2Block(
123 channels=channels,
124 kernel_size=3,
125 dilation=d,
126 scale=res2net_scale,
127 se_channels=se_channels,
128 )
129 )
130
131 in_channels = channels * num_blocks
132 self.mfa = nn.Sequential(
133 nn.Conv1d(in_channels, channels * 3, kernel_size=1),
134 nn.ReLU(),
135 nn.BatchNorm1d(channels * 3),
136 )
137
138 self.pooling = StatisticsPooling()
139 self.fc = nn.Linear(channels * 3 * 2, embed_dim)
140 self.bn = nn.BatchNorm1d(embed_dim, affine=False)
141
142 if num_classes is not None:
143 self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s)
144 else:
145 self.aam = None
146
147 def forward(
148 self, mel: torch.Tensor, labels: Optional[torch.Tensor] = None
149 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
150 x = self.conv1(mel)
151 block_outputs = []
152 for block in self.blocks:
153 x = block(x)
154 block_outputs.append(x)
155
156 x = torch.cat(block_outputs, dim=1)
157 x = self.mfa(x)
158 x = self.pooling(x)
159 x = self.fc(x)
160 x = self.bn(x)
161 embedding = F.normalize(x, p=2, dim=1)
162
163 if labels is not None and self.aam is not None:
164 logits = self.aam(embedding, labels)
165 return embedding, logits
166
167 return embedding, None
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4
5
6 class SupConLoss(nn.Module):
7 def __init__(self, temperature: float = 0.07):
8 super().__init__()
9 self.temperature = temperature
10
11 def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
12 batch_size = features.shape[0]
13 labels = labels.contiguous().view(-1, 1)
14 mask = torch.eq(labels, labels.T).float().to(features.device)
15 mask = mask - torch.eye(batch_size, device=features.device)
16
17 features = F.normalize(features, dim=1)
18 sim = torch.matmul(features, features.T) / self.temperature
19 sim_max, _ = torch.max(sim, dim=1, keepdim=True)
20 sim = sim - sim_max.detach()
21
22 exp_sim = torch.exp(sim) * (1 - torch.eye(batch_size, device=features.device))
23 log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))
24
25 pos_mask = mask
26 pos_count = pos_mask.sum(dim=1)
27 loss = -(log_prob * pos_mask).sum(dim=1)
28 loss = loss / pos_count.clamp(min=1)
29 return loss.mean()
30
31
32 class CombinedLoss(nn.Module):
33 def __init__(
34 self,
35 temperature: float = 0.07,
36 supcon_weight: float = 1.0,
37 aam_weight: float = 0.3,
38 ):
39 super().__init__()
40 self.supcon = SupConLoss(temperature)
41 self.ce = nn.CrossEntropyLoss()
42 self.supcon_weight = supcon_weight
43 self.aam_weight = aam_weight
44
45 def forward(
46 self,
47 embedding: torch.Tensor,
48 logits: torch.Tensor,
49 labels: torch.Tensor,
50 supcon_labels: torch.Tensor,
51 ) -> dict:
52 loss_supcon = self.supcon(embedding, supcon_labels)
53 loss_ce = self.ce(logits, labels)
54
55 total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce
56 return {
57 "loss": total,
58 "supcon_loss": loss_supcon.item(),
59 "ce_loss": loss_ce.item(),
60 }
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 import numpy as np
5 import librosa
6 from typing import List, Optional, Tuple
7
8
9 class AudioProcessor:
10 def __init__(self, sr: int = 16000, n_mels: int = 80, n_fft: int = 512, hop_length: int = 160):
11 self.sr = sr
12 self.n_mels = n_mels
13 self.n_fft = n_fft
14 self.hop_length = hop_length
15
16 def load(self, path: str, sr: Optional[int] = None, duration: Optional[float] = None) -> np.ndarray:
17 y, _ = librosa.load(path, sr=sr or self.sr, mono=True, duration=duration)
18 return y
19
20 def to_mel(self, y: np.ndarray) -> np.ndarray:
21 mel = librosa.feature.melspectrogram(
22 y=y, sr=self.sr, n_mels=self.n_mels,
23 n_fft=self.n_fft, hop_length=self.hop_length
24 )
25 return librosa.power_to_db(mel, ref=np.max)
26
27 def to_mel_tensor(self, y: np.ndarray) -> torch.Tensor:
28 mel = self.to_mel(y)
29 return torch.FloatTensor(mel).unsqueeze(0)
30
31 def sliding_windows(self, y: np.ndarray, window_sec: float = 5.0, stride_sec: float = 2.5) -> List[np.ndarray]:
32 win_len = int(window_sec * self.sr)
33 stride = int(stride_sec * self.sr)
34 if len(y) < win_len:
35 pad = win_len - len(y)
36 y = np.pad(y, (0, pad))
37 windows = []
38 for start in range(0, len(y) - win_len + 1, stride):
39 windows.append(y[start:start + win_len])
40 if not windows:
41 windows.append(y[:win_len])
42 return windows
43
44 def mel_from_path(self, path: str) -> Tuple[torch.Tensor, float]:
45 y = self.load(path)
46 duration = len(y) / self.sr
47 return self.to_mel_tensor(y), duration
48
49 def extract_chroma(self, y: np.ndarray) -> np.ndarray:
50 chroma = librosa.feature.chroma_cqt(y=y, sr=self.sr)
51 return chroma
52
53 def extract_f0(self, y: np.ndarray, fmin=65, fmax=2093) -> np.ndarray:
54 f0, _, _ = librosa.pyin(y, sr=self.sr, fmin=fmin, fmax=fmax)
55 f0 = np.nan_to_num(f0, nan=0.0)
56 return f0
1 import numpy as np
2 import random
3 from typing import Optional, Tuple
4
5
6 class AugmentPipeline:
7 def __init__(self, sr: int = 16000):
8 self.sr = sr
9 self.noise_snr_range = (5, 30)
10 self.pitch_shift_range = (-6, 6)
11 self.time_stretch_range = (0.85, 1.15)
12 self.mp3_bitrate_range = (32, 128)
13
14 def add_noise(self, y: np.ndarray, snr_db: Optional[float] = None) -> np.ndarray:
15 if snr_db is None:
16 snr_db = random.uniform(*self.noise_snr_range)
17 signal_power = np.mean(y ** 2)
18 noise_power = signal_power / (10 ** (snr_db / 10))
19 noise = np.random.randn(len(y)) * np.sqrt(noise_power)
20 return y + noise
21
22 def pitch_shift(self, y: np.ndarray, semitones: Optional[float] = None) -> np.ndarray:
23 if semitones is None:
24 semitones = random.uniform(*self.pitch_shift_range)
25 return librosa_shift(y, sr=self.sr, n_steps=semitones)
26
27 def time_stretch(self, y: np.ndarray, rate: Optional[float] = None) -> np.ndarray:
28 if rate is None:
29 rate = random.uniform(*self.time_stretch_range)
30 return librosa_ts(y, sr=self.sr, rate=rate)
31
32 def add_reverb(self, y: np.ndarray, decay: float = 0.3) -> np.ndarray:
33 ir_len = int(0.1 * self.sr)
34 ir = np.exp(-np.arange(ir_len) * decay / ir_len) * np.random.randn(ir_len)
35 ir /= np.sqrt(np.sum(ir ** 2))
36 return np.convolve(y, ir, mode='same')[:len(y)]
37
38 def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 8) -> np.ndarray:
39 mel = mel.copy()
40 t = mel.shape[1]
41 f = mel.shape[0]
42 for _ in range(2):
43 t_mask = random.randint(0, max_time_mask)
44 t_start = random.randint(0, max(0, t - t_mask))
45 if t_start < t:
46 mel[:, t_start:t_start + t_mask] = 0
47 for _ in range(2):
48 f_mask = random.randint(0, max_freq_mask)
49 f_start = random.randint(0, max(0, f - f_mask))
50 if f_start < f:
51 mel[f_start:f_start + f_mask, :] = 0
52 return mel
53
54 def apply_to_mel(self, mel: np.ndarray) -> np.ndarray:
55 if random.random() < 0.3:
56 mel = self.apply_spec_augment(mel)
57 return mel
58
59 def __call__(self, y: np.ndarray) -> np.ndarray:
60 if random.random() < 0.5:
61 y = self.add_noise(y)
62 if random.random() < 0.3:
63 y = self.time_stretch(y)
64 if random.random() < 0.3:
65 y = self.pitch_shift(y)
66 if random.random() < 0.2:
67 y = self.add_reverb(y)
68 return y
69
70
71 def librosa_shift(y, sr=16000, n_steps=0):
72 return librosa_impl(y, lambda: __import__('librosa').effects.pitch_shift(y, sr=sr, n_steps=n_steps))
73
74
75 def librosa_ts(y, sr=16000, rate=1.0):
76 return librosa_impl(y, lambda: __import__('librosa').effects.time_stretch(y, rate=rate))
77
78
79 def librosa_impl(y, fn):
80 try:
81 return fn()
82 except Exception:
83 return y
1 #!/usr/bin/env python3
2 """
3 ACR Engine - Training script.
4 """
5
6 import os
7 import sys
8 import json
9 import yaml
10 import time
11 import argparse
12 from pathlib import Path
13
14 import torch
15 import torch.nn as nn
16 from torch.utils.data import DataLoader
17 from tqdm import tqdm
18 import numpy as np
19
20 project_root = Path(__file__).parent
21 sys.path.insert(0, str(project_root))
22
23 from src.models.ecapa_tdnn import ECAPA_ACR
24 from src.models.losses import CombinedLoss
25 from src.data.dataset import ACRDataset, ACRTestDataset
26
27
28 def collate_fn(batch):
29 mels = [b["mel"] for b in batch]
30 song_ids = [b["song_id"] for b in batch]
31 song_names = [b["song_name"] for b in batch]
32
33 max_t = max(m.shape[1] for m in mels)
34 mels_padded = []
35 for m in mels:
36 pad = max_t - m.shape[1]
37 if pad > 0:
38 m = torch.nn.functional.pad(m, (0, pad))
39 mels_padded.append(m.unsqueeze(0))
40
41 return {
42 "mel": torch.cat(mels_padded, dim=0),
43 "song_id": torch.stack(song_ids),
44 "song_name": song_names,
45 }
46
47
48 def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg):
49 model.train()
50 total_loss = 0
51 total_supcon = 0
52 total_ce = 0
53 correct = 0
54 total = 0
55 steps = 0
56
57 pbar = tqdm(loader, desc=f"Epoch {epoch}")
58 for batch in pbar:
59 mel = batch["mel"].to(device)
60 labels = batch["song_id"].to(device)
61
62 with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"):
63 embedding, logits = model(mel, labels)
64 loss_dict = criterion(embedding, logits, labels, labels)
65
66 optimizer.zero_grad()
67 if scaler:
68 scaler.scale(loss_dict["loss"]).backward()
69 scaler.unscale_(optimizer)
70 torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
71 scaler.step(optimizer)
72 scaler.update()
73 else:
74 loss_dict["loss"].backward()
75 torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
76 optimizer.step()
77
78 total_loss += loss_dict["loss"].item()
79 total_supcon += loss_dict["supcon_loss"]
80 total_ce += loss_dict["ce_loss"]
81
82 if logits is not None:
83 preds = logits.argmax(dim=1)
84 correct += (preds == labels).sum().item()
85 total += labels.size(0)
86 steps += 1
87
88 pbar.set_postfix({
89 "loss": f"{loss_dict['loss']:.4f}",
90 "acc": f"{correct/total:.3f}",
91 })
92
93 return {
94 "loss": total_loss / steps,
95 "supcon_loss": total_supcon / steps,
96 "ce_loss": total_ce / steps,
97 "accuracy": correct / total,
98 }
99
100
101 def validate(model, loader, criterion, device):
102 model.eval()
103 total_loss = 0
104 correct = 0
105 total = 0
106
107 with torch.no_grad():
108 for batch in tqdm(loader, desc="Validating"):
109 mel = batch["mel"].to(device)
110 labels = batch["song_id"].to(device)
111
112 embedding, logits = model(mel, labels)
113 loss_dict = criterion(embedding, logits, labels, labels)
114
115 total_loss += loss_dict["loss"].item()
116 if logits is not None:
117 preds = logits.argmax(dim=1)
118 correct += (preds == labels).sum().item()
119 total += labels.size(0)
120
121 acc = correct / total if total > 0 else 0
122 print(f" Validation: loss={total_loss:.4f}, accuracy={acc:.4f}")
123 return {"loss": total_loss, "accuracy": acc}
124
125
126 def main():
127 parser = argparse.ArgumentParser()
128 parser.add_argument("--config", type=str, default="configs/default.yaml")
129 parser.add_argument("--data", type=str, default="data/synthetic")
130 parser.add_argument("--output", type=str, default="data/models")
131 parser.add_argument("--resume", type=str, default=None)
132 parser.add_argument("--device", type=str, default="auto")
133 parser.add_argument("--epochs", type=int, default=None)
134 parser.add_argument("--batch-size", type=int, default=None)
135 parser.add_argument("--lr", type=float, default=None)
136 parser.add_argument("--dry-run", action="store_true", help="Run one batch to verify pipeline")
137 args = parser.parse_args()
138
139 with open(args.config) as f:
140 cfg = yaml.safe_load(f)
141
142 if args.epochs:
143 cfg["training"]["epochs"] = args.epochs
144 if args.batch_size:
145 cfg["training"]["batch_size"] = args.batch_size
146 if args.lr:
147 cfg["training"]["lr"] = args.lr
148
149 device_name = args.device
150 if device_name == "auto":
151 device_name = "cuda" if torch.cuda.is_available() else "cpu"
152 device = torch.device(device_name)
153 print(f"Device: {device}")
154
155 print("Loading datasets...")
156 train_dataset = ACRDataset(
157 args.data, split="train",
158 sr=cfg["data"]["sample_rate"],
159 n_mels=cfg["model"]["n_mels"],
160 n_fft=cfg["data"]["n_fft"],
161 hop_length=cfg["data"]["hop_length"],
162 segment_dur=cfg["data"]["segment_dur"],
163 augment=True,
164 n_crops_per_song=cfg["data"]["crop_per_song"],
165 )
166 val_dataset = ACRDataset(
167 args.data, split="val",
168 sr=cfg["data"]["sample_rate"],
169 n_mels=cfg["model"]["n_mels"],
170 n_fft=cfg["data"]["n_fft"],
171 hop_length=cfg["data"]["hop_length"],
172 segment_dur=cfg["data"]["segment_dur"],
173 augment=False,
174 n_crops_per_song=1,
175 )
176
177 train_loader = DataLoader(
178 train_dataset,
179 batch_size=cfg["training"]["batch_size"],
180 shuffle=True,
181 num_workers=2,
182 collate_fn=collate_fn,
183 drop_last=True,
184 )
185 val_loader = DataLoader(
186 val_dataset,
187 batch_size=cfg["training"]["batch_size"],
188 shuffle=False,
189 num_workers=2,
190 collate_fn=collate_fn,
191 )
192
193 num_classes = len(train_dataset.song_ids)
194 print(f"Classes: {num_classes}")
195 print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
196
197 model = ECAPA_ACR(
198 n_mels=cfg["model"]["n_mels"],
199 embed_dim=cfg["model"]["embed_dim"],
200 channels=cfg["model"]["channels"],
201 se_channels=cfg["model"]["se_channels"],
202 res2net_scale=cfg["model"]["res2net_scale"],
203 num_blocks=cfg["model"]["num_blocks"],
204 num_classes=num_classes,
205 aam_m=cfg["model"]["aam_m"],
206 aam_s=cfg["model"]["aam_s"],
207 ).to(device)
208
209 criterion = CombinedLoss(
210 temperature=cfg["training"]["temperature"],
211 supcon_weight=cfg["training"]["supcon_weight"],
212 aam_weight=cfg["training"]["aam_weight"],
213 )
214 optimizer = torch.optim.AdamW(
215 model.parameters(),
216 lr=cfg["training"]["lr"],
217 weight_decay=cfg["training"]["weight_decay"],
218 )
219
220 scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
221
222 start_epoch = 1
223 if args.resume:
224 ckpt = torch.load(args.resume, map_location=device, weights_only=True)
225 model.load_state_dict(ckpt["model_state_dict"])
226 optimizer.load_state_dict(ckpt["optimizer_state_dict"])
227 start_epoch = ckpt["epoch"] + 1
228 print(f"Resumed from epoch {ckpt['epoch']}")
229
230 if args.dry_run:
231 print("Dry run: running one batch through forward/backward...")
232 batch = next(iter(train_loader))
233 mel = batch["mel"].to(device)
234 labels = batch["song_id"].to(device)
235 embedding, logits = model(mel, labels)
236 loss_dict = criterion(embedding, logits, labels, labels)
237 loss_dict["loss"].backward()
238 print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}")
239 print(f" Embedding shape: {embedding.shape}")
240 print("Dry run passed! Pipeline is working.")
241 return
242
243 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
244 optimizer, T_max=cfg["training"]["epochs"]
245 )
246
247 best_acc = float("-inf")
248 output_dir = Path(args.output)
249 output_dir.mkdir(parents=True, exist_ok=True)
250
251 print("Starting training...")
252 for epoch in range(start_epoch, cfg["training"]["epochs"] + 1):
253 train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch, cfg)
254 val_metrics = validate(model, val_loader, criterion, device)
255 scheduler.step()
256
257 lr = optimizer.param_groups[0]["lr"]
258 print(f" LR: {lr:.2e}")
259
260 if epoch % cfg["training"]["save_every"] == 0 or val_metrics["accuracy"] > best_acc:
261 if val_metrics["accuracy"] > best_acc:
262 best_acc = val_metrics["accuracy"]
263 path = output_dir / "best_model.pt"
264 else:
265 path = output_dir / f"checkpoint_epoch_{epoch}.pt"
266
267 torch.save({
268 "epoch": epoch,
269 "model_state_dict": model.state_dict(),
270 "optimizer_state_dict": optimizer.state_dict(),
271 "best_acc": best_acc,
272 "config": cfg,
273 }, path)
274 print(f" Saved: {path}")
275
276 print(f"\nTraining complete. Best validation accuracy: {best_acc:.4f}")
277 print(f"Model saved to: {output_dir / 'best_model.pt'}")
278
279
280 if __name__ == "__main__":
281 main()
1 # Changelog
2
3 ## 2026-06-02
4
5 ### Stage: 文档补全 + ACR 最小可运行链路
6
7 完成项:
8 - 补充项目职责图:`docs/project-responsibility-map.md`
9 - 补充系统架构图:`docs/acr-architecture.md`
10 - 补充阶段路线图:`docs/roadmap.md`
11 - 补充运行手册:`docs/runbook.md`
12 - 补充引擎说明:`acr-engine/README.md`
13 - 新增依赖清单:`acr-engine/requirements.txt`
14 - 新增 demo CLI:`acr-engine/run_demo.py`
15 - 修复数据集读取路径问题:`acr-engine/src/data/dataset.py`
16 - 修复首次训练不落 best checkpoint 的问题:`acr-engine/train.py`
17
18 验证结果:
19 - 已生成 synthetic dataset
20 - 已通过 `train.py --dry-run`
21 - 已完成 1 epoch CPU 训练并生成 `best_model.pt`
22 - 已完成指纹索引与 embedding 索引构建
23 - 已完成识别命令并输出 JSON 候选结果
1 # ACR 项目架构图
2
3 > 更新:2026-06-02
4
5 ## 1. 总体架构
6
7 ```mermaid
8 flowchart LR
9 Q[Query 音频] --> P[预处理]
10 P --> F1[传统指纹特征]
11 P --> F2[Mel 特征]
12
13 F1 --> M1[Chromaprint Matcher]
14 F2 --> M2[ECAPA Embedder]
15
16 R[Reference 曲库] --> I1[指纹索引]
17 R --> I2[Embedding 索引]
18
19 M1 --> C[候选集合]
20 M2 --> C
21 C --> H[Hybrid 重排序]
22 H --> O[Top-K 识别结果]
23 ```
24
25 ## 2. 训练架构
26
27 ```mermaid
28 flowchart TD
29 A[原始/合成音频] --> B[随机裁剪]
30 B --> C[增强: 噪声/变速/移调/混响]
31 C --> D[Mel Spectrogram]
32 D --> E[ECAPA-TDNN]
33 E --> F[Embedding]
34 F --> G[SupCon Loss]
35 F --> H[AAM Softmax]
36 G --> I[联合优化]
37 H --> I
38 ```
39
40 ## 3. 推理架构
41
42 ```mermaid
43 sequenceDiagram
44 participant U as User Query
45 participant P as Preprocessor
46 participant C as Chroma Matcher
47 participant E as ECAPA Embedder
48 participant H as Hybrid Engine
49
50 U->>P: 输入音频
51 P->>C: 指纹特征
52 P->>E: Mel 特征
53 C-->>H: Top-N 指纹候选
54 E-->>H: Top-N embedding 候选
55 H-->>U: 融合后的识别结果
56 ```
57
58 ## 4. 当前可运行闭环
59
60 1.`synthetic.py` 生成合成曲库
61 2.`train.py` 训练 ECAPA 原型模型
62 3.`run_demo.py build-index` 构建:
63 - 指纹索引
64 - embedding 索引
65 4.`run_demo.py recognize` 对片段做识别
66
67 ## 5. 后续生产化架构建议
68
69 - API Gateway
70 - 异步音频入库流水线
71 - Faiss/HNSW 向量服务
72 - Postgres/MySQL 元数据服务
73 - 对象存储保存原始音频
74 - 模型服务与索引服务解耦
1 # ACR 项目职责图
2
3 > 更新:2026-06-02
4
5 ## 1. 项目定位
6
7 本项目是一个**听歌识曲 / 音频内容识别(ACR)原型系统**,目标是先跑通:
8
9 - 数据生成
10 - 特征提取
11 - 模型训练
12 - 指纹检索
13 - embedding 检索
14 - hybrid 混合识别
15
16 当前不以生产服务为目标,重点是**算法链路验证**
17
18 ## 2. 仓库职责分层
19
20 ```text
21 /workspace
22 ├── acr-engine/ # ACR 核心算法与可运行 demo
23 │ ├── configs/ # 训练/推理参数配置
24 │ ├── src/data/ # 数据集读取、合成数据生成
25 │ ├── src/models/ # 声学模型、损失函数
26 │ ├── src/engines/ # 指纹/embedding/hybrid 检索引擎
27 │ ├── train.py # 模型训练入口
28 │ ├── run_demo.py # 数据生成、建索引、识别入口
29 │ └── requirements.txt # Python 依赖
30 ├── docs/ # 设计、架构、路线图、使用说明
31 ├── scripts/ # 环境安装与工具 bootstrap
32 ├── container/ # 容器环境定义
33 └── .codex/.omx/ # Codex / OMX 协作与运行时元数据
34 ```
35
36 ## 3. 模块职责图
37
38 ```mermaid
39 flowchart TD
40 A[音频输入] --> B[数据层]
41 B --> B1[合成数据生成 synthetic.py]
42 B --> B2[训练/验证数据集 dataset.py]
43
44 A --> C[特征层]
45 C --> C1[Mel Spectrogram]
46 C --> C2[Chroma / F0]
47 C --> C3[增强 augment.py]
48
49 C --> D[模型层]
50 D --> D1[ECAPA-TDNN]
51 D --> D2[SupCon + AAM Loss]
52
53 A --> E[检索层]
54 E --> E1[ChromaprintMatcher]
55 E --> E2[ECAPAEmbedder]
56 E --> E3[HybridEngine]
57
58 D --> F[训练入口 train.py]
59 E --> G[推理入口 run_demo.py]
60 ```
61
62 ## 4. 角色职责
63
64 | 模块 | 职责 | 当前状态 |
65 |---|---|---|
66 | `src/data/synthetic.py` | 生成可控的合成歌曲与片段 | 已实现 |
67 | `src/data/dataset.py` | 训练/验证数据装载 | 已实现 |
68 | `src/utils/audio.py` | Mel、滑窗、F0、Chroma | 已实现 |
69 | `src/utils/augment.py` | 噪声、变速、移调、混响增强 | 已实现 |
70 | `src/models/ecapa_tdnn.py` | embedding 编码器 | 已实现 |
71 | `src/models/losses.py` | 对比学习 + 分类训练目标 | 已实现 |
72 | `src/engines/chromaprint_matcher.py` | 传统哈希指纹检索 | 已实现 |
73 | `src/engines/ecapa_embedder.py` | embedding 提取与向量检索 | 已实现 |
74 | `src/engines/hybrid_engine.py` | 融合匹配结果 | 已实现 |
75 | `train.py` | 训练入口 | 已实现 |
76 | `run_demo.py` | demo 入口 | 本次补齐 |
77
78 ## 5. 当前边界
79
80 当前项目**负责**
81
82 - 原型级算法验证
83 - 小规模曲库识别
84 - 本地训练与本地识别 demo
85
86 当前项目**暂不负责**
87
88 - 在线 API 服务
89 - 海量曲库 ANN 线上部署
90 - 权限、账号、计费
91 - 真正版权音频数据治理
92 - 生产监控告警
1 # ACR 项目 Roadmap
2
3 > 更新:2026-06-02
4
5 ## Phase 0:原型跑通(当前阶段)
6
7 ### 目标
8 完成一个端到端可运行的本地 demo。
9
10 ### 范围
11 - [x] 合成数据生成
12 - [x] 数据增强
13 - [x] ECAPA embedding 模型
14 - [x] 传统指纹匹配器
15 - [x] HybridEngine
16 - [x] 最小训练入口
17 - [x] 最小识别入口
18 - [x] 文档补全
19
20 ### 验收标准
21 - 能生成数据
22 - 能训练至少 1 epoch
23 - 能建立 reference 索引
24 - 能对测试片段输出 Top-K 候选
25
26 ---
27
28 ## Phase 1:研究验证
29
30 ### 目标
31 验证不同场景下识别效果是否可接受。
32
33 ### 任务
34 - [ ] 增加 top-1 / top-5 / MRR 评估脚本
35 - [ ] 对 clean / noisy / stretched / pitch-shifted 分开评测
36 - [ ] 增加 query-by-humming 专项评测集
37 - [ ] 加入更稳健的 negative sampling
38 - [ ] 补充 checkpoint / config versioning
39
40 ---
41
42 ## Phase 2:工程化
43
44 ### 目标
45 把原型升级为可复现实验项目。
46
47 ### 任务
48 - [ ] 增加 `Makefile``justfile`
49 - [ ] 增加 `pytest` 基础测试
50 - [ ] 增加日志与指标记录
51 - [ ] 增加模型导出与加载规范
52 - [ ] 增加 CLI 参数校验
53 - [ ] 增加 Docker 运行方式
54
55 ---
56
57 ## Phase 3:产品化 PoC
58
59 ### 目标
60 提供可被业务方调用的服务接口。
61
62 ### 任务
63 - [ ] FastAPI 服务化
64 - [ ] 上传音频并返回候选歌曲
65 - [ ] 曲库增量入库命令
66 - [ ] 元数据管理接口
67 - [ ] 结果缓存与批量检索
68
69 ---
70
71 ## Phase 4:大规模检索
72
73 ### 目标
74 支持百万级以上曲库。
75
76 ### 任务
77 - [ ] 接入 Faiss / HNSW
78 - [ ] embedding 分片与压缩
79 - [ ] 双层召回 + 精排
80 - [ ] 在线索引更新
81 - [ ] 冷热分层存储
82
83 ---
84
85 ## Phase 5:真实业务能力
86
87 ### 目标
88 逼近真实听歌识曲产品。
89
90 ### 任务
91 - [ ] 真实版权音频数据接入
92 - [ ] 哼唱专项模型/旋律塔
93 - [ ] 多模态融合(旋律 + 声纹 + 指纹)
94 - [ ] 在线 A/B 评估
95 - [ ] 监控与质量回流
1 # ACR 项目运行手册
2
3 ## 1. 环境
4
5 ```bash
6 cd acr-engine
7 python -m venv .venv
8 source .venv/bin/activate
9 pip install -r requirements.txt
10 ```
11
12 ## 2. 生成数据
13
14 ```bash
15 python run_demo.py generate-data --output data/synthetic --num-songs 24
16 ```
17
18 ## 3. 校验训练链路
19
20 ```bash
21 python train.py --data data/synthetic --dry-run --device cpu
22 ```
23
24 ## 4. 最小训练
25
26 ```bash
27 python train.py --data data/synthetic --output data/models --device cpu --epochs 1 --batch-size 8
28 ```
29
30 ## 5. 建索引
31
32 ```bash
33 python run_demo.py build-index --data data/synthetic --model data/models/best_model.pt --output data/index --device cpu
34 ```
35
36 ## 6. 跑识别
37
38 ```bash
39 python run_demo.py recognize \
40 --query data/synthetic/segments/song_0020_seg_00.wav \
41 --data data/synthetic \
42 --model data/models/best_model.pt \
43 --index-prefix data/index/reference \
44 --device cpu
45 ```
46
47 ## 7. 成功判定
48
49 至少满足:
50
51 - 能输出 JSON 结果
52 - 返回 `candidates`
53 - 结果中包含 `song_id``confidence`