losses.py 1.9 KB
import torch
import torch.nn as nn
import torch.nn.functional as F


class SupConLoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        mask = mask - torch.eye(batch_size, device=features.device)

        features = F.normalize(features, dim=1)
        sim = torch.matmul(features, features.T) / self.temperature
        sim_max, _ = torch.max(sim, dim=1, keepdim=True)
        sim = sim - sim_max.detach()

        exp_sim = torch.exp(sim) * (1 - torch.eye(batch_size, device=features.device))
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))

        pos_mask = mask
        pos_count = pos_mask.sum(dim=1)
        loss = -(log_prob * pos_mask).sum(dim=1)
        loss = loss / pos_count.clamp(min=1)
        return loss.mean()


class CombinedLoss(nn.Module):
    def __init__(
        self,
        temperature: float = 0.07,
        supcon_weight: float = 1.0,
        aam_weight: float = 0.3,
    ):
        super().__init__()
        self.supcon = SupConLoss(temperature)
        self.ce = nn.CrossEntropyLoss()
        self.supcon_weight = supcon_weight
        self.aam_weight = aam_weight

    def forward(
        self,
        embedding: torch.Tensor,
        logits: torch.Tensor,
        labels: torch.Tensor,
        supcon_labels: torch.Tensor,
    ) -> dict:
        loss_supcon = self.supcon(embedding, supcon_labels)
        loss_ce = self.ce(logits, labels)

        total = self.supcon_weight * loss_supcon + self.aam_weight * loss_ce
        return {
            "loss": total,
            "supcon_loss": loss_supcon.item(),
            "ce_loss": loss_ce.item(),
        }