augment.py
4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import random
from pathlib import Path
from typing import Iterable, Optional, Tuple
import librosa
import soundfile as sf
try:
from audiomentations import AddBackgroundNoise, AddGaussianNoise, BandPassFilter, Compose, Mp3Compression, PitchShift, TimeStretch
HAS_AUDIO_AUG = True
except Exception:
AddBackgroundNoise = AddGaussianNoise = BandPassFilter = Compose = Mp3Compression = PitchShift = TimeStretch = None
HAS_AUDIO_AUG = False
class NoiseLibrary:
def __init__(self, roots: Optional[Iterable[str]] = None):
self.paths = []
for root in roots or []:
base = Path(root)
if not base.exists():
continue
for pattern in ("*.wav", "*.mp3", "*.flac", "*.ogg", "*.m4a"):
self.paths.extend(base.rglob(pattern))
def directories(self) -> list[str]:
if not self.paths:
return []
return sorted({str(path.parent) for path in self.paths})
class AugmentPipeline:
def __init__(
self,
sr: int = 16000,
aggressive: bool = False,
noise_roots: Optional[Iterable[str]] = None,
freq_mask_prob: float = 0.3,
):
self.sr = sr
self.aggressive = aggressive
self.freq_mask_prob = freq_mask_prob
self.noise_library = NoiseLibrary(noise_roots)
self.wave_augment = self._build_wave_augmenter()
def _build_wave_augmenter(self):
if not HAS_AUDIO_AUG:
return None
transforms = [
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.5 if not self.aggressive else 0.8),
BandPassFilter(
min_center_freq=300.0,
max_center_freq=3200.0,
min_bandwidth_fraction=0.3,
max_bandwidth_fraction=0.8,
p=0.35 if not self.aggressive else 0.55,
),
Mp3Compression(min_bitrate=24, max_bitrate=96, p=0.35 if not self.aggressive else 0.55),
PitchShift(min_semitones=-5, max_semitones=5, p=0.35 if not self.aggressive else 0.55),
TimeStretch(min_rate=0.8, max_rate=1.2, p=0.35 if not self.aggressive else 0.55),
]
noise_dirs = self.noise_library.directories()
if noise_dirs:
transforms.append(
AddBackgroundNoise(
sounds_path=noise_dirs,
min_snr_db=3.0 if self.aggressive else 8.0,
max_snr_db=20.0 if self.aggressive else 30.0,
noise_transform=Compose([
BandPassFilter(
min_center_freq=250.0,
max_center_freq=4000.0,
min_bandwidth_fraction=0.2,
max_bandwidth_fraction=0.9,
p=0.5,
)
]),
p=0.35 if not self.aggressive else 0.6,
)
)
return Compose(transforms)
def apply_spec_augment(self, mel: np.ndarray, max_time_mask: int = 20, max_freq_mask: int = 12) -> np.ndarray:
mel = mel.copy()
t = mel.shape[1]
f = mel.shape[0]
for _ in range(2):
t_mask = random.randint(0, max_time_mask)
t_start = random.randint(0, max(0, t - t_mask))
if t_start < t:
mel[:, t_start:t_start + t_mask] = 0
for _ in range(2):
f_mask = random.randint(max(1, max_freq_mask // 3), max_freq_mask)
f_start = random.randint(0, max(0, f - f_mask))
if f_start < f:
mel[f_start:f_start + f_mask, :] = 0
return mel
def apply_to_mel(self, mel: np.ndarray) -> np.ndarray:
if random.random() < self.freq_mask_prob:
mel = self.apply_spec_augment(mel)
return mel
def __call__(self, y: np.ndarray) -> np.ndarray:
if self.wave_augment is None:
return y
try:
return self.wave_augment(samples=y.astype(np.float32), sample_rate=self.sr)
except Exception:
return y