test_context_exporter.py
1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tempfile
import unittest
from pathlib import Path
import test_bootstrap
import numpy as np
import soundfile as sf
from src.utils.context_exporter import export_match_context, find_best_matching_window
class ContextExporterTests(unittest.TestCase):
def test_find_best_matching_window_returns_valid_range(self):
sr = 16000
with tempfile.TemporaryDirectory() as tmp:
query = Path(tmp) / 'query.wav'
ref = Path(tmp) / 'ref.wav'
tone = 0.2 * np.sin(2 * np.pi * 440 * np.linspace(0, 3, sr * 3, endpoint=False)).astype(np.float32)
ref_y = np.concatenate([np.zeros(sr), tone, np.zeros(sr)]).astype(np.float32)
sf.write(query, tone, sr)
sf.write(ref, ref_y, sr)
match = find_best_matching_window(str(query), str(ref), sr=sr, stride_sec=0.5)
self.assertGreaterEqual(match['window_start_sec'], 0.0)
self.assertGreater(match['window_end_sec'], match['window_start_sec'])
def test_export_match_context_writes_audio(self):
sr = 16000
with tempfile.TemporaryDirectory() as tmp:
ref = Path(tmp) / 'ref.wav'
out = Path(tmp) / 'context.wav'
y = 0.2 * np.sin(2 * np.pi * 440 * np.linspace(0, 12, sr * 12, endpoint=False)).astype(np.float32)
sf.write(ref, y, sr)
info = export_match_context(str(ref), 4.0, 7.0, str(out), context_sec=10.0, output_format='wav', sr=sr)
self.assertTrue(Path(info['output_path']).exists())
self.assertEqual(info['output_format'], 'wav')
if __name__ == '__main__':
unittest.main()