ecapa_tdnn.py 13.3 KB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List

try:
    from transformers import AutoModel
except ImportError:
    AutoModel = None


class FrozenMERTFeatureExtractor(nn.Module):
    def __init__(self, model_name: Optional[str], n_mels: int, hidden_dim: int):
        super().__init__()
        self.model_name = model_name
        self.hidden_dim = hidden_dim
        self.backbone = None
        self.proj = nn.Sequential(
            nn.Conv1d(n_mels, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim),
        )
        for parameter in self.proj.parameters():
            parameter.requires_grad = False

        if model_name and AutoModel is not None:
            try:
                self.backbone = AutoModel.from_pretrained(model_name)
            except Exception:
                self.backbone = None
            if self.backbone is not None:
                for parameter in self.backbone.parameters():
                    parameter.requires_grad = False
                backbone_dim = getattr(self.backbone.config, "hidden_size", hidden_dim)
                self.proj = nn.Sequential(
                    nn.Conv1d(backbone_dim, hidden_dim, kernel_size=1),
                    nn.GELU(),
                    nn.BatchNorm1d(hidden_dim),
                )

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        if self.backbone is None:
            with torch.no_grad():
                return self.proj(mel)

        waveform_like = mel.transpose(1, 2)
        with torch.no_grad():
            outputs = self.backbone(inputs_embeds=waveform_like)
            hidden = outputs.last_hidden_state.transpose(1, 2)
        return self.proj(hidden)


class SEModule(nn.Module):
    def __init__(self, channels, se_channels=128):
        super().__init__()
        self.se = nn.Sequential(
            nn.Conv1d(channels, se_channels, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(se_channels),
            nn.Conv1d(se_channels, channels, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return x * self.se(x)


class BandSplitBlock(nn.Module):
    def __init__(self, n_mels: int, split_points: Optional[List[int]] = None, out_channels: int = 128):
        super().__init__()
        self.split_points = split_points or [16, 32, 64, 96, n_mels]
        starts = [0] + self.split_points[:-1]
        widths = [end - start for start, end in zip(starts, self.split_points)]
        self.band_projs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv1d(width, out_channels, kernel_size=1),
                    nn.ReLU(),
                    nn.BatchNorm1d(out_channels),
                )
                for width in widths
            ]
        )
        self.fuse = nn.Sequential(
            nn.Conv1d(out_channels * len(widths), out_channels * len(widths), kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels * len(widths)),
        )

    def forward(self, x):
        starts = [0] + self.split_points[:-1]
        bands = []
        for proj, start, end in zip(self.band_projs, starts, self.split_points):
            bands.append(proj(x[:, start:end, :]))
        return self.fuse(torch.cat(bands, dim=1))


class Res2Block(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, scale=8, se_channels=128):
        super().__init__()
        self.width = channels // scale
        self.num_split = scale
        self.convs = nn.ModuleList()
        for _ in range(self.num_split):
            self.convs.append(
                nn.Sequential(
                    nn.Conv1d(
                        self.width,
                        self.width,
                        kernel_size,
                        padding=dilation * (kernel_size - 1) // 2,
                        dilation=dilation,
                    ),
                    nn.ReLU(),
                    nn.BatchNorm1d(self.width),
                )
            )
        self.conv1x1 = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )
        self.se = SEModule(channels, se_channels)

    def forward(self, x):
        residual = x
        split_x = torch.split(x, self.width, dim=1)
        out = []
        for i, (part, conv) in enumerate(zip(split_x, self.convs)):
            if i == 0:
                out.append(conv(part))
            else:
                out.append(conv(part + out[-1]))
        x = torch.cat(out, dim=1)
        x = self.conv1x1(x)
        x = self.se(x)
        return x + residual


class StatisticsPooling(nn.Module):
    def forward(self, x):
        mean = torch.mean(x, dim=2)
        std = torch.sqrt(torch.var(x, dim=2, unbiased=False) + 1e-12)
        return torch.cat([mean, std], dim=1)


class AAMSoftmax(nn.Module):
    def __init__(self, in_features, out_features, m=0.3, s=30.0):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        self.m = m
        self.s = s
        self.cos_m = torch.cos(torch.tensor(m))
        self.sin_m = torch.sin(torch.tensor(m))
        self.th = torch.cos(torch.tensor(torch.pi - m))
        self.mm = torch.sin(torch.tensor(torch.pi - m)) * m

    def forward(self, x, labels):
        x = F.normalize(x, dim=1)
        w = F.normalize(self.weight, dim=1)
        cos_theta = F.linear(x, w)
        sin_theta = torch.sqrt(1.0 - torch.clamp(cos_theta ** 2, 0, 1))
        phi = cos_theta * self.cos_m - sin_theta * self.sin_m
        phi = torch.where(cos_theta > self.th, phi, cos_theta - self.mm)
        one_hot = F.one_hot(labels, num_classes=self.weight.size(0)).float()
        output = (one_hot * phi) + ((1.0 - one_hot) * cos_theta)
        output *= self.s
        return output


