voice_chunker.py 4.48 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import List, Dict

import librosa
import numpy as np
import soundfile as sf


def normalize_audio(audio_path: str, sr: int = 16000) -> np.ndarray:
    y, _ = librosa.load(audio_path, sr=sr, mono=True)
    return y.astype(np.float32)


def detect_voiced_intervals(y: np.ndarray, sr: int, top_db: int = 30, min_voiced_sec: float = 2.0) -> List[tuple[int, int]]:
    intervals = librosa.effects.split(y, top_db=top_db)
    min_len = int(sr * min_voiced_sec)
    kept = []
    for start, end in intervals:
        if end - start >= min_len:
            kept.append((int(start), int(end)))
    return kept


def chunk_intervals(intervals: List[tuple[int, int]], sr: int, target_chunk_sec: float = 8.0, stride_sec: float = 4.0, max_chunks: int = 3) -> List[tuple[int, int, bool]]:
    chunk_len = int(sr * target_chunk_sec)
    stride = int(sr * stride_sec)
    chunks: List[tuple[int, int, bool]] = []
    for start, end in intervals:
        seg_len = end - start
        if seg_len < chunk_len:
            chunks.append((start, end, True))
            continue
        pos = start
        while pos + chunk_len <= end:
            chunks.append((pos, pos + chunk_len, False))
            pos += stride
        if pos < end and end - pos >= int(sr * 2.0):
            tail_start = max(start, end - chunk_len)
            chunks.append((tail_start, end, end - tail_start < chunk_len))
    deduped = []
    seen = set()
    for item in chunks:
        key = (item[0], item[1])
        if key not in seen:
            deduped.append(item)
            seen.add(key)
    if max_chunks > 0 and len(deduped) > max_chunks:
        return deduped[:max_chunks]
    return deduped


def write_chunks(y: np.ndarray, sr: int, chunks: List[tuple[int, int, bool]], output_dir: str, source_audio_path: str) -> List[Dict]:
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    chunk_len = None
    results = []
    for idx, (start, end, padded) in enumerate(chunks):
        clip = y[start:end]
        if chunk_len is None:
            chunk_len = max(len(clip), 1)
        target_len = max(chunk_len, len(clip))
        if padded and len(clip) < target_len:
            clip = np.pad(clip, (0, target_len - len(clip)))
        chunk_path = out_dir / f'chunk_{idx:03d}.wav'
        sf.write(str(chunk_path), clip, sr)
        results.append({
            'chunk_id': f'chunk_{idx:03d}',
            'audio_path': str(chunk_path),
            'start_sec': round(start / sr, 4),
            'end_sec': round(end / sr, 4),
            'duration_sec': round(len(clip) / sr, 4),
            'padded': padded,
            'source_audio_path': source_audio_path,
        })
    return results


def voice_to_chunks(audio_path: str, output_dir: str, target_chunk_sec: float = 8.0, stride_sec: float = 4.0, min_voiced_sec: float = 2.0, top_db: int = 30, sr: int = 16000, max_chunks: int = 3) -> List[Dict]:
    y = normalize_audio(audio_path, sr=sr)
    intervals = detect_voiced_intervals(y, sr=sr, top_db=top_db, min_voiced_sec=min_voiced_sec)
    chunks = chunk_intervals(intervals, sr=sr, target_chunk_sec=target_chunk_sec, stride_sec=stride_sec, max_chunks=max_chunks)
    return write_chunks(y, sr, chunks, output_dir, source_audio_path=audio_path)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument('--input', required=True)
    ap.add_argument('--output-dir', required=True)
    ap.add_argument('--target-chunk-sec', type=float, default=8.0)
    ap.add_argument('--stride-sec', type=float, default=4.0)
    ap.add_argument('--min-voiced-sec', type=float, default=2.0)
    ap.add_argument('--top-db', type=int, default=30)
    ap.add_argument('--sr', type=int, default=16000)
    ap.add_argument('--max-chunks', type=int, default=3)
    ap.add_argument('--output-json', default='chunks.json')
    args = ap.parse_args()
    chunks = voice_to_chunks(
        audio_path=args.input,
        output_dir=args.output_dir,
        target_chunk_sec=args.target_chunk_sec,
        stride_sec=args.stride_sec,
        min_voiced_sec=args.min_voiced_sec,
        top_db=args.top_db,
        sr=args.sr,
        max_chunks=args.max_chunks,
    )
    out_json = Path(args.output_dir) / args.output_json
    out_json.write_text(json.dumps({'chunks': chunks}, ensure_ascii=False, indent=2), encoding='utf-8')
    print(json.dumps({'chunks': chunks}, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()