losses.py 2.28 KB
import torch
import torch.nn as nn
import torch.nn.functional as F


class SupConLoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        mask = mask - torch.eye(batch_size, device=features.device)

        features = F.normalize(features, dim=1)
        sim = torch.matmul(features, features.T) / self.temperature
        sim_max, _ = torch.max(sim, dim=1, keepdim=True)
        sim = sim - sim_max.detach()

        exp_sim = torch.exp(sim) * (1 - torch.eye(batch_size, device=features.device))
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))

        pos_mask = mask
        pos_count = pos_mask.sum(dim=1)
        loss = -(log_prob * pos_mask).sum(dim=1)
        loss = loss / pos_count.clamp(min=1)
        return loss


class CombinedLoss(nn.Module):
    def __init__(
        self,
        temperature: float = 0.07,
        supcon_weight: float = 1.0,
        aam_weight: float = 0.3,
    ):
        super().__init__()
        self.supcon = SupConLoss(temperature)
        self.ce = nn.CrossEntropyLoss()
        self.supcon_weight = supcon_weight
        self.aam_weight = aam_weight

    def forward(
        self,
        embedding: torch.Tensor,
        logits: torch.Tensor,
        labels: torch.Tensor,
        supcon_labels: torch.Tensor,
        hard_weight: torch.Tensor | None = None,
    ) -> dict:
        loss_supcon = self.supcon(embedding, supcon_labels)
        loss_ce = F.cross_entropy(logits, labels, reduction="none")
        if hard_weight is not None:
            weight = hard_weight.float()
            if weight.dim() == 0:
                weight = weight.unsqueeze(0)
            loss_supcon = loss_supcon * weight
            loss_ce = loss_ce * weight

        loss_supcon = loss_supcon.mean()
        loss_ce = loss_ce.mean()

        total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce
        return {
            "loss": total,
            "supcon_loss": loss_supcon.item(),
            "ce_loss": loss_ce.item(),
        }