period upload
Showing
11 changed files
with
467 additions
and
170 deletions
| ... | @@ -65,3 +65,16 @@ python run_demo.py full-demo --device cpu | ... | @@ -65,3 +65,16 @@ python run_demo.py full-demo --device cpu |
| 65 | ## 当前定位 | 65 | ## 当前定位 |
| 66 | 66 | ||
| 67 | 这是一个**原型仓库**,目标是验证 ACR 主链路能否跑通,不是生产级服务。 | 67 | 这是一个**原型仓库**,目标是验证 ACR 主链路能否跑通,不是生产级服务。 |
| 68 | |||
| 69 | ## 评测 | ||
| 70 | |||
| 71 | ```bash | ||
| 72 | python evaluate.py --data data/synthetic --model data/models/best_model.pt --index-prefix data/index/reference --split test --device cpu | ||
| 73 | ``` | ||
| 74 | |||
| 75 | ## 当前提升方向 | ||
| 76 | |||
| 77 | - 更强合成混淆样本(confused / humming_like) | ||
| 78 | - Hybrid 分数归一化后再融合 | ||
| 79 | - full-demo 自动训练 | ||
| 80 | - 后续可接入开源数据集 | ... | ... |
acr-engine/evaluate.py
0 → 100644
| 1 | #!/usr/bin/env python3 | ||
| 2 | import argparse | ||
| 3 | import json | ||
| 4 | from pathlib import Path | ||
| 5 | |||
| 6 | import numpy as np | ||
| 7 | |||
| 8 | from src.engines.chromaprint_matcher import ChromaprintMatcher | ||
| 9 | from src.engines.ecapa_embedder import ECAPAEmbedder | ||
| 10 | from src.engines.hybrid_engine import HybridEngine | ||
| 11 | |||
| 12 | |||
| 13 | def load_items(meta_path: Path): | ||
| 14 | with open(meta_path) as f: | ||
| 15 | return json.load(f) | ||
| 16 | |||
| 17 | |||
| 18 | def main(): | ||
| 19 | parser = argparse.ArgumentParser(description="Evaluate ACR recognition quality") | ||
| 20 | parser.add_argument("--data", default="data/synthetic") | ||
| 21 | parser.add_argument("--model", required=True) | ||
| 22 | parser.add_argument("--index-prefix", default="data/index/reference") | ||
| 23 | parser.add_argument("--split", default="test") | ||
| 24 | parser.add_argument("--top-k", type=int, default=5) | ||
| 25 | parser.add_argument("--device", default="cpu") | ||
| 26 | args = parser.parse_args() | ||
| 27 | |||
| 28 | data_dir = Path(args.data) | ||
| 29 | matcher = ChromaprintMatcher() | ||
| 30 | matcher.load(str(Path(args.index_prefix).parent / "chromaprint.pkl")) | ||
| 31 | embedder = ECAPAEmbedder(model_path=args.model, device=args.device) | ||
| 32 | ref_embs = np.load(f"{args.index_prefix}_embs.npy") | ||
| 33 | ref_ids = np.load(f"{args.index_prefix}_ids.npy", allow_pickle=True).tolist() | ||
| 34 | |||
| 35 | engine = HybridEngine(matcher, embedder, ref_embs, ref_ids) | ||
| 36 | for split in ["train.json", "val.json", "test.json"]: | ||
| 37 | p = data_dir / split | ||
| 38 | if p.exists(): | ||
| 39 | engine.load_metadata(str(p)) | ||
| 40 | |||
| 41 | items = load_items(data_dir / f"{args.split}.json") | ||
| 42 | queries = [x for x in items if str(x.get("audio_path", "")).startswith("segments/")] | ||
| 43 | if not queries: | ||
| 44 | raise SystemExit("No segment queries found for evaluation") | ||
| 45 | |||
| 46 | top1 = 0 | ||
| 47 | topk = 0 | ||
| 48 | by_type = {} | ||
| 49 | failures = [] | ||
| 50 | |||
| 51 | for item in queries: | ||
| 52 | result = engine.recognize(str(data_dir / item["audio_path"]), top_n=args.top_k) | ||
| 53 | preds = [c["song_id"] for c in result["candidates"]] | ||
| 54 | truth = item["song_id"] | ||
| 55 | qtype = item.get("type", "unknown") | ||
| 56 | stats = by_type.setdefault(qtype, {"n": 0, "top1": 0, "topk": 0}) | ||
| 57 | stats["n"] += 1 | ||
| 58 | |||
| 59 | if preds and preds[0] == truth: | ||
| 60 | top1 += 1 | ||
| 61 | stats["top1"] += 1 | ||
| 62 | if truth in preds: | ||
| 63 | topk += 1 | ||
| 64 | stats["topk"] += 1 | ||
| 65 | else: | ||
| 66 | failures.append({ | ||
| 67 | "truth": truth, | ||
| 68 | "query": item["audio_path"], | ||
| 69 | "type": qtype, | ||
| 70 | "preds": preds, | ||
| 71 | }) | ||
| 72 | |||
| 73 | total = len(queries) | ||
| 74 | report = { | ||
| 75 | "split": args.split, | ||
| 76 | "num_queries": total, | ||
| 77 | "top1": round(top1 / total, 4), | ||
| 78 | "topk": round(topk / total, 4), | ||
| 79 | "by_type": { | ||
| 80 | k: { | ||
| 81 | "n": v["n"], | ||
| 82 | "top1": round(v["top1"] / v["n"], 4) if v["n"] else 0.0, | ||
| 83 | "topk": round(v["topk"] / v["n"], 4) if v["n"] else 0.0, | ||
| 84 | } | ||
| 85 | for k, v in by_type.items() | ||
| 86 | }, | ||
| 87 | "sample_failures": failures[:10], | ||
| 88 | } | ||
| 89 | print(json.dumps(report, ensure_ascii=False, indent=2)) | ||
| 90 | |||
| 91 | |||
| 92 | if __name__ == "__main__": | ||
| 93 | main() |
| ... | @@ -31,7 +31,7 @@ def build_chroma_index(data_dir: Path, output_dir: Path): | ... | @@ -31,7 +31,7 @@ def build_chroma_index(data_dir: Path, output_dir: Path): |
| 31 | matcher = ChromaprintMatcher() | 31 | matcher = ChromaprintMatcher() |
| 32 | matcher.index_songs_from_dir( | 32 | matcher.index_songs_from_dir( |
| 33 | songs_dir=str(data_dir / 'songs'), | 33 | songs_dir=str(data_dir / 'songs'), |
| 34 | metadata_path=str(data_dir / 'train.json'), | 34 | metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'), |
| 35 | cache_path=str(output_dir / 'chromaprint.pkl'), | 35 | cache_path=str(output_dir / 'chromaprint.pkl'), |
| 36 | ) | 36 | ) |
| 37 | print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}") | 37 | print(f"[done] chromaprint index built: hashes={matcher.num_hashes}, postings={matcher.index_size}") |
| ... | @@ -42,7 +42,7 @@ def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path, | ... | @@ -42,7 +42,7 @@ def build_embedding_index(data_dir: Path, model_path: Path, output_prefix: Path, |
| 42 | embedder = ECAPAEmbedder(model_path=str(model_path), device=device) | 42 | embedder = ECAPAEmbedder(model_path=str(model_path), device=device) |
| 43 | ref_embs, ref_ids = embedder.build_reference_index( | 43 | ref_embs, ref_ids = embedder.build_reference_index( |
| 44 | songs_dir=str(data_dir / 'songs'), | 44 | songs_dir=str(data_dir / 'songs'), |
| 45 | metadata_path=str(data_dir / 'train.json'), | 45 | metadata_path=str(data_dir / 'catalog.json' if (data_dir / 'catalog.json').exists() else data_dir / 'train.json'), |
| 46 | output_path=str(output_prefix), | 46 | output_path=str(output_prefix), |
| 47 | ) | 47 | ) |
| 48 | print(f"[done] embedding index built: {len(ref_ids)} refs") | 48 | print(f"[done] embedding index built: {len(ref_ids)} refs") |
| ... | @@ -104,16 +104,20 @@ def cmd_full_demo(args): | ... | @@ -104,16 +104,20 @@ def cmd_full_demo(args): |
| 104 | 104 | ||
| 105 | model_path = model_dir / 'best_model.pt' | 105 | model_path = model_dir / 'best_model.pt' |
| 106 | if not model_path.exists(): | 106 | if not model_path.exists(): |
| 107 | raise SystemExit( | 107 | import subprocess |
| 108 | 'full-demo requires a trained model at data/models/best_model.pt. '\ | 108 | model_dir.mkdir(parents=True, exist_ok=True) |
| 109 | 'Run train.py first or provide one.' | 109 | cmd = [ |
| 110 | ) | 110 | '/usr/local/miniconda3/bin/python', 'train.py', |
| 111 | '--data', str(data_dir), '--output', str(model_dir), | ||
| 112 | '--device', args.device, '--epochs', '3', '--batch-size', '8' | ||
| 113 | ] | ||
| 114 | print('[full-demo] training model:', ' '.join(cmd)) | ||
| 115 | subprocess.run(cmd, check=True) | ||
| 111 | 116 | ||
| 112 | index_dir.mkdir(parents=True, exist_ok=True) | 117 | index_dir.mkdir(parents=True, exist_ok=True) |
| 113 | matcher = build_chroma_index(data_dir, index_dir) | 118 | matcher = build_chroma_index(data_dir, index_dir) |
| 114 | embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device) | 119 | embedder, ref_embs, ref_ids = build_embedding_index(data_dir, model_path, index_dir / 'reference', args.device) |
| 115 | 120 | ||
| 116 | query = sorted((data_dir / 'test.json').read_text() and [] ) | ||
| 117 | with open(data_dir / 'test.json') as f: | 121 | with open(data_dir / 'test.json') as f: |
| 118 | test_meta = json.load(f) | 122 | test_meta = json.load(f) |
| 119 | query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0]) | 123 | query_item = next((x for x in test_meta if 'segments/' in x['audio_path']), test_meta[0]) | ... | ... |
acr-engine/scripts/download_open_dataset.py
0 → 100644
| 1 | #!/usr/bin/env python3 | ||
| 2 | """Helpers for optional open music dataset integration.""" | ||
| 3 | |||
| 4 | import argparse | ||
| 5 | import json | ||
| 6 | from pathlib import Path | ||
| 7 | |||
| 8 | DATASETS = { | ||
| 9 | "fma_small": { | ||
| 10 | "url": "https://github.com/mdeff/fma", | ||
| 11 | "notes": "Use FMA small subset first; convert clips into catalog/query JSON for local experiments.", | ||
| 12 | }, | ||
| 13 | "mtg_jamendo": { | ||
| 14 | "url": "https://github.com/MTG/mtg-jamendo-dataset", | ||
| 15 | "notes": "Use upstream download scripts; sample a small subset into catalog/query structure.", | ||
| 16 | }, | ||
| 17 | } | ||
| 18 | |||
| 19 | |||
| 20 | def main(): | ||
| 21 | parser = argparse.ArgumentParser() | ||
| 22 | parser.add_argument("dataset", choices=sorted(DATASETS)) | ||
| 23 | parser.add_argument("--output", default="../docs/open-datasets.json") | ||
| 24 | args = parser.parse_args() | ||
| 25 | out = Path(args.output) | ||
| 26 | out.parent.mkdir(parents=True, exist_ok=True) | ||
| 27 | with open(out, "w") as f: | ||
| 28 | json.dump({args.dataset: DATASETS[args.dataset]}, f, indent=2) | ||
| 29 | print(f"Wrote dataset integration note to {out}") | ||
| 30 | |||
| 31 | |||
| 32 | if __name__ == "__main__": | ||
| 33 | main() |
| 1 | import torch | 1 | import json |
| 2 | from torch.utils.data import Dataset | ||
| 3 | import numpy as np | ||
| 4 | import librosa | ||
| 5 | import random | 2 | import random |
| 6 | from pathlib import Path | 3 | from pathlib import Path |
| 7 | from typing import Dict, List, Tuple | 4 | from typing import Dict, List, Optional |
| 8 | import json | 5 | |
| 9 | import os | 6 | import librosa |
| 7 | import numpy as np | ||
| 8 | import torch | ||
| 9 | from torch.utils.data import Dataset | ||
| 10 | 10 | ||
| 11 | 11 | ||
| 12 | class ACRDataset(Dataset): | 12 | class ACRDataset(Dataset): |
| ... | @@ -21,6 +21,8 @@ class ACRDataset(Dataset): | ... | @@ -21,6 +21,8 @@ class ACRDataset(Dataset): |
| 21 | segment_dur: float = 5.0, | 21 | segment_dur: float = 5.0, |
| 22 | augment: bool = True, | 22 | augment: bool = True, |
| 23 | n_crops_per_song: int = 4, | 23 | n_crops_per_song: int = 4, |
| 24 | song_to_idx: Optional[Dict[str, int]] = None, | ||
| 25 | references_only: bool = False, | ||
| 24 | ): | 26 | ): |
| 25 | self.sr = sr | 27 | self.sr = sr |
| 26 | self.n_mels = n_mels | 28 | self.n_mels = n_mels |
| ... | @@ -31,36 +33,39 @@ class ACRDataset(Dataset): | ... | @@ -31,36 +33,39 @@ class ACRDataset(Dataset): |
| 31 | self.n_crops = n_crops_per_song | 33 | self.n_crops = n_crops_per_song |
| 32 | self.data_dir = Path(data_dir) | 34 | self.data_dir = Path(data_dir) |
| 33 | 35 | ||
| 34 | meta_path = Path(data_dir) / f"{split}.json" | 36 | meta_path = self.data_dir / f"{split}.json" |
| 35 | with open(meta_path) as f: | 37 | with open(meta_path) as f: |
| 36 | self.metadata = json.load(f) | 38 | self.metadata = json.load(f) |
| 37 | 39 | ||
| 38 | self.samples = [] | 40 | self.samples = [] |
| 39 | for item in self.metadata: | 41 | for item in self.metadata: |
| 40 | song_path = Path(data_dir) / item["audio_path"] | 42 | if references_only and item.get("type") != "reference": |
| 43 | continue | ||
| 44 | song_path = self.data_dir / item["audio_path"] | ||
| 41 | if song_path.exists(): | 45 | if song_path.exists(): |
| 42 | self.samples.append(item) | 46 | self.samples.append(item) |
| 47 | |||
| 43 | self.song_ids = sorted(set(s["song_id"] for s in self.samples)) | 48 | self.song_ids = sorted(set(s["song_id"] for s in self.samples)) |
| 44 | self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)} | 49 | self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)} |
| 45 | 50 | ||
| 46 | def __len__(self): | 51 | def __len__(self): |
| 47 | return len(self.samples) * self.n_crops | 52 | return len(self.samples) * self.n_crops |
| 48 | 53 | ||
| 49 | def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray: | 54 | def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray: |
| 50 | y, _ = librosa.load( | 55 | y, _ = librosa.load(path, sr=self.sr, mono=True, offset=offset, duration=duration) |
| 51 | path, sr=self.sr, mono=True, | ||
| 52 | offset=offset, duration=duration | ||
| 53 | ) | ||
| 54 | if len(y) < self.segment_len: | 56 | if len(y) < self.segment_len: |
| 55 | y = np.pad(y, (0, self.segment_len - len(y))) | 57 | y = np.pad(y, (0, self.segment_len - len(y))) |
| 56 | else: | 58 | else: |
| 57 | y = y[:self.segment_len] | 59 | y = y[: self.segment_len] |
| 58 | return y | 60 | return y |
| 59 | 61 | ||
| 60 | def _to_mel(self, y: np.ndarray) -> np.ndarray: | 62 | def _to_mel(self, y: np.ndarray) -> np.ndarray: |
| 61 | mel = librosa.feature.melspectrogram( | 63 | mel = librosa.feature.melspectrogram( |
| 62 | y=y, sr=self.sr, n_mels=self.n_mels, | 64 | y=y, |
| 63 | n_fft=self.n_fft, hop_length=self.hop_length | 65 | sr=self.sr, |
| 66 | n_mels=self.n_mels, | ||
| 67 | n_fft=self.n_fft, | ||
| 68 | hop_length=self.hop_length, | ||
| 64 | ) | 69 | ) |
| 65 | return librosa.power_to_db(mel, ref=np.max) | 70 | return librosa.power_to_db(mel, ref=np.max) |
| 66 | 71 | ||
| ... | @@ -73,7 +78,7 @@ class ACRDataset(Dataset): | ... | @@ -73,7 +78,7 @@ class ACRDataset(Dataset): |
| 73 | audio_path = self.data_dir / sample["audio_path"] | 78 | audio_path = self.data_dir / sample["audio_path"] |
| 74 | y = self._load_segment(str(audio_path), offset, 5.0) | 79 | y = self._load_segment(str(audio_path), offset, 5.0) |
| 75 | 80 | ||
| 76 | if self.augment: | 81 | if self.augment and sample.get("type") != "reference": |
| 77 | from src.utils.augment import AugmentPipeline | 82 | from src.utils.augment import AugmentPipeline |
| 78 | aug = AugmentPipeline(self.sr) | 83 | aug = AugmentPipeline(self.sr) |
| 79 | y = aug(y) | 84 | y = aug(y) |
| ... | @@ -88,6 +93,7 @@ class ACRDataset(Dataset): | ... | @@ -88,6 +93,7 @@ class ACRDataset(Dataset): |
| 88 | "mel": mel_tensor, | 93 | "mel": mel_tensor, |
| 89 | "song_id": torch.tensor(class_id, dtype=torch.long), | 94 | "song_id": torch.tensor(class_id, dtype=torch.long), |
| 90 | "song_name": song_id, | 95 | "song_name": song_id, |
| 96 | "type": sample.get("type", "unknown"), | ||
| 91 | } | 97 | } |
| 92 | 98 | ||
| 93 | 99 | ||
| ... | @@ -100,6 +106,7 @@ class ACRTestDataset(Dataset): | ... | @@ -100,6 +106,7 @@ class ACRTestDataset(Dataset): |
| 100 | n_mels: int = 80, | 106 | n_mels: int = 80, |
| 101 | n_fft: int = 512, | 107 | n_fft: int = 512, |
| 102 | hop_length: int = 160, | 108 | hop_length: int = 160, |
| 109 | song_to_idx: Optional[Dict[str, int]] = None, | ||
| 103 | ): | 110 | ): |
| 104 | self.sr = sr | 111 | self.sr = sr |
| 105 | self.n_mels = n_mels | 112 | self.n_mels = n_mels |
| ... | @@ -107,18 +114,18 @@ class ACRTestDataset(Dataset): | ... | @@ -107,18 +114,18 @@ class ACRTestDataset(Dataset): |
| 107 | self.hop_length = hop_length | 114 | self.hop_length = hop_length |
| 108 | self.data_dir = Path(data_dir) | 115 | self.data_dir = Path(data_dir) |
| 109 | 116 | ||
| 110 | meta_path = Path(data_dir) / f"{split}.json" | 117 | meta_path = self.data_dir / f"{split}.json" |
| 111 | with open(meta_path) as f: | 118 | with open(meta_path) as f: |
| 112 | self.metadata = json.load(f) | 119 | self.metadata = json.load(f) |
| 113 | 120 | ||
| 114 | self.samples = [] | 121 | self.samples = [] |
| 115 | for item in self.metadata: | 122 | for item in self.metadata: |
| 116 | p = Path(data_dir) / item["audio_path"] | 123 | p = self.data_dir / item["audio_path"] |
| 117 | if p.exists(): | 124 | if p.exists(): |
| 118 | self.samples.append(item) | 125 | self.samples.append(item) |
| 119 | 126 | ||
| 120 | self.song_ids = sorted(set(s["song_id"] for s in self.samples)) | 127 | self.song_ids = sorted(set(s["song_id"] for s in self.samples)) |
| 121 | self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)} | 128 | self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)} |
| 122 | 129 | ||
| 123 | def __len__(self): | 130 | def __len__(self): |
| 124 | return len(self.samples) | 131 | return len(self.samples) |
| ... | @@ -126,10 +133,7 @@ class ACRTestDataset(Dataset): | ... | @@ -126,10 +133,7 @@ class ACRTestDataset(Dataset): |
| 126 | def __getitem__(self, idx): | 133 | def __getitem__(self, idx): |
| 127 | sample = self.samples[idx] | 134 | sample = self.samples[idx] |
| 128 | audio_path = self.data_dir / sample["audio_path"] | 135 | audio_path = self.data_dir / sample["audio_path"] |
| 129 | y, _ = librosa.load( | 136 | y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True, offset=0, duration=min(sample["duration"], 5.0)) |
| 130 | str(audio_path), sr=self.sr, mono=True, | ||
| 131 | offset=0, duration=min(sample["duration"], 5.0) | ||
| 132 | ) | ||
| 133 | seg_len = 5 * self.sr | 137 | seg_len = 5 * self.sr |
| 134 | if len(y) < seg_len: | 138 | if len(y) < seg_len: |
| 135 | y = np.pad(y, (0, seg_len - len(y))) | 139 | y = np.pad(y, (0, seg_len - len(y))) |
| ... | @@ -137,13 +141,100 @@ class ACRTestDataset(Dataset): | ... | @@ -137,13 +141,100 @@ class ACRTestDataset(Dataset): |
| 137 | y = y[:seg_len] | 141 | y = y[:seg_len] |
| 138 | 142 | ||
| 139 | mel = librosa.power_to_db( | 143 | mel = librosa.power_to_db( |
| 140 | librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels, | 144 | librosa.feature.melspectrogram( |
| 141 | n_fft=self.n_fft, hop_length=self.hop_length), | 145 | y=y, |
| 142 | ref=np.max | 146 | sr=self.sr, |
| 147 | n_mels=self.n_mels, | ||
| 148 | n_fft=self.n_fft, | ||
| 149 | hop_length=self.hop_length, | ||
| 150 | ), | ||
| 151 | ref=np.max, | ||
| 143 | ) | 152 | ) |
| 144 | class_id = self.song_to_idx[sample["song_id"]] | 153 | class_id = self.song_to_idx[sample["song_id"]] |
| 145 | return { | 154 | return { |
| 146 | "mel": torch.FloatTensor(mel), | 155 | "mel": torch.FloatTensor(mel), |
| 147 | "song_id": torch.tensor(class_id, dtype=torch.long), | 156 | "song_id": torch.tensor(class_id, dtype=torch.long), |
| 148 | "song_name": sample["song_id"], | 157 | "song_name": sample["song_id"], |
| 158 | "type": sample.get("type", "unknown"), | ||
| 159 | } | ||
| 160 | |||
| 161 | |||
| 162 | class SongPairDataset(Dataset): | ||
| 163 | def __init__( | ||
| 164 | self, | ||
| 165 | data_dir: str, | ||
| 166 | split: str = "train", | ||
| 167 | sr: int = 16000, | ||
| 168 | n_mels: int = 80, | ||
| 169 | n_fft: int = 512, | ||
| 170 | hop_length: int = 160, | ||
| 171 | segment_dur: float = 5.0, | ||
| 172 | augment: bool = True, | ||
| 173 | ): | ||
| 174 | self.sr = sr | ||
| 175 | self.n_mels = n_mels | ||
| 176 | self.n_fft = n_fft | ||
| 177 | self.hop_length = hop_length | ||
| 178 | self.segment_len = int(segment_dur * sr) | ||
| 179 | self.augment = augment | ||
| 180 | self.data_dir = Path(data_dir) | ||
| 181 | |||
| 182 | with open(self.data_dir / f"{split}.json") as f: | ||
| 183 | metadata = json.load(f) | ||
| 184 | |||
| 185 | self.by_song: Dict[str, List[Dict]] = {} | ||
| 186 | for item in metadata: | ||
| 187 | if item.get("type") == "reference": | ||
| 188 | continue | ||
| 189 | p = self.data_dir / item["audio_path"] | ||
| 190 | if p.exists(): | ||
| 191 | self.by_song.setdefault(item["song_id"], []).append(item) | ||
| 192 | |||
| 193 | self.song_ids = sorted(self.by_song) | ||
| 194 | self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)} | ||
| 195 | |||
| 196 | def __len__(self): | ||
| 197 | return len(self.song_ids) | ||
| 198 | |||
| 199 | def _load_clip(self, sample: Dict) -> np.ndarray: | ||
| 200 | path = self.data_dir / sample["audio_path"] | ||
| 201 | y, _ = librosa.load(str(path), sr=self.sr, mono=True, duration=5.0) | ||
| 202 | if len(y) < self.segment_len: | ||
| 203 | y = np.pad(y, (0, self.segment_len - len(y))) | ||
| 204 | else: | ||
| 205 | y = y[: self.segment_len] | ||
| 206 | return y | ||
| 207 | |||
| 208 | def _to_mel(self, y: np.ndarray) -> torch.Tensor: | ||
| 209 | mel = librosa.feature.melspectrogram( | ||
| 210 | y=y, | ||
| 211 | sr=self.sr, | ||
| 212 | n_mels=self.n_mels, | ||
| 213 | n_fft=self.n_fft, | ||
| 214 | hop_length=self.hop_length, | ||
| 215 | ) | ||
| 216 | mel = librosa.power_to_db(mel, ref=np.max) | ||
| 217 | return torch.FloatTensor(mel) | ||
| 218 | |||
| 219 | def __getitem__(self, idx): | ||
| 220 | song_id = self.song_ids[idx] | ||
| 221 | choices = self.by_song[song_id] | ||
| 222 | if len(choices) == 1: | ||
| 223 | a = b = choices[0] | ||
| 224 | else: | ||
| 225 | a, b = random.sample(choices, 2) | ||
| 226 | |||
| 227 | wavs = [] | ||
| 228 | for sample in (a, b): | ||
| 229 | y = self._load_clip(sample) | ||
| 230 | if self.augment: | ||
| 231 | from src.utils.augment import AugmentPipeline | ||
| 232 | y = AugmentPipeline(self.sr)(y) | ||
| 233 | wavs.append(self._to_mel(y)) | ||
| 234 | |||
| 235 | label = self.song_to_idx[song_id] | ||
| 236 | return { | ||
| 237 | "mel": torch.stack(wavs, dim=0), | ||
| 238 | "song_id": torch.tensor([label, label], dtype=torch.long), | ||
| 239 | "song_name": song_id, | ||
| 149 | } | 240 | } | ... | ... |
| ... | @@ -5,6 +5,7 @@ Generates melodies from fundamental frequencies, simulates: | ... | @@ -5,6 +5,7 @@ Generates melodies from fundamental frequencies, simulates: |
| 5 | - Different "songs" (unique note sequences at different base frequencies) | 5 | - Different "songs" (unique note sequences at different base frequencies) |
| 6 | - Song fragments (random crops from songs) | 6 | - Song fragments (random crops from songs) |
| 7 | - Humming variants (pitch shifted, time stretched versions) | 7 | - Humming variants (pitch shifted, time stretched versions) |
| 8 | - Hard negatives / confusing variants for robustness testing | ||
| 8 | 9 | ||
| 9 | This allows the full pipeline to be validated without external data. | 10 | This allows the full pipeline to be validated without external data. |
| 10 | """ | 11 | """ |
| ... | @@ -13,9 +14,8 @@ import numpy as np | ... | @@ -13,9 +14,8 @@ import numpy as np |
| 13 | import soundfile as sf | 14 | import soundfile as sf |
| 14 | import json | 15 | import json |
| 15 | import random | 16 | import random |
| 16 | import os | ||
| 17 | from pathlib import Path | 17 | from pathlib import Path |
| 18 | from typing import List, Tuple | 18 | from typing import Tuple |
| 19 | from tqdm import tqdm | 19 | from tqdm import tqdm |
| 20 | 20 | ||
| 21 | 21 | ||
| ... | @@ -33,7 +33,10 @@ def harmonic_tone(freq: float, duration: float, sr: int = _SR, n_harmonics: int | ... | @@ -33,7 +33,10 @@ def harmonic_tone(freq: float, duration: float, sr: int = _SR, n_harmonics: int |
| 33 | for h in range(1, n_harmonics + 1): | 33 | for h in range(1, n_harmonics + 1): |
| 34 | amp = 0.5 / h | 34 | amp = 0.5 / h |
| 35 | y += amp * np.sin(2 * np.pi * freq * h * t) | 35 | y += amp * np.sin(2 * np.pi * freq * h * t) |
| 36 | return y / np.max(np.abs(y)) * 0.5 | 36 | peak = np.max(np.abs(y)) |
| 37 | if peak > 0: | ||
| 38 | y = y / peak * 0.5 | ||
| 39 | return y | ||
| 37 | 40 | ||
| 38 | 41 | ||
| 39 | def generate_melody( | 42 | def generate_melody( |
| ... | @@ -44,9 +47,8 @@ def generate_melody( | ... | @@ -44,9 +47,8 @@ def generate_melody( |
| 44 | timbre: str = "harmonic", | 47 | timbre: str = "harmonic", |
| 45 | ) -> np.ndarray: | 48 | ) -> np.ndarray: |
| 46 | notes = [] | 49 | notes = [] |
| 47 | freq = base_freq | 50 | for _ in range(note_count): |
| 48 | for i in range(note_count): | 51 | interval = random.choice([0, 2, 4, 5, 7, 9, 11, 12]) |
| 49 | interval = random.choice([0, 2, 4, 5, 7, 9, 11, 12]) # diatonic intervals | ||
| 50 | freq = base_freq * (2 ** (interval / 12)) | 52 | freq = base_freq * (2 ** (interval / 12)) |
| 51 | dur = note_dur * random.uniform(0.8, 1.2) | 53 | dur = note_dur * random.uniform(0.8, 1.2) |
| 52 | 54 | ||
| ... | @@ -57,7 +59,7 @@ def generate_melody( | ... | @@ -57,7 +59,7 @@ def generate_melody( |
| 57 | 59 | ||
| 58 | if random.random() < 0.15: | 60 | if random.random() < 0.15: |
| 59 | fade = np.linspace(0, 1, min(int(sr * 0.02), len(note))) | 61 | fade = np.linspace(0, 1, min(int(sr * 0.02), len(note))) |
| 60 | note[:len(fade)] *= fade | 62 | note[: len(fade)] *= fade |
| 61 | 63 | ||
| 62 | notes.append(note) | 64 | notes.append(note) |
| 63 | 65 | ||
| ... | @@ -65,15 +67,35 @@ def generate_melody( | ... | @@ -65,15 +67,35 @@ def generate_melody( |
| 65 | 67 | ||
| 66 | 68 | ||
| 67 | _CHORD_PROGRESSIONS = [ | 69 | _CHORD_PROGRESSIONS = [ |
| 68 | [0, 3, 7], # Cm | 70 | [0, 3, 7], |
| 69 | [0, 4, 7], # C | 71 | [0, 4, 7], |
| 70 | [0, 3, 7, 10], # Cm7 | 72 | [0, 3, 7, 10], |
| 71 | [0, 4, 7, 11], # Cmaj7 | 73 | [0, 4, 7, 11], |
| 72 | [0, 4, 9], # Csus4 → C | 74 | [0, 4, 9], |
| 73 | [0, 5, 7], # Csus2 | 75 | [0, 5, 7], |
| 74 | ] | 76 | ] |
| 75 | 77 | ||
| 76 | 78 | ||
| 79 | def apply_confusion_mix(y: np.ndarray, sr: int, strength: float = 0.22) -> np.ndarray: | ||
| 80 | t = np.linspace(0, len(y) / sr, len(y), endpoint=False) | ||
| 81 | distractor = 0.0 | ||
| 82 | for f in [220.0, 330.0, 440.0]: | ||
| 83 | distractor += np.sin(2 * np.pi * f * t + random.uniform(0, np.pi)) | ||
| 84 | distractor /= max(np.max(np.abs(distractor)), 1e-8) | ||
| 85 | mixed = y + strength * distractor | ||
| 86 | peak = np.max(np.abs(mixed)) | ||
| 87 | return mixed / peak * 0.5 if peak > 0 else mixed | ||
| 88 | |||
| 89 | |||
| 90 | def apply_humming_style(y: np.ndarray, sr: int) -> np.ndarray: | ||
| 91 | env = np.linspace(0.7, 1.0, len(y)) | ||
| 92 | hum = y * env | ||
| 93 | kernel = np.ones(max(5, sr // 400)) / max(5, sr // 400) | ||
| 94 | hum = np.convolve(hum, kernel, mode="same") | ||
| 95 | peak = np.max(np.abs(hum)) | ||
| 96 | return hum / peak * 0.5 if peak > 0 else hum | ||
| 97 | |||
| 98 | |||
| 77 | def generate_song( | 99 | def generate_song( |
| 78 | song_id: str, | 100 | song_id: str, |
| 79 | base_freq: float, | 101 | base_freq: float, |
| ... | @@ -98,7 +120,7 @@ def generate_song( | ... | @@ -98,7 +120,7 @@ def generate_song( |
| 98 | env = np.exp(-np.linspace(0, 3, seg_len)) | 120 | env = np.exp(-np.linspace(0, 3, seg_len)) |
| 99 | note = harmonic_tone(freq, seg_len / sr, sr) * env * 0.3 | 121 | note = harmonic_tone(freq, seg_len / sr, sr) * env * 0.3 |
| 100 | min_len = min(seg_len, len(note)) | 122 | min_len = min(seg_len, len(note)) |
| 101 | y[start_sample:start_sample + min_len] += note[:min_len] | 123 | y[start_sample : start_sample + min_len] += note[:min_len] |
| 102 | 124 | ||
| 103 | if with_vocals: | 125 | if with_vocals: |
| 104 | melody = generate_melody(base_freq * 2, note_count=int(duration * 2), note_dur=0.5, sr=sr) | 126 | melody = generate_melody(base_freq * 2, note_count=int(duration * 2), note_dur=0.5, sr=sr) |
| ... | @@ -130,9 +152,11 @@ def generate_dataset( | ... | @@ -130,9 +152,11 @@ def generate_dataset( |
| 130 | songs_dir.mkdir(parents=True, exist_ok=True) | 152 | songs_dir.mkdir(parents=True, exist_ok=True) |
| 131 | segs_dir.mkdir(parents=True, exist_ok=True) | 153 | segs_dir.mkdir(parents=True, exist_ok=True) |
| 132 | 154 | ||
| 133 | base_freqs = [130.81, 146.83, 164.81, 174.61, 196.0, 220.0, 246.94, | 155 | base_freqs = [ |
| 156 | 130.81, 146.83, 164.81, 174.61, 196.0, 220.0, 246.94, | ||
| 134 | 261.63, 293.66, 329.63, 349.23, 392.0, 440.0, 493.88, | 157 | 261.63, 293.66, 329.63, 349.23, 392.0, 440.0, 493.88, |
| 135 | 523.25, 587.33, 659.25, 698.46, 783.99, 880.0, 987.77] | 158 | 523.25, 587.33, 659.25, 698.46, 783.99, 880.0, 987.77, |
| 159 | ] | ||
| 136 | 160 | ||
| 137 | train_meta = [] | 161 | train_meta = [] |
| 138 | val_meta = [] | 162 | val_meta = [] |
| ... | @@ -143,7 +167,7 @@ def generate_dataset( | ... | @@ -143,7 +167,7 @@ def generate_dataset( |
| 143 | song_id = f"song_{i:04d}" | 167 | song_id = f"song_{i:04d}" |
| 144 | base_freq = base_freqs[i % len(base_freqs)] | 168 | base_freq = base_freqs[i % len(base_freqs)] |
| 145 | key_offset = (i // len(base_freqs)) * 2 | 169 | key_offset = (i // len(base_freqs)) * 2 |
| 146 | base_freq *= (2 ** (key_offset / 12)) | 170 | base_freq *= 2 ** (key_offset / 12) |
| 147 | 171 | ||
| 148 | y, dur = generate_song(song_id, base_freq, duration=song_duration, sr=sr) | 172 | y, dur = generate_song(song_id, base_freq, duration=song_duration, sr=sr) |
| 149 | song_path = songs_dir / f"{song_id}.wav" | 173 | song_path = songs_dir / f"{song_id}.wav" |
| ... | @@ -155,42 +179,41 @@ def generate_dataset( | ... | @@ -155,42 +179,41 @@ def generate_dataset( |
| 155 | start_s = int(offset * sr) | 179 | start_s = int(offset * sr) |
| 156 | end_s = start_s + int(segment_duration * sr) | 180 | end_s = start_s + int(segment_duration * sr) |
| 157 | seg = y[start_s:end_s] | 181 | seg = y[start_s:end_s] |
| 182 | target_len = int(segment_duration * sr) | ||
| 158 | 183 | ||
| 159 | if len(seg) < int(segment_duration * sr): | 184 | if len(seg) < target_len: |
| 160 | seg = np.pad(seg, (0, int(segment_duration * sr) - len(seg))) | 185 | seg = np.pad(seg, (0, target_len - len(seg))) |
| 161 | |||
| 162 | is_augmented = (j >= num_segments_per_song // 2) | ||
| 163 | 186 | ||
| 164 | if is_augmented: | 187 | variant_type = "clean" |
| 188 | out_seg = seg.copy() | ||
| 189 | if j >= num_segments_per_song // 2: | ||
| 165 | from src.utils.augment import AugmentPipeline | 190 | from src.utils.augment import AugmentPipeline |
| 166 | aug = AugmentPipeline(sr) | 191 | aug = AugmentPipeline(sr) |
| 167 | seg_aug = aug(seg.copy()) | 192 | out_seg = aug(out_seg) |
| 168 | seg_name = f"{song_id}_seg_{j:02d}_aug.wav" | 193 | variant_type = "augmented" |
| 169 | seg_path = segs_dir / seg_name | 194 | |
| 170 | sf.write(str(seg_path), seg_aug, sr) | 195 | if j == num_segments_per_song - 1: |
| 171 | meta_entry = { | 196 | out_seg = apply_confusion_mix(out_seg, sr) |
| 172 | "song_id": song_id, | 197 | variant_type = "confused" |
| 173 | "audio_path": f"segments/{seg_name}", | 198 | elif j == num_segments_per_song - 2 and num_segments_per_song >= 4: |
| 174 | "duration": segment_duration, | 199 | out_seg = apply_humming_style(out_seg, sr) |
| 175 | "type": "augmented", | 200 | variant_type = "humming_like" |
| 176 | "offset": offset, | 201 | |
| 177 | } | 202 | seg_name = f"{song_id}_seg_{j:02d}_{variant_type}.wav" if variant_type != "clean" else f"{song_id}_seg_{j:02d}.wav" |
| 178 | else: | ||
| 179 | seg_name = f"{song_id}_seg_{j:02d}.wav" | ||
| 180 | seg_path = segs_dir / seg_name | 203 | seg_path = segs_dir / seg_name |
| 181 | sf.write(str(seg_path), seg, sr) | 204 | sf.write(str(seg_path), out_seg, sr) |
| 205 | |||
| 182 | meta_entry = { | 206 | meta_entry = { |
| 183 | "song_id": song_id, | 207 | "song_id": song_id, |
| 184 | "audio_path": f"segments/{seg_name}", | 208 | "audio_path": f"segments/{seg_name}", |
| 185 | "duration": segment_duration, | 209 | "duration": segment_duration, |
| 186 | "type": "clean", | 210 | "type": variant_type, |
| 187 | "offset": offset, | 211 | "offset": offset, |
| 188 | } | 212 | } |
| 189 | 213 | ||
| 190 | offset_sec = offset | 214 | if offset < dur * 0.2: |
| 191 | if offset_sec < dur * 0.2: | ||
| 192 | seg_type = "intro" | 215 | seg_type = "intro" |
| 193 | elif offset_sec > dur * 0.7: | 216 | elif offset > dur * 0.7: |
| 194 | seg_type = "outro" | 217 | seg_type = "outro" |
| 195 | else: | 218 | else: |
| 196 | seg_type = "mid" | 219 | seg_type = "mid" |
| ... | @@ -208,6 +231,7 @@ def generate_dataset( | ... | @@ -208,6 +231,7 @@ def generate_dataset( |
| 208 | "audio_path": f"songs/{song_id}.wav", | 231 | "audio_path": f"songs/{song_id}.wav", |
| 209 | "duration": dur, | 232 | "duration": dur, |
| 210 | "base_freq": base_freq, | 233 | "base_freq": base_freq, |
| 234 | "type": "reference", | ||
| 211 | } | 235 | } |
| 212 | if i < int(num_songs * 0.7): | 236 | if i < int(num_songs * 0.7): |
| 213 | train_meta.append(song_meta) | 237 | train_meta.append(song_meta) |
| ... | @@ -216,6 +240,10 @@ def generate_dataset( | ... | @@ -216,6 +240,10 @@ def generate_dataset( |
| 216 | else: | 240 | else: |
| 217 | test_meta.append(song_meta) | 241 | test_meta.append(song_meta) |
| 218 | 242 | ||
| 243 | catalog_meta = [item for item in train_meta + val_meta + test_meta if item.get("type") == "reference"] | ||
| 244 | with open(output_dir / "catalog.json", "w") as f: | ||
| 245 | json.dump(catalog_meta, f, indent=2) | ||
| 246 | |||
| 219 | for name, data in [("train", train_meta), ("val", val_meta), ("test", test_meta)]: | 247 | for name, data in [("train", train_meta), ("val", val_meta), ("test", test_meta)]: |
| 220 | with open(output_dir / f"{name}.json", "w") as f: | 248 | with open(output_dir / f"{name}.json", "w") as f: |
| 221 | json.dump(data, f, indent=2) | 249 | json.dump(data, f, indent=2) |
| ... | @@ -229,6 +257,7 @@ def generate_dataset( | ... | @@ -229,6 +257,7 @@ def generate_dataset( |
| 229 | 257 | ||
| 230 | if __name__ == "__main__": | 258 | if __name__ == "__main__": |
| 231 | import argparse | 259 | import argparse |
| 260 | |||
| 232 | parser = argparse.ArgumentParser() | 261 | parser = argparse.ArgumentParser() |
| 233 | parser.add_argument("--output", type=str, default="data/synthetic") | 262 | parser.add_argument("--output", type=str, default="data/synthetic") |
| 234 | parser.add_argument("--num-songs", type=int, default=50) | 263 | parser.add_argument("--num-songs", type=int, default=50) | ... | ... |
| 1 | import torch | 1 | import json |
| 2 | import torch.nn.functional as F | ||
| 3 | import numpy as np | ||
| 4 | import librosa | ||
| 5 | from pathlib import Path | 2 | from pathlib import Path |
| 6 | from typing import List, Optional, Tuple | 3 | from typing import List, Optional, Tuple |
| 7 | import json | 4 | |
| 5 | import librosa | ||
| 6 | import numpy as np | ||
| 7 | import torch | ||
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | class ECAPAEmbedder: | 10 | class ECAPAEmbedder: |
| ... | @@ -24,11 +24,22 @@ class ECAPAEmbedder: | ... | @@ -24,11 +24,22 @@ class ECAPAEmbedder: |
| 24 | self.hop_length = hop_length | 24 | self.hop_length = hop_length |
| 25 | 25 | ||
| 26 | from src.models.ecapa_tdnn import ECAPA_ACR | 26 | from src.models.ecapa_tdnn import ECAPA_ACR |
| 27 | self.model = ECAPA_ACR(n_mels=n_mels, embed_dim=192) | 27 | |
| 28 | state = torch.load(model_path, map_location="cpu", weights_only=True) | 28 | state = torch.load(model_path, map_location="cpu", weights_only=True) |
| 29 | if "model_state_dict" in state: | 29 | cfg = state.get("config", {}) |
| 30 | state = state["model_state_dict"] | 30 | model_cfg = cfg.get("model", {}) |
| 31 | self.model.load_state_dict(state, strict=False) | 31 | self.model = ECAPA_ACR( |
| 32 | n_mels=model_cfg.get("n_mels", n_mels), | ||
| 33 | embed_dim=model_cfg.get("embed_dim", 192), | ||
| 34 | channels=model_cfg.get("channels", 512), | ||
| 35 | se_channels=model_cfg.get("se_channels", 128), | ||
| 36 | res2net_scale=model_cfg.get("res2net_scale", 8), | ||
| 37 | num_blocks=model_cfg.get("num_blocks", 3), | ||
| 38 | num_classes=None, | ||
| 39 | ) | ||
| 40 | missing = self.model.load_state_dict(state["model_state_dict"], strict=False) | ||
| 41 | if missing.unexpected_keys: | ||
| 42 | print(f"[warn] unexpected keys while loading model: {missing.unexpected_keys}") | ||
| 32 | self.model.to(self.device) | 43 | self.model.to(self.device) |
| 33 | self.model.eval() | 44 | self.model.eval() |
| 34 | 45 | ||
| ... | @@ -38,26 +49,37 @@ class ECAPAEmbedder: | ... | @@ -38,26 +49,37 @@ class ECAPAEmbedder: |
| 38 | 49 | ||
| 39 | def _to_mel(self, y: np.ndarray) -> torch.Tensor: | 50 | def _to_mel(self, y: np.ndarray) -> torch.Tensor: |
| 40 | mel = librosa.feature.melspectrogram( | 51 | mel = librosa.feature.melspectrogram( |
| 41 | y=y, sr=self.sr, n_mels=self.n_mels, | 52 | y=y, |
| 42 | n_fft=self.n_fft, hop_length=self.hop_length | 53 | sr=self.sr, |
| 54 | n_mels=self.n_mels, | ||
| 55 | n_fft=self.n_fft, | ||
| 56 | hop_length=self.hop_length, | ||
| 43 | ) | 57 | ) |
| 44 | mel = librosa.power_to_db(mel, ref=np.max) | 58 | mel = librosa.power_to_db(mel, ref=np.max) |
| 45 | return torch.FloatTensor(mel).unsqueeze(0) | 59 | return torch.FloatTensor(mel).unsqueeze(0) |
| 46 | 60 | ||
| 61 | def _windows(self, y: np.ndarray, window_sec: float = 5.0, stride_sec: float = 2.5) -> List[np.ndarray]: | ||
| 62 | win_len = int(window_sec * self.sr) | ||
| 63 | stride = int(stride_sec * self.sr) | ||
| 64 | if len(y) < win_len: | ||
| 65 | y = np.pad(y, (0, win_len - len(y))) | ||
| 66 | windows = [] | ||
| 67 | for start in range(0, max(len(y) - win_len + 1, 1), stride): | ||
| 68 | windows.append(y[start : start + win_len]) | ||
| 69 | return windows or [y[:win_len]] | ||
| 70 | |||
| 47 | def extract_embedding(self, audio_path: str) -> np.ndarray: | 71 | def extract_embedding(self, audio_path: str) -> np.ndarray: |
| 48 | y = self._load_audio(audio_path) | 72 | y = self._load_audio(audio_path) |
| 49 | mel = self._to_mel(y).to(self.device) | 73 | return self.extract_embedding_from_wave(y) |
| 50 | with torch.no_grad(): | ||
| 51 | emb, _ = self.model(mel) | ||
| 52 | return emb.cpu().numpy().flatten() | ||
| 53 | 74 | ||
| 54 | def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray: | 75 | def extract_embedding_from_wave(self, y: np.ndarray) -> np.ndarray: |
| 55 | if len(y) < self.sr: | 76 | window_embs = [] |
| 56 | y = np.pad(y, (0, self.sr - len(y))) | 77 | for seg in self._windows(y): |
| 57 | mel = self._to_mel(y[:self.sr * 5]).to(self.device) | 78 | mel = self._to_mel(seg).to(self.device) |
| 58 | with torch.no_grad(): | 79 | with torch.no_grad(): |
| 59 | emb, _ = self.model(mel) | 80 | emb, _ = self.model(mel) |
| 60 | return emb.cpu().numpy().flatten() | 81 | window_embs.append(emb.cpu().numpy().flatten()) |
| 82 | return np.mean(window_embs, axis=0) | ||
| 61 | 83 | ||
| 62 | def build_reference_index( | 84 | def build_reference_index( |
| 63 | self, | 85 | self, |
| ... | @@ -75,7 +97,7 @@ class ECAPAEmbedder: | ... | @@ -75,7 +97,7 @@ class ECAPAEmbedder: |
| 75 | songs_dir = Path(songs_dir) | 97 | songs_dir = Path(songs_dir) |
| 76 | 98 | ||
| 77 | for item in meta: | 99 | for item in meta: |
| 78 | if "songs/" not in item.get("audio_path", ""): | 100 | if item.get("type") != "reference" and "songs/" not in item.get("audio_path", ""): |
| 79 | continue | 101 | continue |
| 80 | audio_path = songs_dir.parent / item["audio_path"] | 102 | audio_path = songs_dir.parent / item["audio_path"] |
| 81 | if not audio_path.exists(): | 103 | if not audio_path.exists(): |
| ... | @@ -83,35 +105,20 @@ class ECAPAEmbedder: | ... | @@ -83,35 +105,20 @@ class ECAPAEmbedder: |
| 83 | song_id = item["song_id"] | 105 | song_id = item["song_id"] |
| 84 | y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True) | 106 | y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True) |
| 85 | 107 | ||
| 86 | win_len = int(window_sec * self.sr) | 108 | for seg in self._windows(y, window_sec=window_sec, stride_sec=stride_sec): |
| 87 | stride = int(stride_sec * self.sr) | ||
| 88 | |||
| 89 | window_embs = [] | ||
| 90 | for start in range(0, len(y) - win_len + 1, stride): | ||
| 91 | seg = y[start:start + win_len] | ||
| 92 | mel = self._to_mel(seg).to(self.device) | 109 | mel = self._to_mel(seg).to(self.device) |
| 93 | with torch.no_grad(): | 110 | with torch.no_grad(): |
| 94 | emb, _ = self.model(mel) | 111 | emb, _ = self.model(mel) |
| 95 | window_embs.append(emb.cpu().numpy().flatten()) | 112 | all_embs.append(emb.cpu().numpy().flatten()) |
| 96 | |||
| 97 | if window_embs: | ||
| 98 | song_emb = np.mean(window_embs, axis=0) | ||
| 99 | all_embs.append(song_emb) | ||
| 100 | all_ids.append(song_id) | 113 | all_ids.append(song_id) |
| 101 | 114 | ||
| 102 | all_embs = np.vstack(all_embs) | 115 | all_embs = np.vstack(all_embs) |
| 103 | np.save(f"{output_path}_embs.npy", all_embs) | 116 | np.save(f"{output_path}_embs.npy", all_embs) |
| 104 | np.save(f"{output_path}_ids.npy", np.array(all_ids)) | 117 | np.save(f"{output_path}_ids.npy", np.array(all_ids)) |
| 105 | print(f"Built reference index: {len(all_ids)} songs, embeddings shape {all_embs.shape}") | 118 | print(f"Built reference index: {len(all_ids)} windows, embeddings shape {all_embs.shape}") |
| 106 | return all_embs, all_ids | 119 | return all_embs, all_ids |
| 107 | 120 | ||
| 108 | def search( | 121 | def search(self, query_emb: np.ndarray, ref_embs: np.ndarray, ref_ids: List[str], top_k: int = 10): |
| 109 | self, | ||
| 110 | query_emb: np.ndarray, | ||
| 111 | ref_embs: np.ndarray, | ||
| 112 | ref_ids: List[str], | ||
| 113 | top_k: int = 10, | ||
| 114 | ) -> List[Tuple[str, float]]: | ||
| 115 | query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) | 122 | query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) |
| 116 | ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12) | 123 | ref_norm = ref_embs / (np.linalg.norm(ref_embs, axis=1, keepdims=True) + 1e-12) |
| 117 | scores = query_norm @ ref_norm.T | 124 | scores = query_norm @ ref_norm.T | ... | ... |
| ... | @@ -2,12 +2,12 @@ | ... | @@ -2,12 +2,12 @@ |
| 2 | Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking. | 2 | Hybrid ACR Engine: Chromaprint fast pre-filter + ECAPA-TDNN deep re-ranking. |
| 3 | """ | 3 | """ |
| 4 | 4 | ||
| 5 | import numpy as np | ||
| 6 | import librosa | ||
| 7 | from typing import List, Tuple, Optional, Dict | ||
| 8 | from pathlib import Path | ||
| 9 | import json | 5 | import json |
| 10 | import time | 6 | import time |
| 7 | from typing import Dict, List, Optional | ||
| 8 | |||
| 9 | import librosa | ||
| 10 | import numpy as np | ||
| 11 | 11 | ||
| 12 | 12 | ||
| 13 | class Candidate: | 13 | class Candidate: |
| ... | @@ -17,9 +17,8 @@ class Candidate: | ... | @@ -17,9 +17,8 @@ class Candidate: |
| 17 | self.ecapa_score = ecapa_score | 17 | self.ecapa_score = ecapa_score |
| 18 | self.metadata: Dict = {} | 18 | self.metadata: Dict = {} |
| 19 | 19 | ||
| 20 | @property | 20 | def combined_score(self, chroma_weight: float, ecapa_weight: float) -> float: |
| 21 | def combined_score(self) -> float: | 21 | return chroma_weight * self.chroma_score + ecapa_weight * self.ecapa_score |
| 22 | return 0.3 * self.chroma_score + 0.7 * self.ecapa_score | ||
| 23 | 22 | ||
| 24 | def __repr__(self): | 23 | def __repr__(self): |
| 25 | return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})" | 24 | return f"Candidate({self.song_id}, chroma={self.chroma_score:.3f}, ecapa={self.ecapa_score:.3f})" |
| ... | @@ -33,9 +32,9 @@ class HybridEngine: | ... | @@ -33,9 +32,9 @@ class HybridEngine: |
| 33 | ref_embs: Optional[np.ndarray] = None, | 32 | ref_embs: Optional[np.ndarray] = None, |
| 34 | ref_ids: Optional[List[str]] = None, | 33 | ref_ids: Optional[List[str]] = None, |
| 35 | sr: int = 16000, | 34 | sr: int = 16000, |
| 36 | chroma_weight: float = 0.3, | 35 | chroma_weight: float = 0.35, |
| 37 | ecapa_weight: float = 0.7, | 36 | ecapa_weight: float = 0.65, |
| 38 | reject_threshold: float = 0.4, | 37 | reject_threshold: float = 0.35, |
| 39 | ): | 38 | ): |
| 40 | self.chroma = chroma_matcher | 39 | self.chroma = chroma_matcher |
| 41 | self.ecapa = ecapa_embedder | 40 | self.ecapa = ecapa_embedder |
| ... | @@ -45,7 +44,6 @@ class HybridEngine: | ... | @@ -45,7 +44,6 @@ class HybridEngine: |
| 45 | self.chroma_weight = chroma_weight | 44 | self.chroma_weight = chroma_weight |
| 46 | self.ecapa_weight = ecapa_weight | 45 | self.ecapa_weight = ecapa_weight |
| 47 | self.reject_threshold = reject_threshold | 46 | self.reject_threshold = reject_threshold |
| 48 | |||
| 49 | self.song_metadata: Dict[str, Dict] = {} | 47 | self.song_metadata: Dict[str, Dict] = {} |
| 50 | 48 | ||
| 51 | def load_metadata(self, metadata_path: str): | 49 | def load_metadata(self, metadata_path: str): |
| ... | @@ -53,75 +51,83 @@ class HybridEngine: | ... | @@ -53,75 +51,83 @@ class HybridEngine: |
| 53 | items = json.load(f) | 51 | items = json.load(f) |
| 54 | for item in items: | 52 | for item in items: |
| 55 | sid = item["song_id"] | 53 | sid = item["song_id"] |
| 56 | if sid not in self.song_metadata: | 54 | existing = self.song_metadata.get(sid, {}) |
| 57 | base = item.get("base_freq", 0) | 55 | if item.get("type") == "reference" or not existing: |
| 58 | self.song_metadata[sid] = { | 56 | self.song_metadata[sid] = { |
| 59 | "song_id": sid, | 57 | "song_id": sid, |
| 60 | "base_freq": base, | 58 | "base_freq": item.get("base_freq", existing.get("base_freq", 0)), |
| 61 | "audio_path": item.get("audio_path", ""), | 59 | "audio_path": item.get("audio_path", existing.get("audio_path", "")), |
| 60 | "type": item.get("type", existing.get("type", "unknown")), | ||
| 62 | } | 61 | } |
| 63 | 62 | ||
| 63 | @staticmethod | ||
| 64 | def _normalize_scores(score_pairs: List[tuple], invert: bool = False) -> Dict[str, float]: | ||
| 65 | if not score_pairs: | ||
| 66 | return {} | ||
| 67 | ids = [sid for sid, _ in score_pairs] | ||
| 68 | values = np.array([float(score) for _, score in score_pairs], dtype=np.float32) | ||
| 69 | if invert: | ||
| 70 | values = -values | ||
| 71 | if len(values) == 1: | ||
| 72 | return {ids[0]: 1.0} | ||
| 73 | vmin = float(values.min()) | ||
| 74 | vmax = float(values.max()) | ||
| 75 | if abs(vmax - vmin) < 1e-8: | ||
| 76 | return {sid: 1.0 for sid in ids} | ||
| 77 | norm = (values - vmin) / (vmax - vmin) | ||
| 78 | return {sid: float(score) for sid, score in zip(ids, norm)} | ||
| 79 | |||
| 64 | def recognize( | 80 | def recognize( |
| 65 | self, | 81 | self, |
| 66 | audio_path: str, | 82 | audio_path: str, |
| 67 | top_n: int = 5, | 83 | top_n: int = 5, |
| 68 | mode: str = "auto", | 84 | mode: str = "auto", |
| 69 | ) -> List[Dict]: | 85 | ) -> Dict: |
| 86 | del mode | ||
| 70 | start = time.time() | 87 | start = time.time() |
| 71 | y, _ = librosa.load(audio_path, sr=self.sr, mono=True) | 88 | y, _ = librosa.load(audio_path, sr=self.sr, mono=True) |
| 72 | 89 | ||
| 73 | chroma_candidates: List[Candidate] = [] | 90 | chroma_matches = self.chroma.match(y, top_k=max(50, top_n * 5)) if self.chroma is not None else [] |
| 74 | if self.chroma is not None: | 91 | chroma_norm = self._normalize_scores(chroma_matches) |
| 75 | chroma_matches = self.chroma.match(y, top_k=50) | 92 | |
| 76 | seen = set() | 93 | ecapa_matches = [] |
| 77 | for song_id, score in chroma_matches: | 94 | if self.ecapa is not None and self.ref_embs is not None and self.ref_ids is not None: |
| 78 | if song_id not in seen: | ||
| 79 | seen.add(song_id) | ||
| 80 | c = Candidate(song_id, chroma_score=score) | ||
| 81 | chroma_candidates.append(c) | ||
| 82 | |||
| 83 | ecapa_candidates: List[Candidate] = [] | ||
| 84 | if self.ecapa is not None and self.ref_embs is not None: | ||
| 85 | query_emb = self.ecapa.extract_embedding_from_wave(y) | 95 | query_emb = self.ecapa.extract_embedding_from_wave(y) |
| 86 | ref_norm = self.ref_embs / ( | 96 | ref_norm = self.ref_embs / (np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12) |
| 87 | np.linalg.norm(self.ref_embs, axis=1, keepdims=True) + 1e-12 | ||
| 88 | ) | ||
| 89 | query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) | 97 | query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12) |
| 90 | scores = query_norm @ ref_norm.T | 98 | scores = query_norm @ ref_norm.T |
| 91 | top_indices = np.argsort(-scores)[:top_n] | 99 | top_indices = np.argsort(-scores)[: max(top_n * 5, 20)] |
| 92 | for idx in top_indices: | 100 | ecapa_matches = [(self.ref_ids[idx], float(scores[idx])) for idx in top_indices] |
| 93 | c = Candidate(self.ref_ids[idx], ecapa_score=float(scores[idx])) | 101 | ecapa_norm = self._normalize_scores(ecapa_matches) |
| 94 | ecapa_candidates.append(c) | 102 | |
| 95 | 103 | all_song_ids = set(chroma_norm) | set(ecapa_norm) | |
| 96 | combined: Dict[str, Candidate] = {} | 104 | combined: List[Candidate] = [] |
| 97 | for c in chroma_candidates: | 105 | for song_id in all_song_ids: |
| 98 | combined[c.song_id] = c | 106 | candidate = Candidate( |
| 99 | for c in ecapa_candidates: | 107 | song_id=song_id, |
| 100 | if c.song_id in combined: | 108 | chroma_score=chroma_norm.get(song_id, 0.0), |
| 101 | combined[c.song_id].ecapa_score = c.ecapa_score | 109 | ecapa_score=ecapa_norm.get(song_id, 0.0), |
| 102 | else: | 110 | ) |
| 103 | combined[c.song_id] = c | 111 | candidate.metadata = self.song_metadata.get(song_id, {}) |
| 104 | 112 | combined.append(candidate) | |
| 105 | for sid in list(combined.keys()): | ||
| 106 | combined[sid].metadata = self.song_metadata.get(sid, {}) | ||
| 107 | |||
| 108 | results = sorted( | ||
| 109 | combined.values(), | ||
| 110 | key=lambda c: c.combined_score, | ||
| 111 | reverse=True, | ||
| 112 | )[:top_n] | ||
| 113 | 113 | ||
| 114 | combined.sort(key=lambda c: c.combined_score(self.chroma_weight, self.ecapa_weight), reverse=True) | ||
| 115 | results = combined[:top_n] | ||
| 114 | elapsed = (time.time() - start) * 1000 | 116 | elapsed = (time.time() - start) * 1000 |
| 115 | 117 | ||
| 116 | output = [] | 118 | output = [] |
| 117 | for c in results: | 119 | for c in results: |
| 118 | output.append({ | 120 | fused = c.combined_score(self.chroma_weight, self.ecapa_weight) |
| 121 | output.append( | ||
| 122 | { | ||
| 119 | "song_id": c.song_id, | 123 | "song_id": c.song_id, |
| 120 | "confidence": round(c.combined_score, 4), | 124 | "confidence": round(fused, 4), |
| 121 | "chromaprint_score": round(c.chroma_score, 4), | 125 | "chromaprint_score": round(c.chroma_score, 4), |
| 122 | "ecapa_score": round(c.ecapa_score, 4), | 126 | "ecapa_score": round(c.ecapa_score, 4), |
| 127 | "accepted": fused >= self.reject_threshold, | ||
| 123 | "metadata": c.metadata, | 128 | "metadata": c.metadata, |
| 124 | }) | 129 | } |
| 130 | ) | ||
| 125 | 131 | ||
| 126 | return { | 132 | return { |
| 127 | "candidates": output, | 133 | "candidates": output, | ... | ... |
This diff is collapsed.
Click to expand it.
docs/open-dataset-plan.md
0 → 100644
| 1 | # Open Dataset Integration Plan | ||
| 2 | |||
| 3 | ## Recommended order | ||
| 4 | |||
| 5 | 1. **FMA small** | ||
| 6 | - URL: https://github.com/mdeff/fma | ||
| 7 | - Why: easiest small realistic music subset for retrieval experiments | ||
| 8 | 2. **MTG-Jamendo** | ||
| 9 | - URL: https://github.com/MTG/mtg-jamendo-dataset | ||
| 10 | - Why: larger CC-licensed corpus with scriptable upstream tooling | ||
| 11 | 3. **QBSH / humming corpora** | ||
| 12 | - Why: add after retrieval baseline is stable | ||
| 13 | |||
| 14 | ## Repo strategy | ||
| 15 | |||
| 16 | - Keep external dataset ingestion optional | ||
| 17 | - Convert external tracks into: | ||
| 18 | - `catalog.json` for searchable references | ||
| 19 | - query segment manifests for evaluation | ||
| 20 | - Start with small local subsets before full-corpus scaling |
-
Please register or sign in to post a comment