ab_smoke_bucketed.py 6.7 KB
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import subprocess
from pathlib import Path
from statistics import mean

PYTHON = "/usr/local/miniconda3/bin/python"


def run(cmd: list[str], cwd: Path) -> str:
    return subprocess.check_output(cmd, cwd=str(cwd), text=True)


def collect_files(input_dir: Path, patterns: list[str], limit: int | None = None) -> list[Path]:
    seen: set[Path] = set()
    files: list[Path] = []
    for pattern in patterns:
        for path in sorted(input_dir.glob(pattern)):
            if not path.is_file() or path.suffix.lower() != ".mp3":
                continue
            resolved = path.resolve()
            if resolved in seen:
                continue
            seen.add(resolved)
            files.append(resolved)
            if limit is not None and len(files) >= limit:
                return files
    return files


def ensure_bucket_subset(input_dir: Path, bucket_dir: Path, patterns: list[str], limit: int | None) -> dict:
    bucket_dir.mkdir(parents=True, exist_ok=True)
    files = collect_files(input_dir, patterns, limit=limit)
    copied: list[str] = []
    for src in files:
        rel = src.relative_to(input_dir)
        dst = bucket_dir / rel
        dst.parent.mkdir(parents=True, exist_ok=True)
        if not dst.exists():
            dst.write_bytes(src.read_bytes())
        copied.append(str(dst))
    return {
        "num_files": len(copied),
        "sample_files": copied[:5],
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Run bucket/style-aware segmented smoke benchmarks")
    parser.add_argument("--dataset", default="fma")
    parser.add_argument("--input-dir", default="data/raw/fma_small_audio")
    parser.add_argument("--bucket-config", required=True, help="JSON file with {buckets:[{name,patterns,subset_size?}]} or {bucket_name:[patterns]}")
    parser.add_argument("--work-root", default="/tmp/ab_smoke_bucketed")
    parser.add_argument("--query-duration", type=float, default=8.0)
    parser.add_argument("--query-stride", type=float, default=None)
    parser.add_argument("--train-epochs", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--device", default="cpu")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-test-queries", type=int, default=None)
    parser.add_argument("--default-subset-size", type=int, default=16)
    parser.add_argument("--min-files", type=int, default=2)
    parser.add_argument("--strategies", nargs="*", default=["high_energy", "hybrid"])
    parser.add_argument("--output-json", default=None)
    args = parser.parse_args()

    repo = Path(__file__).resolve().parents[1]
    input_dir = (repo / args.input_dir).resolve()
    work_root = Path(args.work_root).resolve()
    config = json.loads(Path(args.bucket_config).read_text())

    if isinstance(config, dict) and "buckets" in config:
        bucket_specs = config["buckets"]
    elif isinstance(config, dict):
        bucket_specs = [{"name": k, "patterns": v} for k, v in config.items()]
    else:
        raise ValueError("bucket config must be an object")

    bucket_reports = []
    for spec in bucket_specs:
        name = spec["name"]
        patterns = spec["patterns"]
        subset_size = spec.get("subset_size", args.default_subset_size)
        bucket_root = work_root / name
        subset_dir = bucket_root / "bucket_input"
        subset_info = ensure_bucket_subset(input_dir, subset_dir, patterns, subset_size)
        if subset_info["num_files"] < args.min_files:
            bucket_reports.append({
                "bucket": name,
                "patterns": patterns,
                "subset_size": subset_info["num_files"],
                "skipped": True,
                "reason": f"num_files<{args.min_files}",
                "subset": subset_info,
            })
            continue

        cmd = [
            PYTHON,
            "scripts/ab_smoke_segmentation.py",
            "--dataset", args.dataset,
            "--input-dir", str(subset_dir),
            "--work-root", str(bucket_root / "run"),
            "--subset-size", str(subset_info["num_files"]),
            "--query-duration", str(args.query_duration),
            "--train-epochs", str(args.train_epochs),
            "--batch-size", str(args.batch_size),
            "--device", args.device,
            "--seed", str(args.seed),
            "--strategies", *args.strategies,
        ]
        if args.max_test_queries is not None:
            cmd += ["--max-test-queries", str(args.max_test_queries)]
        if args.query_stride is not None:
            cmd += ["--query-stride", str(args.query_stride)]
        out_json = bucket_root / "bucket_report.json"
        cmd += ["--output-json", str(out_json)]
        run(cmd, cwd=repo)
        result = json.loads(out_json.read_text())
        bucket_reports.append({
            "bucket": name,
            "patterns": patterns,
            "subset_size": subset_info["num_files"],
            "skipped": False,
            "subset": subset_info,
            "winner": result.get("winner"),
            "strategies": result.get("strategies", []),
        })

    strategy_aggregate: dict[str, dict[str, list[float]]] = {}
    for bucket in bucket_reports:
        if bucket.get("skipped"):
            continue
        for row in bucket["strategies"]:
            agg = strategy_aggregate.setdefault(row["strategy"], {"top1": [], "topk": [], "num_queries": []})
            agg["top1"].append(row["top1"])
            agg["topk"].append(row["topk"])
            agg["num_queries"].append(row["num_queries"])

    aggregate = {
        strategy: {
            "bucket_runs": len(vals["top1"]),
            "mean_top1": round(mean(vals["top1"]), 4),
            "mean_topk": round(mean(vals["topk"]), 4),
            "mean_num_queries": round(mean(vals["num_queries"]), 4),
        }
        for strategy, vals in strategy_aggregate.items() if vals["top1"]
    }

    report = {
        "dataset": args.dataset,
        "input_dir": str(input_dir),
        "bucket_config": str(Path(args.bucket_config).resolve()),
        "query_duration": args.query_duration,
        "query_stride": args.query_stride,
        "train_epochs": args.train_epochs,
        "batch_size": args.batch_size,
        "device": args.device,
        "seed": args.seed,
        "max_test_queries": args.max_test_queries,
        "strategies": args.strategies,
        "buckets": bucket_reports,
        "aggregate": aggregate,
    }
    text = json.dumps(report, ensure_ascii=False, indent=2)
    if args.output_json:
        out = Path(args.output_json)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(text)
    print(text)


if __name__ == "__main__":
    main()