Commit 62688d3b 62688d3bccd07a3a03a98d6ed698f1980e1e298d by cnb.bofCdSsphPA

period upload

1 parent 44d8268c
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
7 .codex/skills/.system/** 7 .codex/skills/.system/**
8 !.codex/prompts/ 8 !.codex/prompts/
9 !.codex/prompts/** 9 !.codex/prompts/**
10 .venv
......
...@@ -65,3 +65,16 @@ python run_demo.py full-demo --device cpu ...@@ -65,3 +65,16 @@ python run_demo.py full-demo --device cpu
65 ## 当前定位 65 ## 当前定位
66 66
67 这是一个**原型仓库**,目标是验证 ACR 主链路能否跑通,不是生产级服务。 67 这是一个**原型仓库**,目标是验证 ACR 主链路能否跑通,不是生产级服务。
68
69 ## 评测
70
71 ```bash
72 python evaluate.py --data data/synthetic --model data/models/best_model.pt --index-prefix data/index/reference --split test --device cpu
73 ```
74
75 ## 当前提升方向
76
77 - 更强合成混淆样本(confused / humming_like)
78 - Hybrid 分数归一化后再融合
79 - full-demo 自动训练
80 - 后续可接入开源数据集
......
1 #!/usr/bin/env python3
2 import argparse
3 import json
4 from pathlib import Path
5
6 import numpy as np
7
8 from src.engines.chromaprint_matcher import ChromaprintMatcher
9 from src.engines.ecapa_embedder import ECAPAEmbedder
10 from src.engines.hybrid_engine import HybridEngine
11
12
13 def load_items(meta_path: Path):
14 with open(meta_path) as f:
15 return json.load(f)
16
17
18 def main():
19 parser = argparse.ArgumentParser(description="Evaluate ACR recognition quality")
20 parser.add_argument("--data", default="data/synthetic")
21 parser.add_argument("--model", required=True)
22 parser.add_argument("--index-prefix", default="data/index/reference")
23 parser.add_argument("--split", default="test")
24 parser.add_argument("--top-k", type=int, default=5)
25 parser.add_argument("--device", default="cpu")
26 args = parser.parse_args()
27
28 data_dir = Path(args.data)
29 matcher = ChromaprintMatcher()
30 matcher.load(str(Path(args.index_prefix).parent / "chromaprint.pkl"))
31 embedder = ECAPAEmbedder(model_path=args.model, device=args.device)
32 ref_embs = np.load(f"{args.index_prefix}_embs.npy")
33 ref_ids = np.load(f"{args.index_prefix}_ids.npy", allow_pickle=True).tolist()
34
35 engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
36 for split in ["train.json", "val.json", "test.json"]:
37 p = data_dir / split
38 if p.exists():
39 engine.load_metadata(str(p))
40
41 items = load_items(data_dir / f"{args.split}.json")
42 queries = [x for x in items if str(x.get("audio_path", "")).startswith("segments/")]
43 if not queries:
44 raise SystemExit("No segment queries found for evaluation")
45
46 top1 = 0
47 topk = 0
48 by_type = {}
49 failures = []
50
51 for item in queries:
52 result = engine.recognize(str(data_dir / item["audio_path"]), top_n=args.top_k)
53 preds = [c["song_id"] for c in result["candidates"]]
54 truth = item["song_id"]
55 qtype = item.get("type", "unknown")
56 stats = by_type.setdefault(qtype, {"n": 0, "top1": 0, "topk": 0})
57 stats["n"] += 1
58
59 if preds and preds[0] == truth:
60 top1 += 1
61 stats["top1"] += 1
62 if truth in preds:
63 topk += 1
64 stats["topk"] += 1
65 else:
66 failures.append({
67 "truth": truth,
68 "query": item["audio_path"],
69 "type": qtype,
70 "preds": preds,
71 })
72
73 total = len(queries)
74 report = {
75 "split": args.split,
76 "num_queries": total,
77 "top1": round(top1 / total, 4),
78 "topk": round(topk / total, 4),
79 "by_type": {
80 k: {
81 "n": v["n"],
82 "top1": round(v["top1"] / v["n"], 4) if v["n"] else 0.0,
83 "topk": round(v["topk"] / v["n"], 4) if v["n"] else 0.0,
84 }
85 for k, v in by_type.items()
86 },
87 "sample_failures": failures[:10],
88 }
89 print(json.dumps(report, ensure_ascii=False, indent=2))
90
91
92 if __name__ == "__main__":
93 main()
...@@ -31,7 +31,7 @@ def build_chroma_index(data_dir: Path, output_dir: Path): ...@@ -31,7 +31,7 @@ def build_chroma_index(data_dir: Path, output_dir: Path):
31 matcher = ChromaprintMatcher() 31 matcher = ChromaprintMatcher()
32 matcher.index_songs_from_dir( 32 matcher.index_songs_from_dir(
33 songs_dir=str(data_dir / 'songs'), 33 songs_dir=str(data_dir / 'songs'),
34 metadata_path=str(data_dir / 'train.json'), 34 metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'),
35 cache_path=str(output_dir / 'chromaprint.pkl'), 35 cache_path=str(output_dir / 'chromaprint.pkl'),
36 ) 36 )
37 print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}") 37 print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}")
...@@ -42,7 +42,7 @@ def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path, ...@@ -42,7 +42,7 @@ def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path,
42 embedder = ECAPAEmbedder(model_path=str(model_path), device=device) 42 embedder = ECAPAEmbedder(model_path=str(model_path), device=device)
43 ref_embs, ref_ids = embedder.build_reference_index( 43 ref_embs, ref_ids = embedder.build_reference_index(
44 songs_dir=str(data_dir / 'songs'), 44 songs_dir=str(data_dir / 'songs'),
45 metadata_path=str(data_dir / 'train.json'), 45 metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'),
46 output_path=str(output_prefix), 46 output_path=str(output_prefix),
47 ) 47 )
48 print(f"[done] embedding index built: {len(ref_ids)} refs") 48 print(f"[done] embedding index built: {len(ref_ids)} refs")
...@@ -104,16 +104,20 @@ def cmd_full_demo(args): ...@@ -104,16 +104,20 @@ def cmd_full_demo(args):
104 104
105 model_path = model_dir / 'best_model.pt' 105 model_path = model_dir / 'best_model.pt'
106 if not model_path.exists(): 106 if not model_path.exists():
107 raise SystemExit( 107 import subprocess
108 'full-demo requires a trained model at data/models/best_model.pt. '\ 108 model_dir.mkdir(parents=True, exist_ok=True)
109 'Run train.py first or provide one.' 109 cmd = [
110 ) 110 '/usr/local/miniconda3/bin/python', 'train.py',
111 '--data', str(data_dir), '--output', str(model_dir),
112 '--device', args.device, '--epochs', '3', '--batch-size', '8'
113 ]
114 print('[full-demo] training model:', ' '.join(cmd))
115 subprocess.run(cmd, check=True)
111 116
112 index_dir.mkdir(parents=True, exist_ok=True) 117 index_dir.mkdir(parents=True, exist_ok=True)
113 matcher = build_chroma_index(data_dir, index_dir) 118 matcher = build_chroma_index(data_dir, index_dir)
114 embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device) 119 embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device)
115 120
116 query = sorted((data_dir / 'test.json').read_text() and [] )
117 with open(data_dir / 'test.json') as f: 121 with open(data_dir / 'test.json') as f:
118 test_meta = json.load(f) 122 test_meta = json.load(f)
119 query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0]) 123 query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0])
......
1 #!/usr/bin/env python3
2 """Helpers for optional open music dataset integration."""
3
4 import argparse
5 import json
6 from pathlib import Path
7
8 DATASETS = {
9 "fma_small": {
10 "url": "https://github.com/mdeff/fma",
11 "notes": "Use FMA small subset first; convert clips into catalog/query JSON for local experiments.",
12 },
13 "mtg_jamendo": {
14 "url": "https://github.com/MTG/mtg-jamendo-dataset",
15 "notes": "Use upstream download scripts; sample a small subset into catalog/query structure.",
16 },
17 }
18
19
20 def main():
21 parser = argparse.ArgumentParser()
22 parser.add_argument("dataset", choices=sorted(DATASETS))
23 parser.add_argument("--output", default="../docs/open-datasets.json")
24 args = parser.parse_args()
25 out = Path(args.output)
26 out.parent.mkdir(parents=True, exist_ok=True)
27 with open(out, "w") as f:
28 json.dump({args.dataset: DATASETS[args.dataset]}, f, indent=2)
29 print(f"Wrote dataset integration note to {out}")
30
31
32 if __name__ == "__main__":
33 main()
1 import torch 1 import json
2 from torch.utils.data import Dataset
3 import numpy as np
4 import librosa
5 import random 2 import random
6 from pathlib import Path 3 from pathlib import Path
7 from typing import Dict, List, Tuple 4 from typing import Dict, List, Optional
8 import json 5
9 import os 6 import librosa
7 import numpy as np
8 import torch
9 from torch.utils.data import Dataset
10 10
11 11
12 class ACRDataset(Dataset): 12 class ACRDataset(Dataset):
...@@ -21,6 +21,8 @@ class ACRDataset(Dataset): ...@@ -21,6 +21,8 @@ class ACRDataset(Dataset):
21 segment_dur: float = 5.0, 21 segment_dur: float = 5.0,
22 augment: bool = True, 22 augment: bool = True,
23 n_crops_per_song: int = 4, 23 n_crops_per_song: int = 4,
24 song_to_idx: Optional[Dict[str, int]] = None,
25 references_only: bool = False,
24 ): 26 ):
25 self.sr = sr 27 self.sr = sr
26 self.n_mels = n_mels 28 self.n_mels = n_mels
...@@ -31,36 +33,39 @@ class ACRDataset(Dataset): ...@@ -31,36 +33,39 @@ class ACRDataset(Dataset):
31 self.n_crops = n_crops_per_song 33 self.n_crops = n_crops_per_song
32 self.data_dir = Path(data_dir) 34 self.data_dir = Path(data_dir)
33 35
34 meta_path = Path(data_dir) / f"{split}.json" 36 meta_path = self.data_dir / f"{split}.json"
35 with open(meta_path) as f: 37 with open(meta_path) as f:
36 self.metadata = json.load(f) 38 self.metadata = json.load(f)
37 39
38 self.samples = [] 40 self.samples = []
39 for item in self.metadata: 41 for item in self.metadata:
40 song_path = Path(data_dir) / item["audio_path"] 42 if references_only and item.get("type") != "reference":
43 continue
44 song_path = self.data_dir / item["audio_path"]
41 if song_path.exists(): 45 if song_path.exists():
42 self.samples.append(item) 46 self.samples.append(item)
47
43 self.song_ids = sorted(set(s["song_id"] for s in self.samples)) 48 self.song_ids = sorted(set(s["song_id"] for s in self.samples))
44 self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)} 49 self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}
45 50
46 def __len__(self): 51 def __len__(self):
47 return len(self.samples) * self.n_crops 52 return len(self.samples) * self.n_crops
48 53
49 def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray: 54 def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray:
50 y, _ = librosa.load( 55 y, _ = librosa.load(path, sr=self.sr, mono=True, offset=offset, duration=duration)
51 path, sr=self.sr, mono=True,
52 offset=offset, duration=duration
53 )
54 if len(y) < self.segment_len: 56 if len(y) < self.segment_len:
55 y = np.pad(y, (0, self.segment_len - len(y))) 57 y = np.pad(y, (0, self.segment_len - len(y)))
56 else: 58 else:
57 y = y[:self.segment_len] 59 y = y[: self.segment_len]
58 return y 60 return y
59 61
60 def _to_mel(self, y: np.ndarray) -> np.ndarray: 62 def _to_mel(self, y: np.ndarray) -> np.ndarray:
61 mel = librosa.feature.melspectrogram( 63 mel = librosa.feature.melspectrogram(
62 y=y, sr=self.sr, n_mels=self.n_mels, 64 y=y,
63 n_fft=self.n_fft, hop_length=self.hop_length 65 sr=self.sr,
66 n_mels=self.n_mels,
67 n_fft=self.n_fft,
68 hop_length=self.hop_length,
64 ) 69 )
65 return librosa.power_to_db(mel, ref=np.max) 70 return librosa.power_to_db(mel, ref=np.max)
66 71
...@@ -73,7 +78,7 @@ class ACRDataset(Dataset): ...@@ -73,7 +78,7 @@ class ACRDataset(Dataset):
73 audio_path = self.data_dir / sample["audio_path"] 78 audio_path = self.data_dir / sample["audio_path"]
74 y = self._load_segment(str(audio_path), offset, 5.0) 79 y = self._load_segment(str(audio_path), offset, 5.0)
75 80
76 if self.augment: 81 if self.augment and sample.get("type") != "reference":
77 from src.utils.augment import AugmentPipeline 82 from src.utils.augment import AugmentPipeline
78 aug = AugmentPipeline(self.sr) 83 aug = AugmentPipeline(self.sr)
79 y = aug(y) 84 y = aug(y)
...@@ -88,6 +93,7 @@ class ACRDataset(Dataset): ...@@ -88,6 +93,7 @@ class ACRDataset(Dataset):
88 "mel": mel_tensor, 93 "mel": mel_tensor,
89 "song_id": torch.tensor(class_id, dtype=torch.long), 94 "song_id": torch.tensor(class_id, dtype=torch.long),
90 "song_name": song_id, 95 "song_name": song_id,
96 "type": sample.get("type", "unknown"),
91 } 97 }
92 98
93 99
...@@ -100,6 +106,7 @@ class ACRTestDataset(Dataset): ...@@ -100,6 +106,7 @@ class ACRTestDataset(Dataset):
100 n_mels: int = 80, 106 n_mels: int = 80,
101 n_fft: int = 512, 107 n_fft: int = 512,
102 hop_length: int = 160, 108 hop_length: int = 160,
109 song_to_idx: Optional[Dict[str, int]] = None,
103 ): 110 ):
104 self.sr = sr 111 self.sr = sr
105 self.n_mels = n_mels 112 self.n_mels = n_mels
...@@ -107,18 +114,18 @@ class ACRTestDataset(Dataset): ...@@ -107,18 +114,18 @@ class ACRTestDataset(Dataset):
107 self.hop_length = hop_length 114 self.hop_length = hop_length
108 self.data_dir = Path(data_dir) 115 self.data_dir = Path(data_dir)
109 116
110 meta_path = Path(data_dir) / f"{split}.json" 117 meta_path = self.data_dir / f"{split}.json"
111 with open(meta_path) as f: 118 with open(meta_path) as f:
112 self.metadata = json.load(f) 119 self.metadata = json.load(f)
113 120
114 self.samples = [] 121 self.samples = []
115 for item in self.metadata: 122 for item in self.metadata:
116 p = Path(data_dir) / item["audio_path"] 123 p = self.data_dir / item["audio_path"]
117 if p.exists(): 124 if p.exists():
118 self.samples.append(item) 125 self.samples.append(item)
119 126
120 self.song_ids = sorted(set(s["song_id"] for s in self.samples)) 127 self.song_ids = sorted(set(s["song_id"] for s in self.samples))
121 self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)} 128 self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}
122 129
123 def __len__(self): 130 def __len__(self):
124 return len(self.samples) 131 return len(self.samples)
...@@ -126,10 +133,7 @@ class ACRTestDataset(Dataset): ...@@ -126,10 +133,7 @@ class ACRTestDataset(Dataset):
126 def __getitem__(self, idx): 133 def __getitem__(self, idx):
127 sample = self.samples[idx] 134 sample = self.samples[idx]
128 audio_path = self.data_dir / sample["audio_path"] 135 audio_path = self.data_dir / sample["audio_path"]
129 y, _ = librosa.load( 136 y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True, offset=0, duration=min(sample["duration"], 5.0))
130 str(audio_path), sr=self.sr, mono=True,
131 offset=0, duration=min(sample["duration"], 5.0)
132 )
133 seg_len = 5 * self.sr 137 seg_len = 5 * self.sr
134 if len(y) < seg_len: 138 if len(y) < seg_len:
135 y = np.pad(y, (0, seg_len - len(y))) 139 y = np.pad(y, (0, seg_len - len(y)))
...@@ -137,13 +141,100 @@ class ACRTestDataset(Dataset): ...@@ -137,13 +141,100 @@ class ACRTestDataset(Dataset):
137 y = y[:seg_len] 141 y = y[:seg_len]
138 142
139 mel = librosa.power_to_db( 143 mel = librosa.power_to_db(
140 librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels, 144 librosa.feature.melspectrogram(
141 n_fft=self.n_fft, hop_length=self.hop_length), 145 y=y,
142 ref=np.max 146 sr=self.sr,
147 n_mels=self.n_mels,
148 n_fft=self.n_fft,
149 hop_length=self.hop_length,
150 ),
151 ref=np.max,
143 ) 152 )
144 class_id = self.song_to_idx[sample["song_id"]] 153 class_id = self.song_to_idx[sample["song_id"]]
145 return { 154 return {
146 "mel": torch.FloatTensor(mel), 155 "mel": torch.FloatTensor(mel),
147 "song_id": torch.tensor(class_id, dtype=torch.long), 156 "song_id": torch.tensor(class_id, dtype=torch.long),
148 "song_name": sample["song_id"], 157 "song_name": sample["song_id"],
158 "type": sample.get("type", "unknown"),
159 }
160
161
162 class SongPairDataset(Dataset):
163 def __init__(
164 self,
165 data_dir: str,
166 split: str = "train",
167 sr: int = 16000,
168 n_mels: int = 80,
169 n_fft: int = 512,
170 hop_length: int = 160,
171 segment_dur: float = 5.0,
172 augment: bool = True,
173 ):
174 self.sr = sr
175 self.n_mels = n_mels
176 self.n_fft = n_fft
177 self.hop_length = hop_length
178 self.segment_len = int(segment_dur * sr)
179 self.augment = augment
180 self.data_dir = Path(data_dir)
181
182 with open(self.data_dir / f"{split}.json") as f:
183 metadata = json.load(f)
184
185 self.by_song: Dict[str, List[Dict]] = {}
186 for item in metadata:
187 if item.get("type") == "reference":
188 continue
189 p = self.data_dir / item["audio_path"]
190 if p.exists():
191 self.by_song.setdefault(item["song_id"], []).append(item)
192
193 self.song_ids = sorted(self.by_song)
194 self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}
195
196 def __len__(self):
197 return len(self.song_ids)
198
199 def _load_clip(self, sample: Dict) -> np.ndarray:
200 path = self.data_dir / sample["audio_path"]
201 y, _ = librosa.load(str(path), sr=self.sr, mono=True, duration=5.0)
202 if len(y) < self.segment_len:
203 y = np.pad(y, (0, self.segment_len - len(y)))
204 else:
205 y = y[: self.segment_len]
206 return y
207
208 def _to_mel(self, y: np.ndarray) -> torch.Tensor:
209 mel = librosa.feature.melspectrogram(
210 y=y,
211 sr=self.sr,
212 n_mels=self.n_mels,
213 n_fft=self.n_fft,
214 hop_length=self.hop_length,
215 )
216 mel = librosa.power_to_db(mel, ref=np.max)
217 return torch.FloatTensor(mel)
218
219 def __getitem__(self, idx):
220 song_id = self.song_ids[idx]
221 choices = self.by_song[song_id]
222 if len(choices) == 1:
223 a = b = choices[0]
224 else:
225 a, b = random.sample(choices, 2)
226
227 wavs = []
228 for sample in (a, b):
229 y = self._load_clip(sample)
230 if self.augment:
231 from src.utils.augment import AugmentPipeline
232 y = AugmentPipeline(self.sr)(y)
233 wavs.append(self._to_mel(y))
234
235 label = self.song_to_idx[song_id]
236 return {
237 "mel": torch.stack(wavs, dim=0),
238 "song_id": torch.tensor([label, label], dtype=torch.long),
239 "song_name": song_id,
149 } 240 }
......
...@@ -5,6 +5,7 @@ Generates melodies from fundamental frequencies, simulates: ...@@ -5,6 +5,7 @@ Generates melodies from fundamental frequencies, simulates:
5 - Different "songs" (unique note sequences at different base frequencies) 5 - Different "songs" (unique note sequences at different base frequencies)
6 - Song fragments (random crops from songs) 6 - Song fragments (random crops from songs)
7 - Humming variants (pitch shifted, time stretched versions) 7 - Humming variants (pitch shifted, time stretched versions)
8 - Hard negatives / confusing variants for robustness testing
8 9
9 This allows the full pipeline to be validated without external data. 10 This allows the full pipeline to be validated without external data.
10 """ 11 """
...@@ -13,9 +14,8 @@ import numpy as np ...@@ -13,9 +14,8 @@ import numpy as np
13 import soundfile as sf 14 import soundfile as sf
14 import json 15 import json
15 import random 16 import random
16 import os
17 from pathlib import Path 17 from pathlib import Path
18 from typing import List, Tuple 18 from typing import Tuple
19 from tqdm import tqdm 19 from tqdm import tqdm
20 20
21 21
...@@ -33,7 +33,10 @@ def harmonic_tone(freq: float, duration: float, sr: int = _SR, n_harmonics: int ...@@ -33,7 +33,10 @@ def harmonic_tone(freq: float, duration: float, sr: int = _SR, n_harmonics: int
33 for h in range(1, n_harmonics + 1): 33 for h in range(1, n_harmonics + 1):
34 amp = 0.5 / h 34 amp = 0.5 / h
35 y += amp * np.sin(2 * np.pi * freq * h * t) 35 y += amp * np.sin(2 * np.pi * freq * h * t)
36 return y / np.max(np.abs(y)) * 0.5 36 peak = np.max(np.abs(y))
37 if peak > 0:
38 y = y / peak * 0.5
39 return y
37 40
38 41
39 def generate_melody( 42 def generate_melody(
...@@ -44,9 +47,8 @@ def generate_melody( ...@@ -44,9 +47,8 @@ def generate_melody(
44 timbre: str = "harmonic", 47 timbre: str = "harmonic",
45 ) -> np.ndarray: 48 ) -> np.ndarray:
46 notes = [] 49 notes = []
47 freq = base_freq 50 for _ in range(note_count):
48 for i in range(note_count): 51 interval = random.choice([0, 2, 4, 5, 7, 9, 11, 12])
49 interval = random.choice([0, 2, 4, 5, 7, 9, 11, 12]) # diatonic intervals
50 freq = base_freq * (2 ** (interval / 12)) 52 freq = base_freq * (2 ** (interval / 12))
51 dur = note_dur * random.uniform(0.8, 1.2) 53 dur = note_dur * random.uniform(0.8, 1.2)
52 54
...@@ -57,7 +59,7 @@ def generate_melody( ...@@ -57,7 +59,7 @@ def generate_melody(
57 59
58 if random.random() < 0.15: 60 if random.random() < 0.15:
59 fade = np.linspace(0, 1, min(int(sr * 0.02), len(note))) 61 fade = np.linspace(0, 1, min(int(sr * 0.02), len(note)))
60 note[:len(fade)] *= fade 62 note[: len(fade)] *= fade
61 63
62 notes.append(note) 64 notes.append(note)
63 65
...@@ -65,15 +67,35 @@ def generate_melody( ...@@ -65,15 +67,35 @@ def generate_melody(
65 67
66 68
67 _CHORD_PROGRESSIONS = [ 69 _CHORD_PROGRESSIONS = [
68 [0, 3, 7], # Cm 70 [0, 3, 7],
69 [0, 4, 7], # C 71 [0, 4, 7],
70 [0, 3, 7, 10], # Cm7 72 [0, 3, 7, 10],
71 [0, 4, 7, 11], # Cmaj7 73 [0, 4, 7, 11],
72 [0, 4, 9], # Csus4 → C 74 [0, 4, 9],
73 [0, 5, 7], # Csus2 75 [0, 5, 7],
74 ] 76 ]
75 77
76 78
79 def apply_confusion_mix(y: np.ndarray, sr: int, strength: float = 0.22) -> np.ndarray:
80 t = np.linspace(0, len(y) / sr, len(y), endpoint=False)
81 distractor = 0.0
82 for f in [220.0, 330.0, 440.0]:
83 distractor += np.sin(2 * np.pi * f * t + random.uniform(0, np.pi))
84 distractor /= max(np.max(np.abs(distractor)), 1e-8)
85 mixed = y + strength * distractor
86 peak = np.max(np.abs(mixed))
87 return mixed / peak * 0.5 if peak > 0 else mixed
88
89
90 def apply_humming_style(y: np.ndarray, sr: int) -> np.ndarray:
91 env = np.linspace(0.7, 1.0, len(y))
92 hum = y * env
93 kernel = np.ones(max(5, sr // 400)) / max(5, sr // 400)
94 hum = np.convolve(hum, kernel, mode="same")
95 peak = np.max(np.abs(hum))
96 return hum / peak * 0.5 if peak > 0 else hum
97
98
77 def generate_song( 99 def generate_song(
78 song_id: str, 100 song_id: str,
79 base_freq: float, 101 base_freq: float,
...@@ -98,7 +120,7 @@ def generate_song( ...@@ -98,7 +120,7 @@ def generate_song(
98 env = np.exp(-np.linspace(0, 3, seg_len)) 120 env = np.exp(-np.linspace(0, 3, seg_len))
99 note = harmonic_tone(freq, seg_len / sr, sr) * env * 0.3 121 note = harmonic_tone(freq, seg_len / sr, sr) * env * 0.3
100 min_len = min(seg_len, len(note)) 122 min_len = min(seg_len, len(note))
101 y[start_sample:start_sample + min_len] += note[:min_len] 123 y[start_sample : start_sample + min_len] += note[:min_len]
102 124
103 if with_vocals: 125 if with_vocals:
104 melody = generate_melody(base_freq * 2, note_count=int(duration * 2), note_dur=0.5, sr=sr) 126 melody = generate_melody(base_freq * 2, note_count=int(duration * 2), note_dur=0.5, sr=sr)
...@@ -130,9 +152,11 @@ def generate_dataset( ...@@ -130,9 +152,11 @@ def generate_dataset(
130 songs_dir.mkdir(parents=True, exist_ok=True) 152 songs_dir.mkdir(parents=True, exist_ok=True)
131 segs_dir.mkdir(parents=True, exist_ok=True) 153 segs_dir.mkdir(parents=True, exist_ok=True)
132 154
133 base_freqs = [130.81, 146.83, 164.81, 174.61, 196.0, 220.0, 246.94, 155 base_freqs = [
134 261.63, 293.66, 329.63, 349.23, 392.0, 440.0, 493.88, 156 130.81, 146.83, 164.81, 174.61, 196.0, 220.0, 246.94,
135 523.25, 587.33, 659.25, 698.46, 783.99, 880.0, 987.77] 157 261.63, 293.66, 329.63, 349.23, 392.0, 440.0, 493.88,
158 523.25, 587.33, 659.25, 698.46, 783.99, 880.0, 987.77,
159 ]
136 160
137 train_meta = [] 161 train_meta = []
138 val_meta = [] 162 val_meta = []
...@@ -143,7 +167,7 @@ def generate_dataset( ...@@ -143,7 +167,7 @@ def generate_dataset(
143 song_id = f"song_{i:04d}" 167 song_id = f"song_{i:04d}"
144 base_freq = base_freqs[i % len(base_freqs)] 168 base_freq = base_freqs[i % len(base_freqs)]
145 key_offset = (i // len(base_freqs)) * 2 169 key_offset = (i // len(base_freqs)) * 2
146 base_freq *= (2 ** (key_offset / 12)) 170 base_freq *= 2 ** (key_offset / 12)
147 171
148 y, dur = generate_song(song_id, base_freq, duration=song_duration, sr=sr) 172 y, dur = generate_song(song_id, base_freq, duration=song_duration, sr=sr)
149 song_path = songs_dir / f"{song_id}.wav" 173 song_path = songs_dir / f"{song_id}.wav"
...@@ -155,42 +179,41 @@ def generate_dataset( ...@@ -155,42 +179,41 @@ def generate_dataset(
155 start_s = int(offset * sr) 179 start_s = int(offset * sr)
156 end_s = start_s + int(segment_duration * sr) 180 end_s = start_s + int(segment_duration * sr)
157 seg = y[start_s:end_s] 181 seg = y[start_s:end_s]
182 target_len = int(segment_duration * sr)
158 183
159 if len(seg) < int(segment_duration * sr): 184 if len(seg) < target_len:
160 seg = np.pad(seg, (0, int(segment_duration * sr) - len(seg))) 185 seg = np.pad(seg, (0, target_len - len(seg)))
161
162 is_augmented = (j >= num_segments_per_song // 2)
163 186
164 if is_augmented: 187 variant_type = "clean"
188 out_seg = seg.copy()
189 if j >= num_segments_per_song // 2:
165 from src.utils.augment import AugmentPipeline 190 from src.utils.augment import AugmentPipeline
166 aug = AugmentPipeline(sr) 191 aug = AugmentPipeline(sr)
167 seg_aug = aug(seg.copy()) 192 out_seg = aug(out_seg)
168 seg_name = f"{song_id}_seg_{j:02d}_aug.wav" 193 variant_type = "augmented"
169 seg_path = segs_dir / seg_name 194
170 sf.write(str(seg_path), seg_aug, sr) 195 if j == num_segments_per_song - 1:
171 meta_entry = { 196 out_seg = apply_confusion_mix(out_seg, sr)
172 "song_id": song_id, 197 variant_type = "confused"
173 "audio_path": f"segments/{seg_name}", 198 elif j == num_segments_per_song - 2 and num_segments_per_song >= 4:
174 "duration": segment_duration, 199 out_seg = apply_humming_style(out_seg, sr)
175 "type": "augmented", 200 variant_type = "humming_like"
176 "offset": offset, 201
177 } 202 seg_name = f"{song_id}_seg_{j:02d}_{variant_type}.wav" if variant_type != "clean" else f"{song_id}_seg_{j:02d}.wav"
178 else: 203 seg_path = segs_dir / seg_name
179 seg_name = f"{song_id}_seg_{j:02d}.wav" 204 sf.write(str(seg_path), out_seg, sr)
180 seg_path = segs_dir / seg_name 205
181 sf.write(str(seg_path), seg, sr) 206 meta_entry = {
182 meta_entry = { 207 "song_id": song_id,
183 "song_id": song_id, 208 "audio_path": f"segments/{seg_name}",
184 "audio_path": f"segments/{seg_name}", 209 "duration": segment_duration,
185 "duration": segment_duration, 210 "type": variant_type,
186 "type": "clean", 211 "offset": offset,
187 "offset": offset, 212 }
188 } 213
189 214 if offset < dur * 0.2:
190 offset_sec = offset
191 if offset_sec < dur * 0.2:
192 seg_type = "intro" 215 seg_type = "intro"
193 elif offset_sec > dur * 0.7: 216 elif offset > dur * 0.7:
194 seg_type = "outro" 217 seg_type = "outro"
195 else: 218 else:
196 seg_type = "mid" 219 seg_type = "mid"
...@@ -208,6 +231,7 @@ def generate_dataset( ...@@ -208,6 +231,7 @@ def generate_dataset(
208 "audio_path": f"songs/{song_id}.wav", 231 "audio_path": f"songs/{song_id}.wav",
209 "duration": dur, 232 "duration": dur,
210 "base_freq": base_freq, 233 "base_freq": base_freq,
234 "type": "reference",
211 } 235 }
212 if i < int(num_songs * 0.7): 236 if i < int(num_songs * 0.7):
213 train_meta.append(song_meta) 237 train_meta.append(song_meta)
...@@ -216,6 +240,10 @@ def generate_dataset( ...@@ -216,6 +240,10 @@ def generate_dataset(
216 else: 240 else:
217 test_meta.append(song_meta) 241 test_meta.append(song_meta)
218 242
243 catalog_meta = [item for item in train_meta + val_meta + test_meta if item.get("type") == "reference"]
244 with open(output_dir / "catalog.json", "w") as f:
245 json.dump(catalog_meta, f, indent=2)
246
219 for name, data in [("train", train_meta), ("val", val_meta), ("test", test_meta)]: 247 for name, data in [("train", train_meta), ("val", val_meta), ("test", test_meta)]:
220 with open(output_dir / f"{name}.json", "w") as f: 248 with open(output_dir / f"{name}.json", "w") as f:
221 json.dump(data, f, indent=2) 249 json.dump(data, f, indent=2)
...@@ -229,6 +257,7 @@ def generate_dataset( ...@@ -229,6 +257,7 @@ def generate_dataset(
229 257
230 if __name__ == "__main__": 258 if __name__ == "__main__":
231 import argparse 259 import argparse
260
232 parser = argparse.ArgumentParser() 261 parser = argparse.ArgumentParser()
233 parser.add_argument("--output", type=str, default="data/synthetic") 262 parser.add_argument("--output", type=str, default="data/synthetic")
234 parser.add_argument("--num-songs", type=int, default=50) 263 parser.add_argument("--num-songs", type=int, default=50)
......
1 import torch 1 import json
2 import torch.nn.functional as F
3 import numpy as np
4 import librosa
5 from pathlib import Path 2 from pathlib import Path
6 from typing import List, Optional, Tuple 3 from typing import List, Optional, Tuple
7 import json 4
5 import librosa
6 import numpy as np
7 import torch
8 8
9 9
10 class ECAPAEmbedder: 10 class ECAPAEmbedder:
...@@ -24,11 +24,22 @@ class ECAPAEmbedder: ...@@ -24,11 +24,22 @@ class ECAPAEmbedder:
24 self.hop_length = hop_length 24 self.hop_length = hop_length
25 25
26 from src.models.ecapa_tdnn import ECAPA_ACR 26 from src.models.ecapa_tdnn import ECAPA_ACR
27 self.model = ECAPA_ACR(n_mels=n_mels, embed_dim=192) 27
28 state = torch.load(model_path, map_location="cpu", weights_only=True) 28 state = torch.load(model_path, map_location="cpu", weights_only=True)
29 if "model_state_dict" in state: 29 cfg = state.get("config", {})
30 state = state["model_state_dict"] 30 model_cfg = cfg.get("model", {})
31 self.model.load_state_dict(state, strict=False) 31 self.model = ECAPA_ACR(
32 n_mels=model_cfg.get("n_mels", n_mels),
33 embed_dim=model_cfg.get("embed_dim", 192),
34 channels=model_cfg.get("channels", 512),
35 se_channels=model_cfg.get("se_channels", 128),
36 res2net_scale=model_cfg.get("res2net_scale", 8),
37 num_blocks=model_cfg.get("num_blocks", 3),
38 num_classes=None,
39 )
40 missing = self.model.load_state_dict(state["model_state_dict"], strict=False)
41 if missing.unexpected_keys:
42 print(f"[warn] unexpected keys while loading model: {missing.unexpected_keys}")
32 self.model.to(self.device) 43 self.model.to(self.device)
33 self.model.eval() 44 self.model.eval()
34 45
...@@ -38,26 +49,37 @@ class ECAPAEmbedder: ...@@ -38,26 +49,37 @@ class ECAPAEmbedder:
38 49
39 def _to_mel(self, y: np.ndarray) -> torch.Tensor: 50 def _to_mel(self, y: np.ndarray) -> torch.Tensor:
40 mel = librosa.feature.melspectrogram( 51 mel = librosa.feature.melspectrogram(
41 y=y, sr=self.sr, n_mels=self.n_mels, 52 y=y,
42 n_fft=self.n_fft, hop_length=self.hop_length 53 sr=self.sr,
54 n_mels=self.n_mels,
55 n_fft=self.n_fft,
56 hop_length=self.hop_length,
43 ) 57 )
44 mel = librosa.power_to_db(mel, ref=np.max) 58 mel = librosa.power_to_db(mel, ref=np.max)
45 return torch.FloatTensor(mel).unsqueeze(0) 59 return torch.FloatTensor(mel).unsqueeze(0)
46 60
61 def _windows(self, y: np.ndarray, window_sec: float = 5.0, stride_sec: float = 2.5) -> List[np.ndarray]:
62 win_len = int(window_sec * self.sr)
63 stride = int(stride_sec * self.sr)
64 if len(y) < win_len:
65 y = np.pad(y, (0, win_len - len(y)))
66 windows = []
67 for start in range(0, max(len(y) - win_len + 1, 1), stride):
68 windows.append(y[start : start + win_len])
69 return windows or [y[:win_len]]
70
47 def extract_embedding(self, audio_path: str) -> np.ndarray: 71 def extract_embedding(self, audio_path: str) -> np.ndarray:
48 y = self._load_audio(audio_path) 72 y = self._load_audio(audio_path)
49 mel = self._to_mel(y).to(self.device) 73 return self.extract_embedding_from_wave(y)
50 with torch.no_grad():
51 emb, _ = self.model(mel)
52 return emb.cpu().numpy().flatten()
53 74
54 def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray: 75 def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray:
55 if len(y) < self.sr: 76 window_embs = []
56 y = np.pad(y, (0, self.sr - len(y))) 77 for seg in self._windows(y):
57 mel = self._to_mel(y[:self.sr * 5]).to(self.device) 78 mel = self._to_mel(seg).to(self.device)
58 with torch.no_grad(): 79 with torch.no_grad():
59 emb, _ = self.model(mel) 80 emb, _ = self.model(mel)
60 return emb.cpu().numpy().flatten() 81 window_embs.append(emb.cpu().numpy().flatten())
82 return np.mean(window_embs, axis=0)
61 83
62 def build_reference_index( 84 def build_reference_index(
63 self, 85 self,
...@@ -75,7 +97,7 @@ class ECAPAEmbedder: ...@@ -75,7 +97,7 @@ class ECAPAEmbedder:
75 songs_dir = Path(songs_dir) 97 songs_dir = Path(songs_dir)
76 98
77 for item in meta: 99 for item in meta:
78 if "songs/" not in item.get("audio_path", ""): 100 if item.get("type") != "reference" and "songs/" not in item.get("audio_path", ""):
79 continue 101 continue
80 audio_path = songs_dir.parent / item["audio_path"] 102 audio_path = songs_dir.parent / item["audio_path"]
81 if not audio_path.exists(): 103 if not audio_path.exists():
...@@ -83,35 +105,20 @@ class ECAPAEmbedder: ...@@ -83,35 +105,20 @@ class ECAPAEmbedder:
83 song_id = item["song_id"] 105 song_id = item["song_id"]
84 y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True) 106 y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
85 107
86 win_len = int(window_sec * self.sr) 108 for seg in self._windows(y, window_sec=window_sec, stride_sec=stride_sec):
87 stride = int(stride_sec * self.sr)
88
89 window_embs = []
90 for start in range(0, len(y) - win_len + 1, stride):
91 seg = y[start:start + win_len]
92 mel = self._to_mel(seg).to(self.device) 109 mel = self._to_mel(seg).to(self.device)
93 with torch.no_grad(): 110 with torch.no_grad():
94 emb, _ = self.model(mel) 111 emb, _ = self.model(mel)
95 window_embs.append(emb.cpu().numpy().flatten()) 112 all_embs.append(emb.cpu().numpy().flatten())
96
97 if window_embs:
98 song_emb = np.mean(window_embs, axis=0)
99 all_embs.append(song_emb)
100 all_ids.append(song_id) 113 all_ids.append(song_id)
101 114
102 all_embs = np.vstack(all_embs) 115 all_embs = np.vstack(all_embs)
103 np.save(f"{output_path}_embs.npy", all_embs) 116 np.save(f"{output_path}_embs.npy", all_embs)
104 np.save(f"{output_path}_ids.npy", np.array(all_ids)) 117 np.save(f"{output_path}_ids.npy", np.array(all_ids))
105 print(f"Built reference index: {len(all_ids)} songs, embeddings shape {all_embs.shape}") 118 print(f"Built reference index: {len(all_ids)} windows, embeddings shape {all_embs.shape}")
106 return all_embs, all_ids 119 return all_embs, all_ids
107 120
108 def search( 121 def search(self, query_emb: np.ndarray, ref_embs: np.ndarray, ref_ids: List[str], top_k: int = 10):
109 self,
110 query_emb: np.ndarray,
111 ref_embs: np.ndarray,
112 ref_ids: List[str],
113 top_k: int = 10,
114 ) -> List[Tuple[str, float]]:
115 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) 122 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
116 ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12) 123 ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12)
117 scores = query_norm @ ref_norm.T 124 scores = query_norm @ ref_norm.T
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
2 Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking. 2 Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking.
3 """ 3 """
4 4
5 import numpy as np
6 import librosa
7 from typing import List, Tuple, Optional, Dict
8 from pathlib import Path
9 import json 5 import json
10 import time 6 import time
7 from typing import Dict, List, Optional
8
9 import librosa
10 import numpy as np
11 11
12 12
13 class Candidate: 13 class Candidate:
...@@ -17,9 +17,8 @@ class Candidate: ...@@ -17,9 +17,8 @@ class Candidate:
17 self.ecapa_score = ecapa_score 17 self.ecapa_score = ecapa_score
18 self.metadata: Dict = {} 18 self.metadata: Dict = {}
19 19
20 @property 20 def combined_score(self, chroma_weight: float, ecapa_weight: float) -> float:
21 def combined_score(self) -> float: 21 return chroma_weight * self.chroma_score + ecapa_weight * self.ecapa_score
22 return 0.3 * self.chroma_score + 0.7 * self.ecapa_score
23 22
24 def __repr__(self): 23 def __repr__(self):
25 return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})" 24 return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})"
...@@ -33,9 +32,9 @@ class HybridEngine: ...@@ -33,9 +32,9 @@ class HybridEngine:
33 ref_embs: Optional[np.ndarray] = None, 32 ref_embs: Optional[np.ndarray] = None,
34 ref_ids: Optional[List[str]] = None, 33 ref_ids: Optional[List[str]] = None,
35 sr: int = 16000, 34 sr: int = 16000,
36 chroma_weight: float = 0.3, 35 chroma_weight: float = 0.35,
37 ecapa_weight: float = 0.7, 36 ecapa_weight: float = 0.65,
38 reject_threshold: float = 0.4, 37 reject_threshold: float = 0.35,
39 ): 38 ):
40 self.chroma = chroma_matcher 39 self.chroma = chroma_matcher
41 self.ecapa = ecapa_embedder 40 self.ecapa = ecapa_embedder
...@@ -45,7 +44,6 @@ class HybridEngine: ...@@ -45,7 +44,6 @@ class HybridEngine:
45 self.chroma_weight = chroma_weight 44 self.chroma_weight = chroma_weight
46 self.ecapa_weight = ecapa_weight 45 self.ecapa_weight = ecapa_weight
47 self.reject_threshold = reject_threshold 46 self.reject_threshold = reject_threshold
48
49 self.song_metadata: Dict[str, Dict] = {} 47 self.song_metadata: Dict[str, Dict] = {}
50 48
51 def load_metadata(self, metadata_path: str): 49 def load_metadata(self, metadata_path: str):
...@@ -53,75 +51,83 @@ class HybridEngine: ...@@ -53,75 +51,83 @@ class HybridEngine:
53 items = json.load(f) 51 items = json.load(f)
54 for item in items: 52 for item in items:
55 sid = item["song_id"] 53 sid = item["song_id"]
56 if sid not in self.song_metadata: 54 existing = self.song_metadata.get(sid, {})
57 base = item.get("base_freq", 0) 55 if item.get("type") == "reference" or not existing:
58 self.song_metadata[sid] = { 56 self.song_metadata[sid] = {
59 "song_id": sid, 57 "song_id": sid,
60 "base_freq": base, 58 "base_freq": item.get("base_freq", existing.get("base_freq", 0)),
61 "audio_path": item.get("audio_path", ""), 59 "audio_path": item.get("audio_path", existing.get("audio_path", "")),
60 "type": item.get("type", existing.get("type", "unknown")),
62 } 61 }
63 62
63 @staticmethod
64 def _normalize_scores(score_pairs: List[tuple], invert: bool = False) -> Dict[str, float]:
65 if not score_pairs:
66 return {}
67 ids = [sid for sid, _ in score_pairs]
68 values = np.array([float(score) for _, score in score_pairs], dtype=np.float32)
69 if invert:
70 values = -values
71 if len(values) == 1:
72 return {ids[0]: 1.0}
73 vmin = float(values.min())
74 vmax = float(values.max())
75 if abs(vmax - vmin) < 1e-8:
76 return {sid: 1.0 for sid in ids}
77 norm = (values - vmin) / (vmax - vmin)
78 return {sid: float(score) for sid, score in zip(ids, norm)}
79
64 def recognize( 80 def recognize(
65 self, 81 self,
66 audio_path: str, 82 audio_path: str,
67 top_n: int = 5, 83 top_n: int = 5,
68 mode: str = "auto", 84 mode: str = "auto",
69 ) -> List[Dict]: 85 ) -> Dict:
86 del mode
70 start = time.time() 87 start = time.time()
71 y, _ = librosa.load(audio_path, sr=self.sr, mono=True) 88 y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
72 89
73 chroma_candidates: List[Candidate] = [] 90 chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else []
74 if self.chroma is not None: 91 chroma_norm = self._normalize_scores(chroma_matches)
75 chroma_matches = self.chroma.match(y, top_k=50) 92
76 seen = set() 93 ecapa_matches = []
77 for song_id, score in chroma_matches: 94 if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None:
78 if song_id not in seen:
79 seen.add(song_id)
80 c = Candidate(song_id, chroma_score=score)
81 chroma_candidates.append(c)
82
83 ecapa_candidates: List[Candidate] = []
84 if self.ecapa is not None and self.ref_embs is not None:
85 query_emb = self.ecapa.extract_embedding_from_wave(y) 95 query_emb = self.ecapa.extract_embedding_from_wave(y)
86 ref_norm = self.ref_embs / ( 96 ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12)
87 np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12
88 )
89 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) 97 query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
90 scores = query_norm @ ref_norm.T 98 scores = query_norm @ ref_norm.T
91 top_indices = np.argsort(-scores)[:top_n] 99 top_indices = np.argsort(-scores)[: max(top_n * 5, 20)]
92 for idx in top_indices: 100 ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices]
93 c = Candidate(self.ref_ids[idx], ecapa_score=float(scores[idx])) 101 ecapa_norm = self._normalize_scores(ecapa_matches)
94 ecapa_candidates.append(c) 102
95 103 all_song_ids = set(chroma_norm) | set(ecapa_norm)
96 combined: Dict[str, Candidate] = {} 104 combined: List[Candidate] = []
97 for c in chroma_candidates: 105 for song_id in all_song_ids:
98 combined[c.song_id] = c 106 candidate = Candidate(
99 for c in ecapa_candidates: 107 song_id=song_id,
100 if c.song_id in combined: 108 chroma_score=chroma_norm.get(song_id, 0.0),
101 combined[c.song_id].ecapa_score = c.ecapa_score 109 ecapa_score=ecapa_norm.get(song_id, 0.0),
102 else: 110 )
103 combined[c.song_id] = c 111 candidate.metadata = self.song_metadata.get(song_id, {})
104 112 combined.append(candidate)
105 for sid in list(combined.keys()):
106 combined[sid].metadata = self.song_metadata.get(sid, {})
107
108 results = sorted(
109 combined.values(),
110 key=lambda c: c.combined_score,
111 reverse=True,
112 )[:top_n]
113 113
114 combined.sort(key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight), reverse=True)
115 results = combined[:top_n]
114 elapsed = (time.time() - start) * 1000 116 elapsed = (time.time() - start) * 1000
115 117
116 output = [] 118 output = []
117 for c in results: 119 for c in results:
118 output.append({ 120 fused = c.combined_score(self.chroma_weight, self.ecapa_weight)
119 "song_id": c.song_id, 121 output.append(
120 "confidence": round(c.combined_score, 4), 122 {
121 "chromaprint_score": round(c.chroma_score, 4), 123 "song_id": c.song_id,
122 "ecapa_score": round(c.ecapa_score, 4), 124 "confidence": round(fused, 4),
123 "metadata": c.metadata, 125 "chromaprint_score": round(c.chroma_score, 4),
124 }) 126 "ecapa_score": round(c.ecapa_score, 4),
127 "accepted": fused >= self.reject_threshold,
128 "metadata": c.metadata,
129 }
130 )
125 131
126 return { 132 return {
127 "candidates": output, 133 "candidates": output,
......
1 # Open Dataset Integration Plan
2
3 ## Recommended order
4
5 1. **FMA small**
6 - URL: https://github.com/mdeff/fma
7 - Why: easiest small realistic music subset for retrieval experiments
8 2. **MTG-Jamendo**
9 - URL: https://github.com/MTG/mtg-jamendo-dataset
10 - Why: larger CC-licensed corpus with scriptable upstream tooling
11 3. **QBSH / humming corpora**
12 - Why: add after retrieval baseline is stable
13
14 ## Repo strategy
15
16 - Keep external dataset ingestion optional
17 - Convert external tracks into:
18 - `catalog.json` for searchable references
19 - query segment manifests for evaluation
20 - Start with small local subsets before full-corpus scaling