Commit 4b16286e 4b16286e20856132abbe8cfeafab1af1ee23c0ce by cnb.bofCdSsphPA

Raise ACR robustness with retrieval-first structure and music-aware inputs

Shift the prototype toward music-retrieval behavior by documenting dataset contracts, upgrading the frontend to 128-bin Mel plus band splitting, and adding retrieval evaluation plus harder confusion-oriented augmentation.

Constraint: The previous pipeline mixed train splits with the searchable catalog and hid real retrieval quality
Rejected: Keep classification-centric validation and whole-song averaged references | it masked structural accuracy failures
Confidence: medium
Scope-risk: moderate
Directive: Next iterations should target humming/confused top1 with specialized melody-aware retrieval and stronger real-data calibration
Tested: synthetic_v2 generation; 3-epoch CPU training; index build; evaluate.py top1=0.65 top5=0.95 on test split
Not-tested: external open-dataset ingestion; foundation-model baselines; production latency
1 parent 62688d3b
......@@ -5,9 +5,11 @@ model:
se_channels: 128
res2net_scale: 8
num_blocks: 3
n_mels: 80
n_mels: 128
aam_m: 0.3
aam_s: 30.0
use_band_split: true
band_split_channels: 128
data:
sample_rate: 16000
......@@ -39,3 +41,8 @@ engine:
chroma_weight: 0.3
ecapa_weight: 0.7
reject_threshold: 0.4
augmentation:
pro_wgan_balance: true
minority_noise_scale: 0.35
minority_pitch_shift: 8
......
......@@ -229,9 +229,12 @@ class SongPairDataset(Dataset):
y = self._load_clip(sample)
if self.augment:
from src.utils.augment import AugmentPipeline
y = AugmentPipeline(self.sr)(y)
y = AugmentPipeline(self.sr, aggressive=sample.get("type") in {"confused", "humming_like"})(y)
wavs.append(self._to_mel(y))
max_t = max(w.shape[1] for w in wavs)
wavs = [torch.nn.functional.pad(w, (0, max_t - w.shape[1])) if w.shape[1] < max_t else w for w in wavs]
label = self.song_to_idx[song_id]
return {
"mel": torch.stack(wavs, dim=0),
......
......@@ -28,14 +28,20 @@ class ECAPAEmbedder:
state = torch.load(model_path, map_location="cpu", weights_only=True)
cfg = state.get("config", {})
model_cfg = cfg.get("model", {})
data_cfg = cfg.get("data", {})
self.n_mels = model_cfg.get("n_mels", n_mels)
self.n_fft = data_cfg.get("n_fft", n_fft)
self.hop_length = data_cfg.get("hop_length", hop_length)
self.model = ECAPA_ACR(
n_mels=model_cfg.get("n_mels", n_mels),
n_mels=self.n_mels,
embed_dim=model_cfg.get("embed_dim", 192),
channels=model_cfg.get("channels", 512),
se_channels=model_cfg.get("se_channels", 128),
res2net_scale=model_cfg.get("res2net_scale", 8),
num_blocks=model_cfg.get("num_blocks", 3),
num_classes=None,
use_band_split=model_cfg.get("use_band_split", True),
band_split_channels=model_cfg.get("band_split_channels", 128),
)
missing = self.model.load_state_dict(state["model_state_dict"], strict=False)
if missing.unexpected_keys:
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from typing import Optional, Tuple, List
class SEModule(nn.Module):
......@@ -19,13 +19,43 @@ class SEModule(nn.Module):
return x * self.se(x)
class BandSplitBlock(nn.Module):
def __init__(self, n_mels: int, split_points: Optional[List[int]] = None, out_channels: int = 128):
super().__init__()
self.split_points = split_points or [16, 32, 64, 96, n_mels]
starts = [0] + self.split_points[:-1]
widths = [end - start for start, end in zip(starts, self.split_points)]
self.band_projs = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(width, out_channels, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(out_channels),
)
for width in widths
]
)
self.fuse = nn.Sequential(
nn.Conv1d(out_channels * len(widths), out_channels * len(widths), kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(out_channels * len(widths)),
)
def forward(self, x):
starts = [0] + self.split_points[:-1]
bands = []
for proj, start, end in zip(self.band_projs, starts, self.split_points):
bands.append(proj(x[:, start:end, :]))
return self.fuse(torch.cat(bands, dim=1))
class Res2Block(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=1, scale=8, se_channels=128):
super().__init__()
self.width = channels // scale
self.num_split = scale
self.convs = nn.ModuleList()
for i in range(self.num_split):
for _ in range(self.num_split):
self.convs.append(
nn.Sequential(
nn.Conv1d(
......@@ -54,7 +84,7 @@ class Res2Block(nn.Module):
if i == 0:
out.append(conv(part))
else:
out.append(conv(out[-1] if len(out) else part + part))
out.append(conv(part + out[-1]))
x = torch.cat(out, dim=1)
x = self.conv1x1(x)
x = self.se(x)
......@@ -96,7 +126,7 @@ class AAMSoftmax(nn.Module):
class ECAPA_ACR(nn.Module):
def __init__(
self,
n_mels: int = 80,
n_mels: int = 128,
embed_dim: int = 192,
channels: int = 512,
se_channels: int = 128,
......@@ -105,20 +135,23 @@ class ECAPA_ACR(nn.Module):
num_classes: Optional[int] = None,
aam_m: float = 0.3,
aam_s: float = 30.0,
use_band_split: bool = True,
band_split_channels: int = 128,
):
super().__init__()
self.embed_dim = embed_dim
front_channels = band_split_channels * 5 if use_band_split else channels
self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None
self.conv1 = nn.Sequential(
nn.Conv1d(n_mels, channels, kernel_size=5, stride=1, padding=2),
nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.BatchNorm1d(channels),
)
dilations = [1, 2, 3] if num_blocks == 3 else [d * 1 for d in range(1, num_blocks + 1)]
self.blocks = nn.ModuleList()
for d in dilations[:num_blocks]:
self.blocks.append(
dilations = [1, 2, 3] if num_blocks == 3 else [d for d in range(1, num_blocks + 1)]
self.blocks = nn.ModuleList(
[
Res2Block(
channels=channels,
kernel_size=3,
......@@ -126,6 +159,8 @@ class ECAPA_ACR(nn.Module):
scale=res2net_scale,
se_channels=se_channels,
)
for d in dilations[:num_blocks]
]
)
in_channels = channels * num_blocks
......@@ -134,34 +169,25 @@ class ECAPA_ACR(nn.Module):
nn.ReLU(),
nn.BatchNorm1d(channels * 3),
)
self.pooling = StatisticsPooling()
self.fc = nn.Linear(channels * 3 * 2, embed_dim)
self.bn = nn.BatchNorm1d(embed_dim, affine=False)
self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s) if num_classes is not None else None
if num_classes is not None:
self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s)
else:
self.aam = None
def forward(
self, mel: torch.Tensor, labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.conv1(mel)
def forward(self, mel: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.band_split(mel) if self.band_split is not None else mel
x = self.conv1(x)
block_outputs = []
for block in self.blocks:
x = block(x)
block_outputs.append(x)
x = torch.cat(block_outputs, dim=1)
x = self.mfa(x)
x = self.pooling(x)
x = self.fc(x)
x = self.bn(x)
embedding = F.normalize(x, p=2, dim=1)
if labels is not None and self.aam is not None:
logits = self.aam(embedding, labels)
return embedding, logits
return embedding, None
......
......@@ -4,12 +4,13 @@ from typing import Optional, Tuple
class AugmentPipeline:
def __init__(self, sr: int = 16000):
def __init__(self, sr: int = 16000, aggressive: bool = False):
self.sr = sr
self.noise_snr_range = (5, 30)
self.pitch_shift_range = (-6, 6)
self.time_stretch_range = (0.85, 1.15)
self.mp3_bitrate_range = (32, 128)
self.aggressive = aggressive
def add_noise(self, y: np.ndarray, snr_db: Optional[float] = None) -> np.ndarray:
if snr_db is None:
......@@ -57,14 +58,18 @@ class AugmentPipeline:
return mel
def __call__(self, y: np.ndarray) -> np.ndarray:
if random.random() < 0.5:
y = self.add_noise(y)
if random.random() < 0.3:
y = self.time_stretch(y)
if random.random() < 0.3:
y = self.pitch_shift(y)
if random.random() < 0.2:
y = self.add_reverb(y)
noise_p = 0.75 if self.aggressive else 0.5
stretch_p = 0.55 if self.aggressive else 0.3
pitch_p = 0.55 if self.aggressive else 0.3
reverb_p = 0.35 if self.aggressive else 0.2
if random.random() < noise_p:
y = self.add_noise(y, snr_db=random.uniform(0, 18) if self.aggressive else None)
if random.random() < stretch_p:
y = self.time_stretch(y, rate=random.uniform(0.8, 1.2) if self.aggressive else None)
if random.random() < pitch_p:
y = self.pitch_shift(y, semitones=random.uniform(-8, 8) if self.aggressive else None)
if random.random() < reverb_p:
y = self.add_reverb(y, decay=random.uniform(0.2, 0.6))
return y
......
......@@ -187,6 +187,8 @@ def main():
num_classes=num_classes,
aam_m=cfg["model"]["aam_m"],
aam_s=cfg["model"]["aam_s"],
use_band_split=cfg["model"].get("use_band_split", True),
band_split_channels=cfg["model"].get("band_split_channels", 128),
).to(device)
criterion = CombinedLoss(
......
......@@ -21,3 +21,35 @@
- 已完成 1 epoch CPU 训练并生成 `best_model.pt`
- 已完成指纹索引与 embedding 索引构建
- 已完成识别命令并输出 JSON 候选结果
## 2026-06-02
### Stage: 准确率优化 v2(128 Mel / band-split / retrieval 评测 / dataset 规范 / SOTA 调研)
完成项:
- 补充 dataset / 输入输出规范:`docs/dataset-spec.md`
- 补充开源数据集接入计划:`docs/open-dataset-plan.md`
- 补充 2026 SOTA 研究说明:`docs/sota-research-2026.md`
- 输入特征从低维说话人风格配置改为 `128 Mel`
- 新增频带分割模块 `BandSplitBlock`
- 引入 pro-WGAN 风格工程近似平衡策略(针对困难样本的更强增广)
- 合成数据新增 `confused` / `humming_like` 样本类型
- 引入 `catalog.json` 作为可搜索 reference 清单
- 索引从整曲单向量改为 window-level embedding index
- 新增 `evaluate.py` 做 retrieval 评测
- 训练逻辑改为更 retrieval-oriented 的 song-pair 训练输入
验证结果:
- synthetic_v2 端到端重新跑通
- build-index 成功
- evaluate 成功
- test split 指标:top1=0.65, top5=0.95
- 分类型指标:
- clean top1=1.00
- augmented top1=0.75
- humming_like top1=0.25
- confused top1=0.25
结论:
- 结构性错误(catalog/index/fusion/评测缺失)已明显改善
- 当前主要剩余短板是 humming_like / confused 的鲁棒识别
......
# ACR Dataset / 输入输出规范
> 更新:2026-06-02
## 1. 目标
定义本项目数据集规范、输入输出处理流程、catalog/query 划分方式,以及训练/评测所需的 manifest 结构。
## 2. 数据层对象
### 2.1 Reference / Catalog
可检索曲库中的标准参考音频。
字段:
```json
{
"song_id": "song_0001",
"audio_path": "songs/song_0001.wav",
"duration": 20.0,
"base_freq": 261.63,
"type": "reference"
}
```
用途:
- 建立 chromaprint 索引
- 建立 embedding window 索引
- 作为检索目标集合
### 2.2 Query Segment
待识别片段。
字段:
```json
{
"song_id": "song_0001",
"audio_path": "segments/song_0001_seg_02_confused.wav",
"duration": 5.0,
"type": "confused",
"offset": 8.3,
"segment_type": "mid"
}
```
用途:
- 训练片段对
- top-k 检索评测
- 鲁棒性测试
## 3. Manifest 文件
| 文件 | 用途 |
|---|---|
| `train.json` | 训练查询片段 + 训练 reference |
| `val.json` | 验证查询片段 + 验证 reference |
| `test.json` | 测试查询片段 + 测试 reference |
| `catalog.json` | 可搜索 reference 总表 |
注意:
- `catalog.json`**检索索引输入**
- `train/val/test.json`**实验 split**
- 不再把 “模型训练 split” 和 “可搜索曲库” 混为一谈
## 4. 输入特征规范
### 4.1 输入音频
- 默认采样率:`16 kHz`
- 通道:`mono`
- 训练/query 窗长:`5s`
- 滑窗步长:`2.5s`
### 4.2 声学特征
当前改为:
- **128维 Mel 频谱**
不再采用传统说话人任务常见的 40 维 MFCC 作为主输入,因为:
- 音乐任务更依赖频带结构与谐波信息
- Mel 频谱对音乐 timbre / harmony / texture 表达更自然
- 便于 band-split 模块对频带进行分块建模
## 5. 输出规范
### 5.1 训练输出
模型输出:
- `embedding: [B, D]`
- `logits: [B, num_classes]`(辅助分类头)
主要目标:
- retrieval embedding 学得稳定
- 同 song 片段彼此接近
- 不同 song 分离
### 5.2 推理输出
识别输出:
```json
{
"candidates": [
{
"song_id": "song_0001",
"confidence": 0.93,
"chromaprint_score": 0.88,
"ecapa_score": 0.96,
"accepted": true,
"metadata": {}
}
],
"processing_time_ms": 120.4,
"num_candidates": 5
}
```
## 6. Query 类型定义
| type | 含义 |
|---|---|
| `clean` | 原始干净片段 |
| `augmented` | 常规增强片段 |
| `confused` | 强混淆/干扰片段 |
| `humming_like` | 哼唱风格近似片段 |
| `reference` | 标准参考整曲 |
## 7. pro-WGAN 平衡策略(工程近似版)
当前仓库先实现的是**pro-WGAN 风格的数据平衡近似策略**,不是完整生成式 GAN 训练:
- 对难样本类型(`confused`, `humming_like`)增加更强增广概率
- 通过 harder augmentation 近似 minority/hard-case oversampling
- 保持 manifest 结构兼容,后续可替换成真正的生成式平衡器
后续若接入完整 GAN 平衡器,可把它作为:
- 离线样本扩增器
- 困难类别样本生成器
- catalog/query domain adaptation 工具
## 8. 频带分割模块
输入层新增 `BandSplitBlock`
- 将 128 Mel bins 分割为多个子频带
- 每个子带做独立投影
- 再拼接进入主干网络
目的:
- 强化低频节奏 / 中频和声 / 高频音色的分带建模
- 更符合音乐频谱结构
- 为后续更复杂 band-aware retrieval 打基础
# ACR / Music Retrieval SOTA Research (截至 2026-06-02)
## 结论摘要
到 2025-2026,这个方向相比传统“从零训练一个小型 ECAPA embedding”已经明显前进了。
当前更强的方向主要有三类:
1. **Neural Audio Fingerprinting 的鲁棒训练增强**
2. **Music Foundation Model 作为 backbone / teacher**
3. **Band-split / band-aware 结构用于音乐频谱建模**
## 1. Neural AFP 的更强实践
### Enhancing Neural Audio Fingerprint Robustness to Audio Degradation for Music Identification (2025)
- arXiv: https://arxiv.org/abs/2506.22661
关键信息:
- 指出很多 neural AFP 工作对真实退化模拟不够真实
- 系统比较 metric learning 方法
- 发现自监督 triplet loss 变体在该任务中更优
- 强调多个 positive samples 对不同 loss 的影响不同
对本项目的启发:
- 不应只依赖当前简单 SupCon + CE
- 应增加更真实的退化增强
- 应明确做 retrieval 指标选择,而非只看分类头
## 2. Music Foundation Model Backbones
### Robust Neural Audio Fingerprinting using Music Foundation Models (2025)
- arXiv: https://arxiv.org/abs/2511.05399
关键信息:
- 使用预训练 music foundation model(例如 MuQ、MERT)作为 neural fingerprinting backbone
- 在 distorted / compressed / manipulated 音频条件下优于从零训练模型
- 还能更好做 segment-level localization
### MERT (2023)
- arXiv: https://arxiv.org/abs/2306.00107
关键信息:
- 大规模自监督 music understanding 模型
- 在多个 music understanding 任务上达到强表现
### MuQ (2025)
- arXiv: https://arxiv.org/abs/2501.01108
关键信息:
- 面向音乐的自监督表征学习模型
- 使用 Mel-RVQ 目标
- 在多种下游任务上优于更早工作
对本项目的启发:
- 2026 继续只用小模型从零训,不太可能是最佳路线
- 更合理路线:
- 当前仓库保留轻量自训 baseline
- 下一阶段增加 MERT / MuQ frozen encoder 或 adapter fine-tune 版本
## 3. Band-split / band-aware 结构
### Music Source Separation with Band-split RNN (2022)
- arXiv: https://arxiv.org/abs/2209.15174
关键信息:
- 显式把频谱切成多个频带再建模
- 对音乐任务优于直接照搬通用音频结构
虽然该文主要做 source separation,不是 ACR,但它对“音乐频带先验”很有启发。
对本项目的启发:
- 输入层加入 band-split 是合理工程方向
- 未来可继续发展成:
- band-aware attention
- multi-band retrieval heads
- harmonic/rhythm 双塔结构
## 4. 数据平衡与生成增强
### BAGAN: Data Augmentation with Balancing GAN (2018)
- arXiv: https://arxiv.org/abs/1803.09655
严格说你提到的 `pro-WGAN` 我这次没有找到一个明确、权威、在该任务里广泛标准化的同名主文献;当前更接近、且有明确权威来源的是 **BAGAN / balancing GAN** 这一类面向不平衡数据增强的方法。
因此本次实现里我采用的是:
- **pro-WGAN 风格的工程近似平衡策略**
- 不是声称已经复现某篇明确的 `pro-WGAN` SOTA 论文
如果你之后指定了准确论文或仓库,我可以按那一版精确对齐实现。
## 5. 2026 年是否已经有更好的方案?
有,结论是:**有明显更好的路线**
最值得参考的是:
1.**music foundation model** 做 backbone
2.**更真实退化模拟 + retrieval-first metric learning**
3.**segment-level / window-level indexing**,而不是整曲平均 embedding
4. 对哼唱任务增加 **melody/pitch contour 专门支路**
## 6. 对本项目的建议排序
### 当前阶段(已开始)
- 128 Mel 替换低维说话人风格输入
- band-split 输入层
- 更强混淆增强
- retrieval-first 评测
### 下一阶段
- MERT / MuQ frozen feature baseline
- triplet / multi-positive metric learning 对比 SupCon
- window-level index aggregation
- FMA / Jamendo 小规模真实数据验证
### 更后阶段
- humming 专门 melody tower
- foundation model + lightweight fingerprint head
- ANN + reranker 两阶段工业化检索
## Sources
- Araz et al., 2025, Enhancing Neural Audio Fingerprint Robustness to Audio Degradation for Music Identification: https://arxiv.org/abs/2506.22661
- Singh et al., 2025, Robust Neural Audio Fingerprinting using Music Foundation Models: https://arxiv.org/abs/2511.05399
- Li et al., 2023, MERT: Acoustic Music Understanding Model with Large-Scale Self-supervised Training: https://arxiv.org/abs/2306.00107
- Zhu et al., 2025, MuQ: Self-Supervised Music Representation Learning with Mel Residual Vector Quantization: https://arxiv.org/abs/2501.01108
- Luo & Yu, 2022, Music Source Separation with Band-split RNN: https://arxiv.org/abs/2209.15174
- Mariani et al., 2018, BAGAN: Data Augmentation with Balancing GAN: https://arxiv.org/abs/1803.09655