manifest_tools.py 12 KB
"""External dataset manifest conversion templates."""

from __future__ import annotations

import argparse
import csv
import json
import random
import shutil
import sys
from pathlib import Path
from typing import List, Dict
import numpy as np
import soundfile as sf
import librosa

ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.data.dataset import compute_candidate_offsets


def write_catalog(records: List[Dict], output_path: Path):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(records, f, indent=2, ensure_ascii=False)


def csv_to_catalog(csv_path: Path, output_path: Path, path_field: str = "audio_path", id_field: str = "song_id"):
    records = []
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            records.append(
                {
                    "song_id": row[id_field],
                    "audio_path": row[path_field],
                    "duration": float(row.get("duration", 0.0) or 0.0),
                    "type": "reference",
                    "source_dataset": row.get("source_dataset", "external"),
                }
            )
    write_catalog(records, output_path)
    return len(records)


def build_train_eval_from_audio_dir(
    audio_dir: Path,
    output_dir: Path,
    source_dataset: str,
    exts: tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg"),
    eval_ratio: float = 0.2,
    query_duration: float = 8.0,
    query_stride: float | None = None,
    query_strategy: str = "random",
    silence_top_db: int = 30,
    seed: int = 42,
):
    rng = random.Random(seed)
    files = [p for p in sorted(audio_dir.rglob("*")) if p.suffix.lower() in exts]
    output_dir.mkdir(parents=True, exist_ok=True)
    manifests_dir = output_dir / "manifests"
    manifests_dir.mkdir(parents=True, exist_ok=True)
    audio_out_dir = output_dir / "audio"
    audio_out_dir.mkdir(parents=True, exist_ok=True)

    refs = []
    train = []
    test = []

    def compute_strategy_offsets(path: Path, duration: float, strategy: str) -> List[float]:
        if duration < query_duration:
            return []
        try:
            y, sr = librosa.load(str(path), sr=None, mono=True)
            target_len = int(query_duration * sr)
            candidates = compute_candidate_offsets(
                y=y,
                sr=sr,
                segment_len=target_len,
                strategy=strategy,
                silence_top_db=silence_top_db,
            )
            offsets = []
            for start in candidates:
                start = int(start)
                if query_stride and query_stride > 0 and strategy in {"silence_aware"}:
                    offsets.append(round(start / sr, 3))
                else:
                    offsets.append(round(start / sr, 3))
            return sorted(set(x for x in offsets if x <= max(0.0, duration - query_duration)))
        except Exception:
            return []

    for idx, path in enumerate(files):
        target_name = f"{source_dataset}_{idx:05d}{path.suffix.lower()}"
        target_path = audio_out_dir / target_name
        if not target_path.exists():
            shutil.copy2(path, target_path)
        rel = target_path.relative_to(output_dir)
        song_id = f"{source_dataset}_{idx:05d}"
        try:
            info = sf.info(str(path))
            duration = float(info.duration)
        except Exception:
            duration = 0.0

        ref = {
            "song_id": song_id,
            "audio_path": str(rel),
            "duration": duration,
            "type": "reference",
            "source_dataset": source_dataset,
        }
        refs.append(ref)

        if duration >= query_duration:
            strategy_offsets = []
            if query_strategy in {"silence_aware", "high_energy", "onset_aware"}:
                strategy_offsets = compute_strategy_offsets(path, duration, query_strategy)
            elif query_strategy == "hybrid":
                for strategy in ("high_energy", "onset_aware", "silence_aware"):
                    strategy_offsets.extend(compute_strategy_offsets(path, duration, strategy))
                strategy_offsets = sorted(set(strategy_offsets))

            if query_strategy in {"silence_aware", "high_energy", "onset_aware"} and strategy_offsets:
                offsets = strategy_offsets
            elif query_strategy == "hybrid" and strategy_offsets:
                if query_stride and query_stride > 0:
                    offsets = strategy_offsets
                else:
                    max_offset = max(0.0, duration - query_duration)
                    random_offset = round(rng.uniform(0.0, max_offset) if max_offset > 0 else 0.0, 3)
                    offsets = sorted(set(strategy_offsets + [random_offset]))
            elif query_stride and query_stride > 0:
                max_offset = max(0.0, duration - query_duration)
                offsets = [round(x, 3) for x in np.arange(0.0, max_offset + 1e-9, query_stride).tolist()]
                if not offsets:
                    offsets = [0.0]
                if offsets[-1] < round(max_offset, 3):
                    offsets.append(round(max_offset, 3))
            else:
                max_offset = max(0.0, duration - query_duration)
                offsets = [round(rng.uniform(0.0, max_offset) if max_offset > 0 else 0.0, 3)]

            for seg_idx, offset in enumerate(offsets):
                query = {
                    "song_id": song_id,
                    "audio_path": str(rel),
                    "duration": query_duration,
                    "type": "clean",
                    "offset": offset,
                    "segment_type": "external_query",
                    "source_dataset": source_dataset,
                    "query_index": seg_idx,
                }
                if rng.random() < eval_ratio:
                    test.append(query)
                else:
                    train.append(query)

    if len(files) >= 2 and not train and test:
        train.append(test.pop())
    if len(files) >= 2 and not test and train:
        test.append(train.pop())

    write_catalog(refs, manifests_dir / "catalog.json")
    write_catalog(train + refs, manifests_dir / "train.json")
    write_catalog(test + refs, manifests_dir / "test.json")
    write_catalog([], manifests_dir / "val.json")
    return {
        "catalog": len(refs),
        "train_queries": len(train),
        "test_queries": len(test),
        "query_duration": query_duration,
        "query_stride": query_stride,
        "query_strategy": query_strategy,
        "output_dir": str(manifests_dir),
    }


