manifest_tools.py 9.38 KB
"""External dataset manifest conversion templates."""

from __future__ import annotations

import argparse
import csv
import json
import random
import shutil
from pathlib import Path
from typing import List, Dict
import numpy as np
import soundfile as sf


def write_catalog(records: List[Dict], output_path: Path):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(records, f, indent=2, ensure_ascii=False)


def csv_to_catalog(csv_path: Path, output_path: Path, path_field: str = "audio_path", id_field: str = "song_id"):
    records = []
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            records.append(
                {
                    "song_id": row[id_field],
                    "audio_path": row[path_field],
                    "duration": float(row.get("duration", 0.0) or 0.0),
                    "type": "reference",
                    "source_dataset": row.get("source_dataset", "external"),
                }
            )
    write_catalog(records, output_path)
    return len(records)


def build_train_eval_from_audio_dir(
    audio_dir: Path,
    output_dir: Path,
    source_dataset: str,
    exts: tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg"),
    eval_ratio: float = 0.2,
    query_duration: float = 8.0,
    query_stride: float | None = None,
    seed: int = 42,
):
    rng = random.Random(seed)
    files = [p for p in sorted(audio_dir.rglob("*")) if p.suffix.lower() in exts]
    output_dir.mkdir(parents=True, exist_ok=True)
    manifests_dir = output_dir / "manifests"
    manifests_dir.mkdir(parents=True, exist_ok=True)
    audio_out_dir = output_dir / "audio"
    audio_out_dir.mkdir(parents=True, exist_ok=True)

    refs = []
    train = []
    test = []

    for idx, path in enumerate(files):
        target_name = f"{source_dataset}_{idx:05d}{path.suffix.lower()}"
        target_path = audio_out_dir / target_name
        if not target_path.exists():
            shutil.copy2(path, target_path)
        rel = target_path.relative_to(output_dir)
        song_id = f"{source_dataset}_{idx:05d}"
        try:
            info = sf.info(str(path))
            duration = float(info.duration)
        except Exception:
            duration = 0.0

        ref = {
            "song_id": song_id,
            "audio_path": str(rel),
            "duration": duration,
            "type": "reference",
            "source_dataset": source_dataset,
        }
        refs.append(ref)

        if duration >= query_duration:
            if query_stride and query_stride > 0:
                max_offset = max(0.0, duration - query_duration)
                offsets = [round(x, 3) for x in np.arange(0.0, max_offset + 1e-9, query_stride).tolist()]
                if not offsets:
                    offsets = [0.0]
                if offsets[-1] < round(max_offset, 3):
                    offsets.append(round(max_offset, 3))
            else:
                max_offset = max(0.0, duration - query_duration)
                offsets = [round(rng.uniform(0.0, max_offset) if max_offset > 0 else 0.0, 3)]

            for seg_idx, offset in enumerate(offsets):
                query = {
                    "song_id": song_id,
                    "audio_path": str(rel),
                    "duration": query_duration,
                    "type": "clean",
                    "offset": offset,
                    "segment_type": "external_query",
                    "source_dataset": source_dataset,
                    "query_index": seg_idx,
                }
                if rng.random() < eval_ratio:
                    test.append(query)
                else:
                    train.append(query)

    if len(files) >= 2 and not train and test:
        train.append(test.pop())
    if len(files) >= 2 and not test and train:
        test.append(train.pop())

    write_catalog(refs, manifests_dir / "catalog.json")
    write_catalog(train + refs, manifests_dir / "train.json")
    write_catalog(test + refs, manifests_dir / "test.json")
    write_catalog([], manifests_dir / "val.json")
    return {
        "catalog": len(refs),
        "train_queries": len(train),
        "test_queries": len(test),
        "query_duration": query_duration,
        "query_stride": query_stride,
        "output_dir": str(manifests_dir),
    }


