dataset.py 12 KB
import json
import random
from pathlib import Path
from typing import Dict, List, Optional

import librosa
import numpy as np
import torch
from torch.utils.data import Dataset


class ACRDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        segment_dur: float = 5.0,
        augment: bool = True,
        n_crops_per_song: int = 4,
        song_to_idx: Optional[Dict[str, int]] = None,
        references_only: bool = False,
        segment_strategy: str = "random",
        silence_top_db: int = 30,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.segment_len = int(segment_dur * sr)
        self.augment = augment
        self.n_crops = n_crops_per_song
        self.segment_strategy = segment_strategy
        self.silence_top_db = silence_top_db
        self.data_dir = Path(data_dir)
        self.asset_root = self.data_dir.parent if self.data_dir.name == "manifests" else self.data_dir

        meta_path = self.data_dir / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            if references_only and item.get("type") != "reference":
                continue
            song_path = self.asset_root / item["audio_path"]
            if song_path.exists():
                self.samples.append(item)

        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples) * self.n_crops

    def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray:
        y, _ = librosa.load(path, sr=self.sr, mono=True, offset=offset, duration=duration)
        if len(y) < self.segment_len:
            y = np.pad(y, (0, self.segment_len - len(y)))
        else:
            y = y[: self.segment_len]
        return y

    def _to_mel(self, y: np.ndarray) -> np.ndarray:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        return librosa.power_to_db(mel, ref=np.max)

    def _find_non_silent_intervals(self, y: np.ndarray) -> List[tuple[int, int]]:
        intervals = librosa.effects.split(y, top_db=self.silence_top_db)
        if intervals is None or len(intervals) == 0:
            return [(0, len(y))]
        return [(int(s), int(e)) for s, e in intervals]

    def _choose_offset(self, sample: Dict, audio_path: Path) -> float:
        duration = float(sample["duration"])
        max_offset = max(0.0, duration - 5.0)
        if max_offset <= 0:
            return 0.0

        if self.segment_strategy == "random":
            return random.uniform(0, max_offset)

        y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True)
        target_len = self.segment_len
        intervals = self._find_non_silent_intervals(y)
        valid_intervals = []
        for start, end in intervals:
            if end - start >= target_len:
                valid_intervals.append((start, end))

        if self.segment_strategy == "silence_aware":
            if valid_intervals:
                start, end = random.choice(valid_intervals)
                seg_max_start = max(start, end - target_len)
                chosen = random.randint(start, seg_max_start) if seg_max_start > start else start
                return min(chosen / self.sr, max_offset)
            return random.uniform(0, max_offset)

        if self.segment_strategy == "hybrid":
            if valid_intervals and random.random() < 0.7:
                start, end = random.choice(valid_intervals)
                seg_max_start = max(start, end - target_len)
                chosen = random.randint(start, seg_max_start) if seg_max_start > start else start
                return min(chosen / self.sr, max_offset)
            return random.uniform(0, max_offset)

        return random.uniform(0, max_offset)

    def __getitem__(self, idx):
        sample = self.samples[idx // self.n_crops]

        audio_path = self.asset_root / sample["audio_path"]
        offset = self._choose_offset(sample, audio_path)
        y = self._load_segment(str(audio_path), offset, 5.0)

        if self.augment and sample.get("type") != "reference":
            from src.utils.augment import AugmentPipeline
            aug = AugmentPipeline(self.sr)
            y = aug(y)

        mel = self._to_mel(y)
        mel_tensor = torch.FloatTensor(mel)

        song_id = sample["song_id"]
        class_id = self.song_to_idx[song_id]

        return {
            "mel": mel_tensor,
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": song_id,
            "type": sample.get("type", "unknown"),
        }


class ACRTestDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "test",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        song_to_idx: Optional[Dict[str, int]] = None,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.data_dir = Path(data_dir)
        self.asset_root = self.data_dir.parent if self.data_dir.name == "manifests" else self.data_dir

        meta_path = self.data_dir / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            p = self.asset_root / item["audio_path"]
            if p.exists():
                self.samples.append(item)

        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = song_to_idx or {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        audio_path = self.asset_root / sample["audio_path"]
        y, _ = librosa.load(str(audio_path), sr=self.sr, mono=True, offset=0, duration=min(sample["duration"], 5.0))
        seg_len = 5 * self.sr
        if len(y) < seg_len:
            y = np.pad(y, (0, seg_len - len(y)))
        else:
            y = y[:seg_len]

        mel = librosa.power_to_db(
            librosa.feature.melspectrogram(
                y=y,
                sr=self.sr,
                n_mels=self.n_mels,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
            ),
            ref=np.max,
        )
        class_id = self.song_to_idx[sample["song_id"]]
        return {
            "mel": torch.FloatTensor(mel),
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": sample["song_id"],
            "type": sample.get("type", "unknown"),
        }


class SongPairDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        segment_dur: float = 5.0,
        augment: bool = True,
        segment_strategy: str = "random",
        silence_top_db: int = 30,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.segment_len = int(segment_dur * sr)
        self.augment = augment
        self.segment_strategy = segment_strategy
        self.silence_top_db = silence_top_db
        self.data_dir = Path(data_dir)
        self.asset_root = self.data_dir.parent if self.data_dir.name == "manifests" else self.data_dir

        with open(self.data_dir / f"{split}.json") as f:
            metadata = json.load(f)

        self.by_song: Dict[str, List[Dict]] = {}
        for item in metadata:
            if item.get("type") == "reference":
                continue
            p = self.asset_root / item["audio_path"]
            if p.exists():
                self.by_song.setdefault(item["song_id"], []).append(item)

        self.song_ids = sorted(self.by_song)
        self.sample_song_ids = []
        for sid, items in self.by_song.items():
            item_types = {x.get("type") for x in items}
            if "confused" in item_types:
                weight = 5
            elif "humming_like" in item_types:
                weight = 3
            else:
                weight = 1
            self.sample_song_ids.extend([sid] * weight)
        self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.sample_song_ids)

    def _load_clip(self, sample: Dict) -> np.ndarray:
        path = self.asset_root / sample["audio_path"]
        full_y, _ = librosa.load(str(path), sr=self.sr, mono=True)
        duration = float(sample.get("duration", len(full_y) / self.sr))
        max_offset = max(0.0, duration - 5.0)
        offset = 0.0
        if max_offset > 0:
            if self.segment_strategy == "random":
                offset = random.uniform(0, max_offset)
            else:
                intervals = librosa.effects.split(full_y, top_db=self.silence_top_db)
                valid = [(int(s), int(e)) for s, e in intervals if int(e) - int(s) >= self.segment_len] if len(intervals) else []
                if self.segment_strategy == "silence_aware" and valid:
                    start, end = random.choice(valid)
                    seg_max_start = max(start, end - self.segment_len)
                    chosen = random.randint(start, seg_max_start) if seg_max_start > start else start
                    offset = min(chosen / self.sr, max_offset)
                elif self.segment_strategy == "hybrid" and valid and random.random() < 0.7:
                    start, end = random.choice(valid)
                    seg_max_start = max(start, end - self.segment_len)
                    chosen = random.randint(start, seg_max_start) if seg_max_start > start else start
                    offset = min(chosen / self.sr, max_offset)
                else:
                    offset = random.uniform(0, max_offset)
        start = int(offset * self.sr)
        y = full_y[start : start + self.segment_len]
        if len(y) < self.segment_len:
            y = np.pad(y, (0, self.segment_len - len(y)))
        return y

    def _to_mel(self, y: np.ndarray) -> torch.Tensor:
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
        )
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.FloatTensor(mel)

    def __getitem__(self, idx):
        song_id = self.sample_song_ids[idx]
        choices = self.by_song[song_id]
        if len(choices) == 1:
            a = b = choices[0]
        else:
            a, b = random.sample(choices, 2)

        type_to_weight = {
            "confused": 4.0,
            "humming_like": 2.5,
            "augmented": 1.4,
        }
        pair_weights = [
            type_to_weight.get(a.get("type", "unknown"), 1.0),
            type_to_weight.get(b.get("type", "unknown"), 1.0),
        ]

        wavs = []
        for sample in (a, b):
            y = self._load_clip(sample)
            if self.augment:
                from src.utils.augment import AugmentPipeline
                y = AugmentPipeline(self.sr, aggressive=sample.get("type") in {"confused", "humming_like"})(y)
            wavs.append(self._to_mel(y))

        max_t = max(w.shape[1] for w in wavs)
        wavs = [torch.nn.functional.pad(w, (0, max_t - w.shape[1])) if w.shape[1] < max_t else w for w in wavs]

        label = self.song_to_idx[song_id]
        return {
            "mel": torch.stack(wavs, dim=0),
            "song_id": torch.tensor([label, label], dtype=torch.long),
            "song_name": song_id,
            "hard_weight": torch.tensor(pair_weights, dtype=torch.float32),
        }