app.py 13.8 KB
from __future__ import annotations

from pathlib import Path
from tempfile import TemporaryDirectory
from threading import Lock
from typing import Optional

import faiss
import numpy as np
from fastapi import FastAPI, File, HTTPException, UploadFile
from pydantic import BaseModel

from src.data.voice_chunker import voice_to_chunks
from src.service.settings import ServiceSettings
from src.utils.context_exporter import export_match_context, find_best_matching_window
from scripts.local_music20_acr import REFERENCE_TYPE, SUPPORTED_QUERY_TYPES, embed_chroma, first_file


class RecognizeRequest(BaseModel):
    query_path: str
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    index_prefix: Optional[str] = None
    top_n: int = 5
    device: Optional[str] = None


class BuildIndexRequest(BaseModel):
    data_dir: Optional[str] = None
    model_path: Optional[str] = None
    output_dir: str
    device: Optional[str] = None


app = FastAPI(title='ACR Service', version='0.5.0')
settings = ServiceSettings()
_engine_cache: dict[tuple[str, str, str, str], object] = {}
_cache_lock = Lock()


def _resolve(req_data_dir=None, req_model_path=None, req_index_prefix=None, req_device=None):
    return {
        'data_dir': req_data_dir or settings.data_dir,
        'model_path': req_model_path or settings.model_path,
        'index_prefix': req_index_prefix or settings.index_prefix,
        'device': req_device or settings.device,
    }


def _readiness_snapshot(data_dir: str, model_path: str, index_prefix: str) -> dict:
    chroma_path = str(Path(index_prefix).parent / 'chromaprint.pkl')
    embs_path = f'{index_prefix}_embs.npy'
    ids_path = f'{index_prefix}_ids.npy'
    manifest_candidates = [
        str((Path(data_dir) / split).resolve())
        for split in ['catalog.json', 'train.json', 'val.json', 'test.json']
        if (Path(data_dir) / split).exists()
    ]
    files = {
        'data_dir': {'path': str(Path(data_dir).resolve()), 'exists': Path(data_dir).exists()},
        'model': {'path': str(Path(model_path).resolve()), 'exists': Path(model_path).exists()},
        'chromaprint_index': {'path': str(Path(chroma_path).resolve()), 'exists': Path(chroma_path).exists()},
        'embedding_index': {'path': str(Path(embs_path).resolve()), 'exists': Path(embs_path).exists()},
        'id_index': {'path': str(Path(ids_path).resolve()), 'exists': Path(ids_path).exists()},
    }
    return {'ready': all(item['exists'] for item in files.values()), 'files': files, 'manifests': manifest_candidates}


