cli.py 14.7 KB
"""Command line tools for lyric duplicate checking."""

from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path

from lyric_dedup.checker import DuplicateChecker
from lyric_dedup.checker import LyricRecord
from lyric_dedup.eval_dataset import generate_eval_set
from lyric_dedup.file_import import iter_lyric_files
from lyric_dedup.file_import import read_lyric_file
from lyric_dedup.file_import import record_from_file
from lyric_dedup.file_import import records_from_dir


def main() -> None:
    parser = argparse.ArgumentParser(prog="lyric-dedup")
    subparsers = parser.add_subparsers(dest="command", required=True)

    build = subparsers.add_parser("build-index", help="build an index from .lrc/.txt files")
    build.add_argument("--lyrics-dir", required=True)
    build.add_argument("--index", required=True)

    check = subparsers.add_parser("check-file", help="check one .lrc/.txt file against an index")
    check.add_argument("--index", required=True)
    check.add_argument("--file", required=True)
    check.add_argument("--max-candidates", type=int, default=10)

    batch = subparsers.add_parser("batch-check", help="check a directory of .lrc/.txt files against an index")
    batch.add_argument("--index", required=True)
    batch.add_argument("--lyrics-dir", required=True)
    batch.add_argument("--out", required=True)
    batch.add_argument("--max-candidates", type=int, default=5)

    evaluate = subparsers.add_parser("evaluate-csv", help="evaluate labeled duplicate samples from a CSV file")
    evaluate.add_argument("--index", required=True)
    evaluate.add_argument("--csv", required=True)
    evaluate.add_argument("--out", required=True)
    evaluate.add_argument("--base-dir", default="")
    evaluate.add_argument("--positive-decisions", default="duplicate")
    evaluate.add_argument("--max-candidates", type=int, default=5)

    generate = subparsers.add_parser("generate-eval-set", help="generate labeled eval samples from a lyric library")
    generate.add_argument("--library-dir", required=True)
    generate.add_argument("--lyrics-dir", required=True)
    generate.add_argument("--csv", required=True)
    generate.add_argument("--size", type=int, default=100)
    generate.add_argument("--positive-ratio", type=float, default=0.3)
    generate.add_argument("--seed", type=int, default=20260602)
    generate.add_argument("--index", default="", help="optional existing index for hard-negative generation")

    args = parser.parse_args()
    if args.command == "build-index":
        build_index(Path(args.lyrics_dir), Path(args.index))
    elif args.command == "check-file":
        check_file(Path(args.index), Path(args.file), args.max_candidates)
    elif args.command == "batch-check":
        batch_check(Path(args.index), Path(args.lyrics_dir), Path(args.out), args.max_candidates)
    elif args.command == "evaluate-csv":
        evaluate_csv(
            Path(args.index),
            Path(args.csv),
            Path(args.out),
            base_dir=Path(args.base_dir) if args.base_dir else None,
            positive_decisions={item.strip() for item in args.positive_decisions.split(",") if item.strip()},
            max_candidates=args.max_candidates,
        )
    elif args.command == "generate-eval-set":
        summary = generate_eval_set(
            library_dir=Path(args.library_dir),
            output_dir=Path(args.lyrics_dir),
            csv_path=Path(args.csv),
            size=args.size,
            positive_ratio=args.positive_ratio,
            seed=args.seed,
            index_path=Path(args.index) if args.index else None,
        )
        print(json.dumps(summary, ensure_ascii=False))


def build_index(lyrics_dir: Path, index_path: Path) -> None:
    checker = DuplicateChecker()
    records = records_from_dir(lyrics_dir)
    for record in records:
        checker.add_record(record)
    index_path.parent.mkdir(parents=True, exist_ok=True)
    checker.save(index_path)
    print(json.dumps({"indexed": checker.record_count, "index": str(index_path)}, ensure_ascii=False))


def check_file(index_path: Path, file_path: Path, max_candidates: int) -> None:
    checker = DuplicateChecker.load(index_path)
    record = record_from_file(file_path)
    result = checker.check_record(record, max_candidates=max_candidates)
    print(json.dumps(_result_to_dict(result, source=str(file_path)), ensure_ascii=False, indent=2))


