dataset.py 8.33 KB
import json
import random
from pathlib import Path
from typing import Dict, List, Optional

import librosa
import numpy as np
import torch
from torch.utils.data import Dataset


class ACRDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        segment_dur: float = 5.0,
        augment: bool = True,
        n_crops_per_song: int = 4,
        song_to_idx: Optional[Dict[str, int]] = None,
        references_only: bool = False,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.segment_len = int(segment_dur * sr)
        self.augment = augment
        self.n_crops = n_crops_per_song
        self.data_dir = Path(data_dir)

        meta_path = self.data_dir / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            if references_only and item.get("type") != "reference":
                continue
            song_path = self.data_dir / item["audio_path"]
            if song_path.exists():
                self.samples.append(item)

        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples) * self.n_crops

    def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray:
        y, _ = librosa.load(path, sr=self.sr, mono=True, offset=offset, duration=duration)
        if len(y) < self.segment_len:
            y = np.pad(y, (0, self.segment_len - len(y)))
        else:
            y = y[: self.segment_len]
        return y

    def _to_mel(self, y: np.ndarray) -> np.ndarray:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        return librosa.power_to_db(mel, ref=np.max)

    def __getitem__(self, idx):
        sample = self.samples[idx // self.n_crops]
        duration = sample["duration"]
        max_offset = max(0, duration - 5.0)
        offset = random.uniform(0, max_offset) if max_offset > 0 else 0

        audio_path = self.data_dir / sample["audio_path"]
        y = self._load_segment(str(audio_path), offset, 5.0)

        if self.augment and sample.get("type") != "reference":
            from src.utils.augment import AugmentPipeline
            aug = AugmentPipeline(self.sr)
            y = aug(y)

        mel = self._to_mel(y)
        mel_tensor = torch.FloatTensor(mel)

        song_id = sample["song_id"]
        class_id = self.song_to_idx[song_id]

        return {
            "mel": mel_tensor,
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": song_id,
            "type": sample.get("type", "unknown"),
        }


class ACRTestDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "test",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        song_to_idx: Optional[Dict[str, int]] = None,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.data_dir = Path(data_dir)

        meta_path = self.data_dir / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            p = self.data_dir / item["audio_path"]
            if p.exists():
                self.samples.append(item)

        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        audio_path = self.data_dir / sample["audio_path"]
        y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True, offset=0, duration=min(sample["duration"], 5.0))
        seg_len = 5 * self.sr
        if len(y) < seg_len:
            y = np.pad(y, (0, seg_len - len(y)))
        else:
            y = y[:seg_len]

        mel = librosa.power_to_db(
            librosa.feature.melspectrogram(
                y=y,
                sr=self.sr,
                n_mels=self.n_mels,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
            ),
            ref=np.max,
        )
        class_id = self.song_to_idx[sample["song_id"]]
        return {
            "mel": torch.FloatTensor(mel),
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": sample["song_id"],
            "type": sample.get("type", "unknown"),
        }


class SongPairDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        segment_dur: float = 5.0,
        augment: bool = True,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.segment_len = int(segment_dur * sr)
        self.augment = augment
        self.data_dir = Path(data_dir)

        with open(self.data_dir / f"{split}.json") as f:
            metadata = json.load(f)

        self.by_song: Dict[str, List[Dict]] = {}
        for item in metadata:
            if item.get("type") == "reference":
                continue
            p = self.data_dir / item["audio_path"]
            if p.exists():
                self.by_song.setdefault(item["song_id"], []).append(item)

        self.song_ids = sorted(self.by_song)
        self.sample_song_ids = []
        for sid, items in self.by_song.items():
            item_types = {x.get("type") for x in items}
            if "confused" in item_types:
                weight = 5
            elif "humming_like" in item_types:
                weight = 3
            else:
                weight = 1
            self.sample_song_ids.extend([sid] * weight)
        self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.sample_song_ids)

    def _load_clip(self, sample: Dict) -> np.ndarray:
        path = self.data_dir / sample["audio_path"]
        y, _ = librosa.load(str(path), sr=self.sr, mono=True, duration=5.0)
        if len(y) < self.segment_len:
            y = np.pad(y, (0, self.segment_len - len(y)))
        else:
            y = y[: self.segment_len]
        return y

    def _to_mel(self, y: np.ndarray) -> torch.Tensor:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.FloatTensor(mel)

    def __getitem__(self, idx):
        song_id = self.sample_song_ids[idx]
        choices = self.by_song[song_id]
        if len(choices) == 1:
            a = b = choices[0]
        else:
            a, b = random.sample(choices, 2)

        type_to_weight = {
            "confused": 4.0,
            "humming_like": 2.5,
            "augmented": 1.4,
        }
        pair_weights = [
            type_to_weight.get(a.get("type", "unknown"), 1.0),
            type_to_weight.get(b.get("type", "unknown"), 1.0),
        ]

        wavs = []
        for sample in (a, b):
            y = self._load_clip(sample)
            if self.augment:
                from src.utils.augment import AugmentPipeline
                y = AugmentPipeline(self.sr, aggressive=sample.get("type") in {"confused", "humming_like"})(y)
            wavs.append(self._to_mel(y))

        max_t = max(w.shape[1] for w in wavs)
        wavs = [torch.nn.functional.pad(w, (0, max_t - w.shape[1])) if w.shape[1] < max_t else w for w in wavs]

        label = self.song_to_idx[song_id]
        return {
            "mel": torch.stack(wavs, dim=0),
            "song_id": torch.tensor([label, label], dtype=torch.long),
            "song_name": song_id,
            "hard_weight": torch.tensor(pair_weights, dtype=torch.float32),
        }