def inspect_audio_dir(
    audio_dir: Path,
    exts: tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg"),
    query_duration: float = 8.0,
    eval_ratio: float = 0.2,
):
    files = [p for p in sorted(audio_dir.rglob("*")) if p.suffix.lower() in exts]
    durations = []
    eligible = 0
    for path in files:
        try:
            duration = float(sf.info(str(path)).duration)
        except Exception:
            duration = 0.0
        durations.append(duration)
        if duration >= query_duration:
            eligible += 1

    durations_sorted = sorted(durations)
    total = len(files)
    train_queries = max(0, eligible - max(1 if eligible >= 2 else 0, round(eligible * eval_ratio)))
    test_queries = 0 if eligible == 0 else max(1 if eligible >= 2 else eligible, round(eligible * eval_ratio))

    return {
        "audio_dir": str(audio_dir),
        "num_audio_files": total,
        "eligible_query_files": eligible,
        "query_duration": query_duration,
        "recommended_train_queries": train_queries,
        "recommended_test_queries": test_queries,
        "duration_stats": {
            "min": round(durations_sorted[0], 3) if durations_sorted else 0.0,
            "median": round(durations_sorted[len(durations_sorted) // 2], 3) if durations_sorted else 0.0,
            "max": round(durations_sorted[-1], 3) if durations_sorted else 0.0,
        },
    }


def validate_splits(manifests_dir: Path):
    required = ["catalog.json", "train.json", "test.json", "val.json"]
    missing = [name for name in required if not (manifests_dir / name).exists()]
    if missing:
        return {"ok": False, "missing_files": missing}

    catalog = json.loads((manifests_dir / "catalog.json").read_text())
    train = json.loads((manifests_dir / "train.json").read_text())
    test = json.loads((manifests_dir / "test.json").read_text())
    val = json.loads((manifests_dir / "val.json").read_text())

    catalog_refs = [x for x in catalog if x.get("type") == "reference"]
    train_queries = [x for x in train if x.get("type") != "reference"]
    test_queries = [x for x in test if x.get("type") != "reference"]
    val_queries = [x for x in val if x.get("type") != "reference"]

    source_values = {
        x.get("source_dataset", "unknown")
        for x in catalog_refs + train_queries + test_queries + val_queries
    }

    errors = []
    if not catalog_refs:
        errors.append("catalog_has_no_references")
    if not train_queries:
        errors.append("train_has_no_queries")
    if not test_queries:
        errors.append("test_has_no_queries")
    if len(source_values) > 1:
        errors.append("mixed_source_dataset_values")

    return {
        "ok": len(errors) == 0,
        "errors": errors,
        "catalog_references": len(catalog_refs),
        "train_queries": len(train_queries),
        "test_queries": len(test_queries),
        "val_queries": len(val_queries),
        "source_datasets": sorted(source_values),
    }


def main():
    parser = argparse.ArgumentParser()
    sub = parser.add_subparsers(dest="cmd", required=True)

    p = sub.add_parser("csv-to-catalog")
    p.add_argument("csv_path")
    p.add_argument("output_path")
    p.add_argument("--path-field", default="audio_path")
    p.add_argument("--id-field", default="song_id")

    p = sub.add_parser("audio-dir-to-splits")
    p.add_argument("audio_dir")
    p.add_argument("output_dir")
    p.add_argument("--source-dataset", required=True)
    p.add_argument("--eval-ratio", type=float, default=0.2)
    p.add_argument("--query-duration", type=float, default=8.0)
    p.add_argument("--query-stride", type=float, default=None)
    p.add_argument("--query-strategy", choices=["random", "sliding", "silence_aware", "high_energy", "onset_aware", "hybrid"], default="random")
    p.add_argument("--silence-top-db", type=int, default=30)
    p.add_argument("--seed", type=int, default=42)

    p = sub.add_parser("inspect-audio-dir")
    p.add_argument("audio_dir")
    p.add_argument("--query-duration", type=float, default=8.0)
    p.add_argument("--eval-ratio", type=float, default=0.2)

    p = sub.add_parser("validate-splits")
    p.add_argument("manifests_dir")

    args = parser.parse_args()
    if args.cmd == "csv-to-catalog":
        count = csv_to_catalog(Path(args.csv_path), Path(args.output_path), args.path_field, args.id_field)
        print(json.dumps({"status": "ok", "records": count}, ensure_ascii=False))
    elif args.cmd == "audio-dir-to-splits":
        summary = build_train_eval_from_audio_dir(
            Path(args.audio_dir),
            Path(args.output_dir),
            source_dataset=args.source_dataset,
            eval_ratio=args.eval_ratio,
            query_duration=args.query_duration,
            query_stride=args.query_stride,
            query_strategy=args.query_strategy,
            silence_top_db=args.silence_top_db,
            seed=args.seed,
        )
        print(json.dumps({"status": "ok", **summary}, ensure_ascii=False))
    elif args.cmd == "inspect-audio-dir":
        summary = inspect_audio_dir(
            Path(args.audio_dir),
            query_duration=args.query_duration,
            eval_ratio=args.eval_ratio,
        )
        print(json.dumps({"status": "ok", **summary}, ensure_ascii=False))
    elif args.cmd == "validate-splits":
        summary = validate_splits(Path(args.manifests_dir))
        print(json.dumps(summary, ensure_ascii=False))


if __name__ == "__main__":
    main()