def inspect_audio_dir(
    audio_dir: Path,
    exts: tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg"),
    query_duration: float = 8.0,
    eval_ratio: float = 0.2,
):
    files = [p for p in sorted(audio_dir.rglob("*")) if p.suffix.lower() in exts]
    durations = []
    eligible = 0
    for path in files:
        try:
            duration = float(sf.info(str(path)).duration)
        except Exception:
            duration = 0.0
        durations.append(duration)
        if duration >= query_duration:
            eligible += 1

    durations_sorted = sorted(durations)
    total = len(files)
    train_queries = max(0, eligible - max(1 if eligible >= 2 else 0, round(eligible * eval_ratio)))
    test_queries = 0 if eligible == 0 else max(1 if eligible >= 2 else eligible, round(eligible * eval_ratio))

    return {
        "audio_dir": str(audio_dir),
        "num_audio_files": total,
        "eligible_query_files": eligible,
        "query_duration": query_duration,
        "recommended_train_queries": train_queries,
        "recommended_test_queries": test_queries,
        "duration_stats": {
            "min": round(durations_sorted[0], 3) if durations_sorted else 0.0,
            "median": round(durations_sorted[len(durations_sorted) // 2], 3) if durations_sorted else 0.0,
            "max": round(durations_sorted[-1], 3) if durations_sorted else 0.0,
        },
    }


def validate_splits(manifests_dir: Path):
    required = ["catalog.json", "train.json", "test.json", "val.json"]
    missing = [name for name in required if not (manifests_dir / name).exists()]
    if missing:
        return {"ok": False, "missing_files": missing}

    catalog = json.loads((manifests_dir / "catalog.json").read_text())
    train = json.loads((manifests_dir / "train.json").read_text())
    test = json.loads((manifests_dir / "test.json").read_text())
    val = json.loads((manifests_dir / "val.json").read_text())

    catalog_refs = [x for x in catalog if x.get("type") == "reference"]
    train_queries = [x for x in train if x.get("type") != "reference"]
    test_queries = [x for x in test if x.get("type") != "reference"]
    val_queries = [x for x in val if x.get("type") != "reference"]

    source_values = {
        x.get("source_dataset", "unknown")
        for x in catalog_refs + train_queries + test_queries + val_queries
    }

    errors = []
    if not catalog_refs:
        errors.append("catalog_has_no_references")
    if not train_queries:
        errors.append("train_has_no_queries")
    if not test_queries:
        errors.append("test_has_no_queries")
    if len(source_values) > 1:
        errors.append("mixed_source_dataset_values")

    return {
        "ok": len(errors) == 0,
        "errors": errors,
        "catalog_references": len(catalog_refs),
        "train_queries": len(train_queries),
        "test_queries": len(test_queries),
        "val_queries": len(val_queries),
        "source_datasets": sorted(source_values),
    }


def main():
    parser = argparse.ArgumentParser()
    sub = parser.add_subparsers(dest="cmd", required=True)

    p = sub.add_parser("csv-to-catalog")
    p.add_argument("csv_path")
    p.add_argument("output_path")
    p.add_argument("--path-field", default="audio_path")
    p.add_argument("--id-field", default="song_id")

    p = sub.add_parser("audio-dir-to-splits")
    p.add_argument("audio_dir")
    p.add_argument("output_dir")
    p.add_argument("--source-dataset", required=True)
    p.add_argument("--eval-ratio", type=float, default=0.2)
    p.add_argument("--query-duration", type=float, default=8.0)
    p.add_argument("--query-stride", type=float, default=None)
    p.add_argument("--seed", type=int, default=42)

    p = sub.add_parser("inspect-audio-dir")
    p.add_argument("audio_dir")
    p.add_argument("--query-duration", type=float, default=8.0)
    p.add_argument("--eval-ratio", type=float, default=0.2)

    p = sub.add_parser("validate-splits")
    p.add_argument("manifests_dir")

    args = parser.parse_args()
    if args.cmd == "csv-to-catalog":
        count = csv_to_catalog(Path(args.csv_path), Path(args.output_path), args.path_field, args.id_field)
        print(json.dumps({"status": "ok", "records": count}, ensure_ascii=False))
    elif args.cmd == "audio-dir-to-splits":
        summary = build_train_eval_from_audio_dir(
            Path(args.audio_dir),
            Path(args.output_dir),
            source_dataset=args.source_dataset,
            eval_ratio=args.eval_ratio,
            query_duration=args.query_duration,
            query_stride=args.query_stride,
            seed=args.seed,
        )
        print(json.dumps({"status": "ok", **summary}, ensure_ascii=False))
    elif args.cmd == "inspect-audio-dir":
        summary = inspect_audio_dir(
            Path(args.audio_dir),
            query_duration=args.query_duration,
            eval_ratio=args.eval_ratio,
        )
        print(json.dumps({"status": "ok", **summary}, ensure_ascii=False))
    elif args.cmd == "validate-splits":
        summary = validate_splits(Path(args.manifests_dir))
        print(json.dumps(summary, ensure_ascii=False))


if __name__ == "__main__":
    main()