app.py
2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pathlib import Path
from typing import Optional
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine
class RecognizeRequest(BaseModel):
query_path: str
data_dir: str = "data/synthetic_v2"
model_path: str = "data/models_v3/best_model.pt"
index_prefix: str = "data/index_v3/reference"
top_n: int = 5
device: str = "cpu"
class BuildIndexRequest(BaseModel):
data_dir: str
model_path: str
output_dir: str
device: str = "cpu"
app = FastAPI(title="ACR Service", version="0.1.0")
def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str) -> HybridEngine:
matcher = ChromaprintMatcher()
chroma_path = str(Path(index_prefix).parent / "chromaprint.pkl")
if not Path(chroma_path).exists():
raise HTTPException(status_code=400, detail=f"Missing chromaprint index: {chroma_path}")
matcher.load(chroma_path)
if not Path(model_path).exists():
raise HTTPException(status_code=400, detail=f"Missing model: {model_path}")
embedder = ECAPAEmbedder(model_path=model_path, device=device)
embs_path = f"{index_prefix}_embs.npy"
ids_path = f"{index_prefix}_ids.npy"
if not Path(embs_path).exists() or not Path(ids_path).exists():
raise HTTPException(status_code=400, detail="Missing embedding index files")
ref_embs = np.load(embs_path)
ref_ids = np.load(ids_path, allow_pickle=True).tolist()
engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
for split in ["catalog.json", "train.json", "val.json", "test.json"]:
p = Path(data_dir) / split
if p.exists():
engine.load_metadata(str(p))
return engine
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/recognize")
def recognize(req: RecognizeRequest):
if not Path(req.query_path).exists():
raise HTTPException(status_code=400, detail=f"Missing query file: {req.query_path}")
engine = _load_engine(req.data_dir, req.model_path, req.index_prefix, req.device)
return engine.recognize(req.query_path, top_n=req.top_n)
@app.post("/index/build")
def build_index(req: BuildIndexRequest):
from run_demo import build_chroma_index, build_embedding_index
data_dir = Path(req.data_dir)
out_dir = Path(req.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
build_chroma_index(data_dir, out_dir)
_, ref_embs, ref_ids = build_embedding_index(data_dir, Path(req.model_path), out_dir / "reference", req.device)
return {"status": "ok", "num_reference_windows": len(ref_ids), "embedding_dim": int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0}