Commit ba39ce6a ba39ce6aa50b5bc45d24bc32bcbc7c51b870922a by 沈秋雨

添加测试集内部去重

1 parent f8ad329c
......@@ -80,6 +80,7 @@ python -m lyric_dedup.cli generate-eval-set \
--lyrics-dir data/generated_eval/incoming \
--csv data/generated_eval/eval_50000.csv \
--index outputs/indexes/lyrics.pkl \
--eval-index data/generated_eval/eval_50000.csv.index.pkl \
--size 50000 \
--positive-ratio 0.3
```
......@@ -88,10 +89,10 @@ python -m lyric_dedup.cli generate-eval-set \
- 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。
- `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。
- `不应去重` 样本包含同主题新歌词、hard negative、片段歌词、重复副歌碰撞、仅翻译相似、短歌词/占位边界样本。
- `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。
- 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`
- 如果传入 `--index`,生成器会用现有索引构造更接近线上召回风险的 hard negative
- 同时会生成 `*.manifest.json`,记录 seed、曲库规模、样本类型分布、语言/来源分桶和样本来源覆盖数。
- 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它
- 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。
先准备一个 CSV,例如 `data/eval/eval.csv`
......
......@@ -103,6 +103,7 @@ python -m lyric_dedup.cli generate-eval-set \
--lyrics-dir data/generated_eval/incoming \
--csv data/generated_eval/eval_50000.csv \
--index outputs/indexes/library_lyrics.pkl \
--eval-index data/generated_eval/eval_50000.csv.index.pkl \
--size 50000 \
--positive-ratio 0.3
```
......@@ -120,24 +121,26 @@ python -m lyric_dedup.cli generate-eval-set \
```text
positive_* = 应去重,全曲歌词样式变化
negative_random_unrelated = 不应去重,同主题新歌词
negative_hard_candidate = 不应去重,系统容易召回的短句/局部重合样本
negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除
negative_fragment = 不应去重,单曲片段
negative_shared_chorus = 不应去重,重复副歌碰撞
negative_translation_only = 不应去重,仅翻译相似
negative_same_theme_synthetic = 不应去重,同主题新歌词
edge_short_or_placeholder = 不应去重,短歌词/占位边界样本
```
生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。传入 `--index` 后会用现有索引生成 hard negative。每次还会输出:
生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出:
```text
data/generated_eval/eval_50000.csv.manifest.json
data/generated_eval/eval_50000.csv.index.pkl
```
manifest 里重点看:
```text
library_files 曲库歌词文件数
holdout_records 从评估索引中排除、作为真实新歌负样本的数量
sample_type_counts 各样本类型数量
line_count_bucket_counts / language_bucket_counts / source_bucket_counts
unique_source_records 本次评估覆盖了多少真实源文件
......@@ -147,7 +150,7 @@ unique_source_records 本次评估覆盖了多少真实源文件
```bash
python -m lyric_dedup.cli evaluate-csv \
--index outputs/indexes/library_lyrics.pkl \
--index data/generated_eval/eval_50000.csv.index.pkl \
--csv data/generated_eval/eval_50000.csv \
--base-dir data/generated_eval \
--out outputs/results/library_eval_50000.csv
......@@ -171,7 +174,7 @@ false_positive
```bash
python -m lyric_dedup.cli evaluate-csv \
--index outputs/indexes/library_lyrics.pkl \
--index data/generated_eval/eval_50000.csv.index.pkl \
--csv data/generated_eval/eval_50000.csv \
--base-dir data/generated_eval \
--positive-decisions duplicate,review \
......
......@@ -96,16 +96,24 @@ class DuplicateChecker:
def add_record(self, record: LyricRecord) -> None:
indexed = self._index(record)
self._records[record.record_id] = indexed
self._exact_hash_to_ids.setdefault(indexed.exact_hash, set()).add(record.record_id)
self._add_indexed(record.record_id, indexed)
def add_normalized_record(self, record: LyricRecord, normalized: NormalizedLyrics) -> None:
"""Add a record when normalized lyrics have already been computed."""
indexed = self._index_normalized(record, normalized)
self._add_indexed(record.record_id, indexed)
def _add_indexed(self, record_id: str, indexed: _IndexedRecord) -> None:
self._records[record_id] = indexed
self._exact_hash_to_ids.setdefault(indexed.exact_hash, set()).add(record_id)
for line in indexed.normalized.unique_lines:
if len(line) >= 4:
self._line_to_ids.setdefault(line, set()).add(record.record_id)
self._line_to_ids.setdefault(line, set()).add(record_id)
for token in indexed.tokens:
self._token_to_ids.setdefault(token, set()).add(record.record_id)
self._token_to_ids.setdefault(token, set()).add(record_id)
for token in indexed.fallback_tokens:
self._token_to_ids.setdefault(token, set()).add(record.record_id)
self._lsh.add(record.record_id, indexed.signature)
self._token_to_ids.setdefault(token, set()).add(record_id)
self._lsh.add(record_id, indexed.signature)
def save(self, path: str | Path) -> None:
"""Persist the in-memory index for later checks."""
......@@ -187,6 +195,9 @@ class DuplicateChecker:
def _index(self, record: LyricRecord) -> _IndexedRecord:
normalized = normalize_lyrics(record.lyrics)
return self._index_normalized(record, normalized)
def _index_normalized(self, record: LyricRecord, normalized: NormalizedLyrics) -> _IndexedRecord:
tokens = lyric_tokens(normalized)
primary_tokens = lyric_tokens(normalized, lines=normalized.primary_lines)
translation_tokens = lyric_tokens(normalized, lines=normalized.translation_lines)
......
......@@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
from lyric_dedup.checker import DuplicateChecker
......@@ -50,7 +51,8 @@ def main() -> None:
generate.add_argument("--size", type=int, default=100)
generate.add_argument("--positive-ratio", type=float, default=0.3)
generate.add_argument("--seed", type=int, default=20260602)
generate.add_argument("--index", default="", help="optional existing index for hard-negative generation")
generate.add_argument("--index", default="", help="optional source index path recorded in the manifest")
generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set")
args = parser.parse_args()
if args.command == "build-index":
......@@ -77,6 +79,7 @@ def main() -> None:
positive_ratio=args.positive_ratio,
seed=args.seed,
index_path=Path(args.index) if args.index else None,
eval_index_path=Path(args.eval_index) if args.eval_index else None,
)
print(json.dumps(summary, ensure_ascii=False))
......@@ -155,52 +158,58 @@ def evaluate_csv(
positive_decisions: set[str],
max_candidates: int,
) -> None:
_progress(f"load index: {index_path}")
checker = DuplicateChecker.load(index_path)
rows: list[dict[str, object]] = []
total = _csv_data_row_count(csv_path)
_progress(f"evaluate csv: 0/{total}")
out_path.parent.mkdir(parents=True, exist_ok=True)
with csv_path.open(encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file)
if reader.fieldnames is None:
raise ValueError("评估 CSV 需要表头")
for row_number, row in enumerate(reader, start=2):
sample_id = row.get("id") or row.get("sample_id") or str(row_number)
record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
result = checker.check_record(record, max_candidates=max_candidates)
predicted_duplicate = result.decision.value in positive_decisions
best = result.candidates[0] if result.candidates else None
rows.append(
{
"id": sample_id,
"source": source,
"expected_duplicate": expected_duplicate,
"decision": result.decision.value,
"predicted_duplicate": predicted_duplicate,
"correct": expected_duplicate == predicted_duplicate,
"confidence": result.confidence,
"reason": result.reason,
"best_candidate_id": best.record_id if best else "",
"best_candidate_decision": best.decision.value if best else "",
"best_candidate_confidence": best.confidence if best else "",
"best_candidate_jaccard": best.jaccard if best else "",
"best_candidate_line_coverage": best.line_coverage if best else "",
"best_candidate_primary_jaccard": best.primary_jaccard if best else "",
"best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
"best_candidate_translation_jaccard": best.translation_jaccard if best else "",
"best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
"best_candidate_reason": best.reason if best else "",
"matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
}
)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8", newline="") as file:
writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()) if rows else ["id"])
fieldnames = [
"id",
"source",
"expected_duplicate",
"decision",
"predicted_duplicate",
"correct",
"confidence",
"reason",
"best_candidate_id",
"best_candidate_decision",
"best_candidate_confidence",
"best_candidate_jaccard",
"best_candidate_line_coverage",
"best_candidate_primary_jaccard",
"best_candidate_primary_line_coverage",
"best_candidate_translation_jaccard",
"best_candidate_translation_line_coverage",
"best_candidate_reason",
"matched_unique_lines",
]
with out_path.open("w", encoding="utf-8", newline="") as out_file:
writer = csv.DictWriter(out_file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
for index, row in enumerate(reader, start=1):
row_out = _evaluate_row(
row,
row_number=index + 1,
checker=checker,
csv_path=csv_path,
base_dir=base_dir,
positive_decisions=positive_decisions,
max_candidates=max_candidates,
)
rows.append(row_out)
writer.writerow(row_out)
_progress_count("evaluate csv", index, total, step=1000)
summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path)
summary_path = out_path.with_suffix(out_path.suffix + ".summary.json")
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
_progress("evaluation complete")
print(json.dumps(summary, ensure_ascii=False))
......@@ -229,6 +238,45 @@ def _result_to_dict(result, *, source: str) -> dict[str, object]:
}
def _evaluate_row(
row: dict[str, str],
*,
row_number: int,
checker: DuplicateChecker,
csv_path: Path,
base_dir: Path | None,
positive_decisions: set[str],
max_candidates: int,
) -> dict[str, object]:
sample_id = row.get("id") or row.get("sample_id") or str(row_number)
record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
result = checker.check_record(record, max_candidates=max_candidates)
predicted_duplicate = result.decision.value in positive_decisions
best = result.candidates[0] if result.candidates else None
return {
"id": sample_id,
"source": source,
"expected_duplicate": expected_duplicate,
"decision": result.decision.value,
"predicted_duplicate": predicted_duplicate,
"correct": expected_duplicate == predicted_duplicate,
"confidence": result.confidence,
"reason": result.reason,
"best_candidate_id": best.record_id if best else "",
"best_candidate_decision": best.decision.value if best else "",
"best_candidate_confidence": best.confidence if best else "",
"best_candidate_jaccard": best.jaccard if best else "",
"best_candidate_line_coverage": best.line_coverage if best else "",
"best_candidate_primary_jaccard": best.primary_jaccard if best else "",
"best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
"best_candidate_translation_jaccard": best.translation_jaccard if best else "",
"best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
"best_candidate_reason": best.reason if best else "",
"matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
}
def _lyrics_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[str, str]:
lyrics = (row.get("lyrics") or "").strip()
if lyrics:
......@@ -322,5 +370,23 @@ def _evaluation_summary(
}
def _csv_data_row_count(csv_path: Path) -> int:
with csv_path.open(encoding="utf-8-sig", newline="") as file:
reader = csv.reader(file)
next(reader, None)
return sum(1 for _ in reader)
def _progress(message: str) -> None:
print(f"[eval] {message}", file=sys.stderr, flush=True)
def _progress_count(label: str, current: int, total: int, *, step: int = 1000) -> None:
if total <= 0:
return
if current == 1 or current == total or current % step == 0:
_progress(f"{label}: {current}/{total}")
if __name__ == "__main__":
main()
......
......@@ -308,9 +308,11 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None:
assert manifest["library_files"] == 12
assert manifest["sample_size"] == 30
assert manifest["unique_source_records"] > 1
assert manifest["holdout_records"] > 1
assert (tmp_path / "generated" / "eval.csv.index.pkl").exists()
assert "positive_full_duplicate" in manifest["plan"]
assert "negative_real_holdout_full_song" in negative_types
assert "negative_fragment" in negative_types
assert "negative_hard_candidate" in negative_types
assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_"))
......