app.py 3.44 KB
from pathlib import Path
from typing import Optional

import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine
from src.service.settings import ServiceSettings


class RecognizeRequest(BaseModel):
    query_path: str
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    index_prefix: Optional[str] = None
    top_n: int = 5
    device: Optional[str] = None


class BuildIndexRequest(BaseModel):
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    output_dir: str
    device: Optional[str] = None


app = FastAPI(title="ACR Service", version="0.2.0")
settings = ServiceSettings()


def _resolve(req_data_dir=None, req_model_path=None, req_index_prefix=None, req_device=None):
    return {
        "data_dir": req_data_dir or settings.data_dir,
        "model_path": req_model_path or settings.model_path,
        "index_prefix": req_index_prefix or settings.index_prefix,
        "device": req_device or settings.device,
    }


def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str) -> HybridEngine:
    matcher = ChromaprintMatcher()
    chroma_path = str(Path(index_prefix).parent / "chromaprint.pkl")
    if not Path(chroma_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing chromaprint index: {chroma_path}")
    matcher.load(chroma_path)

    if not Path(model_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing model: {model_path}")
    embedder = ECAPAEmbedder(model_path=model_path, device=device)

    embs_path = f"{index_prefix}_embs.npy"
    ids_path = f"{index_prefix}_ids.npy"
    if not Path(embs_path).exists() or not Path(ids_path).exists():
        raise HTTPException(status_code=400, detail="Missing embedding index files")

    ref_embs = np.load(embs_path)
    ref_ids = np.load(ids_path, allow_pickle=True).tolist()
    engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
    for split in ["catalog.json", "train.json", "val.json", "test.json"]:
        p = Path(data_dir) / split
        if p.exists():
            engine.load_metadata(str(p))
    return engine


@app.get("/health")
def health():
    return {"status": "ok", "service": "acr", "version": "0.2.0"}


@app.get("/config")
def config():
    return settings.model_dump()


@app.post("/recognize")
def recognize(req: RecognizeRequest):
    resolved = _resolve(req.data_dir, req.model_path, req.index_prefix, req.device)
    if not Path(req.query_path).exists():
        raise HTTPException(status_code=400, detail=f"Missing query file: {req.query_path}")
    engine = _load_engine(**resolved)
    return engine.recognize(req.query_path, top_n=req.top_n)


@app.post("/index/build")
def build_index(req: BuildIndexRequest):
    from run_demo import build_chroma_index, build_embedding_index

    resolved = _resolve(req.data_dir, req.model_path, None, req.device)
    data_dir = Path(resolved["data_dir"])
    out_dir = Path(req.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    build_chroma_index(data_dir, out_dir)
    _, ref_embs, ref_ids = build_embedding_index(data_dir, Path(resolved["model_path"]), out_dir / "reference", resolved["device"])
    return {"status": "ok", "num_reference_windows": len(ref_ids), "embedding_dim": int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0}