def batch_check(index_path: Path, lyrics_dir: Path, out_path: Path, max_candidates: int) -> None:
    checker = DuplicateChecker.load(index_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    rows: list[dict[str, object]] = []
    for path in iter_lyric_files(lyrics_dir):
        record = record_from_file(path, base_dir=lyrics_dir)
        result = checker.check_record(record, max_candidates=max_candidates)
        best = result.candidates[0] if result.candidates else None
        rows.append(
            {
                "source": str(path),
                "record_id": record.record_id,
                "decision": result.decision.value,
                "confidence": result.confidence,
                "reason": result.reason,
                "best_candidate_id": best.record_id if best else "",
                "best_candidate_decision": best.decision.value if best else "",
                "best_candidate_confidence": best.confidence if best else "",
                "best_candidate_jaccard": best.jaccard if best else "",
                "best_candidate_line_coverage": best.line_coverage if best else "",
                "best_candidate_primary_jaccard": best.primary_jaccard if best else "",
                "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
                "best_candidate_translation_jaccard": best.translation_jaccard if best else "",
                "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
                "best_candidate_reason": best.reason if best else "",
                "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
            }
        )

    if out_path.suffix.lower() == ".jsonl":
        with out_path.open("w", encoding="utf-8") as file:
            for row in rows:
                file.write(json.dumps(row, ensure_ascii=False) + "\n")
    else:
        with out_path.open("w", encoding="utf-8", newline="") as file:
            writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()) if rows else ["source"])
            writer.writeheader()
            writer.writerows(rows)
    summary = {
        "checked": len(rows),
        "duplicate": sum(1 for row in rows if row["decision"] == "duplicate"),
        "review": sum(1 for row in rows if row["decision"] == "review"),
        "new": sum(1 for row in rows if row["decision"] == "new"),
        "out": str(out_path),
    }
    print(json.dumps(summary, ensure_ascii=False))


def evaluate_csv(
    index_path: Path,
    csv_path: Path,
    out_path: Path,
    *,
    base_dir: Path | None,
    positive_decisions: set[str],
    max_candidates: int,
) -> None:
    checker = DuplicateChecker.load(index_path)
    rows: list[dict[str, object]] = []
    with csv_path.open(encoding="utf-8-sig", newline="") as file:
        reader = csv.DictReader(file)
        if reader.fieldnames is None:
            raise ValueError("评估 CSV 需要表头")
        for row_number, row in enumerate(reader, start=2):
            sample_id = row.get("id") or row.get("sample_id") or str(row_number)
            record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
            expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
            result = checker.check_record(record, max_candidates=max_candidates)
            predicted_duplicate = result.decision.value in positive_decisions
            best = result.candidates[0] if result.candidates else None
            rows.append(
                {
                    "id": sample_id,
                    "source": source,
                    "expected_duplicate": expected_duplicate,
                    "decision": result.decision.value,
                    "predicted_duplicate": predicted_duplicate,
                    "correct": expected_duplicate == predicted_duplicate,
                    "confidence": result.confidence,
                    "reason": result.reason,
                    "best_candidate_id": best.record_id if best else "",
                    "best_candidate_decision": best.decision.value if best else "",
                    "best_candidate_confidence": best.confidence if best else "",
                    "best_candidate_jaccard": best.jaccard if best else "",
                    "best_candidate_line_coverage": best.line_coverage if best else "",
                    "best_candidate_primary_jaccard": best.primary_jaccard if best else "",
                    "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
                    "best_candidate_translation_jaccard": best.translation_jaccard if best else "",
                    "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
                    "best_candidate_reason": best.reason if best else "",
                    "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
                }
            )

    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()) if rows else ["id"])
        writer.writeheader()
        writer.writerows(rows)

    summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path)
    summary_path = out_path.with_suffix(out_path.suffix + ".summary.json")
    summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(json.dumps(summary, ensure_ascii=False))


