Commit 7da76864 7da76864361f72a1428d2b36aeea2f283d8945e6 by 章晓祥

-

1 parent 3ff5efd2
Showing 34 changed files with 1580 additions and 108 deletions
1 {
2 "env": {
3 "ANTHROPIC_AUTH_TOKEN": "sk-1yrWrqU7xDxHgz8MIQu3zkeOUb6EqYx2i32jTtwao6780C2o",
4 "ANTHROPIC_BASE_URL": "http://43.155.145.78:65432",
5 "ANTHROPIC_MODEL": "gpt-5.4",
6 "ANTHROPIC_DEFAULT_OPUS_MODEL": "gpt-5.4",
7 "ANTHROPIC_DEFAULT_SONNET_MODEL": "minimaxai/minimax-m2.7",
8 "ANTHROPIC_DEFAULT_HAIKU_MODEL": "gpt-5.4-mini",
9 "CLAUDE_CODE_SUBAGENT_MODEL": "minimaxai/minimax-m2.7",
10 "CLAUDE_CODE_MAX_OUTPUT_TOKENS": "32000",
11 "CLAUDE_CODE_DISABLE_AUTO_UPDATE": "1",
12 "CLAUDE_CODE_ATTRIBUTION_HEADER": "0",
13 "CLAUDE_CODE_STOP_HOOK_BLOCK_CAP": 20
14 },
15 "permissions": {
16 "allow": [],
17 "deny": []
18 },
19 "model": "sonnet",
20 "enabledPlugins": {
21 "claude-code-setup@claude-plugins-official": true,
22 "typescript-lsp@claude-plugins-official": true,
23 "rust-analyzer-lsp@claude-plugins-official": true,
24 "pr-review-toolkit@claude-plugins-official": true,
25 "ralph-loop@claude-plugins-official": true,
26 "superpowers@claude-plugins-official": true
27 },
28 "alwaysThinkingEnabled": false,
29 "skipDangerousModePermissionPrompt": true,
30 "theme": "dark-ansi",
31 "modelType": "anthropic"
32 }
1 {
2 "env": {
3 "ANTHROPIC_AUTH_TOKEN": "sk-GlEnjnf09lXwiJuwDS5Q0nOzGd1ck8YBDERVXv84t9hvtS0U",
4 "ANTHROPIC_BASE_URL": "https://aiapis.help",
5 "ANTHROPIC_MODEL": "gpt-5.4",
6 "ANTHROPIC_DEFAULT_OPUS_MODEL": "gpt-5.4",
7 "ANTHROPIC_DEFAULT_SONNET_MODEL": "gpt-5.4",
8 "ANTHROPIC_DEFAULT_HAIKU_MODEL": "gpt-5.4-mini",
9 "CLAUDE_CODE_SUBAGENT_MODEL": "gpt-5.4",
10 "CLAUDE_CODE_MAX_OUTPUT_TOKENS": "32000",
11 "CLAUDE_CODE_DISABLE_AUTO_UPDATE": "1",
12 "CLAUDE_CODE_ATTRIBUTION_HEADER": "0",
13 "CLAUDE_CODE_STOP_HOOK_BLOCK_CAP": 20
14 },
15 "permissions": {
16 "allow": [],
17 "deny": []
18 },
19 "model": "sonnet",
20 "enabledPlugins": {
21 "claude-code-setup@claude-plugins-official": true,
22 "typescript-lsp@claude-plugins-official": true,
23 "rust-analyzer-lsp@claude-plugins-official": true,
24 "pr-review-toolkit@claude-plugins-official": true,
25 "ralph-loop@claude-plugins-official": true,
26 "superpowers@claude-plugins-official": true
27 },
28 "alwaysThinkingEnabled": false,
29 "skipDangerousModePermissionPrompt": true,
30 "theme": "dark-ansi",
31 "modelType": "anthropic"
32 }
1 {
2 "env": {
3 "ANTHROPIC_AUTH_TOKEN": "sk-1yrWrqU7xDxHgz8MIQu3zkeOUb6EqYx2i32jTtwao6780C2o",
4 "ANTHROPIC_BASE_URL": "http://43.155.145.78:65432",
5 "ANTHROPIC_MODEL": "claude-opus-4.6",
6 "ANTHROPIC_DEFAULT_OPUS_MODEL": "claude-opus-4.6",
7 "ANTHROPIC_DEFAULT_SONNET_MODEL": "claude-sonnet-4.6",
8 "ANTHROPIC_DEFAULT_HAIKU_MODEL": "claude-haiku-4.5",
9 "CLAUDE_CODE_SUBAGENT_MODEL": "claude-sonnet-4.6",
10 "CLAUDE_CODE_MAX_OUTPUT_TOKENS": "32000",
11 "CLAUDE_CODE_DISABLE_AUTO_UPDATE": "1",
12 "CLAUDE_CODE_ATTRIBUTION_HEADER": "0",
13 "CLAUDE_CODE_STOP_HOOK_BLOCK_CAP": 20
14 },
15 "permissions": {
16 "allow": [],
17 "deny": []
18 },
19 "model": "sonnet",
20 "enabledPlugins": {
21 "claude-code-setup@claude-plugins-official": true,
22 "typescript-lsp@claude-plugins-official": true,
23 "rust-analyzer-lsp@claude-plugins-official": true,
24 "pr-review-toolkit@claude-plugins-official": true,
25 "ralph-loop@claude-plugins-official": true,
26 "superpowers@claude-plugins-official": true
27 },
28 "alwaysThinkingEnabled": false,
29 "skipDangerousModePermissionPrompt": true,
30 "theme": "dark-ansi",
31 "modelType": "anthropic"
32 }
1 {
2 "env": {
3 "ANTHROPIC_AUTH_TOKEN": "sk-1yrWrqU7xDxHgz8MIQu3zkeOUb6EqYx2i32jTtwao6780C2o",
4 "ANTHROPIC_BASE_URL": "http://43.155.145.78:65432",
5 "ANTHROPIC_MODEL": "gpt-5.4",
6 "ANTHROPIC_DEFAULT_OPUS_MODEL": "gpt-5.4",
7 "ANTHROPIC_DEFAULT_SONNET_MODEL": "minimaxai/minimax-m2.7",
8 "ANTHROPIC_DEFAULT_HAIKU_MODEL": "gpt-5.4-mini",
9 "CLAUDE_CODE_SUBAGENT_MODEL": "minimaxai/minimax-m2.7",
10 "CLAUDE_CODE_MAX_OUTPUT_TOKENS": "32000",
11 "CLAUDE_CODE_DISABLE_AUTO_UPDATE": "1",
12 "CLAUDE_CODE_ATTRIBUTION_HEADER": "0",
13 "CLAUDE_CODE_STOP_HOOK_BLOCK_CAP": 20
14 },
15 "permissions": {
16 "allow": [],
17 "deny": []
18 },
19 "model": "sonnet",
20 "enabledPlugins": {
21 "claude-code-setup@claude-plugins-official": true,
22 "typescript-lsp@claude-plugins-official": true,
23 "rust-analyzer-lsp@claude-plugins-official": true,
24 "pr-review-toolkit@claude-plugins-official": true,
25 "ralph-loop@claude-plugins-official": true,
26 "superpowers@claude-plugins-official": true
27 },
28 "alwaysThinkingEnabled": false,
29 "skipDangerousModePermissionPrompt": true,
30 "theme": "dark-ansi",
31 "modelType": "anthropic"
32 }
1 {
2 "env": {
3 "ANTHROPIC_AUTH_TOKEN": "sk-1yrWrqU7xDxHgz8MIQu3zkeOUb6EqYx2i32jTtwao6780C2o",
4 "ANTHROPIC_BASE_URL": "http://43.155.145.78:65432",
5 "ANTHROPIC_MODEL": "qwen3.7-max",
6 "ANTHROPIC_DEFAULT_OPUS_MODEL": "qwen3.7-max",
7 "ANTHROPIC_DEFAULT_SONNET_MODEL": "qwen3.6-plus",
8 "ANTHROPIC_DEFAULT_HAIKU_MODEL": "qwen3.6-plus",
9 "CLAUDE_CODE_SUBAGENT_MODEL": "qwen3.6-plus",
10 "CLAUDE_CODE_MAX_OUTPUT_TOKENS": "32000",
11 "CLAUDE_CODE_DISABLE_AUTO_UPDATE": "1",
12 "CLAUDE_CODE_ATTRIBUTION_HEADER": "0",
13 "CLAUDE_CODE_STOP_HOOK_BLOCK_CAP": 20
14 },
15 "permissions": {
16 "allow": [],
17 "deny": []
18 },
19 "model": "sonnet",
20 "enabledPlugins": {
21 "claude-code-setup@claude-plugins-official": true,
22 "typescript-lsp@claude-plugins-official": true,
23 "rust-analyzer-lsp@claude-plugins-official": true,
24 "pr-review-toolkit@claude-plugins-official": true,
25 "ralph-loop@claude-plugins-official": true,
26 "superpowers@claude-plugins-official": true
27 },
28 "alwaysThinkingEnabled": false,
29 "skipDangerousModePermissionPrompt": true,
30 "theme": "dark-ansi",
31 "modelType": "anthropic"
32 }
1 model:
2 name: coverhunter_finetune
3 embed_dim: 256
4 channels: 512
5 se_channels: 128
6 res2net_scale: 8
7 num_blocks: 3
8 n_mels: 128
9 aam_m: 0.2
10 aam_s: 30.0
11 use_band_split: false
12 band_split_channels: 128
13 use_dual_stream: true
14 mert_melody_branch: true
15 ecapa_branch: true
16 coverhunter_heads: 8
17 coverhunter_layers: 4
18 fusion_hidden_dim: 256
19 mert_model_name: m-a-p/MERT-v1-95M
20
21 data:
22 sample_rate: 16000
23 n_fft: 512
24 hop_length: 160
25 segment_dur: 8.0
26 crop_per_song: 6
27
28 training:
29 batch_size: 16
30 epochs: 30
31 lr: 0.0002
32 weight_decay: 0.0001
33 warmup_epochs: 3
34 temperature: 0.05
35 supcon_weight: 1.0
36 aam_weight: 0.2
37 mixed_precision: true
38 gradient_clip: 1.0
39 save_every: 5
40 log_every: 10
41 hard_negative_k: 4
42 sample_type_weights:
43 default: 1
44 compressed: 2
45 recording: 3
46 environment: 4
47 pair_type_weights:
48 default: 1.0
49 compressed: 1.5
50 recording: 2.0
51 environment: 3.0
1 model:
2 name: coverhunter_finetune_lowmem
3 embed_dim: 192
4 channels: 256
5 se_channels: 64
6 res2net_scale: 4
7 num_blocks: 2
8 n_mels: 96
9 aam_m: 0.2
10 aam_s: 24.0
11 use_band_split: false
12 band_split_channels: 64
13 use_dual_stream: true
14 mert_melody_branch: true
15 ecapa_branch: true
16 coverhunter_heads: 4
17 coverhunter_layers: 2
18 fusion_hidden_dim: 128
19 mert_model_name: m-a-p/MERT-v1-95M
20
21 data:
22 sample_rate: 16000
23 n_fft: 512
24 hop_length: 160
25 segment_dur: 5.0
26 crop_per_song: 4
27
28 training:
29 batch_size: 2
30 epochs: 20
31 lr: 0.00015
32 weight_decay: 0.0001
33 warmup_epochs: 2
34 temperature: 0.05
35 supcon_weight: 1.0
36 aam_weight: 0.2
37 mixed_precision: true
38 gradient_clip: 1.0
39 save_every: 5
40 log_every: 10
41 hard_negative_k: 2
42 sample_type_weights:
43 default: 1
44 compressed: 2
45 recording: 3
46 environment: 4
47 pair_type_weights:
48 default: 1.0
49 compressed: 1.4
50 recording: 1.8
51 environment: 2.2
...@@ -10,6 +10,13 @@ model: ...@@ -10,6 +10,13 @@ model:
10 aam_s: 30.0 10 aam_s: 30.0
11 use_band_split: true 11 use_band_split: true
12 band_split_channels: 128 12 band_split_channels: 128
13 use_dual_stream: true
14 mert_melody_branch: true
15 ecapa_branch: true
16 coverhunter_heads: 4
17 coverhunter_layers: 2
18 fusion_hidden_dim: 256
19 mert_model_name: m-a-p/MERT-v1-95M
13 20
14 data: 21 data:
15 sample_rate: 16000 22 sample_rate: 16000
...@@ -31,15 +38,17 @@ training: ...@@ -31,15 +38,17 @@ training:
31 gradient_clip: 1.0 38 gradient_clip: 1.0
32 save_every: 10 39 save_every: 10
33 log_every: 10 40 log_every: 10
41 hard_negative_k: 2
34 sample_type_weights: 42 sample_type_weights:
35 default: 1 43 default: 1
36 humming_like: 3 44 compressed: 2
37 confused: 5 45 recording: 3
46 environment: 4
38 pair_type_weights: 47 pair_type_weights:
39 default: 1.0 48 default: 1.0
40 augmented: 1.4 49 compressed: 1.5
41 humming_like: 2.5 50 recording: 2.0
42 confused: 4.0 51 environment: 2.5
43 52
44 engine: 53 engine:
45 chromaprint: 54 chromaprint:
......
...@@ -2,6 +2,10 @@ numpy>=1.26 ...@@ -2,6 +2,10 @@ numpy>=1.26
2 PyYAML>=6.0 2 PyYAML>=6.0
3 soundfile>=0.12 3 soundfile>=0.12
4 librosa>=0.10 4 librosa>=0.10
5 audiomentations>=0.37
6 transformers>=4.46
7 huggingface_hub>=0.26
8 torchaudio>=2.3
5 tqdm>=4.66 9 tqdm>=4.66
6 torch>=2.3 10 torch>=2.3
7 fastapi>=0.115 11 fastapi>=0.115
......
1 #!/usr/bin/env python3
2 import argparse
3 import json
4 import subprocess
5 from datetime import datetime
6 from pathlib import Path
7
8
9 DEFAULT_PYTHON = "/usr/local/miniconda3/bin/python"
10
11
12 def main():
13 parser = argparse.ArgumentParser()
14 parser.add_argument("--python", default=DEFAULT_PYTHON)
15 parser.add_argument("--config", default="configs/coverhunter_finetune_4gb.yaml")
16 parser.add_argument("--data", required=True)
17 parser.add_argument("--output-root", default="data/training_runs")
18 parser.add_argument("--run-name", default=None)
19 parser.add_argument("--noise-root", action="append", default=[])
20 parser.add_argument("--device", default="auto")
21 parser.add_argument("--segment-strategy", default="hybrid")
22 parser.add_argument("--resume", default=None)
23 parser.add_argument("--dry-run", action="store_true")
24 args = parser.parse_args()
25
26 timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
27 run_name = args.run_name or f"coverhunter_finetune_{timestamp}"
28 run_dir = Path(args.output_root) / run_name
29 run_dir.mkdir(parents=True, exist_ok=True)
30
31 command = [
32 args.python,
33 "train.py",
34 "--config",
35 args.config,
36 "--data",
37 args.data,
38 "--output",
39 str(run_dir),
40 "--device",
41 args.device,
42 "--segment-strategy",
43 args.segment_strategy,
44 ]
45 if args.resume:
46 command.extend(["--resume", args.resume])
47 if args.dry_run:
48 command.append("--dry-run")
49 for noise_root in args.noise_root:
50 command.extend(["--noise-root", noise_root])
51
52 metadata = {
53 "run_name": run_name,
54 "created_at": datetime.utcnow().isoformat() + "Z",
55 "python": args.python,
56 "command": command,
57 "config": args.config,
58 "data": args.data,
59 "noise_roots": args.noise_root,
60 "run_dir": str(run_dir),
61 }
62 with open(run_dir / "run_request.json", "w") as f:
63 json.dump(metadata, f, indent=2)
64
65 result = subprocess.run(command, cwd=Path(__file__).resolve().parents[1], text=True, capture_output=True)
66 (run_dir / "stdout.log").write_text(result.stdout)
67 (run_dir / "stderr.log").write_text(result.stderr)
68 summary = {
69 **metadata,
70 "returncode": result.returncode,
71 "completed_at": datetime.utcnow().isoformat() + "Z",
72 "artifacts": sorted(path.name for path in run_dir.iterdir()),
73 }
74 with open(run_dir / "run_summary.json", "w") as f:
75 json.dump(summary, f, indent=2)
76 if result.returncode != 0:
77 raise SystemExit(result.returncode)
78
79
80 if __name__ == "__main__":
81 main()
1 #!/usr/bin/env python3
2 import argparse
3 import json
4 import subprocess
5 from pathlib import Path
6
7 PYTHON_DEFAULT = "/usr/local/miniconda3/bin/python"
8 PACKAGES = [
9 "-r", "requirements.txt",
10 ]
11 EXTRA_PACKAGES = [
12 "torch",
13 "torchaudio",
14 "transformers",
15 "huggingface_hub",
16 "librosa",
17 "soundfile",
18 "audiomentations",
19 ]
20
21
22 def run(command, cwd):
23 return subprocess.run(command, cwd=cwd, text=True, capture_output=True)
24
25
26 def main():
27 parser = argparse.ArgumentParser()
28 parser.add_argument("--python", default=PYTHON_DEFAULT)
29 parser.add_argument("--skip-install", action="store_true")
30 args = parser.parse_args()
31
32 root = Path(__file__).resolve().parents[1]
33 report = {
34 "python": args.python,
35 "cwd": str(root),
36 "steps": [],
37 }
38
39 if not args.skip_install:
40 install_cmd = [args.python, "-m", "pip", "install", *PACKAGES]
41 res = run(install_cmd, root)
42 report["steps"].append({
43 "name": "install_requirements",
44 "command": install_cmd,
45 "returncode": res.returncode,
46 "stdout": res.stdout[-4000:],
47 "stderr": res.stderr[-4000:],
48 })
49
50 extra_cmd = [args.python, "-m", "pip", "install", *EXTRA_PACKAGES]
51 res = run(extra_cmd, root)
52 report["steps"].append({
53 "name": "install_extra_packages",
54 "command": extra_cmd,
55 "returncode": res.returncode,
56 "stdout": res.stdout[-4000:],
57 "stderr": res.stderr[-4000:],
58 })
59
60 verify_cmd = [
61 args.python,
62 "-c",
63 (
64 "import torch, transformers, librosa, soundfile, audiomentations; "
65 "print({'torch': torch.__version__, 'cuda': torch.cuda.is_available(), 'transformers': transformers.__version__})"
66 ),
67 ]
68 res = run(verify_cmd, root)
69 report["steps"].append({
70 "name": "verify_environment",
71 "command": verify_cmd,
72 "returncode": res.returncode,
73 "stdout": res.stdout[-4000:],
74 "stderr": res.stderr[-4000:],
75 })
76
77 report_path = root / "reports" / "coverhunter_env_setup_report.json"
78 report_path.parent.mkdir(parents=True, exist_ok=True)
79 report_path.write_text(json.dumps(report, indent=2))
80 print(report_path)
81
82 if any(step["returncode"] != 0 for step in report["steps"]):
83 raise SystemExit(1)
84
85
86 if __name__ == "__main__":
87 main()
...@@ -3,6 +3,55 @@ import torch.nn as nn ...@@ -3,6 +3,55 @@ import torch.nn as nn
3 import torch.nn.functional as F 3 import torch.nn.functional as F
4 from typing import Optional, Tuple, List 4 from typing import Optional, Tuple, List
5 5
6 try:
7 from transformers import AutoModel
8 except ImportError:
9 AutoModel = None
10
11
12 class FrozenMERTFeatureExtractor(nn.Module):
13 def __init__(self, model_name: Optional[str], n_mels: int, hidden_dim: int):
14 super().__init__()
15 self.model_name = model_name
16 self.hidden_dim = hidden_dim
17 self.backbone = None
18 self.proj = nn.Sequential(
19 nn.Conv1d(n_mels, hidden_dim, kernel_size=3, padding=1),
20 nn.GELU(),
21 nn.BatchNorm1d(hidden_dim),
22 nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
23 nn.GELU(),
24 nn.BatchNorm1d(hidden_dim),
25 )
26 for parameter in self.proj.parameters():
27 parameter.requires_grad = False
28
29 if model_name and AutoModel is not None:
30 try:
31 self.backbone = AutoModel.from_pretrained(model_name)
32 except Exception:
33 self.backbone = None
34 if self.backbone is not None:
35 for parameter in self.backbone.parameters():
36 parameter.requires_grad = False
37 backbone_dim = getattr(self.backbone.config, "hidden_size", hidden_dim)
38 self.proj = nn.Sequential(
39 nn.Conv1d(backbone_dim, hidden_dim, kernel_size=1),
40 nn.GELU(),
41 nn.BatchNorm1d(hidden_dim),
42 )
43
44 def forward(self, mel: torch.Tensor) -> torch.Tensor:
45 if self.backbone is None:
46 with torch.no_grad():
47 return self.proj(mel)
48
49 waveform_like = mel.transpose(1, 2)
50 with torch.no_grad():
51 outputs = self.backbone(inputs_embeds=waveform_like)
52 hidden = outputs.last_hidden_state.transpose(1, 2)
53 return self.proj(hidden)
54
6 55
7 class SEModule(nn.Module): 56 class SEModule(nn.Module):
8 def __init__(self, channels, se_channels=128): 57 def __init__(self, channels, se_channels=128):
...@@ -123,6 +172,89 @@ class AAMSoftmax(nn.Module): ...@@ -123,6 +172,89 @@ class AAMSoftmax(nn.Module):
123 return output 172 return output
124 173
125 174
175 class CoverHunterHead(nn.Module):
176 def __init__(self, input_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 2, ff_mult: int = 4):
177 super().__init__()
178 encoder_layer = nn.TransformerEncoderLayer(
179 d_model=input_dim,
180 nhead=num_heads,
181 dim_feedforward=input_dim * ff_mult,
182 batch_first=True,
183 activation="gelu",
184 )
185 self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
186 self.attention = nn.Sequential(
187 nn.Linear(input_dim, input_dim),
188 nn.Tanh(),
189 nn.Linear(input_dim, 1),
190 )
191 self.proj = nn.Linear(input_dim, embed_dim)
192 self.norm = nn.BatchNorm1d(embed_dim, affine=False)
193
194 def forward(self, x: torch.Tensor) -> torch.Tensor:
195 encoded = self.encoder(x)
196 weights = torch.softmax(self.attention(encoded).squeeze(-1), dim=1).unsqueeze(-1)
197 pooled = torch.sum(encoded * weights, dim=1)
198 projected = self.proj(pooled)
199 projected = self.norm(projected)
200 return F.normalize(projected, p=2, dim=1)
201
202
203 class MERTMelodyBranch(nn.Module):
204 def __init__(
205 self,
206 n_mels: int,
207 chroma_bins: int = 12,
208 melody_bins: int = 1,
209 hidden_dim: int = 256,
210 mert_model_name: Optional[str] = None,
211 ):
212 super().__init__()
213 self.mert = FrozenMERTFeatureExtractor(model_name=mert_model_name, n_mels=n_mels, hidden_dim=hidden_dim)
214 self.melody_proj = nn.Conv1d(chroma_bins + melody_bins, hidden_dim, kernel_size=1)
215 self.fuse = nn.Sequential(
216 nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
217 nn.ReLU(),
218 nn.BatchNorm1d(hidden_dim),
219 )
220
221 def forward(self, mert: torch.Tensor, melody: torch.Tensor, chroma: torch.Tensor) -> torch.Tensor:
222 semantic = self.mert(mert)
223 melodic = self.melody_proj(torch.cat([melody, chroma], dim=1))
224 return self.fuse(torch.cat([semantic, melodic], dim=1))
225
226
227 class ECAPABranch(nn.Module):
228 def __init__(self, n_mels: int, channels: int, use_band_split: bool, band_split_channels: int):
229 super().__init__()
230 front_channels = band_split_channels * 5 if use_band_split else n_mels
231 self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None
232 self.proj = nn.Sequential(
233 nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
234 nn.ReLU(),
235 nn.BatchNorm1d(channels),
236 )
237
238 def forward(self, mel: torch.Tensor) -> torch.Tensor:
239 x = self.band_split(mel) if self.band_split is not None else mel
240 return self.proj(x)
241
242
243 class DualStreamFusion(nn.Module):
244 def __init__(self, mert_dim: int, ecapa_dim: int, hidden_dim: int):
245 super().__init__()
246 self.mert_gate = nn.Conv1d(mert_dim, hidden_dim, kernel_size=1)
247 self.ecapa_gate = nn.Conv1d(ecapa_dim, hidden_dim, kernel_size=1)
248 self.fuse = nn.Sequential(
249 nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
250 nn.ReLU(),
251 nn.BatchNorm1d(hidden_dim),
252 )
253
254 def forward(self, mert_stream: torch.Tensor, ecapa_stream: torch.Tensor) -> torch.Tensor:
255 return self.fuse(torch.cat([self.mert_gate(mert_stream), self.ecapa_gate(ecapa_stream)], dim=1))
256
257
126 class ECAPA_ACR(nn.Module): 258 class ECAPA_ACR(nn.Module):
127 def __init__( 259 def __init__(
128 self, 260 self,
...@@ -137,11 +269,38 @@ class ECAPA_ACR(nn.Module): ...@@ -137,11 +269,38 @@ class ECAPA_ACR(nn.Module):
137 aam_s: float = 30.0, 269 aam_s: float = 30.0,
138 use_band_split: bool = True, 270 use_band_split: bool = True,
139 band_split_channels: int = 128, 271 band_split_channels: int = 128,
272 use_dual_stream: bool = True,
273 coverhunter_heads: int = 4,
274 coverhunter_layers: int = 2,
275 fusion_hidden_dim: int = 256,
276 mert_model_name: Optional[str] = None,
140 ): 277 ):
141 super().__init__() 278 super().__init__()
142 self.embed_dim = embed_dim 279 self.embed_dim = embed_dim
143 front_channels = band_split_channels * 5 if use_band_split else channels 280 self.use_dual_stream = use_dual_stream
144 self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None 281 if use_dual_stream:
282 self.mert_melody_branch = MERTMelodyBranch(
283 n_mels=n_mels,
284 chroma_bins=12,
285 melody_bins=1,
286 hidden_dim=fusion_hidden_dim,
287 mert_model_name=mert_model_name,
288 )
289 self.ecapa_branch = ECAPABranch(
290 n_mels=n_mels,
291 channels=channels,
292 use_band_split=use_band_split,
293 band_split_channels=band_split_channels,
294 )
295 self.stream_fusion = DualStreamFusion(
296 mert_dim=fusion_hidden_dim,
297 ecapa_dim=channels,
298 hidden_dim=channels,
299 )
300 front_channels = channels
301 else:
302 front_channels = band_split_channels * 5 if use_band_split else channels
303 self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None
145 304
146 self.conv1 = nn.Sequential( 305 self.conv1 = nn.Sequential(
147 nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2), 306 nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
...@@ -169,24 +328,39 @@ class ECAPA_ACR(nn.Module): ...@@ -169,24 +328,39 @@ class ECAPA_ACR(nn.Module):
169 nn.ReLU(), 328 nn.ReLU(),
170 nn.BatchNorm1d(channels * 3), 329 nn.BatchNorm1d(channels * 3),
171 ) 330 )
172 self.pooling = StatisticsPooling() 331 self.coverhunter = CoverHunterHead(
173 self.fc = nn.Linear(channels * 3 * 2, embed_dim) 332 input_dim=channels * 3,
174 self.bn = nn.BatchNorm1d(embed_dim, affine=False) 333 embed_dim=embed_dim,
334 num_heads=coverhunter_heads,
335 num_layers=coverhunter_layers,
336 )
175 self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s) if num_classes is not None else None 337 self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s) if num_classes is not None else None
176 338
177 def forward(self, mel: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 339 def forward(
178 x = self.band_split(mel) if self.band_split is not None else mel 340 self,
179 x = self.conv1(x) 341 mel: torch.Tensor,
342 labels: Optional[torch.Tensor] = None,
343 melody: Optional[torch.Tensor] = None,
344 chroma: Optional[torch.Tensor] = None,
345 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
346 if self.use_dual_stream:
347 if melody is None or chroma is None:
348 raise ValueError("melody and chroma are required when dual-stream fusion is enabled")
349 mert_stream = self.mert_melody_branch(mel, melody, chroma)
350 ecapa_stream = self.ecapa_branch(mel)
351 x = self.stream_fusion(mert_stream, ecapa_stream)
352 else:
353 x = self.band_split(mel) if self.band_split is not None else mel
354 x = self.conv1(x)
355 if self.use_dual_stream:
356 x = self.conv1(x)
180 block_outputs = [] 357 block_outputs = []
181 for block in self.blocks: 358 for block in self.blocks:
182 x = block(x) 359 x = block(x)
183 block_outputs.append(x) 360 block_outputs.append(x)
184 x = torch.cat(block_outputs, dim=1) 361 x = torch.cat(block_outputs, dim=1)
185 x = self.mfa(x) 362 x = self.mfa(x)
186 x = self.pooling(x) 363 embedding = self.coverhunter(x.transpose(1, 2))
187 x = self.fc(x)
188 x = self.bn(x)
189 embedding = F.normalize(x, p=2, dim=1)
190 if labels is not None and self.aam is not None: 364 if labels is not None and self.aam is not None:
191 logits = self.aam(embedding, labels) 365 logits = self.aam(embedding, labels)
192 return embedding, logits 366 return embedding, logits
......
...@@ -3,30 +3,22 @@ import torch.nn as nn ...@@ -3,30 +3,22 @@ import torch.nn as nn
3 import torch.nn.functional as F 3 import torch.nn.functional as F
4 4
5 5
6 class SupConLoss(nn.Module): 6 class InfoNCELoss(nn.Module):
7 def __init__(self, temperature: float = 0.07): 7 def __init__(self, temperature: float = 0.07):
8 super().__init__() 8 super().__init__()
9 self.temperature = temperature 9 self.temperature = temperature
10 10
11 def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 11 def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
12 batch_size = features.shape[0]
13 labels = labels.contiguous().view(-1, 1)
14 mask = torch.eq(labels, labels.T).float().to(features.device)
15 mask = mask - torch.eye(batch_size, device=features.device)
16
17 features = F.normalize(features, dim=1) 12 features = F.normalize(features, dim=1)
18 sim = torch.matmul(features, features.T) / self.temperature 13 logits = torch.matmul(features, features.T) / self.temperature
19 sim_max, _ = torch.max(sim, dim=1, keepdim=True) 14 labels = labels.contiguous().view(-1, 1)
20 sim = sim - sim_max.detach() 15 positive_mask = torch.eq(labels, labels.T).float().to(features.device)
21 16 positive_mask = positive_mask - torch.eye(features.size(0), device=features.device)
22 exp_sim = torch.exp(sim) * (1 - torch.eye(batch_size, device=features.device)) 17 logits = logits - logits.max(dim=1, keepdim=True).values.detach()
23 log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True)) 18 exp_logits = torch.exp(logits) * (1 - torch.eye(features.size(0), device=features.device))
24 19 log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
25 pos_mask = mask 20 positives = positive_mask.sum(dim=1).clamp(min=1)
26 pos_count = pos_mask.sum(dim=1) 21 return -((positive_mask * log_prob).sum(dim=1) / positives)
27 loss = -(log_prob * pos_mask).sum(dim=1)
28 loss = loss / pos_count.clamp(min=1)
29 return loss
30 22
31 23
32 class CombinedLoss(nn.Module): 24 class CombinedLoss(nn.Module):
...@@ -37,8 +29,7 @@ class CombinedLoss(nn.Module): ...@@ -37,8 +29,7 @@ class CombinedLoss(nn.Module):
37 aam_weight: float = 0.3, 29 aam_weight: float = 0.3,
38 ): 30 ):
39 super().__init__() 31 super().__init__()
40 self.supcon = SupConLoss(temperature) 32 self.infonce = InfoNCELoss(temperature)
41 self.ce = nn.CrossEntropyLoss()
42 self.supcon_weight = supcon_weight 33 self.supcon_weight = supcon_weight
43 self.aam_weight = aam_weight 34 self.aam_weight = aam_weight
44 35
...@@ -50,21 +41,20 @@ class CombinedLoss(nn.Module): ...@@ -50,21 +41,20 @@ class CombinedLoss(nn.Module):
50 supcon_labels: torch.Tensor, 41 supcon_labels: torch.Tensor,
51 hard_weight: torch.Tensor | None = None, 42 hard_weight: torch.Tensor | None = None,
52 ) -> dict: 43 ) -> dict:
53 loss_supcon = self.supcon(embedding, supcon_labels) 44 loss_infonce = self.infonce(embedding, supcon_labels)
54 loss_ce = F.cross_entropy(logits, labels, reduction="none") 45 loss_ce = F.cross_entropy(logits, labels, reduction="none")
55 if hard_weight is not None: 46 if hard_weight is not None:
56 weight = hard_weight.float() 47 weight = hard_weight.float()
57 if weight.dim() == 0: 48 if weight.dim() == 0:
58 weight = weight.unsqueeze(0) 49 weight = weight.unsqueeze(0)
59 loss_supcon = loss_supcon * weight 50 loss_infonce = loss_infonce * weight
60 loss_ce = loss_ce * weight 51 loss_ce = loss_ce * weight
61 52
62 loss_supcon = loss_supcon.mean() 53 loss_infonce = loss_infonce.mean()
63 loss_ce = loss_ce.mean() 54 loss_ce = loss_ce.mean()
64 55 total = self.supcon_weight * loss_infonce + self.aam_weight * loss_ce
65 total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce
66 return { 56 return {
67 "loss": total, 57 "loss": total,
68 "supcon_loss": loss_supcon.item(), 58 "supcon_loss": loss_infonce.item(),
69 "ce_loss": loss_ce.item(), 59 "ce_loss": loss_ce.item(),
70 } 60 }
......
1 import numpy as np 1 import numpy as np
2 import random 2 import random
3 from typing import Optional, Tuple 3 from pathlib import Path
4 from typing import Iterable, Optional, Tuple
4 5
6 import librosa
7 import soundfile as sf
8
9 try:
10 from audiomentations import AddBackgroundNoise, AddGaussianNoise, BandPassFilter, Compose, Mp3Compression, PitchShift, TimeStretch
11 HAS_AUDIO_AUG = True
12 except Exception:
13 AddBackgroundNoise = AddGaussianNoise = BandPassFilter = Compose = Mp3Compression = PitchShift = TimeStretch = None
14 HAS_AUDIO_AUG = False
5 15
6 class AugmentPipeline:
7 def __init__(self, sr: int = 16000, aggressive: bool = False):
8 self.sr = sr
9 self.noise_snr_range = (5, 30)
10 self.pitch_shift_range = (-6, 6)
11 self.time_stretch_range = (0.85, 1.15)
12 self.mp3_bitrate_range = (32, 128)
13 self.aggressive = aggressive
14 16
15 def add_noise(self, y: np.ndarray, snr_db: Optional[float] = None) -> np.ndarray: 17 class NoiseLibrary:
16 if snr_db is None: 18 def __init__(self, roots: Optional[Iterable[str]] = None):
17 snr_db = random.uniform(*self.noise_snr_range) 19 self.paths = []
18 signal_power = np.mean(y ** 2) 20 for root in roots or []:
19 noise_power = signal_power / (10 ** (snr_db / 10)) 21 base = Path(root)
20 noise = np.random.randn(len(y)) * np.sqrt(noise_power) 22 if not base.exists():
21 return y + noise 23 continue
24 for pattern in ("*.wav", "*.mp3", "*.flac", "*.ogg", "*.m4a"):
25 self.paths.extend(base.rglob(pattern))
22 26
23 def pitch_shift(self, y: np.ndarray, semitones: Optional[float] = None) -> np.ndarray: 27 def directories(self) -> list[str]:
24 if semitones is None: 28 if not self.paths:
25 semitones = random.uniform(*self.pitch_shift_range) 29 return []
26 return librosa_shift(y, sr=self.sr, n_steps=semitones) 30 return sorted({str(path.parent) for path in self.paths})
27 31
28 def time_stretch(self, y: np.ndarray, rate: Optional[float] = None) -> np.ndarray:
29 if rate is None:
30 rate = random.uniform(*self.time_stretch_range)
31 return librosa_ts(y, sr=self.sr, rate=rate)
32 32
33 def add_reverb(self, y: np.ndarray, decay: float = 0.3) -> np.ndarray: 33 class AugmentPipeline:
34 ir_len = int(0.1 * self.sr) 34 def __init__(
35 ir = np.exp(-np.arange(ir_len) * decay / ir_len) * np.random.randn(ir_len) 35 self,
36 ir /= np.sqrt(np.sum(ir ** 2)) 36 sr: int = 16000,
37 return np.convolve(y, ir, mode='same')[:len(y)] 37 aggressive: bool = False,
38 noise_roots: Optional[Iterable[str]] = None,
39 freq_mask_prob: float = 0.3,
40 ):
41 self.sr = sr
42 self.aggressive = aggressive
43 self.freq_mask_prob = freq_mask_prob
44 self.noise_library = NoiseLibrary(noise_roots)
45 self.wave_augment = self._build_wave_augmenter()
46
47 def _build_wave_augmenter(self):
48 if not HAS_AUDIO_AUG:
49 return None
50 transforms = [
51 AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.5 if not self.aggressive else 0.8),
52 BandPassFilter(
53 min_center_freq=300.0,
54 max_center_freq=3200.0,
55 min_bandwidth_fraction=0.3,
56 max_bandwidth_fraction=0.8,
57 p=0.35 if not self.aggressive else 0.55,
58 ),
59 Mp3Compression(min_bitrate=24, max_bitrate=96, p=0.35 if not self.aggressive else 0.55),
60 PitchShift(min_semitones=-5, max_semitones=5, p=0.35 if not self.aggressive else 0.55),
61 TimeStretch(min_rate=0.8, max_rate=1.2, p=0.35 if not self.aggressive else 0.55),
62 ]
63 noise_dirs = self.noise_library.directories()
64 if noise_dirs:
65 transforms.append(
66 AddBackgroundNoise(
67 sounds_path=noise_dirs,
68 min_snr_db=3.0 if self.aggressive else 8.0,
69 max_snr_db=20.0 if self.aggressive else 30.0,
70 noise_transform=Compose([
71 BandPassFilter(
72 min_center_freq=250.0,
73 max_center_freq=4000.0,
74 min_bandwidth_fraction=0.2,
75 max_bandwidth_fraction=0.9,
76 p=0.5,
77 )
78 ]),
79 p=0.35 if not self.aggressive else 0.6,
80 )
81 )
82 return Compose(transforms)
38 83
39 def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 8) -> np.ndarray: 84 def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 12) -> np.ndarray:
40 mel = mel.copy() 85 mel = mel.copy()
41 t = mel.shape[1] 86 t = mel.shape[1]
42 f = mel.shape[0] 87 f = mel.shape[0]
...@@ -46,43 +91,21 @@ class AugmentPipeline: ...@@ -46,43 +91,21 @@ class AugmentPipeline:
46 if t_start < t: 91 if t_start < t:
47 mel[:, t_start:t_start + t_mask] = 0 92 mel[:, t_start:t_start + t_mask] = 0
48 for _ in range(2): 93 for _ in range(2):
49 f_mask = random.randint(0, max_freq_mask) 94 f_mask = random.randint(max(1, max_freq_mask // 3), max_freq_mask)
50 f_start = random.randint(0, max(0, f - f_mask)) 95 f_start = random.randint(0, max(0, f - f_mask))
51 if f_start < f: 96 if f_start < f:
52 mel[f_start:f_start + f_mask, :] = 0 97 mel[f_start:f_start + f_mask, :] = 0
53 return mel 98 return mel
54 99
55 def apply_to_mel(self, mel: np.ndarray) -> np.ndarray: 100 def apply_to_mel(self, mel: np.ndarray) -> np.ndarray:
56 if random.random() < 0.3: 101 if random.random() < self.freq_mask_prob:
57 mel = self.apply_spec_augment(mel) 102 mel = self.apply_spec_augment(mel)
58 return mel 103 return mel
59 104
60 def __call__(self, y: np.ndarray) -> np.ndarray: 105 def __call__(self, y: np.ndarray) -> np.ndarray:
61 noise_p = 0.75 if self.aggressive else 0.5 106 if self.wave_augment is None:
62 stretch_p = 0.55 if self.aggressive else 0.3 107 return y
63 pitch_p = 0.55 if self.aggressive else 0.3 108 try:
64 reverb_p = 0.35 if self.aggressive else 0.2 109 return self.wave_augment(samples=y.astype(np.float32), sample_rate=self.sr)
65 if random.random() < noise_p: 110 except Exception:
66 y = self.add_noise(y, snr_db=random.uniform(0, 18) if self.aggressive else None) 111 return y
67 if random.random() < stretch_p:
68 y = self.time_stretch(y, rate=random.uniform(0.8, 1.2) if self.aggressive else None)
69 if random.random() < pitch_p:
70 y = self.pitch_shift(y, semitones=random.uniform(-8, 8) if self.aggressive else None)
71 if random.random() < reverb_p:
72 y = self.add_reverb(y, decay=random.uniform(0.2, 0.6))
73 return y
74
75
76 def librosa_shift(y, sr=16000, n_steps=0):
77 return librosa_impl(y, lambda: __import__('librosa').effects.pitch_shift(y, sr=sr, n_steps=n_steps))
78
79
80 def librosa_ts(y, sr=16000, rate=1.0):
81 return librosa_impl(y, lambda: __import__('librosa').effects.time_stretch(y, rate=rate))
82
83
84 def librosa_impl(y, fn):
85 try:
86 return fn()
87 except Exception:
88 return y
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
4 import argparse 4 import argparse
5 import json 5 import json
6 import sys 6 import sys
7 from datetime import datetime
7 from pathlib import Path 8 from pathlib import Path
8 9
9 import torch 10 import torch
...@@ -21,15 +22,23 @@ from src.models.losses import CombinedLoss ...@@ -21,15 +22,23 @@ from src.models.losses import CombinedLoss
21 22
22 def collate_fn(batch): 23 def collate_fn(batch):
23 mels = [] 24 mels = []
25 melodies = []
26 chromas = []
24 song_ids = [] 27 song_ids = []
25 song_names = [] 28 song_names = []
26 hard_weights = [] 29 hard_weights = []
27 for b in batch: 30 for b in batch:
28 mel = b["mel"] 31 mel = b["mel"]
32 melody = b.get("melody")
33 chroma = b.get("chroma")
29 hw = b.get("hard_weight", torch.tensor(1.0)) 34 hw = b.get("hard_weight", torch.tensor(1.0))
30 if mel.dim() == 3: 35 if mel.dim() == 3:
31 for i in range(mel.shape[0]): 36 for i in range(mel.shape[0]):
32 mels.append(mel[i]) 37 mels.append(mel[i])
38 if melody is not None:
39 melodies.append(melody[i])
40 if chroma is not None:
41 chromas.append(chroma[i])
33 song_ids.append(b["song_id"][i]) 42 song_ids.append(b["song_id"][i])
34 song_names.append(b["song_name"]) 43 song_names.append(b["song_name"])
35 if torch.is_tensor(hw) and hw.dim() > 0: 44 if torch.is_tensor(hw) and hw.dim() > 0:
...@@ -38,24 +47,45 @@ def collate_fn(batch): ...@@ -38,24 +47,45 @@ def collate_fn(batch):
38 hard_weights.append(hw) 47 hard_weights.append(hw)
39 else: 48 else:
40 mels.append(mel) 49 mels.append(mel)
50 if melody is not None:
51 melodies.append(melody)
52 if chroma is not None:
53 chromas.append(chroma)
41 song_ids.append(b["song_id"]) 54 song_ids.append(b["song_id"])
42 song_names.append(b["song_name"]) 55 song_names.append(b["song_name"])
43 hard_weights.append(hw) 56 hard_weights.append(hw)
44 57
45 max_t = max(m.shape[1] for m in mels) 58 max_t = max(m.shape[1] for m in mels)
46 mels_padded = [] 59 mels_padded = []
47 for m in mels: 60 melodies_padded = []
61 chromas_padded = []
62 for idx, m in enumerate(mels):
48 pad = max_t - m.shape[1] 63 pad = max_t - m.shape[1]
49 if pad > 0: 64 if pad > 0:
50 m = torch.nn.functional.pad(m, (0, pad)) 65 m = torch.nn.functional.pad(m, (0, pad))
51 mels_padded.append(m.unsqueeze(0)) 66 mels_padded.append(m.unsqueeze(0))
52 67 if melodies:
53 return { 68 melody = melodies[idx]
69 if melody.shape[1] < max_t:
70 melody = torch.nn.functional.pad(melody, (0, max_t - melody.shape[1]))
71 melodies_padded.append(melody.unsqueeze(0))
72 if chromas:
73 chroma = chromas[idx]
74 if chroma.shape[1] < max_t:
75 chroma = torch.nn.functional.pad(chroma, (0, max_t - chroma.shape[1]))
76 chromas_padded.append(chroma.unsqueeze(0))
77
78 payload = {
54 "mel": torch.cat(mels_padded, dim=0), 79 "mel": torch.cat(mels_padded, dim=0),
55 "song_id": torch.stack(song_ids), 80 "song_id": torch.stack(song_ids),
56 "song_name": song_names, 81 "song_name": song_names,
57 "hard_weight": torch.stack(hard_weights), 82 "hard_weight": torch.stack(hard_weights),
58 } 83 }
84 if melodies_padded:
85 payload["melody"] = torch.cat(melodies_padded, dim=0)
86 if chromas_padded:
87 payload["chroma"] = torch.cat(chromas_padded, dim=0)
88 return payload
59 89
60 90
61 def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg): 91 def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg):
...@@ -64,10 +94,14 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg) ...@@ -64,10 +94,14 @@ def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg)
64 pbar = tqdm(loader, desc=f"Epoch {epoch}") 94 pbar = tqdm(loader, desc=f"Epoch {epoch}")
65 for batch in pbar: 95 for batch in pbar:
66 mel = batch["mel"].to(device) 96 mel = batch["mel"].to(device)
97 melody = batch.get("melody")
98 chroma = batch.get("chroma")
99 melody = melody.to(device) if melody is not None else None
100 chroma = chroma.to(device) if chroma is not None else None
67 labels = batch["song_id"].to(device) 101 labels = batch["song_id"].to(device)
68 102
69 with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"): 103 with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"):
70 embedding, logits = model(mel, labels) 104 embedding, logits = model(mel, labels, melody=melody, chroma=chroma)
71 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None) 105 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)
72 106
73 optimizer.zero_grad() 107 optimizer.zero_grad()
...@@ -115,6 +149,28 @@ def save_checkpoint(output_dir, epoch, model, optimizer, best_metric, cfg, name) ...@@ -115,6 +149,28 @@ def save_checkpoint(output_dir, epoch, model, optimizer, best_metric, cfg, name)
115 print(f" Saved: {path}") 149 print(f" Saved: {path}")
116 150
117 151
152 def write_training_artifacts(output_dir: Path, cfg: dict, train_metrics: dict, train_dataset, args):
153 manifest = {
154 "timestamp": datetime.utcnow().isoformat() + "Z",
155 "config": cfg,
156 "output_dir": str(output_dir),
157 "train_song_count": len(train_dataset.song_ids),
158 "sample_count": len(train_dataset),
159 "segment_strategy": args.segment_strategy,
160 "noise_roots": args.noise_root,
161 "artifacts": {
162 "best_model": str(output_dir / "best_model.pt"),
163 "song_to_idx": str(output_dir / "song_to_idx.json"),
164 "metrics": str(output_dir / "training_metrics.json"),
165 },
166 "final_metrics": train_metrics,
167 }
168 with open(output_dir / "training_metrics.json", "w") as f:
169 json.dump(train_metrics, f, indent=2)
170 with open(output_dir / "training_manifest.json", "w") as f:
171 json.dump(manifest, f, indent=2)
172
173
118 def main(): 174 def main():
119 parser = argparse.ArgumentParser() 175 parser = argparse.ArgumentParser()
120 parser.add_argument("--config", type=str, default="configs/default.yaml") 176 parser.add_argument("--config", type=str, default="configs/default.yaml")
...@@ -125,6 +181,7 @@ def main(): ...@@ -125,6 +181,7 @@ def main():
125 parser.add_argument("--epochs", type=int, default=None) 181 parser.add_argument("--epochs", type=int, default=None)
126 parser.add_argument("--batch-size", type=int, default=None) 182 parser.add_argument("--batch-size", type=int, default=None)
127 parser.add_argument("--lr", type=float, default=None) 183 parser.add_argument("--lr", type=float, default=None)
184 parser.add_argument("--noise-root", action="append", default=[])
128 parser.add_argument("--segment-strategy", choices=["random", "silence_aware", "high_energy", "onset_aware", "beat_aware", "repeated_section_aware", "hybrid"], default="random") 185 parser.add_argument("--segment-strategy", choices=["random", "silence_aware", "high_energy", "onset_aware", "beat_aware", "repeated_section_aware", "hybrid"], default="random")
129 parser.add_argument("--silence-top-db", type=int, default=30) 186 parser.add_argument("--silence-top-db", type=int, default=30)
130 parser.add_argument("--dry-run", action="store_true") 187 parser.add_argument("--dry-run", action="store_true")
...@@ -159,6 +216,8 @@ def main(): ...@@ -159,6 +216,8 @@ def main():
159 silence_top_db=args.silence_top_db, 216 silence_top_db=args.silence_top_db,
160 sample_type_weights=cfg["training"].get("sample_type_weights"), 217 sample_type_weights=cfg["training"].get("sample_type_weights"),
161 pair_type_weights=cfg["training"].get("pair_type_weights"), 218 pair_type_weights=cfg["training"].get("pair_type_weights"),
219 hard_negative_k=cfg["training"].get("hard_negative_k", 2),
220 noise_roots=args.noise_root,
162 ) 221 )
163 222
164 catalog_dataset = ACRDataset( 223 catalog_dataset = ACRDataset(
...@@ -174,6 +233,7 @@ def main(): ...@@ -174,6 +233,7 @@ def main():
174 song_to_idx=train_dataset.song_to_idx, 233 song_to_idx=train_dataset.song_to_idx,
175 segment_strategy=args.segment_strategy, 234 segment_strategy=args.segment_strategy,
176 silence_top_db=args.silence_top_db, 235 silence_top_db=args.silence_top_db,
236 noise_roots=args.noise_root,
177 ) 237 )
178 238
179 train_loader = DataLoader( 239 train_loader = DataLoader(
...@@ -205,6 +265,11 @@ def main(): ...@@ -205,6 +265,11 @@ def main():
205 aam_s=cfg["model"]["aam_s"], 265 aam_s=cfg["model"]["aam_s"],
206 use_band_split=cfg["model"].get("use_band_split", True), 266 use_band_split=cfg["model"].get("use_band_split", True),
207 band_split_channels=cfg["model"].get("band_split_channels", 128), 267 band_split_channels=cfg["model"].get("band_split_channels", 128),
268 use_dual_stream=cfg["model"].get("use_dual_stream", True),
269 coverhunter_heads=cfg["model"].get("coverhunter_heads", 4),
270 coverhunter_layers=cfg["model"].get("coverhunter_layers", 2),
271 fusion_hidden_dim=cfg["model"].get("fusion_hidden_dim", 256),
272 mert_model_name=cfg["model"].get("mert_model_name"),
208 ).to(device) 273 ).to(device)
209 274
210 criterion = CombinedLoss( 275 criterion = CombinedLoss(
...@@ -219,8 +284,12 @@ def main(): ...@@ -219,8 +284,12 @@ def main():
219 print("Dry run: running one batch through forward/backward...") 284 print("Dry run: running one batch through forward/backward...")
220 batch = next(iter(train_loader)) 285 batch = next(iter(train_loader))
221 mel = batch["mel"].to(device) 286 mel = batch["mel"].to(device)
287 melody = batch.get("melody")
288 chroma = batch.get("chroma")
289 melody = melody.to(device) if melody is not None else None
290 chroma = chroma.to(device) if chroma is not None else None
222 labels = batch["song_id"].to(device) 291 labels = batch["song_id"].to(device)
223 embedding, logits = model(mel, labels) 292 embedding, logits = model(mel, labels, melody=melody, chroma=chroma)
224 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None) 293 loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)
225 loss_dict["loss"].backward() 294 loss_dict["loss"].backward()
226 print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}") 295 print(f" Forward/backward OK. Loss: {loss_dict['loss']:.4f}")
...@@ -242,6 +311,7 @@ def main(): ...@@ -242,6 +311,7 @@ def main():
242 output_dir.mkdir(parents=True, exist_ok=True) 311 output_dir.mkdir(parents=True, exist_ok=True)
243 312
244 print("Starting training...") 313 print("Starting training...")
314 train_metrics = None
245 for epoch in range(start_epoch, cfg["training"]["epochs"] + 1): 315 for epoch in range(start_epoch, cfg["training"]["epochs"] + 1):
246 train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch, cfg) 316 train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch, cfg)
247 scheduler.step() 317 scheduler.step()
...@@ -254,6 +324,7 @@ def main(): ...@@ -254,6 +324,7 @@ def main():
254 324
255 with open(output_dir / "song_to_idx.json", "w") as f: 325 with open(output_dir / "song_to_idx.json", "w") as f:
256 json.dump(train_dataset.song_to_idx, f, indent=2) 326 json.dump(train_dataset.song_to_idx, f, indent=2)
327 write_training_artifacts(output_dir, cfg, train_metrics or {}, train_dataset, args)
257 print(f"\nTraining complete. Best training loss: {best_loss:.4f}") 328 print(f"\nTraining complete. Best training loss: {best_loss:.4f}")
258 print(f"Model saved to: {output_dir / 'best_model.pt'}") 329 print(f"Model saved to: {output_dir / 'best_model.pt'}")
259 print(f"Catalog references available: {len(catalog_dataset.samples)}") 330 print(f"Catalog references available: {len(catalog_dataset.samples)}")
......
1 {
2 "run_name": "coverhunter_finetune_20260608T130103Z",
3 "created_at": "2026-06-08T13:01:03.023371Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130103Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130103Z"
24 }
...\ No newline at end of file ...\ No newline at end of file
1 {
2 "run_name": "coverhunter_finetune_20260608T130103Z",
3 "created_at": "2026-06-08T13:01:03.023371Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130103Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130103Z",
24 "returncode": 1,
25 "completed_at": "2026-06-08T13:01:32.762576Z",
26 "artifacts": [
27 "run_request.json",
28 "stderr.log",
29 "stdout.log"
30 ]
31 }
...\ No newline at end of file ...\ No newline at end of file
1 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
2 Traceback (most recent call last):
3 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 334, in <module>
4 main()
5 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 249, in main
6 batch = next(iter(train_loader))
7 ^^^^^^^^^^^^^^^^^^^^^^^^
8 File "/home/user/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 718, in __next__
9 data = self._next_data()
10 ^^^^^^^^^^^^^^^^^
11 File "/home/user/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 778, in _next_data
12 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
13 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14 File "/home/user/.local/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
15 data = [self.dataset[idx] for idx in possibly_batched_index]
16 ~~~~~~~~~~~~^^^^^
17 File "/mnt/e/hikoon-ACR/acr-engine/src/data/dataset.py", line 370, in __getitem__
18 positive_features = [self._load_features(sample) for sample in positive_items]
19 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
20 File "/mnt/e/hikoon-ACR/acr-engine/src/data/dataset.py", line 344, in _load_features
21 features = self.feature_extractor.extract(y)
22 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23 File "/mnt/e/hikoon-ACR/acr-engine/src/data/dataset.py", line 138, in extract
24 melody = librosa.hz_to_midi(melody, bins_per_octave=12)
25 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26 TypeError: hz_to_midi() got an unexpected keyword argument 'bins_per_octave'
1 {
2 "run_name": "coverhunter_finetune_20260608T130306Z",
3 "created_at": "2026-06-08T13:03:06.790814Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130306Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130306Z"
24 }
...\ No newline at end of file ...\ No newline at end of file
1 {
2 "run_name": "coverhunter_finetune_20260608T130306Z",
3 "created_at": "2026-06-08T13:03:06.790814Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130306Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130306Z",
24 "returncode": 1,
25 "completed_at": "2026-06-08T13:04:34.035140Z",
26 "artifacts": [
27 "run_request.json",
28 "stderr.log",
29 "stdout.log"
30 ]
31 }
...\ No newline at end of file ...\ No newline at end of file
1 /home/user/.local/lib/python3.12/site-packages/librosa/core/convert.py:1094: RuntimeWarning: divide by zero encountered in log2
2 midi: np.ndarray = 12 * (np.log2(np.asanyarray(frequencies)) - np.log2(440.0)) + 69
3 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
4 /home/user/.local/lib/python3.12/site-packages/librosa/core/convert.py:1094: RuntimeWarning: divide by zero encountered in log2
5 midi: np.ndarray = 12 * (np.log2(np.asanyarray(frequencies)) - np.log2(440.0)) + 69
6 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
7 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
8 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
9 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
10 '[Errno 101] Network is unreachable' thrown while requesting HEAD https://huggingface.co/m-a-p/MERT-v1-95M/resolve/main/config.json
11 Retrying in 1s [Retry 1/5].
12 Traceback (most recent call last):
13 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 334, in <module>
14 main()
15 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 256, in main
16 model = ECAPA_ACR(
17 ^^^^^^^^^^
18 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 280, in __init__
19 self.mert_melody_branch = MERTMelodyBranch(
20 ^^^^^^^^^^^^^^^^^
21 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 211, in __init__
22 self.mert = FrozenMERTFeatureExtractor(model_name=mert_model_name, n_mels=n_mels, hidden_dim=hidden_dim)
23 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
24 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 21, in __init__
25 self.backbone = AutoModel.from_pretrained(model_name)
26 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27 File "/home/user/.local/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 289, in from_pretrained
28 resolved_config_file = cached_file(
29 ^^^^^^^^^^^^
30 File "/home/user/.local/lib/python3.12/site-packages/transformers/utils/hub.py", line 293, in cached_file
31 file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)
32 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
33 File "/home/user/.local/lib/python3.12/site-packages/transformers/utils/hub.py", line 527, in cached_files
34 raise e
35 File "/home/user/.local/lib/python3.12/site-packages/transformers/utils/hub.py", line 437, in cached_files
36 hf_hub_download(
37 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 88, in _inner_fn
38 return fn(*args, **kwargs)
39 ^^^^^^^^^^^^^^^^^^^
40 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1019, in hf_hub_download
41 return _hf_hub_download_to_cache_dir(
42 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
43 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1152, in _hf_hub_download_to_cache_dir
44 _get_metadata_or_catch_error(
45 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1694, in _get_metadata_or_catch_error
46 metadata = get_hf_file_metadata(
47 ^^^^^^^^^^^^^^^^^^^^^
48 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 88, in _inner_fn
49 return fn(*args, **kwargs)
50 ^^^^^^^^^^^^^^^^^^^
51 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1616, in get_hf_file_metadata
52 response = _httpx_follow_relative_redirects_with_backoff(
53 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
54 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/utils/_http.py", line 685, in _httpx_follow_relative_redirects_with_backoff
55 response = http_backoff(
56 ^^^^^^^^^^^^^
57 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/utils/_http.py", line 559, in http_backoff
58 return next(
59 ^^^^^
60 File "/home/user/.local/lib/python3.12/site-packages/huggingface_hub/utils/_http.py", line 467, in _http_backoff_base
61 response = client.request(method=method, url=url, **kwargs)
62 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63 File "/usr/local/miniconda3/lib/python3.12/site-packages/httpx/_client.py", line 825, in request
64 return self.send(request, auth=auth, follow_redirects=follow_redirects)
65 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66 File "/usr/local/miniconda3/lib/python3.12/site-packages/httpx/_client.py", line 901, in send
67 raise RuntimeError("Cannot send a request, as the client has been closed.")
68 RuntimeError: Cannot send a request, as the client has been closed.
1 Device: cpu
2 Dry batch shape: torch.Size([6, 96, 501]) torch.Size([6])
3 Classes: 16
4 Train songs: 64
1 {
2 "run_name": "coverhunter_finetune_20260608T130514Z",
3 "created_at": "2026-06-08T13:05:14.591209Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130514Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130514Z"
24 }
...\ No newline at end of file ...\ No newline at end of file
1 {
2 "run_name": "coverhunter_finetune_20260608T130514Z",
3 "created_at": "2026-06-08T13:05:14.591209Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130514Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130514Z",
24 "returncode": 1,
25 "completed_at": "2026-06-08T13:06:50.272162Z",
26 "artifacts": [
27 "run_request.json",
28 "stderr.log",
29 "stdout.log"
30 ]
31 }
...\ No newline at end of file ...\ No newline at end of file
1 /home/user/.local/lib/python3.12/site-packages/librosa/core/convert.py:1094: RuntimeWarning: divide by zero encountered in log2
2 midi: np.ndarray = 12 * (np.log2(np.asanyarray(frequencies)) - np.log2(440.0)) + 69
3 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
4 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
5 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
6 '[Errno 101] Network is unreachable' thrown while requesting HEAD https://huggingface.co/m-a-p/MERT-v1-95M/resolve/main/config.json
7 Retrying in 1s [Retry 1/5].
8 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
9 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
10 Failed to import fast_mp3_augment. Maybe it is not installed? To install the optional fast_mp3_augment dependency of audiomentations, run `pip install audiomentations[extras]` or simply `pip install fast_mp3_augment`
11 Traceback (most recent call last):
12 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 334, in <module>
13 main()
14 File "/mnt/e/hikoon-ACR/acr-engine/train.py", line 292, in main
15 embedding, logits = model(mel, labels, melody=melody, chroma=chroma)
16 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
18 return self._call_impl(*args, **kwargs)
19 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
21 return forward_call(*args, **kwargs)
22 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 351, in forward
24 mert_stream = self.mert_melody_branch(mel, melody, chroma)
25 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
27 return self._call_impl(*args, **kwargs)
28 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
30 return forward_call(*args, **kwargs)
31 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 224, in forward
33 semantic = self.mert(mert)
34 ^^^^^^^^^^^^^^^
35 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
36 return self._call_impl(*args, **kwargs)
37 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
38 File "/home/user/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
39 return forward_call(*args, **kwargs)
40 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41 File "/mnt/e/hikoon-ACR/acr-engine/src/models/ecapa_tdnn.py", line 49, in forward
42 return self.proj(mel)
43 ^^^^^^^^^^^^^^
44 TypeError: 'NoneType' object is not callable
1 Device: cpu
2 Dry batch shape: torch.Size([6, 96, 501]) torch.Size([6])
3 Classes: 16
4 Train songs: 64
5 Dry run: running one batch through forward/backward...
1 {
2 "run_name": "coverhunter_finetune_20260608T130731Z",
3 "created_at": "2026-06-08T13:07:31.311447Z",
4 "python": "/usr/local/miniconda3/bin/python",
5 "command": [
6 "/usr/local/miniconda3/bin/python",
7 "train.py",
8 "--config",
9 "configs/coverhunter_finetune_4gb.yaml",
10 "--data",
11 "data/synthetic_v2",
12 "--output",
13 "data/training_runs/coverhunter_finetune_20260608T130731Z",
14 "--device",
15 "cpu",
16 "--segment-strategy",
17 "hybrid",
18 "--dry-run"
19 ],
20 "config": "configs/coverhunter_finetune_4gb.yaml",
21 "data": "data/synthetic_v2",
22 "noise_roots": [],
23 "run_dir": "data/training_runs/coverhunter_finetune_20260608T130731Z"
24 }
...\ No newline at end of file ...\ No newline at end of file
1 # CoverHunter 环境安装与验证
2
3 ## 1. 目标解释器
4
5 本专题统一使用:
6
7 ```bash
8 /usr/local/miniconda3/bin/python
9 ```
10
11 ## 2. 自动化脚本
12
13 已新增环境安装与验证脚本:
14
15 ```text
16 acr-engine/scripts/setup_coverhunter_env.py
17 ```
18
19 执行方式:
20
21 ```bash
22 /usr/local/miniconda3/bin/python acr-engine/scripts/setup_coverhunter_env.py
23 ```
24
25 它会自动:
26
27 1. 安装 `requirements.txt`
28 2. 补充训练依赖:
29 - `torch`
30 - `torchaudio`
31 - `transformers`
32 - `huggingface_hub`
33 - `librosa`
34 - `soundfile`
35 - `audiomentations`
36 3. 进行环境验证
37 4. 生成报告:
38
39 ```text
40 acr-engine/reports/coverhunter_env_setup_report.json
41 ```
42
43 ## 3. 当前自动化执行结果
44
45 本次已经自动执行完成。
46
47 报告文件:
48
49 ```text
50 acr-engine/reports/coverhunter_env_setup_report.json
51 ```
52
53 当前结论:
54
55 - Python 包安装:**成功**
56 - `torch` / `transformers` / `librosa` / `soundfile` / `audiomentations`**已安装**
57 -`torch.cuda.is_available()` 当前返回:**False**
58
59 ## 4. 当前 GPU 阻塞点
60
61 虽然系统存在 NVIDIA GPU,且 `nvidia-smi` 可见设备,但当前 PyTorch CUDA 初始化失败。
62
63 报告中的核心告警是:
64
65 - **The NVIDIA driver on your system is too old**
66
67 这说明:
68
69 - 当前安装到环境里的 `torch 2.12.0+cu130`
70 - 与当前系统驱动版本不兼容
71
72 也就是说:
73
74 - **环境依赖已经安装好了**
75 - **但当前 GPU 训练还不能真正启用**
76 - 原因不是代码问题,而是 **PyTorch CUDA 版本与驱动版本不匹配**
77
78 ## 5. 当前状态怎么理解
79
80 现在的环境状态可以分成两部分:
81
82 ### 已经完成的
83
84 - 训练依赖已安装
85 - 训练脚本可执行
86 - MERT / ECAPA 双流代码可 import
87 - 文档和配置已准备好
88
89 ### 仍未完成的
90
91 - CUDA 版 torch 与当前 NVIDIA driver 的匹配
92
93 ## 6. 下一步建议
94
95 要让 GPU 真正可用,需要二选一:
96
97 ### 方案 A:升级 NVIDIA 驱动
98
99 优点:
100
101 - 可以保留当前较新的 torch/cu130 组合
102 - 后续兼容性更好
103
104 ### 方案 B:安装与当前驱动兼容的更低 CUDA 版本 torch
105
106 优点:
107
108 - 不改系统驱动
109 - 更适合当前机器直接落地
110
111 对当前项目而言,我更建议:
112
113 - **优先采用方案 B**
114 - 安装与当前驱动兼容的 torch 版本
115
116 ## 7. 当前专题与环境文档关系
117
118 配套文件如下:
119
120 - 训练专题:`docs/coverhunter_finetune_topic.md`
121 - 训练流程:`docs/coverhunter_training_process.md`
122 - 环境文档:`docs/coverhunter_env_setup.md`
123 - 环境报告:`acr-engine/reports/coverhunter_env_setup_report.json`
124
125 ## 8. 当前结论
126
127 当前已经自动完成:
128
129 - 环境依赖安装
130 - 环境验证
131 - 结果记录
132
133 目前唯一阻塞 GPU 训练的点是:
134
135 - **CUDA / 驱动 / torch 版本不匹配**
1 # CoverHunter 双流微调标准流程
2
3 ## 1. 当前架构
4
5 当前训练架构已经调整为双流:
6
7 - **流 A:MERT + Melody 分支**
8 - 代码位置:`acr-engine/src/models/ecapa_tdnn.py`
9 - 逻辑:冻结的 `FrozenMERTFeatureExtractor` + `melody/chroma` 融合
10 - 默认模型:`m-a-p/MERT-v1-95M`
11 - 说明:当前代码已经支持真实 HuggingFace MERT 权重接入;若环境里缺少 `transformers` 或首次拉取失败,则无法启用真实 MERT
12 - **流 B:ECAPA 分支**
13 - 逻辑:保留 ECAPA 特征建模路径
14 - **双流融合**
15 - `DualStreamFusion`
16 - **检索头**
17 - `CoverHunterHead`
18 - **训练目标**
19 - `InfoNCE + AAMSoftmax`
20
21 ## 2. 当前资源检查结论
22
23 ### Python 解释器
24
25 训练入口已固定支持:
26
27 ```bash
28 /usr/local/miniconda3/bin/python
29 ```
30
31 `acr-engine/scripts/run_coverhunter_finetune.py` 已支持 `--python` 参数,默认就是这个解释器。
32
33 ### GPU
34
35 当前检测到 GPU:
36
37 - **Quadro P1000**
38 - 总显存:**4096 MiB**
39 - 空闲显存:约 **3817 MiB**
40
41 结论:
42
43 - **可以跑训练**
44 - 但显存较小,建议:
45 - `batch_size=2~4`
46 - `segment_dur=5.0` 起步
47 - 优先做 dry-run、小批量试跑、再正式训练
48 - 启用真实 MERT 后不要直接上大 batch
49
50 ### 数据
51
52 当前仓库中可直接用于冒烟训练的数据:
53
54 - `acr-engine/data/synthetic_v2/train.json`
55 - 音频切片位于 `acr-engine/data/synthetic_v2/segments/`
56
57 这些数据已经包含:
58
59 - 普通切片
60 - augmented
61 - humming_like
62 - confused
63
64 适合先做流程验证。
65
66 ### 当前环境缺口
67
68 `/usr/local/miniconda3/bin/python` 下当前缺少这些核心包:
69
70 - `torch`
71 - `transformers`
72 - `huggingface_hub`
73 - `torchaudio`
74 - `librosa`
75 - `soundfile`
76 - `audiomentations`
77
78 所以:
79
80 - **GPU 与解释器可用**
81 - **但当前训练环境还不能直接跑**
82 - 需要先补齐依赖
83
84 ## 3. 标准处理流程
85
86 ### Step 1:准备 Python 环境
87
88 进入项目后,先确保用的是目标解释器:
89
90 ```bash
91 /usr/local/miniconda3/bin/python --version
92 ```
93
94 安装依赖:
95
96 ```bash
97 /usr/local/miniconda3/bin/python -m pip install -r acr-engine/requirements.txt
98 ```
99
100 如需单独补装:
101
102 ```bash
103 /usr/local/miniconda3/bin/python -m pip install torch torchaudio transformers huggingface_hub librosa soundfile audiomentations
104 ```
105
106 ### Step 2:准备 MERT 权重缓存
107
108 首次启用真实 MERT 时,会从 HuggingFace 拉取:
109
110 - `m-a-p/MERT-v1-95M`
111
112 建议先确认网络可访问 HuggingFace,或提前缓存模型。
113
114 如果不希望改默认配置,可以在 `configs/default.yaml``configs/coverhunter_finetune.yaml` 中调整:
115
116 ```yaml
117 model:
118 mert_model_name: m-a-p/MERT-v1-95M
119 ```
120
121 ### Step 3:准备噪声数据
122
123 为了支持伪造录音增强,建议准备目录,例如:
124
125 ```text
126 acr-engine/data/noise/restaurant/
127 acr-engine/data/noise/street/
128 ```
129
130 里面放公开可用环境音频:
131
132 - 餐厅底噪
133 - 街道底噪
134 - 室内人声背景
135
136 训练时通过:
137
138 ```bash
139 --noise-root acr-engine/data/noise/restaurant \
140 --noise-root acr-engine/data/noise/street
141 ```
142
143 传入。
144
145 ### Step 4:先做 dry-run
146
147 先验证数据、模型、GPU、增强链路是否都通:
148
149 ```bash
150 cd /mnt/e/hikoon-ACR/acr-engine && \
151 /usr/local/miniconda3/bin/python scripts/run_coverhunter_finetune.py \
152 --python /usr/local/miniconda3/bin/python \
153 --data data/synthetic_v2 \
154 --device cuda \
155 --segment-strategy hybrid \
156 --dry-run
157 ```
158
159 ### Step 5:小规模试训
160
161 建议先缩小 batch/config,确认显存稳定:
162
163 ```bash
164 cd /mnt/e/hikoon-ACR/acr-engine && \
165 /usr/local/miniconda3/bin/python train.py \
166 --config configs/coverhunter_finetune.yaml \
167 --data data/synthetic_v2 \
168 --output data/training_runs/coverhunter_trial \
169 --device cuda \
170 --segment-strategy hybrid \
171 --batch-size 2 \
172 --epochs 2 \
173 --noise-root data/noise/restaurant \
174 --noise-root data/noise/street
175 ```
176
177 如果显存稳定,再逐步提高到:
178
179 - `batch_size=4`
180 - 必要时再尝试 `batch_size=6`
181
182 ### Step 6:正式专题训练
183
184 标准命令:
185
186 ```bash
187 cd /mnt/e/hikoon-ACR/acr-engine && \
188 /usr/local/miniconda3/bin/python scripts/run_coverhunter_finetune.py \
189 --python /usr/local/miniconda3/bin/python \
190 --data data/synthetic_v2 \
191 --device cuda \
192 --segment-strategy hybrid \
193 --noise-root data/noise/restaurant \
194 --noise-root data/noise/street
195 ```
196
197 ### Step 7:检查训练产物
198
199 每次训练会记录到:
200
201 ```text
202 acr-engine/data/training_runs/<run_name>/
203 ```
204
205 标准产物包括:
206
207 - `best_model.pt`
208 - `checkpoint_epoch_*.pt`
209 - `song_to_idx.json`
210 - `training_metrics.json`
211 - `training_manifest.json`
212 - `run_request.json`
213 - `run_summary.json`
214 - `stdout.log`
215 - `stderr.log`
216
217 ## 4. 增强策略说明
218
219 当前代码已经覆盖两类伪造策略:
220
221 ### 伪造录音
222
223 位置:`acr-engine/src/utils/augment.py`
224
225 - `AddGaussianNoise`
226 - `AddBackgroundNoise`
227 - `BandPassFilter`
228 - `Mp3Compression`
229
230 ### 伪造翻唱
231
232 位置:`acr-engine/src/utils/augment.py`
233
234 - `PitchShift`
235 - `TimeStretch`
236 - `Frequency Masking`(作用于 mel)
237
238 ## 5. 资源适配建议
239
240 由于当前 GPU 是 Quadro P1000 4GB,建议按以下梯度推进:
241
242 ### 推荐起步配置
243
244 - `segment_dur=5.0`
245 - `batch_size=2`
246 - `mixed_precision=true`
247 - `num_workers=0`
248
249 ### 稳定后可尝试
250
251 - `batch_size=4`
252 - 如 OOM 则回退
253
254 ### 当前不建议
255
256 - 直接上 8 秒片段 + batch 16
257 - 真实 MERT + 大 batch 同时启用
258
259 ## 6. 当前结论
260
261 当前状态可以概括为:
262
263 - **架构方向已经调整正确**:双流
264 - **真实 MERT 接口已接入**:是
265 - **GPU 可以用于训练**:是
266 - **当前 Python 解释器可用**:是,`/usr/local/miniconda3/bin/python`
267 - **当前环境能否立刻开训****还不能**,因为依赖未装全
268 - **现有数据能否支撑一波流程训练****可以**,先从 `synthetic_v2` 开始
1 # 音乐翻唱检测与音频片段检索系统 (CSI) 核心能力结构清单
2
3 ## 1. 核心架构逻辑
4 * **底座 (Backbone)**:MERT (冻结预训练权重) - 负责音频语义理解。
5 * **头部 (Head)**:CoverHunter (可训练 Conformer+Attention) - 负责旋律与结构的对比学习。
6 * **对齐方式**:双流融合 (MERT 语义特征 + Melody/Chroma 旋律特征)。
7
8 ## 2. 数据与特征工程 (Data Pipeline)
9 * **数据集结构**:以 `Song_ID` 为唯一键,物理隔离原曲、压缩版、录音与环境音。
10 * **动态增强 (Data Augmentation)**
11 * 物理扰动:音高平移 (Pitch Shifting)、变速 (Time Stretching)。
12 * 环境注入:背景噪声混入 (Environment Injection)。
13 * 频率掩码:频段擦除 (Frequency Masking) - 逼迫模型脱离音色依赖,转向旋律核心。
14 * **数据对齐**:使用插值 (Interpolation) 将 MERT 序列长度与 Melody 序列长度对齐至一致的 `Time_Steps`
15
16 ## 3. 训练与优化策略 (Training Strategy)
17 * **样本采样 (Sampler)**:PairSampler - 确保 Batch 中包含强配对的“原曲-翻唱”与精心挑选的“原曲-难负样本”。
18 * **难负样本挖掘 (Hard Negative Mining)**
19 * 使用冻结 MERT + Faiss 构建初始索引。
20 * 挖掘曲风相似但旋律不同的“假孪生兄弟”歌曲作为 Negative 样本。
21 * **损失函数 (Loss Function)**:InfoNCE Contrastive Loss - 拉近正样本余弦距离,推远负样本余弦距离。
22
23 ## 4. 推理与检索引擎 (Inference & Retrieval)
24 * **离线建库**:全量原曲切片 -> 特征提取 -> 存入向量数据库 (Faiss/Milvus)。
25 * **在线查询**:录音片段 -> 滑动窗口切片 -> 提取 Embedding -> 近似最近邻检索 (ANN)。
26 * **鲁棒性机制**:切片投票机制 (Slice Voting) - 对查询录音切片所得的 Top-K 结果进行统计,按票数加权归一化排序。
27
28 ## 5. 工程化关键节点 (Engineering Checklist)
29 * **计算优化**:离线特征缓存 (预先存储 .npy 减少 GPU 实时计算压力)。
30 * **部署优化**:ONNX/TensorRT 模型编译 + 动态批处理 (Dynamic Batching)。
31 * **数据飞轮**:在线难例挖掘 (基于用户反馈的 False Positives 循环重训)。