augment.py 4.06 KB
import numpy as np
import random
from pathlib import Path
from typing import Iterable, Optional, Tuple

import librosa
import soundfile as sf

try:
    from audiomentations import AddBackgroundNoise, AddGaussianNoise, BandPassFilter, Compose, Mp3Compression, PitchShift, TimeStretch
    HAS_AUDIO_AUG = True
except Exception:
    AddBackgroundNoise = AddGaussianNoise = BandPassFilter = Compose = Mp3Compression = PitchShift = TimeStretch = None
    HAS_AUDIO_AUG = False


class NoiseLibrary:
    def __init__(self, roots: Optional[Iterable[str]] = None):
        self.paths = []
        for root in roots or []:
            base = Path(root)
            if not base.exists():
                continue
            for pattern in ("*.wav", "*.mp3", "*.flac", "*.ogg", "*.m4a"):
                self.paths.extend(base.rglob(pattern))

    def directories(self) -> list[str]:
        if not self.paths:
            return []
        return sorted({str(path.parent) for path in self.paths})


class AugmentPipeline:
    def __init__(
        self,
        sr: int = 16000,
        aggressive: bool = False,
        noise_roots: Optional[Iterable[str]] = None,
        freq_mask_prob: float = 0.3,
    ):
        self.sr = sr
        self.aggressive = aggressive
        self.freq_mask_prob = freq_mask_prob
        self.noise_library = NoiseLibrary(noise_roots)
        self.wave_augment = self._build_wave_augmenter()

    def _build_wave_augmenter(self):
        if not HAS_AUDIO_AUG:
            return None
        transforms = [
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.5 if not self.aggressive else 0.8),
            BandPassFilter(
                min_center_freq=300.0,
                max_center_freq=3200.0,
                min_bandwidth_fraction=0.3,
                max_bandwidth_fraction=0.8,
                p=0.35 if not self.aggressive else 0.55,
            ),
            Mp3Compression(min_bitrate=24, max_bitrate=96, p=0.35 if not self.aggressive else 0.55),
            PitchShift(min_semitones=-5, max_semitones=5, p=0.35 if not self.aggressive else 0.55),
            TimeStretch(min_rate=0.8, max_rate=1.2, p=0.35 if not self.aggressive else 0.55),
        ]
        noise_dirs = self.noise_library.directories()
        if noise_dirs:
            transforms.append(
                AddBackgroundNoise(
                    sounds_path=noise_dirs,
                    min_snr_db=3.0 if self.aggressive else 8.0,
                    max_snr_db=20.0 if self.aggressive else 30.0,
                    noise_transform=Compose([
                        BandPassFilter(
                            min_center_freq=250.0,
                            max_center_freq=4000.0,
                            min_bandwidth_fraction=0.2,
                            max_bandwidth_fraction=0.9,
                            p=0.5,
                        )
                    ]),
                    p=0.35 if not self.aggressive else 0.6,
                )
            )
        return Compose(transforms)

    def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 12) -> np.ndarray:
        mel = mel.copy()
        t = mel.shape[1]
        f = mel.shape[0]
        for _ in range(2):
            t_mask = random.randint(0, max_time_mask)
            t_start = random.randint(0, max(0, t - t_mask))
            if t_start < t:
                mel[:, t_start:t_start + t_mask] = 0
        for _ in range(2):
            f_mask = random.randint(max(1, max_freq_mask // 3), max_freq_mask)
            f_start = random.randint(0, max(0, f - f_mask))
            if f_start < f:
                mel[f_start:f_start + f_mask, :] = 0
        return mel

    def apply_to_mel(self, mel: np.ndarray) -> np.ndarray:
        if random.random() < self.freq_mask_prob:
            mel = self.apply_spec_augment(mel)
        return mel

    def __call__(self, y: np.ndarray) -> np.ndarray:
        if self.wave_augment is None:
            return y
        try:
            return self.wave_augment(samples=y.astype(np.float32), sample_rate=self.sr)
        except Exception:
            return y