import_library_postgres.py 12.5 KB
"""Import normalized lyric library records into PostgreSQL."""

from __future__ import annotations

import argparse
import csv
import hashlib
import sys
from pathlib import Path
from typing import Any


PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from lyric_dedup.file_import import iter_lyric_files
from lyric_dedup.file_import import record_from_file
from lyric_dedup.normalization import fingerprint_text
from lyric_dedup.normalization import normalize_lyrics


def main() -> None:
    parser = argparse.ArgumentParser(description="Import lyric library into PostgreSQL.")
    parser.add_argument("--dsn", required=True)
    parser.add_argument("--lyrics-dir", required=True)
    parser.add_argument("--batch-size", type=int, default=500)
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--skip-dedup-exact", action="store_true", help="Skip exact-hash duplicate soft deletion after import.")
    parser.add_argument("--duplicate-report", default="outputs/results/postgres_exact_duplicates.csv")
    parser.add_argument("--line-duplicate-report", default="", help="Optional CSV report for high line-coverage duplicate candidates.")
    parser.add_argument("--line-coverage-threshold", type=float, default=0.95)
    parser.add_argument("--line-duplicate-limit", type=int, default=10000)
    args = parser.parse_args()

    psycopg = _import_psycopg()
    lyrics_dir = Path(args.lyrics_dir)
    paths = iter_lyric_files(lyrics_dir)
    if args.limit > 0:
        paths = paths[: args.limit]
    print(f"[pg-import] files: {len(paths)}", file=sys.stderr, flush=True)

    imported = 0
    exact_deleted = 0
    line_reported = 0
    nul_cleaned = 0
    with psycopg.connect(args.dsn) as conn:
        for start in range(0, len(paths), args.batch_size):
            batch = paths[start : start + args.batch_size]
            with conn.transaction():
                with conn.cursor() as cursor:
                    for path in batch:
                        lyric_id, line_rows, cleaned = _upsert_lyric(cursor, path, lyrics_dir)
                        nul_cleaned += cleaned
                        cursor.execute("delete from lyric_lines where lyric_id = %s", (lyric_id,))
                        if line_rows:
                            cursor.executemany(
                                """
                                insert into lyric_lines
                                  (lyric_id, role, line_no, normalized_line, line_hash)
                                values (%s, %s, %s, %s, %s)
                                """,
                                line_rows,
                            )
                        imported += 1
            _progress("import", imported, len(paths), step=args.batch_size)
        if not args.skip_dedup_exact:
            exact_deleted = _soft_delete_exact_duplicates(conn, Path(args.duplicate_report))
        if args.line_duplicate_report:
            line_reported = _write_line_duplicate_report(
                conn,
                Path(args.line_duplicate_report),
                threshold=args.line_coverage_threshold,
                limit=args.line_duplicate_limit,
            )
    print(
        {
            "imported": imported,
            "records_with_nul_cleaned": nul_cleaned,
            "exact_duplicates_soft_deleted": exact_deleted,
            "line_duplicate_candidates_reported": line_reported,
        }
    )


def _upsert_lyric(cursor: Any, path: Path, lyrics_dir: Path) -> tuple[int, list[tuple[object, ...]], int]:
    record = record_from_file(path, base_dir=lyrics_dir)
    raw_text, raw_cleaned = _pg_text(record.lyrics)
    normalized = normalize_lyrics(raw_text)
    primary_text = _pg_text("\n".join(normalized.primary_lines))[0]
    translation_text = _pg_text("\n".join(normalized.translation_lines))[0] or None
    normalized_text = _pg_text(normalized.normalized_full_text)[0]
    exact_text = fingerprint_text(normalized)
    exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
    cursor.execute(
        """
        insert into lyrics (
          record_id, source_path, title, artist, raw_text, normalized_text,
          primary_text, translation_text, exact_hash, split_confidence,
          split_reason, line_count, updated_at, deleted_at
        )
        values (
          %(record_id)s, %(source_path)s, %(title)s, %(artist)s, %(raw_text)s,
          %(normalized_text)s, %(primary_text)s, %(translation_text)s,
          %(exact_hash)s, %(split_confidence)s, %(split_reason)s,
          %(line_count)s, now(), null
        )
        on conflict (record_id) do update set
          source_path = excluded.source_path,
          title = excluded.title,
          artist = excluded.artist,
          raw_text = excluded.raw_text,
          normalized_text = excluded.normalized_text,
          primary_text = excluded.primary_text,
          translation_text = excluded.translation_text,
          exact_hash = excluded.exact_hash,
          split_confidence = excluded.split_confidence,
          split_reason = excluded.split_reason,
          line_count = excluded.line_count,
          updated_at = now(),
          deleted_at = null
        returning id
        """,
        {
            "record_id": record.record_id,
            "source_path": str(path),
            "title": _pg_text(record.title)[0],
            "artist": _pg_text(record.artist)[0],
            "raw_text": raw_text,
            "normalized_text": normalized_text,
            "primary_text": primary_text,
            "translation_text": translation_text,
            "exact_hash": exact_hash,
            "split_confidence": _pg_text(normalized.split_confidence)[0],
            "split_reason": _pg_text(normalized.split_reason)[0],
            "line_count": len(normalized.primary_lines or normalized.unique_lines),
        },
    )
    lyric_id = cursor.fetchone()[0]
    line_rows: list[tuple[object, ...]] = []
    line_rows.extend(_line_rows(lyric_id, "primary", normalized.primary_lines))
    line_rows.extend(_line_rows(lyric_id, "translation", normalized.translation_lines))
    line_rows.extend(_line_rows(lyric_id, "unknown", normalized.unknown_lines))
    return lyric_id, line_rows, int(raw_cleaned)


