Extend dataset bootstrap coverage and improve humming hard-case weighting
Broaden external dataset bootstrap support and replace naive hard-case oversampling with a more targeted weighting signal that measurably helps humming-like queries while preserving the release/eval workflow. Constraint: Hard-case optimization must be evidence-driven and preserve a record of mixed outcomes across iterations Rejected: Reuse naive oversampling after regression | it already showed worse overall behavior with no hard-case gain Confidence: medium Scope-risk: moderate Directive: Next iteration should target confused-case negatives explicitly; do not assume humming gains transfer to confusion robustness Tested: bootstrap generation for MTG-Jamendo and ModelScope placeholders; 2-epoch CPU training for models_v5; index_v5 build; fast eval JSON generation for smoke-v5 Not-tested: real audio ingestion for the new datasets; full melody-aware slow evaluation on models_v5
Showing
17 changed files
with
184 additions
and
2 deletions
acr-engine/__pycache__/train.cpython-310.pyc
0 → 100644
No preview for this file type
| 1 | [ | ||
| 2 | { | ||
| 3 | "song_id": "modelscope_music_track_0000", | ||
| 4 | "audio_path": "raw/modelscope_music_track_0000.wav", | ||
| 5 | "duration": 0.0, | ||
| 6 | "type": "reference", | ||
| 7 | "source_dataset": "modelscope_music", | ||
| 8 | "license_status": "deny_until_whitelisted" | ||
| 9 | }, | ||
| 10 | { | ||
| 11 | "song_id": "modelscope_music_track_0001", | ||
| 12 | "audio_path": "raw/modelscope_music_track_0001.wav", | ||
| 13 | "duration": 0.0, | ||
| 14 | "type": "reference", | ||
| 15 | "source_dataset": "modelscope_music", | ||
| 16 | "license_status": "deny_until_whitelisted" | ||
| 17 | }, | ||
| 18 | { | ||
| 19 | "song_id": "modelscope_music_track_0002", | ||
| 20 | "audio_path": "raw/modelscope_music_track_0002.wav", | ||
| 21 | "duration": 0.0, | ||
| 22 | "type": "reference", | ||
| 23 | "source_dataset": "modelscope_music", | ||
| 24 | "license_status": "deny_until_whitelisted" | ||
| 25 | } | ||
| 26 | ] | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | [ | ||
| 2 | { | ||
| 3 | "song_id": "mtg_jamendo_track_0000", | ||
| 4 | "audio_path": "raw/mtg_jamendo_track_0000.wav", | ||
| 5 | "duration": 0.0, | ||
| 6 | "type": "reference", | ||
| 7 | "source_dataset": "mtg_jamendo", | ||
| 8 | "license_status": "review_required" | ||
| 9 | }, | ||
| 10 | { | ||
| 11 | "song_id": "mtg_jamendo_track_0001", | ||
| 12 | "audio_path": "raw/mtg_jamendo_track_0001.wav", | ||
| 13 | "duration": 0.0, | ||
| 14 | "type": "reference", | ||
| 15 | "source_dataset": "mtg_jamendo", | ||
| 16 | "license_status": "review_required" | ||
| 17 | }, | ||
| 18 | { | ||
| 19 | "song_id": "mtg_jamendo_track_0002", | ||
| 20 | "audio_path": "raw/mtg_jamendo_track_0002.wav", | ||
| 21 | "duration": 0.0, | ||
| 22 | "type": "reference", | ||
| 23 | "source_dataset": "mtg_jamendo", | ||
| 24 | "license_status": "review_required" | ||
| 25 | } | ||
| 26 | ] | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
acr-engine/data/index_v5/chromaprint.pkl
0 → 100644
No preview for this file type
acr-engine/data/index_v5/reference_embs.npy
0 → 100644
No preview for this file type
acr-engine/data/index_v5/reference_ids.npy
0 → 100644
No preview for this file type
acr-engine/data/models_v5/best_model.pt
0 → 100644
This file is too large to display.
acr-engine/data/models_v5/song_to_idx.json
0 → 100644
| 1 | { | ||
| 2 | "song_0000": 0, | ||
| 3 | "song_0001": 1, | ||
| 4 | "song_0002": 2, | ||
| 5 | "song_0003": 3, | ||
| 6 | "song_0004": 4, | ||
| 7 | "song_0005": 5, | ||
| 8 | "song_0006": 6, | ||
| 9 | "song_0007": 7, | ||
| 10 | "song_0008": 8, | ||
| 11 | "song_0009": 9, | ||
| 12 | "song_0010": 10, | ||
| 13 | "song_0011": 11, | ||
| 14 | "song_0012": 12, | ||
| 15 | "song_0013": 13, | ||
| 16 | "song_0014": 14, | ||
| 17 | "song_0015": 15 | ||
| 18 | } | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | { | ||
| 2 | "split": "test", | ||
| 3 | "num_queries": 20, | ||
| 4 | "top1": 0.6, | ||
| 5 | "topk": 0.9, | ||
| 6 | "by_type": { | ||
| 7 | "clean": { | ||
| 8 | "n": 8, | ||
| 9 | "top1": 1.0, | ||
| 10 | "topk": 1.0 | ||
| 11 | }, | ||
| 12 | "augmented": { | ||
| 13 | "n": 4, | ||
| 14 | "top1": 0.5, | ||
| 15 | "topk": 1.0 | ||
| 16 | }, | ||
| 17 | "humming_like": { | ||
| 18 | "n": 4, | ||
| 19 | "top1": 0.5, | ||
| 20 | "topk": 0.75 | ||
| 21 | }, | ||
| 22 | "confused": { | ||
| 23 | "n": 4, | ||
| 24 | "top1": 0.0, | ||
| 25 | "topk": 0.75 | ||
| 26 | } | ||
| 27 | }, | ||
| 28 | "hard_case_summary": { | ||
| 29 | "humming_like": { | ||
| 30 | "n": 4, | ||
| 31 | "top1": 0.5, | ||
| 32 | "topk": 0.75 | ||
| 33 | }, | ||
| 34 | "confused": { | ||
| 35 | "n": 4, | ||
| 36 | "top1": 0.0, | ||
| 37 | "topk": 0.75 | ||
| 38 | } | ||
| 39 | }, | ||
| 40 | "sample_failures": [ | ||
| 41 | { | ||
| 42 | "truth": "song_0020", | ||
| 43 | "query": "segments/song_0020_seg_04_confused.wav", | ||
| 44 | "type": "confused", | ||
| 45 | "preds": [ | ||
| 46 | "song_0002", | ||
| 47 | "song_0022", | ||
| 48 | "song_0006", | ||
| 49 | "song_0023", | ||
| 50 | "song_0001" | ||
| 51 | ] | ||
| 52 | }, | ||
| 53 | { | ||
| 54 | "truth": "song_0022", | ||
| 55 | "query": "segments/song_0022_seg_03_humming_like.wav", | ||
| 56 | "type": "humming_like", | ||
| 57 | "preds": [ | ||
| 58 | "song_0021", | ||
| 59 | "song_0001", | ||
| 60 | "song_0000", | ||
| 61 | "song_0003", | ||
| 62 | "song_0023" | ||
| 63 | ] | ||
| 64 | } | ||
| 65 | ] | ||
| 66 | } | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
No preview for this file type
| ... | @@ -228,6 +228,9 @@ class SongPairDataset(Dataset): | ... | @@ -228,6 +228,9 @@ class SongPairDataset(Dataset): |
| 228 | else: | 228 | else: |
| 229 | a, b = random.sample(choices, 2) | 229 | a, b = random.sample(choices, 2) |
| 230 | 230 | ||
| 231 | pair_types = {a.get("type", "unknown"), b.get("type", "unknown")} | ||
| 232 | hard_weight = 2.5 if pair_types & {"confused", "humming_like"} else 1.0 | ||
| 233 | |||
| 231 | wavs = [] | 234 | wavs = [] |
| 232 | for sample in (a, b): | 235 | for sample in (a, b): |
| 233 | y = self._load_clip(sample) | 236 | y = self._load_clip(sample) |
| ... | @@ -244,4 +247,5 @@ class SongPairDataset(Dataset): | ... | @@ -244,4 +247,5 @@ class SongPairDataset(Dataset): |
| 244 | "mel": torch.stack(wavs, dim=0), | 247 | "mel": torch.stack(wavs, dim=0), |
| 245 | "song_id": torch.tensor([label, label], dtype=torch.long), | 248 | "song_id": torch.tensor([label, label], dtype=torch.long), |
| 246 | "song_name": song_id, | 249 | "song_name": song_id, |
| 250 | "hard_weight": torch.tensor(hard_weight, dtype=torch.float32), | ||
| 247 | } | 251 | } | ... | ... |
No preview for this file type
| ... | @@ -48,9 +48,14 @@ class CombinedLoss(nn.Module): | ... | @@ -48,9 +48,14 @@ class CombinedLoss(nn.Module): |
| 48 | logits: torch.Tensor, | 48 | logits: torch.Tensor, |
| 49 | labels: torch.Tensor, | 49 | labels: torch.Tensor, |
| 50 | supcon_labels: torch.Tensor, | 50 | supcon_labels: torch.Tensor, |
| 51 | hard_weight: torch.Tensor | None = None, | ||
| 51 | ) -> dict: | 52 | ) -> dict: |
| 52 | loss_supcon = self.supcon(embedding, supcon_labels) | 53 | loss_supcon = self.supcon(embedding, supcon_labels) |
| 53 | loss_ce = self.ce(logits, labels) | 54 | loss_ce = self.ce(logits, labels) |
| 55 | if hard_weight is not None: | ||
| 56 | weight = hard_weight.float().mean() | ||
| 57 | loss_supcon = loss_supcon * weight | ||
| 58 | loss_ce = loss_ce * weight | ||
| 54 | 59 | ||
| 55 | total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce | 60 | total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce |
| 56 | return { | 61 | return { | ... | ... |
| ... | @@ -23,17 +23,21 @@ def collate_fn(batch): | ... | @@ -23,17 +23,21 @@ def collate_fn(batch): |
| 23 | mels = [] | 23 | mels = [] |
| 24 | song_ids = [] | 24 | song_ids = [] |
| 25 | song_names = [] | 25 | song_names = [] |
| 26 | hard_weights = [] | ||
| 26 | for b in batch: | 27 | for b in batch: |
| 27 | mel = b["mel"] | 28 | mel = b["mel"] |
| 29 | hw = b.get("hard_weight", torch.tensor(1.0)) | ||
| 28 | if mel.dim() == 3: | 30 | if mel.dim() == 3: |
| 29 | for i in range(mel.shape[0]): | 31 | for i in range(mel.shape[0]): |
| 30 | mels.append(mel[i]) | 32 | mels.append(mel[i]) |
| 31 | song_ids.append(b["song_id"][i]) | 33 | song_ids.append(b["song_id"][i]) |
| 32 | song_names.append(b["song_name"]) | 34 | song_names.append(b["song_name"]) |
| 35 | hard_weights.append(hw) | ||
| 33 | else: | 36 | else: |
| 34 | mels.append(mel) | 37 | mels.append(mel) |
| 35 | song_ids.append(b["song_id"]) | 38 | song_ids.append(b["song_id"]) |
| 36 | song_names.append(b["song_name"]) | 39 | song_names.append(b["song_name"]) |
| 40 | hard_weights.append(hw) | ||
| 37 | 41 | ||
| 38 | max_t = max(m.shape[1] for m in mels) | 42 | max_t = max(m.shape[1] for m in mels) |
| 39 | mels_padded = [] | 43 | mels_padded = [] |
| ... | @@ -47,6 +51,7 @@ def collate_fn(batch): | ... | @@ -47,6 +51,7 @@ def collate_fn(batch): |
| 47 | "mel": torch.cat(mels_padded, dim=0), | 51 | "mel": torch.cat(mels_padded, dim=0), |
| 48 | "song_id": torch.stack(song_ids), | 52 | "song_id": torch.stack(song_ids), |
| 49 | "song_name": song_names, | 53 | "song_name": song_names, |
| 54 | "hard_weight": torch.stack(hard_weights), | ||
| 50 | } | 55 | } |
| 51 | 56 | ||
| 52 | 57 | ||
| ... | @@ -60,7 +65,7 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg) | ... | @@ -60,7 +65,7 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg) |
| 60 | 65 | ||
| 61 | with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"): | 66 | with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"): |
| 62 | embedding, logits = model(mel, labels) | 67 | embedding, logits = model(mel, labels) |
| 63 | loss_dict = criterion(embedding, logits, labels, labels) | 68 | loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None) |
| 64 | 69 | ||
| 65 | optimizer.zero_grad() | 70 | optimizer.zero_grad() |
| 66 | if scaler: | 71 | if scaler: |
| ... | @@ -205,7 +210,7 @@ def main(): | ... | @@ -205,7 +210,7 @@ def main(): |
| 205 | mel = batch["mel"].to(device) | 210 | mel = batch["mel"].to(device) |
| 206 | labels = batch["song_id"].to(device) | 211 | labels = batch["song_id"].to(device) |
| 207 | embedding, logits = model(mel, labels) | 212 | embedding, logits = model(mel, labels) |
| 208 | loss_dict = criterion(embedding, logits, labels, labels) | 213 | loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None) |
| 209 | loss_dict["loss"].backward() | 214 | loss_dict["loss"].backward() |
| 210 | print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}") | 215 | print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}") |
| 211 | print(f" Embedding shape: {embedding.shape}") | 216 | print(f" Embedding shape: {embedding.shape}") | ... | ... |
| ... | @@ -136,3 +136,25 @@ | ... | @@ -136,3 +136,25 @@ |
| 136 | 结论: | 136 | 结论: |
| 137 | - 该轮简单过采样策略无效,且整体精度下降 | 137 | - 该轮简单过采样策略无效,且整体精度下降 |
| 138 | - 下一轮应改用更细粒度 hard-negative / melody-aware 正则,而不是继续放大样本重复权重 | 138 | - 下一轮应改用更细粒度 hard-negative / melody-aware 正则,而不是继续放大样本重复权重 |
| 139 | |||
| 140 | ## 2026-06-02 | ||
| 141 | |||
| 142 | ### Stage: MTG-Jamendo / ModelScope bootstrap + type-aware hard-case weighting | ||
| 143 | |||
| 144 | 完成项: | ||
| 145 | - 补充 `mtg_jamendo` 与 `modelscope_music` 的 bootstrap manifest 生成 | ||
| 146 | - 在训练链路中加入 type-aware hard-case weighting(针对 `confused` / `humming_like`) | ||
| 147 | - 重训 `models_v5`、重建 `index_v5`、重跑 `smoke-v5` 评测 | ||
| 148 | |||
| 149 | 验证结果: | ||
| 150 | - `data/external_bootstrap/mtg_jamendo/manifests/catalog.bootstrap.json` 成功生成 | ||
| 151 | - `data/external_bootstrap/modelscope_music/manifests/catalog.bootstrap.json` 成功生成 | ||
| 152 | - `reports/smoke-v5/synthetic_v2/eval.json` 成功生成 | ||
| 153 | - 当前结果:top1=0.60, top5=0.90 | ||
| 154 | - hard-case 结果: | ||
| 155 | - humming_like top1=0.50(较 v4 有提升) | ||
| 156 | - confused top1=0.00(仍未解决) | ||
| 157 | |||
| 158 | 结论: | ||
| 159 | - type-aware weighting 比 naive oversampling 更有效 | ||
| 160 | - 下一轮应专门针对 confused 类设计更强的 negative mining / confusion-aware 信号 | ... | ... |
-
Please register or sign in to post a comment