test_context_exporter.py 1.6 KB
import tempfile
import unittest
from pathlib import Path

import test_bootstrap

import numpy as np
import soundfile as sf

from src.utils.context_exporter import export_match_context, find_best_matching_window


class ContextExporterTests(unittest.TestCase):
    def test_find_best_matching_window_returns_valid_range(self):
        sr = 16000
        with tempfile.TemporaryDirectory() as tmp:
            query = Path(tmp) / 'query.wav'
            ref = Path(tmp) / 'ref.wav'
            tone = 0.2 * np.sin(2 * np.pi * 440 * np.linspace(0, 3, sr * 3, endpoint=False)).astype(np.float32)
            ref_y = np.concatenate([np.zeros(sr), tone, np.zeros(sr)]).astype(np.float32)
            sf.write(query, tone, sr)
            sf.write(ref, ref_y, sr)
            match = find_best_matching_window(str(query), str(ref), sr=sr, stride_sec=0.5)
            self.assertGreaterEqual(match['window_start_sec'], 0.0)
            self.assertGreater(match['window_end_sec'], match['window_start_sec'])

    def test_export_match_context_writes_audio(self):
        sr = 16000
        with tempfile.TemporaryDirectory() as tmp:
            ref = Path(tmp) / 'ref.wav'
            out = Path(tmp) / 'context.wav'
            y = 0.2 * np.sin(2 * np.pi * 440 * np.linspace(0, 12, sr * 12, endpoint=False)).astype(np.float32)
            sf.write(ref, y, sr)
            info = export_match_context(str(ref), 4.0, 7.0, str(out), context_sec=10.0, output_format='wav', sr=sr)
            self.assertTrue(Path(info['output_path']).exists())
            self.assertEqual(info['output_format'], 'wav')


if __name__ == '__main__':
    unittest.main()