train.py 9.17 KB
#!/usr/bin/env python3
"""ACR Engine - Training script."""

import argparse
import json
import sys
from pathlib import Path

import torch
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm

project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))

from src.data.dataset import SongPairDataset, ACRDataset
from src.models.ecapa_tdnn import ECAPA_ACR
from src.models.losses import CombinedLoss


def collate_fn(batch):
    mels = []
    song_ids = []
    song_names = []
    hard_weights = []
    for b in batch:
        mel = b["mel"]
        hw = b.get("hard_weight", torch.tensor(1.0))
        if mel.dim() == 3:
            for i in range(mel.shape[0]):
                mels.append(mel[i])
                song_ids.append(b["song_id"][i])
                song_names.append(b["song_name"])
                if torch.is_tensor(hw) and hw.dim() > 0:
                    hard_weights.append(hw[i])
                else:
                    hard_weights.append(hw)
        else:
            mels.append(mel)
            song_ids.append(b["song_id"])
            song_names.append(b["song_name"])
            hard_weights.append(hw)

    max_t = max(m.shape[1] for m in mels)
    mels_padded = []
    for m in mels:
        pad = max_t - m.shape[1]
        if pad > 0:
            m = torch.nn.functional.pad(m, (0, pad))
        mels_padded.append(m.unsqueeze(0))

    return {
        "mel": torch.cat(mels_padded, dim=0),
        "song_id": torch.stack(song_ids),
        "song_name": song_names,
        "hard_weight": torch.stack(hard_weights),
    }


def train_epoch(model, loader, optimizer, criterion, scaler, device, epoch, cfg):
    model.train()
    total_loss = total_supcon = total_ce = correct = total = steps = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        mel = batch["mel"].to(device)
        labels = batch["song_id"].to(device)

        with torch.amp.autocast("cuda", enabled=cfg["training"]["mixed_precision"] and device.type == "cuda"):
            embedding, logits = model(mel, labels)
            loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)

        optimizer.zero_grad()
        if scaler:
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict["loss"].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["training"]["gradient_clip"])
            optimizer.step()

        total_loss += float(loss_dict["loss"].item())
        total_supcon += float(loss_dict["supcon_loss"])
        total_ce += float(loss_dict["ce_loss"])
        if logits is not None:
            preds = logits.argmax(dim=1)
            correct += int((preds == labels).sum().item())
        total += labels.size(0)
        steps += 1
        pbar.set_postfix({"loss": f"{loss_dict['loss']:.4f}", "acc": f"{correct / max(total,1):.3f}"})

    return {
        "loss": total_loss / max(steps, 1),
        "supcon_loss": total_supcon / max(steps, 1),
        "ce_loss": total_ce / max(steps, 1),
        "accuracy": correct / max(total, 1),
    }


