ecapa_tdnn.py 6.7 KB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List


class SEModule(nn.Module):
    def __init__(self, channels, se_channels=128):
        super().__init__()
        self.se = nn.Sequential(
            nn.Conv1d(channels, se_channels, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(se_channels),
            nn.Conv1d(se_channels, channels, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return x * self.se(x)


class BandSplitBlock(nn.Module):
    def __init__(self, n_mels: int, split_points: Optional[List[int]] = None, out_channels: int = 128):
        super().__init__()
        self.split_points = split_points or [16, 32, 64, 96, n_mels]
        starts = [0] + self.split_points[:-1]
        widths = [end - start for start, end in zip(starts, self.split_points)]
        self.band_projs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv1d(width, out_channels, kernel_size=1),
                    nn.ReLU(),
                    nn.BatchNorm1d(out_channels),
                )
                for width in widths
            ]
        )
        self.fuse = nn.Sequential(
            nn.Conv1d(out_channels * len(widths), out_channels * len(widths), kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels * len(widths)),
        )

    def forward(self, x):
        starts = [0] + self.split_points[:-1]
        bands = []
        for proj, start, end in zip(self.band_projs, starts, self.split_points):
            bands.append(proj(x[:, start:end, :]))
        return self.fuse(torch.cat(bands, dim=1))


class Res2Block(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, scale=8, se_channels=128):
        super().__init__()
        self.width = channels // scale
        self.num_split = scale
        self.convs = nn.ModuleList()
        for _ in range(self.num_split):
            self.convs.append(
                nn.Sequential(
                    nn.Conv1d(
                        self.width,
                        self.width,
                        kernel_size,
                        padding=dilation * (kernel_size - 1) // 2,
                        dilation=dilation,
                    ),
                    nn.ReLU(),
                    nn.BatchNorm1d(self.width),
                )
            )
        self.conv1x1 = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )
        self.se = SEModule(channels, se_channels)

    def forward(self, x):
        residual = x
        split_x = torch.split(x, self.width, dim=1)
        out = []
        for i, (part, conv) in enumerate(zip(split_x, self.convs)):
            if i == 0:
                out.append(conv(part))
            else:
                out.append(conv(part + out[-1]))
        x = torch.cat(out, dim=1)
        x = self.conv1x1(x)
        x = self.se(x)
        return x + residual


class StatisticsPooling(nn.Module):
    def forward(self, x):
        mean = torch.mean(x, dim=2)
        std = torch.sqrt(torch.var(x, dim=2, unbiased=False) + 1e-12)
        return torch.cat([mean, std], dim=1)


class AAMSoftmax(nn.Module):
    def __init__(self, in_features, out_features, m=0.3, s=30.0):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        self.m = m
        self.s = s
        self.cos_m = torch.cos(torch.tensor(m))
        self.sin_m = torch.sin(torch.tensor(m))
        self.th = torch.cos(torch.tensor(torch.pi - m))
        self.mm = torch.sin(torch.tensor(torch.pi - m)) * m

    def forward(self, x, labels):
        x = F.normalize(x, dim=1)
        w = F.normalize(self.weight, dim=1)
        cos_theta = F.linear(x, w)
        sin_theta = torch.sqrt(1.0 - torch.clamp(cos_theta ** 2, 0, 1))
        phi = cos_theta * self.cos_m - sin_theta * self.sin_m
        phi = torch.where(cos_theta > self.th, phi, cos_theta - self.mm)
        one_hot = F.one_hot(labels, num_classes=self.weight.size(0)).float()
        output = (one_hot * phi) + ((1.0 - one_hot) * cos_theta)
        output *= self.s
        return output


class ECAPA_ACR(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        embed_dim: int = 192,
        channels: int = 512,
        se_channels: int = 128,
        res2net_scale: int = 8,
        num_blocks: int = 3,
        num_classes: Optional[int] = None,
        aam_m: float = 0.3,
        aam_s: float = 30.0,
        use_band_split: bool = True,
        band_split_channels: int = 128,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        front_channels = band_split_channels * 5 if use_band_split else channels
        self.band_split = BandSplitBlock(n_mels=n_mels, out_channels=band_split_channels) if use_band_split else None

        self.conv1 = nn.Sequential(
            nn.Conv1d(front_channels, channels, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )

        dilations = [1, 2, 3] if num_blocks == 3 else [d for d in range(1, num_blocks + 1)]
        self.blocks = nn.ModuleList(
            [
                Res2Block(
                    channels=channels,
                    kernel_size=3,
                    dilation=d,
                    scale=res2net_scale,
                    se_channels=se_channels,
                )
                for d in dilations[:num_blocks]
            ]
        )

        in_channels = channels * num_blocks
        self.mfa = nn.Sequential(
            nn.Conv1d(in_channels, channels * 3, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(channels * 3),
        )
        self.pooling = StatisticsPooling()
        self.fc = nn.Linear(channels * 3 * 2, embed_dim)
        self.bn = nn.BatchNorm1d(embed_dim, affine=False)
        self.aam = AAMSoftmax(embed_dim, num_classes, m=aam_m, s=aam_s) if num_classes is not None else None

    def forward(self, mel: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        x = self.band_split(mel) if self.band_split is not None else mel
        x = self.conv1(x)
        block_outputs = []
        for block in self.blocks:
            x = block(x)
            block_outputs.append(x)
        x = torch.cat(block_outputs, dim=1)
        x = self.mfa(x)
        x = self.pooling(x)
        x = self.fc(x)
        x = self.bn(x)
        embedding = F.normalize(x, p=2, dim=1)
        if labels is not None and self.aam is not None:
            logits = self.aam(embedding, labels)
            return embedding, logits
        return embedding, None