losses.py 2.2 KB
import torch
import torch.nn as nn
import torch.nn.functional as F


class InfoNCELoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        features = F.normalize(features, dim=1)
        logits = torch.matmul(features, features.T) / self.temperature
        labels = labels.contiguous().view(-1, 1)
        positive_mask = torch.eq(labels, labels.T).float().to(features.device)
        positive_mask = positive_mask - torch.eye(features.size(0), device=features.device)
        logits = logits - logits.max(dim=1, keepdim=True).values.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(features.size(0), device=features.device))
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
        positives = positive_mask.sum(dim=1).clamp(min=1)
        return -((positive_mask * log_prob).sum(dim=1) / positives)


class CombinedLoss(nn.Module):
    def __init__(
        self,
        temperature: float = 0.07,
        supcon_weight: float = 1.0,
        aam_weight: float = 0.3,
    ):
        super().__init__()
        self.infonce = InfoNCELoss(temperature)
        self.supcon_weight = supcon_weight
        self.aam_weight = aam_weight

    def forward(
        self,
        embedding: torch.Tensor,
        logits: torch.Tensor,
        labels: torch.Tensor,
        supcon_labels: torch.Tensor,
        hard_weight: torch.Tensor | None = None,
    ) -> dict:
        loss_infonce = self.infonce(embedding, supcon_labels)
        loss_ce = F.cross_entropy(logits, labels, reduction="none")
        if hard_weight is not None:
            weight = hard_weight.float()
            if weight.dim() == 0:
                weight = weight.unsqueeze(0)
            loss_infonce = loss_infonce * weight
            loss_ce = loss_ce * weight

        loss_infonce = loss_infonce.mean()
        loss_ce = loss_ce.mean()
        total = self.supcon_weight * loss_infonce + self.aam_weight * loss_ce
        return {
            "loss": total,
            "supcon_loss": loss_infonce.item(),
            "ce_loss": loss_ce.item(),
        }