setup_coverhunter_env.py 2.38 KB
#!/usr/bin/env python3
import argparse
import json
import subprocess
from pathlib import Path

PYTHON_DEFAULT = "/usr/local/miniconda3/bin/python"
PACKAGES = [
    "-r", "requirements.txt",
]
EXTRA_PACKAGES = [
    "torch",
    "torchaudio",
    "transformers",
    "huggingface_hub",
    "librosa",
    "soundfile",
    "audiomentations",
]


def run(command, cwd):
    return subprocess.run(command, cwd=cwd, text=True, capture_output=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--python", default=PYTHON_DEFAULT)
    parser.add_argument("--skip-install", action="store_true")
    args = parser.parse_args()

    root = Path(__file__).resolve().parents[1]
    report = {
        "python": args.python,
        "cwd": str(root),
        "steps": [],
    }

    if not args.skip_install:
        install_cmd = [args.python, "-m", "pip", "install", *PACKAGES]
        res = run(install_cmd, root)
        report["steps"].append({
            "name": "install_requirements",
            "command": install_cmd,
            "returncode": res.returncode,
            "stdout": res.stdout[-4000:],
            "stderr": res.stderr[-4000:],
        })

        extra_cmd = [args.python, "-m", "pip", "install", *EXTRA_PACKAGES]
        res = run(extra_cmd, root)
        report["steps"].append({
            "name": "install_extra_packages",
            "command": extra_cmd,
            "returncode": res.returncode,
            "stdout": res.stdout[-4000:],
            "stderr": res.stderr[-4000:],
        })

    verify_cmd = [
        args.python,
        "-c",
        (
            "import torch, transformers, librosa, soundfile, audiomentations; "
            "print({'torch': torch.__version__, 'cuda': torch.cuda.is_available(), 'transformers': transformers.__version__})"
        ),
    ]
    res = run(verify_cmd, root)
    report["steps"].append({
        "name": "verify_environment",
        "command": verify_cmd,
        "returncode": res.returncode,
        "stdout": res.stdout[-4000:],
        "stderr": res.stderr[-4000:],
    })

    report_path = root / "reports" / "coverhunter_env_setup_report.json"
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text(json.dumps(report, indent=2))
    print(report_path)

    if any(step["returncode"] != 0 for step in report["steps"]):
        raise SystemExit(1)


if __name__ == "__main__":
    main()