dataset.py 4.49 KB
import torch
from torch.utils.data import Dataset
import numpy as np
import librosa
import random
from pathlib import Path
from typing import Dict, List, Tuple
import json
import os


class ACRDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
        segment_dur: float = 5.0,
        augment: bool = True,
        n_crops_per_song: int = 4,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.segment_len = int(segment_dur * sr)
        self.augment = augment
        self.n_crops = n_crops_per_song
        self.data_dir = Path(data_dir)

        meta_path = Path(data_dir) / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            song_path = Path(data_dir) / item["audio_path"]
            if song_path.exists():
                self.samples.append(item)
        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples) * self.n_crops

    def _load_segment(self, path: str, offset: float, duration: float) -> np.ndarray:
        y, _ = librosa.load(
            path, sr=self.sr, mono=True,
            offset=offset, duration=duration
        )
        if len(y) < self.segment_len:
            y = np.pad(y, (0, self.segment_len - len(y)))
        else:
            y = y[:self.segment_len]
        return y

    def _to_mel(self, y: np.ndarray) -> np.ndarray:
        mel = librosa.feature.melspectrogram(
            y=y, sr=self.sr, n_mels=self.n_mels,
            n_fft=self.n_fft, hop_length=self.hop_length
        )
        return librosa.power_to_db(mel, ref=np.max)

    def __getitem__(self, idx):
        sample = self.samples[idx // self.n_crops]
        duration = sample["duration"]
        max_offset = max(0, duration - 5.0)
        offset = random.uniform(0, max_offset) if max_offset > 0 else 0

        audio_path = self.data_dir / sample["audio_path"]
        y = self._load_segment(str(audio_path), offset, 5.0)

        if self.augment:
            from src.utils.augment import AugmentPipeline
            aug = AugmentPipeline(self.sr)
            y = aug(y)

        mel = self._to_mel(y)
        mel_tensor = torch.FloatTensor(mel)

        song_id = sample["song_id"]
        class_id = self.song_to_idx[song_id]

        return {
            "mel": mel_tensor,
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": song_id,
        }


class ACRTestDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "test",
        sr: int = 16000,
        n_mels: int = 80,
        n_fft: int = 512,
        hop_length: int = 160,
    ):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.data_dir = Path(data_dir)

        meta_path = Path(data_dir) / f"{split}.json"
        with open(meta_path) as f:
            self.metadata = json.load(f)

        self.samples = []
        for item in self.metadata:
            p = Path(data_dir) / item["audio_path"]
            if p.exists():
                self.samples.append(item)

        self.song_ids = sorted(set(s["song_id"] for s in self.samples))
        self.song_to_idx = {sid: i for i, sid in enumerate(self.song_ids)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        audio_path = self.data_dir / sample["audio_path"]
        y, _ = librosa.load(
            str(audio_path), sr=self.sr, mono=True,
            offset=0, duration=min(sample["duration"], 5.0)
        )
        seg_len = 5 * self.sr
        if len(y) < seg_len:
            y = np.pad(y, (0, seg_len - len(y)))
        else:
            y = y[:seg_len]

        mel = librosa.power_to_db(
            librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels,
                                            n_fft=self.n_fft, hop_length=self.hop_length),
            ref=np.max
        )
        class_id = self.song_to_idx[sample["song_id"]]
        return {
            "mel": torch.FloatTensor(mel),
            "song_id": torch.tensor(class_id, dtype=torch.long),
            "song_name": sample["song_id"],
        }