service.py 9.84 KB
"""Core deduplication service: PostgreSQL recall + DuplicateChecker."""

from __future__ import annotations

import hashlib
import logging
from dataclasses import dataclass, field
from typing import Any

import psycopg

from lyric_dedup.checker import DuplicateChecker
from lyric_dedup.checker import DuplicateDecision
from lyric_dedup.checker import LyricRecord
from lyric_dedup.normalization import fingerprint_text
from lyric_dedup.normalization import normalize_lyrics

from .config import ServerConfig

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class CheckResult:
    duplicate: bool
    decision: str = ""
    confidence: float = 0.0
    reason: str = ""
    candidate_count: int = 0


@dataclass
class DedupService:
    """Thin wrapper around the PostgreSQL recall + DuplicateChecker pipeline."""

    config: ServerConfig
    _logger: logging.Logger = field(default_factory=lambda: logger, repr=False)

    def check(
        self,
        lyrics_text: str,
        title: str | None = None,
        artist: str | None = None,
        source_url: str | None = None,
    ) -> CheckResult:
        """Core entry: download lyrics, recall candidates from PG, decide."""
        record = LyricRecord(
            record_id="__query__",
            lyrics=lyrics_text,
            title=title,
            artist=artist,
        )
        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                cursor.execute("select set_config('statement_timeout', %s, false)", (str(self.config.statement_timeout_ms),))
                cursor.execute("select set_config('pg_trgm.similarity_threshold', %s, false)", (str(self.config.trgm_threshold),))
            candidates = self._recall_candidates(conn, record)
            result = self._check_against_candidates(record, candidates)
            if result.decision == "new" and source_url:
                self._insert_new_record(conn, record, source_url)
        return result

    def _insert_new_record(self, conn: Any, record: LyricRecord, source_url: str) -> None:
        """Insert new lyric into PostgreSQL (lyrics + lyric_lines tables)."""
        raw_text = _pg_text(record.lyrics)[0] or ""
        normalized = normalize_lyrics(raw_text)
        primary_text = _pg_text("\n".join(normalized.primary_lines))[0]
        translation_text = _pg_text("\n".join(normalized.translation_lines))[0] or None
        normalized_text = _pg_text(normalized.normalized_full_text)[0]
        exact_text = fingerprint_text(normalized)
        exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()

        with conn.cursor() as cursor:
            cursor.execute(
                """
                insert into lyrics (
                  record_id, source_path, title, artist, raw_text, normalized_text,
                  primary_text, translation_text, exact_hash, split_confidence,
                  split_reason, line_count, updated_at, deleted_at
                ) values (
                  %(record_id)s, %(source_path)s, %(title)s, %(artist)s, %(raw_text)s,
                  %(normalized_text)s, %(primary_text)s, %(translation_text)s,
                  %(exact_hash)s, %(split_confidence)s, %(split_reason)s,
                  %(line_count)s, now(), null
                )
                on conflict (record_id) do update set
                  source_path = excluded.source_path, title = excluded.title,
                  artist = excluded.artist, raw_text = excluded.raw_text,
                  normalized_text = excluded.normalized_text, primary_text = excluded.primary_text,
                  translation_text = excluded.translation_text, exact_hash = excluded.exact_hash,
                  split_confidence = excluded.split_confidence, split_reason = excluded.split_reason,
                  line_count = excluded.line_count, updated_at = now(), deleted_at = null
                returning id
                """,
                {
                    "record_id": _build_record_id(source_url),
                    "source_path": source_url,
                    "title": _pg_text(record.title)[0],
                    "artist": _pg_text(record.artist)[0],
                    "raw_text": raw_text,
                    "normalized_text": normalized_text,
                    "primary_text": primary_text,
                    "translation_text": translation_text,
                    "exact_hash": exact_hash,
                    "split_confidence": _pg_text(normalized.split_confidence)[0],
                    "split_reason": _pg_text(normalized.split_reason)[0],
                    "line_count": len(normalized.primary_lines or normalized.unique_lines),
                },
            )
            lyric_id = cursor.fetchone()[0]

            cursor.execute("delete from lyric_lines where lyric_id = %s", (lyric_id,))
            line_rows: list[tuple] = list(_line_rows(lyric_id, "primary", normalized.primary_lines))
            line_rows.extend(_line_rows(lyric_id, "translation", normalized.translation_lines))
            line_rows.extend(_line_rows(lyric_id, "unknown", normalized.unknown_lines))
            if line_rows:
                cursor.executemany(
                    "insert into lyric_lines (lyric_id, role, line_no, normalized_line, line_hash) values (%s, %s, %s, %s, %s)",
                    line_rows,
                )
        conn.commit()

    def _recall_candidates(self, conn: Any, record: LyricRecord) -> list[LyricRecord]:
        """Three-tier recall: exact_hash → pg_trgm → line_hash."""
        query_lyrics = _pg_text(record.lyrics)[0] or ""
        normalized = normalize_lyrics(query_lyrics)
        exact_text = fingerprint_text(normalized)
        exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
        primary_text = "\n".join(normalized.primary_lines)
        line_hashes = [hashlib.sha256(line.encode("utf-8")).hexdigest() for line in normalized.primary_lines if line]

        candidates: dict[str, LyricRecord] = {}
        exclude_record_ids: list[str] = []

        with conn.cursor() as cursor:
            # Tier 1: exact hash match
            cursor.execute(
                """
                select record_id, raw_text, title, artist
                from lyrics
                where deleted_at is null
                  and exact_hash = %s
                  and not (record_id = any(%s))
                limit %s
                """,
                (exact_hash, exclude_record_ids, self.config.recall_limit),
            )
            _add_rows(candidates, cursor.fetchall())

            # Tier 2: pg_trgm similarity (optional)
            if self.config.enable_trgm and primary_text:
                cursor.execute(
                    """
                    select record_id, raw_text, title, artist
                    from lyrics
                    where deleted_at is null
                      and not (record_id = any(%s))
                      and primary_text %% %s
                    order by similarity(primary_text, %s) desc
                    limit %s
                    """,
                    (exclude_record_ids, primary_text, primary_text, self.config.recall_limit),
                )
                _add_rows(candidates, cursor.fetchall())

            # Tier 3: line hash match
            if line_hashes:
                cursor.execute(
                    """
                    select l.record_id, l.raw_text, l.title, l.artist
                    from lyric_lines ll
                    join lyrics l on l.id = ll.lyric_id
                    where l.deleted_at is null
                      and not (l.record_id = any(%s))
                      and ll.role = 'primary'
                      and ll.line_hash = any(%s)
                    group by l.id
                    order by count(*) desc
                    limit %s
                    """,
                    (exclude_record_ids, line_hashes, self.config.recall_limit),
                )
                _add_rows(candidates, cursor.fetchall())

        return list(candidates.values())

    def _check_against_candidates(
        self,
        record: LyricRecord,
        candidates: list[LyricRecord],
    ) -> CheckResult:
        """Run DuplicateChecker against recalled candidates."""
        checker = DuplicateChecker()
        for candidate in candidates:
            checker.add_record(candidate)
        result = checker.check_record(record, max_candidates=self.config.max_candidates)
        return CheckResult(
            duplicate=result.decision in (DuplicateDecision.DUPLICATE, DuplicateDecision.REVIEW),
            decision=result.decision.value,
            confidence=result.confidence,
            reason=result.reason,
            candidate_count=len(result.candidates),
        )


def _add_rows(candidates: dict[str, LyricRecord], rows: list[tuple[object, ...]]) -> None:
    for record_id, raw_text, title, artist in rows:
        candidates.setdefault(
            str(record_id),
            LyricRecord(
                record_id=str(record_id),
                lyrics=str(raw_text),
                title=str(title) if title is not None else None,
                artist=str(artist) if artist is not None else None,
            ),
        )


def _build_record_id(source_url: str) -> str:
    """From URL to record_id, format url:{sha12}:{url}."""
    digest = hashlib.sha1(source_url.encode("utf-8")).hexdigest()[:12]
    return f"url:{digest}:{source_url}"


def _line_rows(lyric_id: int, role: str, lines: tuple[str, ...]) -> list[tuple]:
    rows: list[tuple] = []
    for index, line in enumerate(lines):
        line = _pg_text(line)[0] or ""
        line_hash = hashlib.sha256(line.encode("utf-8")).hexdigest()
        rows.append((lyric_id, role, index, line, line_hash))
    return rows


def _pg_text(value: str | None) -> tuple[str | None, bool]:
    """Return (text, had_nul)."""
    if value is None:
        return None, False
    if "\x00" not in value:
        return value, False
    return value.replace("\x00", ""), True