train.py 8.94 KB
#!/usr/bin/env python3
"""
ACR Engine - Training script.
"""

import os
import sys
import json
import yaml
import time
import argparse
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))

from src.models.ecapa_tdnn import ECAPA_ACR
from src.models.losses import CombinedLoss
from src.data.dataset import ACRDataset, ACRTestDataset


def collate_fn(batch):
    mels = [b["mel"] for b in batch]
    song_ids = [b["song_id"] for b in batch]
    song_names = [b["song_name"] for b in batch]

    max_t = max(m.shape[1] for m in mels)
    mels_padded = []
    for m in mels:
        pad = max_t - m.shape[1]
        if pad > 0:
            m = torch.nn.functional.pad(m, (0, pad))
        mels_padded.append(m.unsqueeze(0))

    return {
        "mel": torch.cat(mels_padded, dim=0),
        "song_id": torch.stack(song_ids),
        "song_name": song_names,
    }


def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg):
    model.train()
    total_loss = 0
    total_supcon = 0
    total_ce = 0
    correct = 0
    total = 0
    steps = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        mel = batch["mel"].to(device)
        labels = batch["song_id"].to(device)

        with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"):
            embedding, logits = model(mel, labels)
            loss_dict = criterion(embedding, logits, labels, labels)

        optimizer.zero_grad()
        if scaler:
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict["loss"].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
            optimizer.step()

        total_loss += loss_dict["loss"].item()
        total_supcon += loss_dict["supcon_loss"]
        total_ce += loss_dict["ce_loss"]

        if logits is not None:
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
        total += labels.size(0)
        steps += 1

        pbar.set_postfix({
            "loss": f"{loss_dict['loss']:.4f}",
            "acc": f"{correct/total:.3f}",
        })

    return {
        "loss": total_loss / steps,
        "supcon_loss": total_supcon / steps,
        "ce_loss": total_ce / steps,
        "accuracy": correct / total,
    }


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            mel = batch["mel"].to(device)
            labels = batch["song_id"].to(device)

            embedding, logits = model(mel, labels)
            loss_dict = criterion(embedding, logits, labels, labels)

            total_loss += loss_dict["loss"].item()
            if logits is not None:
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total if total > 0 else 0
    print(f"  Validation: loss={total_loss:.4f}, accuracy={acc:.4f}")
    return {"loss": total_loss, "accuracy": acc}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/default.yaml")
    parser.add_argument("--data", type=str, default="data/synthetic")
    parser.add_argument("--output", type=str, default="data/models")
    parser.add_argument("--resume", type=str, default=None)
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--batch-size", type=int, default=None)
    parser.add_argument("--lr", type=float, default=None)
    parser.add_argument("--dry-run", action="store_true", help="Run one batch to verify pipeline")
    args = parser.parse_args()

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    if args.epochs:
        cfg["training"]["epochs"] = args.epochs
    if args.batch_size:
        cfg["training"]["batch_size"] = args.batch_size
    if args.lr:
        cfg["training"]["lr"] = args.lr

    device_name = args.device
    if device_name == "auto":
        device_name = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_name)
    print(f"Device: {device}")

    print("Loading datasets...")
    train_dataset = ACRDataset(
        args.data, split="train",
        sr=cfg["data"]["sample_rate"],
        n_mels=cfg["model"]["n_mels"],
        n_fft=cfg["data"]["n_fft"],
        hop_length=cfg["data"]["hop_length"],
        segment_dur=cfg["data"]["segment_dur"],
        augment=True,
        n_crops_per_song=cfg["data"]["crop_per_song"],
    )
    val_dataset = ACRDataset(
        args.data, split="val",
        sr=cfg["data"]["sample_rate"],
        n_mels=cfg["model"]["n_mels"],
        n_fft=cfg["data"]["n_fft"],
        hop_length=cfg["data"]["hop_length"],
        segment_dur=cfg["data"]["segment_dur"],
        augment=False,
        n_crops_per_song=1,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg["training"]["batch_size"],
        shuffle=True,
        num_workers=2,
        collate_fn=collate_fn,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg["training"]["batch_size"],
        shuffle=False,
        num_workers=2,
        collate_fn=collate_fn,
    )

    num_classes = len(train_dataset.song_ids)
    print(f"Classes: {num_classes}")
    print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    model = ECAPA_ACR(
        n_mels=cfg["model"]["n_mels"],
        embed_dim=cfg["model"]["embed_dim"],
        channels=cfg["model"]["channels"],
        se_channels=cfg["model"]["se_channels"],
        res2net_scale=cfg["model"]["res2net_scale"],
        num_blocks=cfg["model"]["num_blocks"],
        num_classes=num_classes,
        aam_m=cfg["model"]["aam_m"],
        aam_s=cfg["model"]["aam_s"],
    ).to(device)

    criterion = CombinedLoss(
        temperature=cfg["training"]["temperature"],
        supcon_weight=cfg["training"]["supcon_weight"],
        aam_weight=cfg["training"]["aam_weight"],
    )
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg["training"]["lr"],
        weight_decay=cfg["training"]["weight_decay"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")

    start_epoch = 1
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=True)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt["epoch"] + 1
        print(f"Resumed from epoch {ckpt['epoch']}")

    if args.dry_run:
        print("Dry run: running one batch through forward/backward...")
        batch = next(iter(train_loader))
        mel = batch["mel"].to(device)
        labels = batch["song_id"].to(device)
        embedding, logits = model(mel, labels)
        loss_dict = criterion(embedding, logits, labels, labels)
        loss_dict["loss"].backward()
        print(f"  Forward/backward OK. Loss: {loss_dict['loss']:.4f}")
        print(f"  Embedding shape: {embedding.shape}")
        print("Dry run passed! Pipeline is working.")
        return

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=cfg["training"]["epochs"]
    )

    best_acc = float("-inf")
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Starting training...")
    for epoch in range(start_epoch, cfg["training"]["epochs"] + 1):
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch, cfg)
        val_metrics = validate(model, val_loader, criterion, device)
        scheduler.step()

        lr = optimizer.param_groups[0]["lr"]
        print(f"  LR: {lr:.2e}")

        if epoch % cfg["training"]["save_every"] == 0 or val_metrics["accuracy"] > best_acc:
            if val_metrics["accuracy"] > best_acc:
                best_acc = val_metrics["accuracy"]
                path = output_dir / "best_model.pt"
            else:
                path = output_dir / f"checkpoint_epoch_{epoch}.pt"

            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "best_acc": best_acc,
                "config": cfg,
            }, path)
            print(f"  Saved: {path}")

    print(f"\nTraining complete. Best validation accuracy: {best_acc:.4f}")
    print(f"Model saved to: {output_dir / 'best_model.pt'}")


if __name__ == "__main__":
    main()