Commit ba39ce6a ba39ce6aa50b5bc45d24bc32bcbc7c51b870922a by 沈秋雨

添加测试集内部去重

1 parent f8ad329c
...@@ -80,6 +80,7 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -80,6 +80,7 @@ python -m lyric_dedup.cli generate-eval-set \
80 --lyrics-dir data/generated_eval/incoming \ 80 --lyrics-dir data/generated_eval/incoming \
81 --csv data/generated_eval/eval_50000.csv \ 81 --csv data/generated_eval/eval_50000.csv \
82 --index outputs/indexes/lyrics.pkl \ 82 --index outputs/indexes/lyrics.pkl \
83 --eval-index data/generated_eval/eval_50000.csv.index.pkl \
83 --size 50000 \ 84 --size 50000 \
84 --positive-ratio 0.3 85 --positive-ratio 0.3
85 ``` 86 ```
...@@ -88,10 +89,10 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -88,10 +89,10 @@ python -m lyric_dedup.cli generate-eval-set \
88 89
89 - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。 90 - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。
90 - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。 91 - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。
91 - `不应去重` 样本包含同主题新歌词、hard negative、片段歌词、重复副歌碰撞、仅翻译相似、短歌词/占位边界样本。 92 - `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。
92 - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review` 93 - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`
93 - 如果传入 `--index`,生成器会用现有索引构造更接近线上召回风险的 hard negative 94 - 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它
94 - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、样本类型分布、语言/来源分桶和样本来源覆盖数。 95 - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。
95 96
96 先准备一个 CSV,例如 `data/eval/eval.csv` 97 先准备一个 CSV,例如 `data/eval/eval.csv`
97 98
......
...@@ -103,6 +103,7 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -103,6 +103,7 @@ python -m lyric_dedup.cli generate-eval-set \
103 --lyrics-dir data/generated_eval/incoming \ 103 --lyrics-dir data/generated_eval/incoming \
104 --csv data/generated_eval/eval_50000.csv \ 104 --csv data/generated_eval/eval_50000.csv \
105 --index outputs/indexes/library_lyrics.pkl \ 105 --index outputs/indexes/library_lyrics.pkl \
106 --eval-index data/generated_eval/eval_50000.csv.index.pkl \
106 --size 50000 \ 107 --size 50000 \
107 --positive-ratio 0.3 108 --positive-ratio 0.3
108 ``` 109 ```
...@@ -120,24 +121,26 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -120,24 +121,26 @@ python -m lyric_dedup.cli generate-eval-set \
120 121
121 ```text 122 ```text
122 positive_* = 应去重,全曲歌词样式变化 123 positive_* = 应去重,全曲歌词样式变化
123 negative_random_unrelated = 不应去重,同主题新歌词 124 negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除
124 negative_hard_candidate = 不应去重,系统容易召回的短句/局部重合样本
125 negative_fragment = 不应去重,单曲片段 125 negative_fragment = 不应去重,单曲片段
126 negative_shared_chorus = 不应去重,重复副歌碰撞 126 negative_shared_chorus = 不应去重,重复副歌碰撞
127 negative_translation_only = 不应去重,仅翻译相似 127 negative_translation_only = 不应去重,仅翻译相似
128 negative_same_theme_synthetic = 不应去重,同主题新歌词
128 edge_short_or_placeholder = 不应去重,短歌词/占位边界样本 129 edge_short_or_placeholder = 不应去重,短歌词/占位边界样本
129 ``` 130 ```
130 131
131 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。传入 `--index` 后会用现有索引生成 hard negative。每次还会输出: 132 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出:
132 133
133 ```text 134 ```text
134 data/generated_eval/eval_50000.csv.manifest.json 135 data/generated_eval/eval_50000.csv.manifest.json
136 data/generated_eval/eval_50000.csv.index.pkl
135 ``` 137 ```
136 138
137 manifest 里重点看: 139 manifest 里重点看:
138 140
139 ```text 141 ```text
140 library_files 曲库歌词文件数 142 library_files 曲库歌词文件数
143 holdout_records 从评估索引中排除、作为真实新歌负样本的数量
141 sample_type_counts 各样本类型数量 144 sample_type_counts 各样本类型数量
142 line_count_bucket_counts / language_bucket_counts / source_bucket_counts 145 line_count_bucket_counts / language_bucket_counts / source_bucket_counts
143 unique_source_records 本次评估覆盖了多少真实源文件 146 unique_source_records 本次评估覆盖了多少真实源文件
...@@ -147,7 +150,7 @@ unique_source_records 本次评估覆盖了多少真实源文件 ...@@ -147,7 +150,7 @@ unique_source_records 本次评估覆盖了多少真实源文件
147 150
148 ```bash 151 ```bash
149 python -m lyric_dedup.cli evaluate-csv \ 152 python -m lyric_dedup.cli evaluate-csv \
150 --index outputs/indexes/library_lyrics.pkl \ 153 --index data/generated_eval/eval_50000.csv.index.pkl \
151 --csv data/generated_eval/eval_50000.csv \ 154 --csv data/generated_eval/eval_50000.csv \
152 --base-dir data/generated_eval \ 155 --base-dir data/generated_eval \
153 --out outputs/results/library_eval_50000.csv 156 --out outputs/results/library_eval_50000.csv
...@@ -171,7 +174,7 @@ false_positive ...@@ -171,7 +174,7 @@ false_positive
171 174
172 ```bash 175 ```bash
173 python -m lyric_dedup.cli evaluate-csv \ 176 python -m lyric_dedup.cli evaluate-csv \
174 --index outputs/indexes/library_lyrics.pkl \ 177 --index data/generated_eval/eval_50000.csv.index.pkl \
175 --csv data/generated_eval/eval_50000.csv \ 178 --csv data/generated_eval/eval_50000.csv \
176 --base-dir data/generated_eval \ 179 --base-dir data/generated_eval \
177 --positive-decisions duplicate,review \ 180 --positive-decisions duplicate,review \
......
...@@ -96,16 +96,24 @@ class DuplicateChecker: ...@@ -96,16 +96,24 @@ class DuplicateChecker:
96 96
97 def add_record(self, record: LyricRecord) -> None: 97 def add_record(self, record: LyricRecord) -> None:
98 indexed = self._index(record) 98 indexed = self._index(record)
99 self._records[record.record_id] = indexed 99 self._add_indexed(record.record_id, indexed)
100 self._exact_hash_to_ids.setdefault(indexed.exact_hash, set()).add(record.record_id) 100
101 def add_normalized_record(self, record: LyricRecord, normalized: NormalizedLyrics) -> None:
102 """Add a record when normalized lyrics have already been computed."""
103 indexed = self._index_normalized(record, normalized)
104 self._add_indexed(record.record_id, indexed)
105
106 def _add_indexed(self, record_id: str, indexed: _IndexedRecord) -> None:
107 self._records[record_id] = indexed
108 self._exact_hash_to_ids.setdefault(indexed.exact_hash, set()).add(record_id)
101 for line in indexed.normalized.unique_lines: 109 for line in indexed.normalized.unique_lines:
102 if len(line) >= 4: 110 if len(line) >= 4:
103 self._line_to_ids.setdefault(line, set()).add(record.record_id) 111 self._line_to_ids.setdefault(line, set()).add(record_id)
104 for token in indexed.tokens: 112 for token in indexed.tokens:
105 self._token_to_ids.setdefault(token, set()).add(record.record_id) 113 self._token_to_ids.setdefault(token, set()).add(record_id)
106 for token in indexed.fallback_tokens: 114 for token in indexed.fallback_tokens:
107 self._token_to_ids.setdefault(token, set()).add(record.record_id) 115 self._token_to_ids.setdefault(token, set()).add(record_id)
108 self._lsh.add(record.record_id, indexed.signature) 116 self._lsh.add(record_id, indexed.signature)
109 117
110 def save(self, path: str | Path) -> None: 118 def save(self, path: str | Path) -> None:
111 """Persist the in-memory index for later checks.""" 119 """Persist the in-memory index for later checks."""
...@@ -187,6 +195,9 @@ class DuplicateChecker: ...@@ -187,6 +195,9 @@ class DuplicateChecker:
187 195
188 def _index(self, record: LyricRecord) -> _IndexedRecord: 196 def _index(self, record: LyricRecord) -> _IndexedRecord:
189 normalized = normalize_lyrics(record.lyrics) 197 normalized = normalize_lyrics(record.lyrics)
198 return self._index_normalized(record, normalized)
199
200 def _index_normalized(self, record: LyricRecord, normalized: NormalizedLyrics) -> _IndexedRecord:
190 tokens = lyric_tokens(normalized) 201 tokens = lyric_tokens(normalized)
191 primary_tokens = lyric_tokens(normalized, lines=normalized.primary_lines) 202 primary_tokens = lyric_tokens(normalized, lines=normalized.primary_lines)
192 translation_tokens = lyric_tokens(normalized, lines=normalized.translation_lines) 203 translation_tokens = lyric_tokens(normalized, lines=normalized.translation_lines)
......
...@@ -5,6 +5,7 @@ from __future__ import annotations ...@@ -5,6 +5,7 @@ from __future__ import annotations
5 import argparse 5 import argparse
6 import csv 6 import csv
7 import json 7 import json
8 import sys
8 from pathlib import Path 9 from pathlib import Path
9 10
10 from lyric_dedup.checker import DuplicateChecker 11 from lyric_dedup.checker import DuplicateChecker
...@@ -50,7 +51,8 @@ def main() -> None: ...@@ -50,7 +51,8 @@ def main() -> None:
50 generate.add_argument("--size", type=int, default=100) 51 generate.add_argument("--size", type=int, default=100)
51 generate.add_argument("--positive-ratio", type=float, default=0.3) 52 generate.add_argument("--positive-ratio", type=float, default=0.3)
52 generate.add_argument("--seed", type=int, default=20260602) 53 generate.add_argument("--seed", type=int, default=20260602)
53 generate.add_argument("--index", default="", help="optional existing index for hard-negative generation") 54 generate.add_argument("--index", default="", help="optional source index path recorded in the manifest")
55 generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set")
54 56
55 args = parser.parse_args() 57 args = parser.parse_args()
56 if args.command == "build-index": 58 if args.command == "build-index":
...@@ -77,6 +79,7 @@ def main() -> None: ...@@ -77,6 +79,7 @@ def main() -> None:
77 positive_ratio=args.positive_ratio, 79 positive_ratio=args.positive_ratio,
78 seed=args.seed, 80 seed=args.seed,
79 index_path=Path(args.index) if args.index else None, 81 index_path=Path(args.index) if args.index else None,
82 eval_index_path=Path(args.eval_index) if args.eval_index else None,
80 ) 83 )
81 print(json.dumps(summary, ensure_ascii=False)) 84 print(json.dumps(summary, ensure_ascii=False))
82 85
...@@ -155,52 +158,58 @@ def evaluate_csv( ...@@ -155,52 +158,58 @@ def evaluate_csv(
155 positive_decisions: set[str], 158 positive_decisions: set[str],
156 max_candidates: int, 159 max_candidates: int,
157 ) -> None: 160 ) -> None:
161 _progress(f"load index: {index_path}")
158 checker = DuplicateChecker.load(index_path) 162 checker = DuplicateChecker.load(index_path)
159 rows: list[dict[str, object]] = [] 163 rows: list[dict[str, object]] = []
164 total = _csv_data_row_count(csv_path)
165 _progress(f"evaluate csv: 0/{total}")
166 out_path.parent.mkdir(parents=True, exist_ok=True)
160 with csv_path.open(encoding="utf-8-sig", newline="") as file: 167 with csv_path.open(encoding="utf-8-sig", newline="") as file:
161 reader = csv.DictReader(file) 168 reader = csv.DictReader(file)
162 if reader.fieldnames is None: 169 if reader.fieldnames is None:
163 raise ValueError("评估 CSV 需要表头") 170 raise ValueError("评估 CSV 需要表头")
164 for row_number, row in enumerate(reader, start=2): 171 fieldnames = [
165 sample_id = row.get("id") or row.get("sample_id") or str(row_number) 172 "id",
166 record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir) 173 "source",
167 expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target")) 174 "expected_duplicate",
168 result = checker.check_record(record, max_candidates=max_candidates) 175 "decision",
169 predicted_duplicate = result.decision.value in positive_decisions 176 "predicted_duplicate",
170 best = result.candidates[0] if result.candidates else None 177 "correct",
171 rows.append( 178 "confidence",
172 { 179 "reason",
173 "id": sample_id, 180 "best_candidate_id",
174 "source": source, 181 "best_candidate_decision",
175 "expected_duplicate": expected_duplicate, 182 "best_candidate_confidence",
176 "decision": result.decision.value, 183 "best_candidate_jaccard",
177 "predicted_duplicate": predicted_duplicate, 184 "best_candidate_line_coverage",
178 "correct": expected_duplicate == predicted_duplicate, 185 "best_candidate_primary_jaccard",
179 "confidence": result.confidence, 186 "best_candidate_primary_line_coverage",
180 "reason": result.reason, 187 "best_candidate_translation_jaccard",
181 "best_candidate_id": best.record_id if best else "", 188 "best_candidate_translation_line_coverage",
182 "best_candidate_decision": best.decision.value if best else "", 189 "best_candidate_reason",
183 "best_candidate_confidence": best.confidence if best else "", 190 "matched_unique_lines",
184 "best_candidate_jaccard": best.jaccard if best else "", 191 ]
185 "best_candidate_line_coverage": best.line_coverage if best else "", 192 with out_path.open("w", encoding="utf-8", newline="") as out_file:
186 "best_candidate_primary_jaccard": best.primary_jaccard if best else "", 193 writer = csv.DictWriter(out_file, fieldnames=fieldnames)
187 "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
188 "best_candidate_translation_jaccard": best.translation_jaccard if best else "",
189 "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
190 "best_candidate_reason": best.reason if best else "",
191 "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
192 }
193 )
194
195 out_path.parent.mkdir(parents=True, exist_ok=True)
196 with out_path.open("w", encoding="utf-8", newline="") as file:
197 writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()) if rows else ["id"])
198 writer.writeheader() 194 writer.writeheader()
199 writer.writerows(rows) 195 for index, row in enumerate(reader, start=1):
196 row_out = _evaluate_row(
197 row,
198 row_number=index + 1,
199 checker=checker,
200 csv_path=csv_path,
201 base_dir=base_dir,
202 positive_decisions=positive_decisions,
203 max_candidates=max_candidates,
204 )
205 rows.append(row_out)
206 writer.writerow(row_out)
207 _progress_count("evaluate csv", index, total, step=1000)
200 208
201 summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path) 209 summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path)
202 summary_path = out_path.with_suffix(out_path.suffix + ".summary.json") 210 summary_path = out_path.with_suffix(out_path.suffix + ".summary.json")
203 summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") 211 summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
212 _progress("evaluation complete")
204 print(json.dumps(summary, ensure_ascii=False)) 213 print(json.dumps(summary, ensure_ascii=False))
205 214
206 215
...@@ -229,6 +238,45 @@ def _result_to_dict(result, *, source: str) -> dict[str, object]: ...@@ -229,6 +238,45 @@ def _result_to_dict(result, *, source: str) -> dict[str, object]:
229 } 238 }
230 239
231 240
241 def _evaluate_row(
242 row: dict[str, str],
243 *,
244 row_number: int,
245 checker: DuplicateChecker,
246 csv_path: Path,
247 base_dir: Path | None,
248 positive_decisions: set[str],
249 max_candidates: int,
250 ) -> dict[str, object]:
251 sample_id = row.get("id") or row.get("sample_id") or str(row_number)
252 record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
253 expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
254 result = checker.check_record(record, max_candidates=max_candidates)
255 predicted_duplicate = result.decision.value in positive_decisions
256 best = result.candidates[0] if result.candidates else None
257 return {
258 "id": sample_id,
259 "source": source,
260 "expected_duplicate": expected_duplicate,
261 "decision": result.decision.value,
262 "predicted_duplicate": predicted_duplicate,
263 "correct": expected_duplicate == predicted_duplicate,
264 "confidence": result.confidence,
265 "reason": result.reason,
266 "best_candidate_id": best.record_id if best else "",
267 "best_candidate_decision": best.decision.value if best else "",
268 "best_candidate_confidence": best.confidence if best else "",
269 "best_candidate_jaccard": best.jaccard if best else "",
270 "best_candidate_line_coverage": best.line_coverage if best else "",
271 "best_candidate_primary_jaccard": best.primary_jaccard if best else "",
272 "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
273 "best_candidate_translation_jaccard": best.translation_jaccard if best else "",
274 "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
275 "best_candidate_reason": best.reason if best else "",
276 "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
277 }
278
279
232 def _lyrics_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[str, str]: 280 def _lyrics_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[str, str]:
233 lyrics = (row.get("lyrics") or "").strip() 281 lyrics = (row.get("lyrics") or "").strip()
234 if lyrics: 282 if lyrics:
...@@ -322,5 +370,23 @@ def _evaluation_summary( ...@@ -322,5 +370,23 @@ def _evaluation_summary(
322 } 370 }
323 371
324 372
373 def _csv_data_row_count(csv_path: Path) -> int:
374 with csv_path.open(encoding="utf-8-sig", newline="") as file:
375 reader = csv.reader(file)
376 next(reader, None)
377 return sum(1 for _ in reader)
378
379
380 def _progress(message: str) -> None:
381 print(f"[eval] {message}", file=sys.stderr, flush=True)
382
383
384 def _progress_count(label: str, current: int, total: int, *, step: int = 1000) -> None:
385 if total <= 0:
386 return
387 if current == 1 or current == total or current % step == 0:
388 _progress(f"{label}: {current}/{total}")
389
390
325 if __name__ == "__main__": 391 if __name__ == "__main__":
326 main() 392 main()
......
...@@ -308,9 +308,11 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None: ...@@ -308,9 +308,11 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None:
308 assert manifest["library_files"] == 12 308 assert manifest["library_files"] == 12
309 assert manifest["sample_size"] == 30 309 assert manifest["sample_size"] == 30
310 assert manifest["unique_source_records"] > 1 310 assert manifest["unique_source_records"] > 1
311 assert manifest["holdout_records"] > 1
312 assert (tmp_path / "generated" / "eval.csv.index.pkl").exists()
311 assert "positive_full_duplicate" in manifest["plan"] 313 assert "positive_full_duplicate" in manifest["plan"]
314 assert "negative_real_holdout_full_song" in negative_types
312 assert "negative_fragment" in negative_types 315 assert "negative_fragment" in negative_types
313 assert "negative_hard_candidate" in negative_types
314 assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_")) 316 assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_"))
315 317
316 318
......