def _load_engine_uncached(data_dir: str, model_path: str, index_prefix: str, device: str):
    try:
        from src.engines.chromaprint_matcher import ChromaprintMatcher
        from src.engines.ecapa_embedder import ECAPAEmbedder
        from src.engines.hybrid_engine import HybridEngine
    except Exception as exc:
        raise HTTPException(status_code=500, detail=f'Engine dependencies unavailable: {exc}')

    matcher = ChromaprintMatcher()
    chroma_path = str(Path(index_prefix).parent / 'chromaprint.pkl')
    if not Path(chroma_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing chromaprint index: {chroma_path}')
    matcher.load(chroma_path)
    if not Path(model_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing model: {model_path}')
    embedder = ECAPAEmbedder(model_path=model_path, device=device)
    embs_path = f'{index_prefix}_embs.npy'
    ids_path = f'{index_prefix}_ids.npy'
    if not Path(embs_path).exists() or not Path(ids_path).exists():
        raise HTTPException(status_code=400, detail='Missing embedding index files')
    ref_embs = np.load(embs_path)
    ref_ids = np.load(ids_path, allow_pickle=True).tolist()
    engine = HybridEngine(matcher, embedder, ref_embs, ref_ids)
    for split in ['catalog.json', 'train.json', 'val.json', 'test.json']:
        p = Path(data_dir) / split
        if p.exists():
            engine.load_metadata(str(p))
    return engine


def _load_engine(data_dir: str, model_path: str, index_prefix: str, device: str):
    key = (str(Path(data_dir).resolve()), str(Path(model_path).resolve()), str(Path(index_prefix).resolve()), device)
    with _cache_lock:
        cached = _engine_cache.get(key)
    if cached is not None:
        return cached, True
    engine = _load_engine_uncached(data_dir, model_path, index_prefix, device)
    with _cache_lock:
        _engine_cache[key] = engine
    return engine, False


def _cache_stats() -> dict:
    with _cache_lock:
        keys = list(_engine_cache.keys())
    return {'engine_cache_size': len(keys), 'cache_keys': keys}


def _aggregate_chunk_results(chunk_results: list[dict], top_n: int) -> list[dict]:
    by_song: dict[str, dict] = {}
    for chunk in chunk_results:
        for cand in chunk.get('candidates', []):
            song_id = cand['song_id']
            entry = by_song.setdefault(song_id, {
                'song_id': song_id,
                'best_confidence': -1.0,
                'match_count': 0,
                'best_chunk': None,
                'best_candidate': None,
            })
            entry['match_count'] += 1
            if cand['confidence'] > entry['best_confidence']:
                entry['best_confidence'] = cand['confidence']
                entry['best_chunk'] = chunk
                entry['best_candidate'] = cand
    ranked = []
    for entry in by_song.values():
        combined = float(entry['best_confidence']) + 0.05 * float(entry['match_count'])
        ranked.append({
            'song_id': entry['song_id'],
            'combined_confidence': round(combined, 4),
            'best_confidence': round(float(entry['best_confidence']), 4),
            'match_count': entry['match_count'],
            'best_chunk': entry['best_chunk'],
            'best_candidate': entry['best_candidate'],
        })
    ranked.sort(key=lambda x: x['combined_confidence'], reverse=True)
    return ranked[:top_n]


def _reference_audio_for_song(engine, song_id: str) -> str | None:
    return getattr(engine, 'song_audio_paths', {}).get(song_id)


def _workspace_reference_map(downloads_dir: Path, song_limit: int = 20) -> list[dict]:
    refs = []
    for song_dir in sorted(p for p in downloads_dir.iterdir() if p.is_dir()):
        ref = first_file(song_dir / f'type_{REFERENCE_TYPE}')
        if ref:
            refs.append({'song_id': song_dir.name, 'reference_path': str(ref)})
        if len(refs) >= song_limit:
            break
    return refs


def _workspace_faiss_candidates(query_audio_path: str, downloads_dir: Path, song_limit: int, sr: int, duration: float, top_n: int) -> list[dict]:
    refs = _workspace_reference_map(downloads_dir, song_limit)
    if not refs:
        return []
    ref_vecs = [embed_chroma(item['reference_path'], sr, duration) for item in refs]
    qry_vec = embed_chroma(query_audio_path, sr, duration).reshape(1, -1).astype(np.float32)
    ref_matrix = np.vstack(ref_vecs).astype(np.float32)
    index = faiss.IndexFlatIP(ref_matrix.shape[1])
    index.add(ref_matrix)
    sims, idxs = index.search(qry_vec, top_n)
    results = []
    for j in range(top_n):
        ref_idx = int(idxs[0, j])
        results.append({
            'song_id': refs[ref_idx]['song_id'],
            'confidence': float(sims[0, j]),
            'reference_path': refs[ref_idx]['reference_path'],
        })
    return results


@app.get('/health')
def health():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved['data_dir'], resolved['model_path'], resolved['index_prefix'])
    return {'status': 'ok', 'service': 'acr', 'version': '0.5.0', 'ready': readiness['ready']}


@app.get('/ready')
def ready():
    resolved = _resolve()
    readiness = _readiness_snapshot(resolved['data_dir'], resolved['model_path'], resolved['index_prefix'])
    return {'service': 'acr', 'version': '0.5.0', **readiness, **_cache_stats()}


@app.get('/config')
def config():
    return settings.model_dump()


@app.get('/cache')
def cache_status():
    return _cache_stats()


@app.post('/recognize')
def recognize(req: RecognizeRequest):
    resolved = _resolve(req.data_dir, req.model_path, req.index_prefix, req.device)
    if not Path(req.query_path).exists():
        raise HTTPException(status_code=400, detail=f'Missing query file: {req.query_path}')
    engine, cache_hit = _load_engine(**resolved)
    result = engine.recognize(req.query_path, top_n=req.top_n)
    return {'cache_hit': cache_hit, 'resolved': resolved, 'result': result}


@app.post('/index/build')
def build_index(req: BuildIndexRequest):
    from run_demo import build_chroma_index, build_embedding_index
    resolved = _resolve(req.data_dir, req.model_path, None, req.device)
    data_dir = Path(resolved['data_dir'])
    out_dir = Path(req.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    build_chroma_index(data_dir, out_dir)
    _, ref_embs, ref_ids = build_embedding_index(data_dir, Path(resolved['model_path']), out_dir / 'reference', resolved['device'])
    return {'status': 'ok', 'num_reference_windows': len(ref_ids), 'embedding_dim': int(ref_embs.shape[1]) if len(ref_embs.shape) > 1 else 0, 'output_dir': str(out_dir.resolve())}


@app.post('/recognize/voice')
async def recognize_voice(
    file: UploadFile = File(...),
    top_n: int = 5,
    data_dir: Optional[str] = None,
    model_path: Optional[str] = None,
    index_prefix: Optional[str] = None,
    device: Optional[str] = None,
    context_sec: float = 10.0,
    output_format: str = 'mp3',
    max_chunks: int = 3,
    include_context: bool = True,
    corpus: str = 'synthetic',
    downloads_dir: str = '/workspace/downloads',
    song_limit: int = 20,
    local_duration_sec: float = 8.0,
    local_sr: int = 22050,
):
    with TemporaryDirectory(prefix='acr_voice_') as tmpdir:
        tmp = Path(tmpdir)
        suffix = Path(file.filename or 'upload.wav').suffix or '.wav'
        raw_path = tmp / f'input{suffix}'
        raw_path.write_bytes(await file.read())
        chunk_dir = tmp / 'chunks'
        chunks = voice_to_chunks(str(raw_path), str(chunk_dir), max_chunks=max_chunks)
        if not chunks:
            raise HTTPException(status_code=400, detail='No voiced chunks detected from input audio')

        chunk_results = []
        if corpus == 'workspace_music20':
            for chunk in chunks:
                candidates = _workspace_faiss_candidates(chunk['audio_path'], Path(downloads_dir), song_limit, local_sr, local_duration_sec, top_n)
                chunk_results.append({'chunk': chunk, 'candidates': candidates, 'processing_time_ms': None})
            ranked = _aggregate_chunk_results(chunk_results, top_n=top_n)
            response_candidates = []
            for item in ranked:
                ref_audio = item['best_candidate']['reference_path'] if item.get('best_candidate') else None
                context_info = None
                if include_context and ref_audio and item['best_chunk'] is not None:
                    match = find_best_matching_window(item['best_chunk']['chunk']['audio_path'], ref_audio)
                    out_path = tmp / 'contexts' / f"{item['song_id']}.{output_format}"
                    context_info = export_match_context(ref_audio, match['window_start_sec'], match['window_end_sec'], str(out_path), context_sec=context_sec, output_format=output_format)
                    context_info['match'] = match
                response_candidates.append({
                    'song_id': item['song_id'],
                    'combined_confidence': item['combined_confidence'],
                    'best_confidence': item['best_confidence'],
                    'match_count': item['match_count'],
                    'reference_audio_path': ref_audio,
                    'best_candidate': item['best_candidate'],
                    'best_chunk': item['best_chunk']['chunk'] if item['best_chunk'] else None,
                    'context_clip': context_info,
                })
            return {
                'cache_hit': False,
                'corpus': corpus,
                'query_audio_filename': file.filename,
                'chunk_count': len(chunks),
                'chunk_results': chunk_results,
                'candidates': response_candidates,
            }

        resolved = _resolve(data_dir, model_path, index_prefix, device)
        engine, cache_hit = _load_engine(**resolved)
        for chunk in chunks:
            result = engine.recognize(chunk['audio_path'], top_n=top_n)
            chunk_results.append({'chunk': chunk, 'candidates': result['candidates'], 'processing_time_ms': result['processing_time_ms']})
        ranked = _aggregate_chunk_results(chunk_results, top_n=top_n)
        response_candidates = []
        for item in ranked:
            song_id = item['song_id']
            ref_audio = _reference_audio_for_song(engine, song_id)
            context_info = None
            if include_context and ref_audio and item['best_chunk'] is not None:
                match = find_best_matching_window(query_audio_path=item['best_chunk']['chunk']['audio_path'], reference_audio_path=ref_audio)
                out_path = tmp / 'contexts' / f'{song_id}.{output_format}'
                context_info = export_match_context(audio_path=ref_audio, window_start_sec=match['window_start_sec'], window_end_sec=match['window_end_sec'], output_path=str(out_path), context_sec=context_sec, output_format=output_format)
                context_info['match'] = match
            response_candidates.append({'song_id': song_id, 'combined_confidence': item['combined_confidence'], 'best_confidence': item['best_confidence'], 'match_count': item['match_count'], 'reference_audio_path': ref_audio, 'best_candidate': item['best_candidate'], 'best_chunk': item['best_chunk']['chunk'] if item['best_chunk'] else None, 'context_clip': context_info})
        return {'cache_hit': cache_hit, 'resolved': resolved, 'corpus': corpus, 'query_audio_filename': file.filename, 'chunk_count': len(chunks), 'chunk_results': chunk_results, 'candidates': response_candidates}