app.py
3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from pathlib import Path
from typing import Optional
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine
from src.service.settings import ServiceSettings
class RecognizeRequest(BaseModel):
query_path: str
data_dir: Optional[str] = None
model_path: Optional[str] = None
index_prefix: Optional[str] = None
top_n: int = 5
device: Optional[str] = None
class BuildIndexRequest(BaseModel):
data_dir: Optional[str] = None
model_path: Optional[str] = None
output_dir: str
device: Optional[str] = None
app = FastAPI(title="ACR Service", version="0.2.0")
settings = ServiceSettings()
def _resolve(req_data_dir=None, req_model_path=None, req_index_prefix=None, req_device=None):
return {
"data_dir": req_data_dir or settings.data_dir,
"model_path": req_model_path or settings.model_path,
"index_prefix": req_index_prefix or settings.index_prefix,
"device": req_device or settings.device,
}
def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str) -> HybridEngine:
matcher = ChromaprintMatcher()
chroma_path = str(Path(index_prefix).parent / "chromaprint.pkl")
if not Path(chroma_path).exists():
raise HTTPException(status_code=400, detail=f"Missing chromaprint index: {chroma_path}")
matcher.load(chroma_path)
if not Path(model_path).exists():
raise HTTPException(status_code=400, detail=f"Missing model: {model_path}")
embedder = ECAPAEmbedder(model_path=model_path, device=device)
embs_path = f"{index_prefix}_embs.npy"
ids_path = f"{index_prefix}_ids.npy"
if not Path(embs_path).exists() or not Path(ids_path).exists():
raise HTTPException(status_code=400, detail="Missing embedding index files")
ref_embs = np.load(embs_path)
ref_ids = np.load(ids_path, allow_pickle=True).tolist()
engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
for split in ["catalog.json", "train.json", "val.json", "test.json"]:
p = Path(data_dir) / split
if p.exists():
engine.load_metadata(str(p))
return engine
@app.get("/health")
def health():
return {"status": "ok", "service": "acr", "version": "0.2.0"}
@app.get("/config")
def config():
return settings.model_dump()
@app.post("/recognize")
def recognize(req: RecognizeRequest):
resolved = _resolve(req.data_dir, req.model_path, req.index_prefix, req.device)
if not Path(req.query_path).exists():
raise HTTPException(status_code=400, detail=f"Missing query file: {req.query_path}")
engine = _load_engine(**resolved)
return engine.recognize(req.query_path, top_n=req.top_n)
@app.post("/index/build")
def build_index(req: BuildIndexRequest):
from run_demo import build_chroma_index, build_embedding_index
resolved = _resolve(req.data_dir, req.model_path, None, req.device)
data_dir = Path(resolved["data_dir"])
out_dir = Path(req.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
build_chroma_index(data_dir, out_dir)
_, ref_embs, ref_ids = build_embedding_index(data_dir, Path(resolved["model_path"]), out_dir / "reference", resolved["device"])
return {"status": "ok", "num_reference_windows": len(ref_ids), "embedding_dim": int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0}