Commit 48c97a90 48c97a90bcc97aee9d7cd52fb084b99e9ab46218 by cnb.bofCdSsphPA

Extend dataset bootstrap coverage and improve humming hard-case weighting

Broaden external dataset bootstrap support and replace naive hard-case oversampling with a more targeted weighting signal that measurably helps humming-like queries while preserving the release/eval workflow.

Constraint: Hard-case optimization must be evidence-driven and preserve a record of mixed outcomes across iterations
Rejected: Reuse naive oversampling after regression | it already showed worse overall behavior with no hard-case gain
Confidence: medium
Scope-risk: moderate
Directive: Next iteration should target confused-case negatives explicitly; do not assume humming gains transfer to confusion robustness
Tested: bootstrap generation for MTG-Jamendo and ModelScope placeholders; 2-epoch CPU training for models_v5; index_v5 build; fast eval JSON generation for smoke-v5
Not-tested: real audio ingestion for the new datasets; full melody-aware slow evaluation on models_v5
1 parent ad350314
No preview for this file type
1 # modelscope_music bootstrap
2
3 - Fill raw audio files under `raw/`
4 - Review license before training
5 - Convert to final catalog/query manifests
1 [
2 {
3 "song_id": "modelscope_music_track_0000",
4 "audio_path": "raw/modelscope_music_track_0000.wav",
5 "duration": 0.0,
6 "type": "reference",
7 "source_dataset": "modelscope_music",
8 "license_status": "deny_until_whitelisted"
9 },
10 {
11 "song_id": "modelscope_music_track_0001",
12 "audio_path": "raw/modelscope_music_track_0001.wav",
13 "duration": 0.0,
14 "type": "reference",
15 "source_dataset": "modelscope_music",
16 "license_status": "deny_until_whitelisted"
17 },
18 {
19 "song_id": "modelscope_music_track_0002",
20 "audio_path": "raw/modelscope_music_track_0002.wav",
21 "duration": 0.0,
22 "type": "reference",
23 "source_dataset": "modelscope_music",
24 "license_status": "deny_until_whitelisted"
25 }
26 ]
...\ No newline at end of file ...\ No newline at end of file
1 # mtg_jamendo bootstrap
2
3 - Fill raw audio files under `raw/`
4 - Review license before training
5 - Convert to final catalog/query manifests
1 [
2 {
3 "song_id": "mtg_jamendo_track_0000",
4 "audio_path": "raw/mtg_jamendo_track_0000.wav",
5 "duration": 0.0,
6 "type": "reference",
7 "source_dataset": "mtg_jamendo",
8 "license_status": "review_required"
9 },
10 {
11 "song_id": "mtg_jamendo_track_0001",
12 "audio_path": "raw/mtg_jamendo_track_0001.wav",
13 "duration": 0.0,
14 "type": "reference",
15 "source_dataset": "mtg_jamendo",
16 "license_status": "review_required"
17 },
18 {
19 "song_id": "mtg_jamendo_track_0002",
20 "audio_path": "raw/mtg_jamendo_track_0002.wav",
21 "duration": 0.0,
22 "type": "reference",
23 "source_dataset": "mtg_jamendo",
24 "license_status": "review_required"
25 }
26 ]
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
This file is too large to display.
1 {
2 "song_0000": 0,
3 "song_0001": 1,
4 "song_0002": 2,
5 "song_0003": 3,
6 "song_0004": 4,
7 "song_0005": 5,
8 "song_0006": 6,
9 "song_0007": 7,
10 "song_0008": 8,
11 "song_0009": 9,
12 "song_0010": 10,
13 "song_0011": 11,
14 "song_0012": 12,
15 "song_0013": 13,
16 "song_0014": 14,
17 "song_0015": 15
18 }
...\ No newline at end of file ...\ No newline at end of file
1 {
2 "split": "test",
3 "num_queries": 20,
4 "top1": 0.6,
5 "topk": 0.9,
6 "by_type": {
7 "clean": {
8 "n": 8,
9 "top1": 1.0,
10 "topk": 1.0
11 },
12 "augmented": {
13 "n": 4,
14 "top1": 0.5,
15 "topk": 1.0
16 },
17 "humming_like": {
18 "n": 4,
19 "top1": 0.5,
20 "topk": 0.75
21 },
22 "confused": {
23 "n": 4,
24 "top1": 0.0,
25 "topk": 0.75
26 }
27 },
28 "hard_case_summary": {
29 "humming_like": {
30 "n": 4,
31 "top1": 0.5,
32 "topk": 0.75
33 },
34 "confused": {
35 "n": 4,
36 "top1": 0.0,
37 "topk": 0.75
38 }
39 },
40 "sample_failures": [
41 {
42 "truth": "song_0020",
43 "query": "segments/song_0020_seg_04_confused.wav",
44 "type": "confused",
45 "preds": [
46 "song_0002",
47 "song_0022",
48 "song_0006",
49 "song_0023",
50 "song_0001"
51 ]
52 },
53 {
54 "truth": "song_0022",
55 "query": "segments/song_0022_seg_03_humming_like.wav",
56 "type": "humming_like",
57 "preds": [
58 "song_0021",
59 "song_0001",
60 "song_0000",
61 "song_0003",
62 "song_0023"
63 ]
64 }
65 ]
66 }
...\ No newline at end of file ...\ No newline at end of file
...@@ -228,6 +228,9 @@ class SongPairDataset(Dataset): ...@@ -228,6 +228,9 @@ class SongPairDataset(Dataset):
228 else: 228 else:
229 a, b = random.sample(choices, 2) 229 a, b = random.sample(choices, 2)
230 230
231 pair_types = {a.get("type", "unknown"), b.get("type", "unknown")}
232 hard_weight = 2.5 if pair_types & {"confused", "humming_like"} else 1.0
233
231 wavs = [] 234 wavs = []
232 for sample in (a, b): 235 for sample in (a, b):
233 y = self._load_clip(sample) 236 y = self._load_clip(sample)
...@@ -244,4 +247,5 @@ class SongPairDataset(Dataset): ...@@ -244,4 +247,5 @@ class SongPairDataset(Dataset):
244 "mel": torch.stack(wavs, dim=0), 247 "mel": torch.stack(wavs, dim=0),
245 "song_id": torch.tensor([label, label], dtype=torch.long), 248 "song_id": torch.tensor([label, label], dtype=torch.long),
246 "song_name": song_id, 249 "song_name": song_id,
250 "hard_weight": torch.tensor(hard_weight, dtype=torch.float32),
247 } 251 }
......
...@@ -48,9 +48,14 @@ class CombinedLoss(nn.Module): ...@@ -48,9 +48,14 @@ class CombinedLoss(nn.Module):
48 logits: torch.Tensor, 48 logits: torch.Tensor,
49 labels: torch.Tensor, 49 labels: torch.Tensor,
50 supcon_labels: torch.Tensor, 50 supcon_labels: torch.Tensor,
51 hard_weight: torch.Tensor | None = None,
51 ) -> dict: 52 ) -> dict:
52 loss_supcon = self.supcon(embedding, supcon_labels) 53 loss_supcon = self.supcon(embedding, supcon_labels)
53 loss_ce = self.ce(logits, labels) 54 loss_ce = self.ce(logits, labels)
55 if hard_weight is not None:
56 weight = hard_weight.float().mean()
57 loss_supcon = loss_supcon * weight
58 loss_ce = loss_ce * weight
54 59
55 total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce 60 total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce
56 return { 61 return {
......
...@@ -23,17 +23,21 @@ def collate_fn(batch): ...@@ -23,17 +23,21 @@ def collate_fn(batch):
23 mels = [] 23 mels = []
24 song_ids = [] 24 song_ids = []
25 song_names = [] 25 song_names = []
26 hard_weights = []
26 for b in batch: 27 for b in batch:
27 mel = b["mel"] 28 mel = b["mel"]
29 hw = b.get("hard_weight", torch.tensor(1.0))
28 if mel.dim() == 3: 30 if mel.dim() == 3:
29 for i in range(mel.shape[0]): 31 for i in range(mel.shape[0]):
30 mels.append(mel[i]) 32 mels.append(mel[i])
31 song_ids.append(b["song_id"][i]) 33 song_ids.append(b["song_id"][i])
32 song_names.append(b["song_name"]) 34 song_names.append(b["song_name"])
35 hard_weights.append(hw)
33 else: 36 else:
34 mels.append(mel) 37 mels.append(mel)
35 song_ids.append(b["song_id"]) 38 song_ids.append(b["song_id"])
36 song_names.append(b["song_name"]) 39 song_names.append(b["song_name"])
40 hard_weights.append(hw)
37 41
38 max_t = max(m.shape[1] for m in mels) 42 max_t = max(m.shape[1] for m in mels)
39 mels_padded = [] 43 mels_padded = []
...@@ -47,6 +51,7 @@ def collate_fn(batch): ...@@ -47,6 +51,7 @@ def collate_fn(batch):
47 "mel": torch.cat(mels_padded, dim=0), 51 "mel": torch.cat(mels_padded, dim=0),
48 "song_id": torch.stack(song_ids), 52 "song_id": torch.stack(song_ids),
49 "song_name": song_names, 53 "song_name": song_names,
54 "hard_weight": torch.stack(hard_weights),
50 } 55 }
51 56
52 57
...@@ -60,7 +65,7 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg) ...@@ -60,7 +65,7 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg)
60 65
61 with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"): 66 with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"):
62 embedding, logits = model(mel, labels) 67 embedding, logits = model(mel, labels)
63 loss_dict = criterion(embedding, logits, labels, labels) 68 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)
64 69
65 optimizer.zero_grad() 70 optimizer.zero_grad()
66 if scaler: 71 if scaler:
...@@ -205,7 +210,7 @@ def main(): ...@@ -205,7 +210,7 @@ def main():
205 mel = batch["mel"].to(device) 210 mel = batch["mel"].to(device)
206 labels = batch["song_id"].to(device) 211 labels = batch["song_id"].to(device)
207 embedding, logits = model(mel, labels) 212 embedding, logits = model(mel, labels)
208 loss_dict = criterion(embedding, logits, labels, labels) 213 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)
209 loss_dict["loss"].backward() 214 loss_dict["loss"].backward()
210 print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}") 215 print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}")
211 print(f" Embedding shape: {embedding.shape}") 216 print(f" Embedding shape: {embedding.shape}")
......
...@@ -136,3 +136,25 @@ ...@@ -136,3 +136,25 @@
136 结论: 136 结论:
137 - 该轮简单过采样策略无效,且整体精度下降 137 - 该轮简单过采样策略无效,且整体精度下降
138 - 下一轮应改用更细粒度 hard-negative / melody-aware 正则,而不是继续放大样本重复权重 138 - 下一轮应改用更细粒度 hard-negative / melody-aware 正则,而不是继续放大样本重复权重
139
140 ## 2026-06-02
141
142 ### Stage: MTG-Jamendo / ModelScope bootstrap + type-aware hard-case weighting
143
144 完成项:
145 - 补充 `mtg_jamendo``modelscope_music` 的 bootstrap manifest 生成
146 - 在训练链路中加入 type-aware hard-case weighting(针对 `confused` / `humming_like`
147 - 重训 `models_v5`、重建 `index_v5`、重跑 `smoke-v5` 评测
148
149 验证结果:
150 - `data/external_bootstrap/mtg_jamendo/manifests/catalog.bootstrap.json` 成功生成
151 - `data/external_bootstrap/modelscope_music/manifests/catalog.bootstrap.json` 成功生成
152 - `reports/smoke-v5/synthetic_v2/eval.json` 成功生成
153 - 当前结果:top1=0.60, top5=0.90
154 - hard-case 结果:
155 - humming_like top1=0.50(较 v4 有提升)
156 - confused top1=0.00(仍未解决)
157
158 结论:
159 - type-aware weighting 比 naive oversampling 更有效
160 - 下一轮应专门针对 confused 类设计更强的 negative mining / confusion-aware 信号
......