app.py 10.6 KB
from __future__ import annotations

from pathlib import Path
from tempfile import TemporaryDirectory
from threading import Lock
from typing import Optional

import numpy as np
from fastapi import FastAPI, File, HTTPException, UploadFile
from pydantic import BaseModel

from src.data.voice_chunker import voice_to_chunks
from src.engines.chromaprint_matcher import ChromaprintMatcher
from src.engines.ecapa_embedder import ECAPAEmbedder
from src.engines.hybrid_engine import HybridEngine
from src.service.settings import ServiceSettings
from src.utils.context_exporter import export_match_context, find_best_matching_window


class RecognizeRequest(BaseModel):
    query_path: str
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    index_prefix: Optional[str] = None
    top_n: int = 5
    device: Optional[str] = None


class BuildIndexRequest(BaseModel):
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    output_dir: str
    device: Optional[str] = None


app = FastAPI(title='ACR Service', version='0.4.0')
settings = ServiceSettings()
_engine_cache: dict[tuple[str, str, str, str], HybridEngine] = {}
_cache_lock = Lock()


def _resolve(req_data_dir=None, req_model_path=None, req_index_prefix=None, req_device=None):
    return {
        'data_dir': req_data_dir or settings.data_dir,
        'model_path': req_model_path or settings.model_path,
        'index_prefix': req_index_prefix or settings.index_prefix,
        'device': req_device or settings.device,
    }


def _readiness_snapshot(data_dir: str, model_path: str, index_prefix: str) -> dict:
    chroma_path = str(Path(index_prefix).parent / 'chromaprint.pkl')
    embs_path = f'{index_prefix}_embs.npy'
    ids_path = f'{index_prefix}_ids.npy'
    manifest_candidates = [
        str((Path(data_dir) / split).resolve())
        for split in ['catalog.json', 'train.json', 'val.json', 'test.json']
        if (Path(data_dir) / split).exists()
    ]
    files = {
        'data_dir': {'path': str(Path(data_dir).resolve()), 'exists': Path(data_dir).exists()},
        'model': {'path': str(Path(model_path).resolve()), 'exists': Path(model_path).exists()},
        'chromaprint_index': {'path': str(Path(chroma_path).resolve()), 'exists': Path(chroma_path).exists()},
        'embedding_index': {'path': str(Path(embs_path).resolve()), 'exists': Path(embs_path).exists()},
        'id_index': {'path': str(Path(ids_path).resolve()), 'exists': Path(ids_path).exists()},
    }
    return {'ready': all(item['exists'] for item in files.values()), 'files': files, 'manifests': manifest_candidates}


def _load_engine_uncached(data_dir: str, model_path: str, index_prefix: str, device: str) -> HybridEngine:
    matcher = ChromaprintMatcher()
    chroma_path = str(Path(index_prefix).parent / 'chromaprint.pkl')
    if not Path(chroma_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing chromaprint index: {chroma_path}')
    matcher.load(chroma_path)

    if not Path(model_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing model: {model_path}')
    embedder = ECAPAEmbedder(model_path=model_path, device=device)

    embs_path = f'{index_prefix}_embs.npy'
    ids_path = f'{index_prefix}_ids.npy'
    if not Path(embs_path).exists() or not Path(ids_path).exists():
        raise HTTPException(status_code=400, detail='Missing embedding index files')

    ref_embs = np.load(embs_path)
    ref_ids = np.load(ids_path, allow_pickle=True).tolist()
    engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
    for split in ['catalog.json', 'train.json', 'val.json', 'test.json']:
        p = Path(data_dir) / split
        if p.exists():
            engine.load_metadata(str(p))
    return engine


def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str) -> tuple[HybridEngine, bool]:
    key = (str(Path(data_dir).resolve()), str(Path(model_path).resolve()), str(Path(index_prefix).resolve()), device)
    with _cache_lock:
        cached = _engine_cache.get(key)
    if cached is not None:
        return cached, True
    engine = _load_engine_uncached(data_dir, model_path, index_prefix, device)
    with _cache_lock:
        _engine_cache[key] = engine
    return engine, False


def _cache_stats() -> dict:
    with _cache_lock:
        keys = list(_engine_cache.keys())
    return {'engine_cache_size': len(keys), 'cache_keys': keys}


def _aggregate_chunk_results(chunk_results: list[dict], top_n: int) -> list[dict]:
    by_song: dict[str, dict] = {}
    for chunk in chunk_results:
        for cand in chunk.get('candidates', []):
            song_id = cand['song_id']
            entry = by_song.setdefault(song_id, {
                'song_id': song_id,
                'best_confidence': -1.0,
                'match_count': 0,
                'best_chunk': None,
                'best_candidate': None,
            })
            entry['match_count'] += 1
            if cand['confidence'] > entry['best_confidence']:
                entry['best_confidence'] = cand['confidence']
                entry['best_chunk'] = chunk
                entry['best_candidate'] = cand
    ranked = []
    for entry in by_song.values():
        combined = float(entry['best_confidence']) + 0.05 * float(entry['match_count'])
        ranked.append({
            'song_id': entry['song_id'],
            'combined_confidence': round(combined, 4),
            'best_confidence': round(float(entry['best_confidence']), 4),
            'match_count': entry['match_count'],
            'best_chunk': entry['best_chunk'],
            'best_candidate': entry['best_candidate'],
        })
    ranked.sort(key=lambda x: x['combined_confidence'], reverse=True)
    return ranked[:top_n]


def _reference_audio_for_song(engine: HybridEngine, song_id: str) -> str | None:
    return engine.song_audio_paths.get(song_id)


@app.get('/health')
def health():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved['data_dir'], resolved['model_path'], resolved['index_prefix'])
    return {'status': 'ok', 'service': 'acr', 'version': '0.4.0', 'ready': readiness['ready']}