class CoverHunterHead(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 2, ff_mult: int = 4):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=num_heads,
            dim_feedforward=input_dim * ff_mult,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Tanh(),
            nn.Linear(input_dim, 1),
        )
        self.proj = nn.Linear(input_dim, embed_dim)
        self.norm = nn.BatchNorm1d(embed_dim, affine=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        encoded = self.encoder(x)
        weights = torch.softmax(self.attention(encoded).squeeze(-1), dim=1).unsqueeze(-1)
        pooled = torch.sum(encoded * weights, dim=1)
        projected = self.proj(pooled)
        projected = self.norm(projected)
        return F.normalize(projected, p=2, dim=1)


class MERTMelodyBranch(nn.Module):
    def __init__(
        self,
        n_mels: int,
        chroma_bins: int = 12,
        melody_bins: int = 1,
        hidden_dim: int = 256,
        mert_model_name: Optional[str] = None,
    ):
        super().__init__()
        self.mert = FrozenMERTFeatureExtractor(model_name=mert_model_name, n_mels=n_mels, hidden_dim=hidden_dim)
        self.melody_proj = nn.Conv1d(chroma_bins + melody_bins, hidden_dim, kernel_size=1)
        self.fuse = nn.Sequential(
            nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
        )

    def forward(self, mert: torch.Tensor, melody: torch.Tensor, chroma: torch.Tensor) -> torch.Tensor:
        semantic = self.mert(mert)
        melodic = self.melody_proj(torch.cat([melody, chroma], dim=1))
        return self.fuse(torch.cat([semantic, melodic], dim=1))


class ECAPABranch(nn.Module):
    def __init__(self, n_mels: int, channels: int, use_band_split: bool, band_split_channels: int):
        super().__init__()
        front_channels = band_split_channels * 5 if use_band_split else n_mels
        self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None
        self.proj = nn.Sequential(
            nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        x = self.band_split(mel) if self.band_split is not None else mel
        return self.proj(x)


class DualStreamFusion(nn.Module):
    def __init__(self, mert_dim: int, ecapa_dim: int, hidden_dim: int):
        super().__init__()
        self.mert_gate = nn.Conv1d(mert_dim, hidden_dim, kernel_size=1)
        self.ecapa_gate = nn.Conv1d(ecapa_dim, hidden_dim, kernel_size=1)
        self.fuse = nn.Sequential(
            nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
        )

    def forward(self, mert_stream: torch.Tensor, ecapa_stream: torch.Tensor) -> torch.Tensor:
        return self.fuse(torch.cat([self.mert_gate(mert_stream), self.ecapa_gate(ecapa_stream)], dim=1))


class ECAPA_ACR(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        embed_dim: int = 192,
        channels: int = 512,
        se_channels: int = 128,
        res2net_scale: int = 8,
        num_blocks: int = 3,
        num_classes: Optional[int] = None,
        aam_m: float = 0.3,
        aam_s: float = 30.0,
        use_band_split: bool = True,
        band_split_channels: int = 128,
        use_dual_stream: bool = True,
        coverhunter_heads: int = 4,
        coverhunter_layers: int = 2,
        fusion_hidden_dim: int = 256,
        mert_model_name: Optional[str] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.use_dual_stream = use_dual_stream
        if use_dual_stream:
            self.mert_melody_branch = MERTMelodyBranch(
                n_mels=n_mels,
                chroma_bins=12,
                melody_bins=1,
                hidden_dim=fusion_hidden_dim,
                mert_model_name=mert_model_name,
            )
            self.ecapa_branch = ECAPABranch(
                n_mels=n_mels,
                channels=channels,
                use_band_split=use_band_split,
                band_split_channels=band_split_channels,
            )
            self.stream_fusion = DualStreamFusion(
                mert_dim=fusion_hidden_dim,
                ecapa_dim=channels,
                hidden_dim=channels,
            )
            front_channels = channels
        else:
            front_channels = band_split_channels * 5 if use_band_split else channels
            self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None

        self.conv1 = nn.Sequential(
            nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )

        dilations = [1, 2, 3] if num_blocks == 3 else [d for d in range(1, num_blocks + 1)]
        self.blocks = nn.ModuleList(
            [
                Res2Block(
                    channels=channels,
                    kernel_size=3,
                    dilation=d,
                    scale=res2net_scale,
                    se_channels=se_channels,
                )
                for d in dilations[:num_blocks]
            ]
        )

        in_channels = channels * num_blocks
        self.mfa = nn.Sequential(
            nn.Conv1d(in_channels, channels * 3, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(channels * 3),
        )
        self.coverhunter = CoverHunterHead(
            input_dim=channels * 3,
            embed_dim=embed_dim,
            num_heads=coverhunter_heads,
            num_layers=coverhunter_layers,
        )
        self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s) if num_classes is not None else None

    def forward(
        self,
        mel: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        melody: Optional[torch.Tensor] = None,
        chroma: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.use_dual_stream:
            if melody is None or chroma is None:
                raise ValueError("melody and chroma are required when dual-stream fusion is enabled")
            mert_stream = self.mert_melody_branch(mel, melody, chroma)
            ecapa_stream = self.ecapa_branch(mel)
            x = self.stream_fusion(mert_stream, ecapa_stream)
        else:
            x = self.band_split(mel) if self.band_split is not None else mel
            x = self.conv1(x)
        if self.use_dual_stream:
            x = self.conv1(x)
        block_outputs = []
        for block in self.blocks:
            x = block(x)
            block_outputs.append(x)
        x = torch.cat(block_outputs, dim=1)
        x = self.mfa(x)
        embedding = self.coverhunter(x.transpose(1, 2))
        if labels is not None and self.aam is not None:
            logits = self.aam(embedding, labels)
            return embedding, logits
        return embedding, None