app.py 6.05 KB
from __future__ import annotations

from pathlib import Path
from threading import Lock
from typing import Optional

import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine
from src.service.settings import ServiceSettings


class RecognizeRequest(BaseModel):
    query_path: str
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    index_prefix: Optional[str] = None
    top_n: int = 5
    device: Optional[str] = None


class BuildIndexRequest(BaseModel):
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    output_dir: str
    device: Optional[str] = None


app = FastAPI(title="ACR Service", version="0.3.0")
settings = ServiceSettings()
_engine_cache: dict[tuple[str, str, str, str], HybridEngine] = {}
_cache_lock = Lock()


def _resolve(req_data_dir=None, req_model_path=None, req_index_prefix=None, req_device=None):
    return {
        "data_dir": req_data_dir or settings.data_dir,
        "model_path": req_model_path or settings.model_path,
        "index_prefix": req_index_prefix or settings.index_prefix,
        "device": req_device or settings.device,
    }


def _readiness_snapshot(data_dir: str, model_path: str, index_prefix: str) -> dict:
    chroma_path = str(Path(index_prefix).parent / "chromaprint.pkl")
    embs_path = f"{index_prefix}_embs.npy"
    ids_path = f"{index_prefix}_ids.npy"
    manifest_candidates = [str((Path(data_dir) / split).resolve()) for split in ["catalog.json", "train.json", "val.json", "test.json"] if (Path(data_dir) / split).exists()]
    files = {
        "data_dir": {"path": str(Path(data_dir).resolve()), "exists": Path(data_dir).exists()},
        "model": {"path": str(Path(model_path).resolve()), "exists": Path(model_path).exists()},
        "chromaprint_index": {"path": str(Path(chroma_path).resolve()), "exists": Path(chroma_path).exists()},
        "embedding_index": {"path": str(Path(embs_path).resolve()), "exists": Path(embs_path).exists()},
        "id_index": {"path": str(Path(ids_path).resolve()), "exists": Path(ids_path).exists()},
    }
    return {
        "ready": all(item["exists"] for item in files.values()),
        "files": files,
        "manifests": manifest_candidates,
    }


def _load_engine_uncached(data_dir: str, model_path: str, index_prefix: str, device: str) -> HybridEngine:
    matcher = ChromaprintMatcher()
    chroma_path = str(Path(index_prefix).parent / "chromaprint.pkl")
    if not Path(chroma_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing chromaprint index: {chroma_path}")
    matcher.load(chroma_path)

    if not Path(model_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing model: {model_path}")
    embedder = ECAPAEmbedder(model_path=model_path, device=device)

    embs_path = f"{index_prefix}_embs.npy"
    ids_path = f"{index_prefix}_ids.npy"
    if not Path(embs_path).exists() or not Path(ids_path).exists():
        raise HTTPException(status_code=400, detail="Missing embedding index files")

    ref_embs = np.load(embs_path)
    ref_ids = np.load(ids_path, allow_pickle=True).tolist()
    engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
    for split in ["catalog.json", "train.json", "val.json", "test.json"]:
        p = Path(data_dir) / split
        if p.exists():
            engine.load_metadata(str(p))
    return engine


def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str) -> tuple[HybridEngine, bool]:
    key = (str(Path(data_dir).resolve()), str(Path(model_path).resolve()), str(Path(index_prefix).resolve()), device)
    with _cache_lock:
        cached = _engine_cache.get(key)
    if cached is not None:
        return cached, True
    engine = _load_engine_uncached(data_dir, model_path, index_prefix, device)
    with _cache_lock:
        _engine_cache[key] = engine
    return engine, False


def _cache_stats() -> dict:
    with _cache_lock:
        keys = list(_engine_cache.keys())
    return {"engine_cache_size": len(keys), "cache_keys": keys}


@app.get("/health")
def health():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved["data_dir"], resolved["model_path"], resolved["index_prefix"])
    return {
        "status": "ok",
        "service": "acr",
        "version": "0.3.0",
        "ready": readiness["ready"],
    }


@app.get("/ready")
def ready():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved["data_dir"], resolved["model_path"], resolved["index_prefix"])
    return {
        "service": "acr",
        "version": "0.3.0",
        **readiness,
        **_cache_stats(),
    }


@app.get("/config")
def config():
    return settings.model_dump()


@app.get("/cache")
def cache_status():
    return _cache_stats()


@app.post("/recognize")
def recognize(req: RecognizeRequest):
    resolved = _resolve(req.data_dir, req.model_path, req.index_prefix, req.device)
    if not Path(req.query_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing query file: {req.query_path}")
    engine, cache_hit = _load_engine(**resolved)
    result = engine.recognize(req.query_path, top_n=req.top_n)
    return {
        "cache_hit": cache_hit,
        "resolved": resolved,
        "result": result,
    }


@app.post("/index/build")
def build_index(req: BuildIndexRequest):
    from run_demo import build_chroma_index, build_embedding_index

    resolved = _resolve(req.data_dir, req.model_path, None, req.device)
    data_dir = Path(resolved["data_dir"])
    out_dir = Path(req.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    build_chroma_index(data_dir, out_dir)
    _, ref_embs, ref_ids = build_embedding_index(data_dir, Path(resolved["model_path"]), out_dir / "reference", resolved["device"])
    return {
        "status": "ok",
        "num_reference_windows": len(ref_ids),
        "embedding_dim": int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0,
        "output_dir": str(out_dir.resolve()),
    }