context_exporter.py 3.26 KB
from __future__ import annotations

import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Dict, Tuple

import librosa
import numpy as np
import soundfile as sf


def load_audio(audio_path: str, sr: int = 16000) -> np.ndarray:
    y, _ = librosa.load(audio_path, sr=sr, mono=True)
    return y.astype(np.float32)


def chroma_embedding(y: np.ndarray, sr: int) -> np.ndarray:
    chroma = librosa.feature.chroma_stft(y=y, sr=sr, n_chroma=12)
    feat = np.concatenate([chroma.mean(axis=1), chroma.std(axis=1)], axis=0).astype(np.float32)
    norm = np.linalg.norm(feat)
    return feat / norm if norm > 0 else feat


def find_best_matching_window(
    query_audio_path: str,
    reference_audio_path: str,
    sr: int = 16000,
    stride_sec: float = 1.0,
) -> Dict:
    query_y = load_audio(query_audio_path, sr=sr)
    ref_y = load_audio(reference_audio_path, sr=sr)
    query_len = len(query_y)
    if query_len == 0:
        raise ValueError('Empty query audio')
    if len(ref_y) < query_len:
        ref_y = np.pad(ref_y, (0, query_len - len(ref_y)))

    query_feat = chroma_embedding(query_y, sr)
    stride = max(1, int(sr * stride_sec))
    best_score = -1.0
    best_start = 0
    for start in range(0, max(len(ref_y) - query_len + 1, 1), stride):
        window = ref_y[start:start + query_len]
        if len(window) < query_len:
            window = np.pad(window, (0, query_len - len(window)))
        score = float(np.dot(query_feat, chroma_embedding(window, sr)))
        if score > best_score:
            best_score = score
            best_start = start

    return {
        'window_start_sec': round(best_start / sr, 4),
        'window_end_sec': round((best_start + query_len) / sr, 4),
        'window_score': round(best_score, 6),
        'query_duration_sec': round(query_len / sr, 4),
    }


def export_match_context(
    audio_path: str,
    window_start_sec: float,
    window_end_sec: float,
    output_path: str,
    context_sec: float = 10.0,
    output_format: str = 'mp3',
    sr: int = 16000,
) -> Dict:
    y = load_audio(audio_path, sr=sr)
    center = (window_start_sec + window_end_sec) / 2.0
    half = context_sec / 2.0
    clip_start_sec = max(0.0, center - half)
    clip_end_sec = min(len(y) / sr, center + half)
    start = int(clip_start_sec * sr)
    end = max(start + 1, int(clip_end_sec * sr))
    clip = y[start:end]

    output = Path(output_path)
    output.parent.mkdir(parents=True, exist_ok=True)
    actual_format = output_format

    if output_format == 'mp3' and shutil.which('ffmpeg'):
        with tempfile.TemporaryDirectory() as tmp:
            wav_path = Path(tmp) / 'context.wav'
            sf.write(wav_path, clip, sr)
            cmd = [shutil.which('ffmpeg') or 'ffmpeg', '-y', '-i', str(wav_path), str(output)]
            subprocess.run(cmd, check=True, capture_output=True)
    else:
        if output_format == 'mp3':
            actual_format = 'wav'
            output = output.with_suffix('.wav')
        sf.write(output, clip, sr)

    return {
        'source_audio_path': audio_path,
        'clip_start_sec': round(clip_start_sec, 4),
        'clip_end_sec': round(clip_end_sec, 4),
        'duration_sec': round((end - start) / sr, 4),
        'output_path': str(output),
        'output_format': actual_format,
    }