ab_smoke_segmentation.py 6.06 KB
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import shutil
import subprocess
from pathlib import Path


PYTHON = "/usr/local/miniconda3/bin/python"
DEFAULT_STRATEGIES = [
    "random",
    "silence_aware",
    "high_energy",
    "beat_aware",
    "repeated_section_aware",
    "hybrid",
]


def run(cmd: list[str], cwd: Path) -> str:
    return subprocess.check_output(cmd, cwd=str(cwd), text=True)


def parse_last_json(text: str) -> dict:
    for start in range(len(text) - 1, -1, -1):
        if text[start] != "{":
            continue
        try:
            return json.loads(text[start:])
        except json.JSONDecodeError:
            continue
    raise ValueError("No JSON object found in command output")


def prepare_subset(src_dir: Path, subset_dir: Path, limit: int) -> dict:
    files = sorted(src_dir.rglob("*.mp3"))[:limit]
    subset_dir.mkdir(parents=True, exist_ok=True)
    copied = []
    for src in files:
        rel = src.relative_to(src_dir)
        dst = subset_dir / rel
        dst.parent.mkdir(parents=True, exist_ok=True)
        if not dst.exists():
            shutil.copy2(src, dst)
        copied.append(str(dst))
    return {
        "source_dir": str(src_dir),
        "subset_dir": str(subset_dir),
        "num_files": len(copied),
        "sample_files": copied[:5],
    }


def train_strategy_for_query(strategy: str) -> str:
    if strategy == "sliding":
        return "random"
    return strategy


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="fma")
    parser.add_argument("--input-dir", default="data/raw/fma_small_audio")
    parser.add_argument("--work-root", default="data/ab_smoke_segmentation")
    parser.add_argument("--subset-size", type=int, default=12)
    parser.add_argument("--query-duration", type=float, default=8.0)
    parser.add_argument("--query-stride", type=float, default=None)
    parser.add_argument("--train-epochs", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--device", default="cpu")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-test-queries", type=int, default=None)
    parser.add_argument("--strategies", nargs="*", default=DEFAULT_STRATEGIES)
    parser.add_argument("--output-json", default=None)
    parser.add_argument("--resume", action="store_true")
    args = parser.parse_args()

    repo = Path(__file__).resolve().parents[1]
    input_dir = (repo / args.input_dir).resolve()
    work_root = (repo / args.work_root).resolve()
    subset_dir = work_root / "subset_audio"
    subset_info = prepare_subset(input_dir, subset_dir, args.subset_size)
    progress_path = work_root / "progress.json"
    cached_results = {}
    if args.resume and progress_path.exists():
        try:
            payload = json.loads(progress_path.read_text())
            cached_results = {item["strategy"]: item for item in payload.get("strategies", [])}
        except Exception:
            cached_results = {}

    results = []
    for strategy in args.strategies:
        if strategy in cached_results:
            results.append(cached_results[strategy])
            continue
        smoke_root = work_root / strategy
        if smoke_root.exists():
            shutil.rmtree(smoke_root)
        smoke_root.mkdir(parents=True, exist_ok=True)

        cmd = [
            PYTHON,
            "src/data/external_adapters.py",
            "smoke-local",
            args.dataset,
            str(subset_dir),
            "--output-root",
            str(smoke_root),
            "--eval-ratio",
            "0.2",
            "--query-duration",
            str(args.query_duration),
            "--query-strategy",
            strategy,
            "--segment-strategy",
            train_strategy_for_query(strategy),
            "--train-epochs",
            str(args.train_epochs),
            "--batch-size",
            str(args.batch_size),
            "--device",
            args.device,
            *([] if args.max_test_queries is None else ["--max-test-queries", str(args.max_test_queries)]),
            "--seed",
            str(args.seed),
        ]
        if args.query_stride is not None:
            cmd.extend(["--query-stride", str(args.query_stride)])

        output = run(cmd, cwd=repo)
        summary = parse_last_json(output)
        eval_json = Path(summary["eval_json"])
        eval_report = json.loads(eval_json.read_text())
        results.append({
            "strategy": strategy,
            "train_segment_strategy": train_strategy_for_query(strategy),
            "num_queries": eval_report["num_queries"],
            "top1": eval_report["top1"],
            "topk": eval_report["topk"],
            "eval_json": str(eval_json),
            "report_dir": summary["report_dir"],
            "sample_failures": eval_report.get("sample_failures", [])[:3],
        })
        progress_payload = {
            "dataset": args.dataset,
            "subset": subset_info,
            "query_duration": args.query_duration,
            "query_stride": args.query_stride,
            "train_epochs": args.train_epochs,
            "batch_size": args.batch_size,
            "device": args.device,
            "strategies": results,
        }
        progress_path.write_text(json.dumps(progress_payload, ensure_ascii=False, indent=2))

    results.sort(key=lambda x: (x["top1"], x["topk"], x["num_queries"]), reverse=True)
    report = {
        "dataset": args.dataset,
        "subset": subset_info,
        "query_duration": args.query_duration,
        "query_stride": args.query_stride,
        "train_epochs": args.train_epochs,
        "batch_size": args.batch_size,
        "device": args.device,
        "max_test_queries": args.max_test_queries,
        "strategies": results,
        "winner": results[0] if results else None,
    }
    text = json.dumps(report, ensure_ascii=False, indent=2)
    if args.output_json:
        out = Path(args.output_json)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(text)
    print(text)


if __name__ == "__main__":
    main()