def _result_to_dict(result, *, source: str) -> dict[str, object]:
    return {
        "source": source,
        "decision": result.decision.value,
        "confidence": result.confidence,
        "reason": result.reason,
        "candidates": [
            {
                "record_id": candidate.record_id,
                "decision": candidate.decision.value,
                "confidence": candidate.confidence,
                "jaccard": candidate.jaccard,
                "line_coverage": candidate.line_coverage,
                "primary_jaccard": candidate.primary_jaccard,
                "primary_line_coverage": candidate.primary_line_coverage,
                "translation_jaccard": candidate.translation_jaccard,
                "translation_line_coverage": candidate.translation_line_coverage,
                "reason": candidate.reason,
                "matched_unique_lines": list(candidate.matched_unique_lines),
            }
            for candidate in result.candidates
        ],
    }


def _lyrics_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[str, str]:
    lyrics = (row.get("lyrics") or "").strip()
    if lyrics:
        return lyrics.replace("\\n", "\n"), "inline"

    file_value = (row.get("file") or row.get("path") or row.get("source") or "").strip()
    if not file_value:
        raise ValueError("评估 CSV 每行需要提供 lyrics,或 file/path/source 文件路径")

    file_path = Path(file_value)
    if not file_path.is_absolute():
        file_path = (base_dir or csv_path.parent) / file_path
    return read_lyric_file(file_path), str(file_path)


def _record_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None):
    lyrics = (row.get("lyrics") or "").strip()
    if lyrics:
        return (
            LyricRecord(
                record_id=row.get("id") or row.get("sample_id") or "__eval__",
                lyrics=lyrics.replace("\\n", "\n"),
                title=row.get("title") or None,
                artist=row.get("artist") or None,
            ),
            "inline",
        )

    file_value = (row.get("file") or row.get("path") or row.get("source") or "").strip()
    if not file_value:
        raise ValueError("评估 CSV 每行需要 lyrics,或 file/path/source 文件路径")

    file_path = Path(file_value)
    if not file_path.is_absolute():
        file_path = (base_dir or csv_path.parent) / file_path
    record = record_from_file(file_path)
    if row.get("title") or row.get("artist"):
        record = LyricRecord(
            record_id=record.record_id,
            lyrics=record.lyrics,
            title=row.get("title") or record.title,
            artist=row.get("artist") or record.artist,
        )
    return record, str(file_path)


def _parse_expected(value: str | None) -> bool:
    if value is None:
        raise ValueError("评估 CSV 每行需要 expected/label/target 列")
    normalized = value.strip().lower()
    positives = {"1", "true", "yes", "y", "duplicate", "dup", "重复", "应去重", "去重", "是"}
    negatives = {"0", "false", "no", "n", "new", "not_duplicate", "non_duplicate", "不重复", "不应去重", "新歌", "否"}
    if normalized in positives:
        return True
    if normalized in negatives:
        return False
    raise ValueError(f"无法识别 expected 值: {value!r}")


def _evaluation_summary(
    rows: list[dict[str, object]],
    *,
    positive_decisions: set[str],
    out_path: Path,
) -> dict[str, object]:
    tp = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is True)
    fp = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is True)
    tn = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is False)
    fn = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is False)
    total = len(rows)
    precision = tp / (tp + fp) if tp + fp else 0.0
    recall = tp / (tp + fn) if tp + fn else 0.0
    accuracy = (tp + tn) / total if total else 0.0
    f1 = (2 * precision * recall / (precision + recall)) if precision + recall else 0.0
    return {
        "total": total,
        "positive_decisions": sorted(positive_decisions),
        "accuracy": round(accuracy, 4),
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1": round(f1, 4),
        "true_positive": tp,
        "false_positive": fp,
        "true_negative": tn,
        "false_negative": fn,
        "duplicate": sum(1 for row in rows if row["decision"] == "duplicate"),
        "review": sum(1 for row in rows if row["decision"] == "review"),
        "new": sum(1 for row in rows if row["decision"] == "new"),
        "out": str(out_path),
        "summary": str(out_path.with_suffix(out_path.suffix + ".summary.json")),
    }


if __name__ == "__main__":
    main()