def save_checkpoint(output_dir, epoch, model, optimizer, best_metric, cfg, name):
    path = output_dir / name
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_metric": best_metric,
            "config": cfg,
        },
        path,
    )
    print(f"  Saved: {path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/default.yaml")
    parser.add_argument("--data", type=str, default="data/synthetic")
    parser.add_argument("--output", type=str, default="data/models")
    parser.add_argument("--resume", type=str, default=None)
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--batch-size", type=int, default=None)
    parser.add_argument("--lr", type=float, default=None)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    if args.epochs:
        cfg["training"]["epochs"] = args.epochs
    if args.batch_size:
        cfg["training"]["batch_size"] = args.batch_size
    if args.lr:
        cfg["training"]["lr"] = args.lr

    device_name = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
    if args.device == "auto" and device_name == "auto":
        device_name = "cpu"
    device = torch.device(device_name)
    print(f"Device: {device}")

    train_dataset = SongPairDataset(
        args.data,
        split="train",
        sr=cfg["data"]["sample_rate"],
        n_mels=cfg["model"]["n_mels"],
        n_fft=cfg["data"]["n_fft"],
        hop_length=cfg["data"]["hop_length"],
        segment_dur=cfg["data"]["segment_dur"],
        augment=True,
    )

    catalog_dataset = ACRDataset(
        args.data,
        split="train",
        sr=cfg["data"]["sample_rate"],
        n_mels=cfg["model"]["n_mels"],
        n_fft=cfg["data"]["n_fft"],
        hop_length=cfg["data"]["hop_length"],
        segment_dur=cfg["data"]["segment_dur"],
        augment=False,
        n_crops_per_song=1,
        song_to_idx=train_dataset.song_to_idx,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg["training"]["batch_size"],
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=False,
    )

    if args.dry_run:
        batch = next(iter(train_loader))
        print("Dry batch shape:", batch["mel"].shape, batch["song_id"].shape)

    num_classes = len(train_dataset.song_ids)
    print(f"Classes: {num_classes}")
    print(f"Train songs: {len(train_dataset)}")

    model = ECAPA_ACR(
        n_mels=cfg["model"]["n_mels"],
        embed_dim=cfg["model"]["embed_dim"],
        channels=cfg["model"]["channels"],
        se_channels=cfg["model"]["se_channels"],
        res2net_scale=cfg["model"]["res2net_scale"],
        num_blocks=cfg["model"]["num_blocks"],
        num_classes=num_classes,
        aam_m=cfg["model"]["aam_m"],
        aam_s=cfg["model"]["aam_s"],
        use_band_split=cfg["model"].get("use_band_split", True),
        band_split_channels=cfg["model"].get("band_split_channels", 128),
    ).to(device)

    criterion = CombinedLoss(
        temperature=cfg["training"]["temperature"],
        supcon_weight=cfg["training"]["supcon_weight"],
        aam_weight=cfg["training"]["aam_weight"],
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg["training"]["lr"], weight_decay=cfg["training"]["weight_decay"])
    scaler = torch.amp.GradScaler("cuda", enabled=device.type == "cuda")

    if args.dry_run:
        print("Dry run: running one batch through forward/backward...")
        batch = next(iter(train_loader))
        mel = batch["mel"].to(device)
        labels = batch["song_id"].to(device)
        embedding, logits = model(mel, labels)
        loss_dict = criterion(embedding, logits, labels, labels, batch.get("hard_weight", None).to(device) if "hard_weight" in batch else None)
        loss_dict["loss"].backward()
        print(f"  Forward/backward OK. Loss: {loss_dict['loss']:.4f}")
        print(f"  Embedding shape: {embedding.shape}")
        print("Dry run passed! Pipeline is working.")
        return

    start_epoch = 1
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=True)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt["epoch"] + 1
        print(f"Resumed from epoch {ckpt['epoch']}")

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["training"]["epochs"])
    best_loss = float("inf")
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Starting training...")
    for epoch in range(start_epoch, cfg["training"]["epochs"] + 1):
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch, cfg)
        scheduler.step()
        print(f"  Train loss={train_metrics['loss']:.4f} acc={train_metrics['accuracy']:.4f} lr={optimizer.param_groups[0]['lr']:.2e}")
        if train_metrics["loss"] < best_loss:
            best_loss = train_metrics["loss"]
            save_checkpoint(output_dir, epoch, model, optimizer, best_loss, cfg, "best_model.pt")
        if epoch % cfg["training"]["save_every"] == 0:
            save_checkpoint(output_dir, epoch, model, optimizer, best_loss, cfg, f"checkpoint_epoch_{epoch}.pt")

    with open(output_dir / "song_to_idx.json", "w") as f:
        json.dump(train_dataset.song_to_idx, f, indent=2)
    print(f"\nTraining complete. Best training loss: {best_loss:.4f}")
    print(f"Model saved to: {output_dir / 'best_model.pt'}")
    print(f"Catalog references available: {len(catalog_dataset.samples)}")


if __name__ == "__main__":
    main()