@app.get('/ready')
def ready():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved['data_dir'], resolved['model_path'], resolved['index_prefix'])
    return {'service': 'acr', 'version': '0.4.0', **readiness, **_cache_stats()}


@app.get('/config')
def config():
    return settings.model_dump()


@app.get('/cache')
def cache_status():
    return _cache_stats()


@app.post('/recognize')
def recognize(req: RecognizeRequest):
    resolved = _resolve(req.data_dir, req.model_path, req.index_prefix, req.device)
    if not Path(req.query_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing query file: {req.query_path}')
    engine, cache_hit = _load_engine(**resolved)
    result = engine.recognize(req.query_path, top_n=req.top_n)
    return {'cache_hit': cache_hit, 'resolved': resolved, 'result': result}


@app.post('/index/build')
def build_index(req: BuildIndexRequest):
    from run_demo import build_chroma_index, build_embedding_index

    resolved = _resolve(req.data_dir, req.model_path, None, req.device)
    data_dir = Path(resolved['data_dir'])
    out_dir = Path(req.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    build_chroma_index(data_dir, out_dir)
    _, ref_embs, ref_ids = build_embedding_index(data_dir, Path(resolved['model_path']), out_dir / 'reference', resolved['device'])
    return {
        'status': 'ok',
        'num_reference_windows': len(ref_ids),
        'embedding_dim': int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0,
        'output_dir': str(out_dir.resolve()),
    }


@app.post('/recognize/voice')
async def recognize_voice(
    file: UploadFile = File(...),
    top_n: int = 5,
    data_dir: Optional[str] = None,
    model_path: Optional[str] = None,
    index_prefix: Optional[str] = None,
    device: Optional[str] = None,
    context_sec: float = 10.0,
    output_format: str = 'mp3',
):
    resolved = _resolve(data_dir, model_path, index_prefix, device)
    engine, cache_hit = _load_engine(**resolved)
    with TemporaryDirectory(prefix='acr_voice_') as tmpdir:
        tmp = Path(tmpdir)
        suffix = Path(file.filename or 'upload.wav').suffix or '.wav'
        raw_path = tmp / f'input{suffix}'
        raw_path.write_bytes(await file.read())

        chunk_dir = tmp / 'chunks'
        chunks = voice_to_chunks(str(raw_path), str(chunk_dir))
        if not chunks:
            raise HTTPException(status_code=400, detail='No voiced chunks detected from input audio')

        chunk_results = []
        for chunk in chunks:
            result = engine.recognize(chunk['audio_path'], top_n=top_n)
            chunk_results.append({
                'chunk': chunk,
                'candidates': result['candidates'],
                'processing_time_ms': result['processing_time_ms'],
            })

        ranked = _aggregate_chunk_results(chunk_results, top_n=top_n)
        response_candidates = []
        for item in ranked:
            song_id = item['song_id']
            ref_audio = _reference_audio_for_song(engine, song_id)
            context_info = None
            if ref_audio and item['best_chunk'] is not None:
                match = find_best_matching_window(
                    query_audio_path=item['best_chunk']['chunk']['audio_path'],
                    reference_audio_path=ref_audio,
                )
                out_path = tmp / 'contexts' / f'{song_id}.{output_format}'
                context_info = export_match_context(
                    audio_path=ref_audio,
                    window_start_sec=match['window_start_sec'],
                    window_end_sec=match['window_end_sec'],
                    output_path=str(out_path),
                    context_sec=context_sec,
                    output_format=output_format,
                )
                context_info['match'] = match

            response_candidates.append({
                'song_id': song_id,
                'combined_confidence': item['combined_confidence'],
                'best_confidence': item['best_confidence'],
                'match_count': item['match_count'],
                'reference_audio_path': ref_audio,
                'best_candidate': item['best_candidate'],
                'best_chunk': item['best_chunk']['chunk'] if item['best_chunk'] else None,
                'context_clip': context_info,
            })

        return {
            'cache_hit': cache_hit,
            'resolved': resolved,
            'query_audio_filename': file.filename,
            'chunk_count': len(chunks),
            'chunk_results': chunk_results,
            'candidates': response_candidates,
        }