app.py 4.68 KB
"""FastAPI application for lyric duplicate checking."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from .config import ServerConfig
from .service import DedupService

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# App lifecycle
# ---------------------------------------------------------------------------

app = FastAPI(title="Lyric Dedup API", version="0.1.0")

_config: ServerConfig | None = None
_service: DedupService | None = None


@app.on_event("startup")
def _startup() -> None:
    global _config, _service
    _config = ServerConfig()
    _service = DedupService(config=_config)
    logger.info("Lyric Dedup API started (DSN=%s, trgm=%s)", _config.dsn, _config.enable_trgm)


# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------


class CheckRequest(BaseModel):
    url: str = Field(..., description="URL of the LRC/TXT lyric file")
    title: str | None = Field(None, description="Song title (optional)")
    artist: str | None = Field(None, description="Artist name (optional)")


class CheckResponse(BaseModel):
    duplicate: bool
    decision: str | None = None
    confidence: float | None = None
    reason: str | None = None
    record_ids: list[str] = []


class HealthResponse(BaseModel):
    status: str


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/health", response_model=HealthResponse)
def health() -> dict[str, str]:
    return {"status": "ok"}


@app.post("/api/v1/check", response_model=CheckResponse)
def check_lyric(req: CheckRequest) -> Any:
    if _service is None:
        return JSONResponse(
            status_code=503,
            content={"detail": "service not initialized"},
        )

    # 校验文件格式(仅接受 .txt / .lrc)
    if not _is_valid_lyric_url(req.url):
        return JSONResponse(
            status_code=400,
            content={"detail": "仅支持 .txt 或 .lrc 格式的歌词文件"},
        )

    try:
        lyrics = _download_lyrics(req.url)
    except ValueError as exc:
        return JSONResponse(
            status_code=400,
            content={"detail": str(exc)},
        )
    except Exception as exc:
        logger.exception("unexpected error during download")
        return JSONResponse(
            status_code=500,
            content={"detail": f"下载歌词失败: {exc}"},
        )

    try:
        result = _service.check(lyrics, title=req.title, artist=req.artist, source_url=req.url)
    except Exception as exc:
        logger.exception("unexpected error during dedup check")
        return JSONResponse(
            status_code=500,
            content={"detail": f"歌词去重检测失败: {exc}"},
        )

    return CheckResponse(
        duplicate=result.duplicate,
        decision=result.decision,
        confidence=result.confidence,
        reason=result.reason,
        record_ids=result.record_ids,
    )


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_ENCODING_CHAIN = ("utf-8-sig", "utf-8", "gb18030", "big5")


_ALLOWED_EXTENSIONS = {".txt", ".lrc"}


def _is_valid_lyric_url(url: str) -> bool:
    """Check if URL points to a .txt or .lrc file."""
    from urllib.parse import urlparse

    ext = Path(urlparse(url).path).suffix.lower()
    return ext in _ALLOWED_EXTENSIONS


def _download_lyrics(url: str) -> str:
    """Download a lyric file and decode with encoding fallback chain."""
    import urllib.error
    import urllib.request

    try:
        with urllib.request.urlopen(url, timeout=_config.download_timeout if _config else 10) as resp:
            data = resp.read()
    except urllib.error.HTTPError as exc:
        raise ValueError(f"下载失败: HTTP {exc.code}") from exc
    except urllib.error.URLError as exc:
        raise ValueError(f"下载失败: {exc.reason}") from exc
    except TimeoutError as exc:
        raise ValueError("下载超时") from exc
    except Exception as exc:
        raise ValueError(f"下载失败: {exc}") from exc

    for encoding in _ENCODING_CHAIN:
        try:
            return data.decode(encoding)
        except UnicodeDecodeError:
            continue
    raise ValueError("无法解析文件编码,支持: utf-8-sig / utf-8 / gb18030 / big5")