augment.py 3.41 KB
import numpy as np
import random
from typing import Optional, Tuple


class AugmentPipeline:
    def __init__(self, sr: int = 16000, aggressive: bool = False):
        self.sr = sr
        self.noise_snr_range = (5, 30)
        self.pitch_shift_range = (-6, 6)
        self.time_stretch_range = (0.85, 1.15)
        self.mp3_bitrate_range = (32, 128)
        self.aggressive = aggressive

    def add_noise(self, y: np.ndarray, snr_db: Optional[float] = None) -> np.ndarray:
        if snr_db is None:
            snr_db = random.uniform(*self.noise_snr_range)
        signal_power = np.mean(y ** 2)
        noise_power = signal_power / (10 ** (snr_db / 10))
        noise = np.random.randn(len(y)) * np.sqrt(noise_power)
        return y + noise

    def pitch_shift(self, y: np.ndarray, semitones: Optional[float] = None) -> np.ndarray:
        if semitones is None:
            semitones = random.uniform(*self.pitch_shift_range)
        return librosa_shift(y, sr=self.sr, n_steps=semitones)

    def time_stretch(self, y: np.ndarray, rate: Optional[float] = None) -> np.ndarray:
        if rate is None:
            rate = random.uniform(*self.time_stretch_range)
        return librosa_ts(y, sr=self.sr, rate=rate)

    def add_reverb(self, y: np.ndarray, decay: float = 0.3) -> np.ndarray:
        ir_len = int(0.1 * self.sr)
        ir = np.exp(-np.arange(ir_len) * decay / ir_len) * np.random.randn(ir_len)
        ir /= np.sqrt(np.sum(ir ** 2))
        return np.convolve(y, ir, mode='same')[:len(y)]

    def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 8) -> np.ndarray:
        mel = mel.copy()
        t = mel.shape[1]
        f = mel.shape[0]
        for _ in range(2):
            t_mask = random.randint(0, max_time_mask)
            t_start = random.randint(0, max(0, t - t_mask))
            if t_start < t:
                mel[:, t_start:t_start + t_mask] = 0
        for _ in range(2):
            f_mask = random.randint(0, max_freq_mask)
            f_start = random.randint(0, max(0, f - f_mask))
            if f_start < f:
                mel[f_start:f_start + f_mask, :] = 0
        return mel

    def apply_to_mel(self, mel: np.ndarray) -> np.ndarray:
        if random.random() < 0.3:
            mel = self.apply_spec_augment(mel)
        return mel

    def __call__(self, y: np.ndarray) -> np.ndarray:
        noise_p = 0.75 if self.aggressive else 0.5
        stretch_p = 0.55 if self.aggressive else 0.3
        pitch_p = 0.55 if self.aggressive else 0.3
        reverb_p = 0.35 if self.aggressive else 0.2
        if random.random() < noise_p:
            y = self.add_noise(y, snr_db=random.uniform(0, 18) if self.aggressive else None)
        if random.random() < stretch_p:
            y = self.time_stretch(y, rate=random.uniform(0.8, 1.2) if self.aggressive else None)
        if random.random() < pitch_p:
            y = self.pitch_shift(y, semitones=random.uniform(-8, 8) if self.aggressive else None)
        if random.random() < reverb_p:
            y = self.add_reverb(y, decay=random.uniform(0.2, 0.6))
        return y


def librosa_shift(y, sr=16000, n_steps=0):
    return librosa_impl(y, lambda: __import__('librosa').effects.pitch_shift(y, sr=sr, n_steps=n_steps))


def librosa_ts(y, sr=16000, rate=1.0):
    return librosa_impl(y, lambda: __import__('librosa').effects.time_stretch(y, rate=rate))


def librosa_impl(y, fn):
    try:
        return fn()
    except Exception:
        return y