def _line_rows(lyric_id: int, role: str, lines: tuple[str, ...]) -> list[tuple[object, ...]]:
    rows: list[tuple[object, ...]] = []
    for index, line in enumerate(lines):
        line = _pg_text(line)[0] or ""
        line_hash = hashlib.sha256(line.encode("utf-8")).hexdigest()
        rows.append((lyric_id, role, index, line, line_hash))
    return rows


def _pg_text(value: str | None) -> tuple[str | None, bool]:
    if value is None:
        return None, False
    if "\x00" not in value:
        return value, False
    return value.replace("\x00", ""), True


def _soft_delete_exact_duplicates(conn: Any, report_path: Path) -> int:
    print("[pg-import] deduplicate exact_hash duplicates", file=sys.stderr, flush=True)
    with conn.transaction():
        with conn.cursor() as cursor:
            cursor.execute(
                """
                with ranked as (
                  select
                    id,
                    exact_hash,
                    first_value(id) over (
                      partition by exact_hash
                      order by
                        case when source_path like '%/None_%' then 1 else 0 end,
                        line_count desc,
                        length(primary_text) desc,
                        id
                    ) as kept_id,
                    row_number() over (
                      partition by exact_hash
                      order by
                        case when source_path like '%/None_%' then 1 else 0 end,
                        line_count desc,
                        length(primary_text) desc,
                        id
                    ) as rn
                  from lyrics
                  where deleted_at is null
                ),
                to_delete as (
                  select id, exact_hash, kept_id
                  from ranked
                  where rn > 1
                ),
                updated as (
                  update lyrics l
                  set deleted_at = now(), updated_at = now()
                  from to_delete d
                  where l.id = d.id
                  returning
                    l.id as duplicate_id,
                    l.record_id as duplicate_record_id,
                    l.source_path as duplicate_source_path,
                    d.exact_hash,
                    d.kept_id
                )
                select
                  u.duplicate_id,
                  u.duplicate_record_id,
                  u.duplicate_source_path,
                  k.id as kept_id,
                  k.record_id as kept_record_id,
                  k.source_path as kept_source_path,
                  u.exact_hash
                from updated u
                join lyrics k on k.id = u.kept_id
                order by u.exact_hash, u.duplicate_id
                """
            )
            rows = cursor.fetchall()
    _write_rows(
        report_path,
        [
            "duplicate_id",
            "duplicate_record_id",
            "duplicate_source_path",
            "kept_id",
            "kept_record_id",
            "kept_source_path",
            "exact_hash",
        ],
        rows,
    )
    print(f"[pg-import] exact duplicates soft-deleted: {len(rows)}", file=sys.stderr, flush=True)
    return len(rows)


def _write_line_duplicate_report(conn: Any, report_path: Path, *, threshold: float, limit: int) -> int:
    print("[pg-import] report high line-coverage duplicate candidates", file=sys.stderr, flush=True)
    with conn.cursor() as cursor:
        cursor.execute(
            """
            with pairs as (
              select
                a.lyric_id as left_id,
                b.lyric_id as right_id,
                count(*) as matched_lines
              from lyric_lines a
              join lyric_lines b
                on a.line_hash = b.line_hash
               and a.lyric_id < b.lyric_id
              join lyrics la on la.id = a.lyric_id and la.deleted_at is null
              join lyrics lb on lb.id = b.lyric_id and lb.deleted_at is null
              where a.role = 'primary'
                and b.role = 'primary'
              group by a.lyric_id, b.lyric_id
            )
            select
              p.left_id,
              l1.record_id as left_record_id,
              l1.source_path as left_source_path,
              p.right_id,
              l2.record_id as right_record_id,
              l2.source_path as right_source_path,
              p.matched_lines,
              l1.line_count as left_line_count,
              l2.line_count as right_line_count,
              p.matched_lines::float / greatest(l1.line_count, l2.line_count) as line_coverage
            from pairs p
            join lyrics l1 on l1.id = p.left_id
            join lyrics l2 on l2.id = p.right_id
            where p.matched_lines::float / greatest(l1.line_count, l2.line_count) >= %s
            order by line_coverage desc, matched_lines desc
            limit %s
            """,
            (threshold, limit),
        )
        rows = cursor.fetchall()
    _write_rows(
        report_path,
        [
            "left_id",
            "left_record_id",
            "left_source_path",
            "right_id",
            "right_record_id",
            "right_source_path",
            "matched_lines",
            "left_line_count",
            "right_line_count",
            "line_coverage",
        ],
        rows,
    )
    print(f"[pg-import] line duplicate candidates reported: {len(rows)}", file=sys.stderr, flush=True)
    return len(rows)


def _write_rows(report_path: Path, fieldnames: list[str], rows: list[tuple[object, ...]]) -> None:
    report_path.parent.mkdir(parents=True, exist_ok=True)
    with report_path.open("w", encoding="utf-8", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(fieldnames)
        writer.writerows(rows)


def _progress(label: str, current: int, total: int, *, step: int) -> None:
    if current == total or current % step == 0:
        print(f"[pg-import] {label}: {current}/{total}", file=sys.stderr, flush=True)


def _import_psycopg():
    try:
        import psycopg

        return psycopg
    except ModuleNotFoundError:
        print(
            "Missing dependency: psycopg. Install it with:\n"
            "  python -m pip install 'psycopg[binary]'",
            file=sys.stderr,
        )
        raise SystemExit(1)


if __name__ == "__main__":
    main()