Commit 49008962 4900896283e6b52190437fd467f52ab75caf2530 by 沈秋雨

新增 PostgreSQL 去重检索链路与 hard 评估集支持

- 新增 PostgreSQL 导入脚本、评估脚本和 schema 定义,支持基于 exact_hash、pg_trgm 和行级 hash 的三层召回策略
- 评估 CLI 新增 hard profile,覆盖错别字、OCR 错误、整段翻译、medley 片段等更贴近业务边界的场景
- 调整 checker.py 复核阈值与匹配理由文案,优化翻译行相似与仅副歌重复场景的判定逻辑
- 同步更新 README、TEST_WORKFLOW 和单元测试

Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ba39ce6a
1 # PostgreSQL 迁移说明
2
3 本文档说明为什么要从 `.pkl` 索引迁移到 PostgreSQL、PostgreSQL 在这个项目里承担什么角色,以及从本地初始化到导入曲库的具体步骤。
4
5 ## 1. 为什么迁移
6
7 当前 `outputs/indexes/library_lyrics.pkl` 是一个 Python `DuplicateChecker` 对象快照。它适合实验和离线评估,但不适合线上服务:
8
9 - 增量新增、删除、回滚不方便。
10 - 多实例部署时很难同步同一个 `.pkl`
11 - 没有数据库事务、审计、版本管理。
12 - 索引结构绑定 Python pickle,不利于运维和长期维护。
13
14 PostgreSQL 的目标不是替代全部判重逻辑,而是替代“数据存储 + 候选召回”部分。最终判定仍保留在 Python 服务层。
15
16 ## 2. PostgreSQL 在本项目里的角色
17
18 推荐分工:
19
20 ```text
21 PostgreSQL:
22 保存原始歌词、规范化歌词、行级特征
23 用 exact_hash 做精确召回
24 用 pg_trgm 做文本近似召回
25 用 lyric_lines 做行级重合召回
26
27 Python:
28 normalize_lyrics
29 计算 Jaccard / line coverage
30 处理翻译行、副歌碰撞、片段保护
31 最终输出 duplicate / review / new
32 ```
33
34 不要让 PostgreSQL 的 `similarity()` 直接决定 `duplicate`。数据库只负责找候选,最终判定仍应复用当前算法规则。
35
36 ## 3. pkl 索引和 PostgreSQL 的对应关系
37
38 当前 `.pkl` 里主要有:
39
40 ```text
41 _records 每首歌词的完整索引记录
42 _exact_hash_to_ids exact hash -> record ids
43 _line_to_ids normalized line -> record ids
44 _token_to_ids n-gram token -> record ids
45 _lsh MinHash LSH 桶
46 ```
47
48 迁移到 PostgreSQL 后:
49
50 ```text
51 lyrics.exact_hash 对应 _exact_hash_to_ids
52 lyrics.primary_text + pg_trgm 对应近似文本召回
53 lyric_lines.line_hash 对应行级倒排召回
54 lyrics / lyric_lines 对应 _records 的可持久化部分
55 ```
56
57 第一阶段不迁移 MinHash LSH。先用 `exact_hash + pg_trgm + line_hash` 验证召回效果。
58
59 ## 4. 本地 PostgreSQL 基础
60
61 PostgreSQL 是关系型数据库,常用概念:
62
63 ```text
64 database 数据库,例如 lyric_dedup
65 schema 命名空间,默认 public
66 table 表,例如 lyrics
67 index 索引,用来加速查询
68 extension 扩展,例如 pg_trgm
69 DSN 连接字符串
70 ```
71
72 本机数据库连接字符串常见写法:
73
74 ```bash
75 postgresql:///lyric_dedup
76 ```
77
78 含义是:使用当前系统用户名,连接本机 PostgreSQL 的 `lyric_dedup` 数据库。
79
80 如果有用户名密码:
81
82 ```bash
83 postgresql://postgres:postgres@localhost:5432/lyric_dedup
84 ```
85
86 ## 5. 当前已新增脚本
87
88 项目里已经加入:
89
90 ```text
91 scripts/postgres_schema.sql
92 scripts/init_postgres.py
93 scripts/import_library_postgres.py
94 scripts/evaluate_postgres.py
95 ```
96
97 用途:
98
99 ```text
100 postgres_schema.sql 建表、建索引、启用 pg_trgm
101 init_postgres.py 自动执行 schema SQL
102 import_library_postgres.py 扫描 data/library,规范化后导入 PostgreSQL
103 evaluate_postgres.py 使用 PostgreSQL 召回候选并评测 CSV
104 ```
105
106 ## 6. 安装依赖
107
108 当前 Python 环境需要安装 PostgreSQL 驱动:
109
110 ```bash
111 python -m pip install 'psycopg[binary]'
112 ```
113
114 如果你使用 conda 环境,确认命令运行在当前项目所用的 `(base)` 或目标环境里。
115
116 验证:
117
118 ```bash
119 python - <<'PY'
120 import psycopg
121 print(psycopg.__version__)
122 PY
123 ```
124
125 ## 7. 创建数据库
126
127 你已经执行过:
128
129 ```bash
130 createdb lyric_dedup
131 ```
132
133 如果需要确认数据库存在:
134
135 ```bash
136 psql -l | grep lyric_dedup
137 ```
138
139 进入数据库:
140
141 ```bash
142 psql postgresql:///lyric_dedup
143 ```
144
145 退出 `psql`
146
147 ```text
148 \q
149 ```
150
151 ## 8. 初始化表结构
152
153 执行:
154
155 ```bash
156 python scripts/init_postgres.py \
157 --dsn postgresql:///lyric_dedup
158 ```
159
160 它会执行:
161
162 ```sql
163 create extension if not exists pg_trgm;
164 create table if not exists lyrics (...);
165 create table if not exists lyric_lines (...);
166 create index if not exists ...;
167 ```
168
169 成功输出类似:
170
171 ```text
172 initialized schema from scripts/postgres_schema.sql
173 ```
174
175 可以检查表:
176
177 ```bash
178 psql postgresql:///lyric_dedup -c '\dt'
179 ```
180
181 检查扩展:
182
183 ```bash
184 psql postgresql:///lyric_dedup -c 'select * from pg_extension;'
185 ```
186
187 ## 9. 表结构说明
188
189 ### lyrics
190
191 保存每首歌词的主记录:
192
193 ```text
194 record_id 当前文件生成的稳定 id
195 source_path 原始文件路径
196 title / artist 从文件名解析的元数据
197 raw_text 原始歌词
198 normalized_text 清洗后的全文
199 primary_text 原文行拼接文本,主要用于自动判重
200 translation_text 翻译行拼接文本
201 exact_hash 规范化原文 hash
202 split_confidence 翻译拆分置信度
203 split_reason 翻译拆分原因
204 line_count 有效歌词行数
205 deleted_at 软删除字段
206 ```
207
208 ### lyric_lines
209
210 保存行级特征:
211
212 ```text
213 lyric_id 对应 lyrics.id
214 role primary / translation / unknown
215 line_no 行号
216 normalized_line 规范化歌词行
217 line_hash 行 hash
218 ```
219
220 用途:快速找“哪些歌包含相同行”。
221
222 ## 10. 小批量导入测试
223
224 先导入 1000 条,确认环境和 schema 都正常:
225
226 ```bash
227 python scripts/import_library_postgres.py \
228 --dsn postgresql:///lyric_dedup \
229 --lyrics-dir data/library \
230 --limit 1000
231 ```
232
233 导入脚本默认会在导入结束后执行一次低风险库内去重:
234
235 ```text
236 exact_hash 完全一致的记录只保留一条,其余记录 soft delete,即设置 lyrics.deleted_at。
237 ```
238
239 重复清理报告默认写到:
240
241 ```text
242 outputs/results/postgres_exact_duplicates.csv
243 ```
244
245 如果只是想导入,不做 exact 去重:
246
247 ```bash
248 python scripts/import_library_postgres.py \
249 --dsn postgresql:///lyric_dedup \
250 --lyrics-dir data/library \
251 --limit 1000 \
252 --skip-dedup-exact
253 ```
254
255 查看数量:
256
257 ```bash
258 psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;'
259 psql postgresql:///lyric_dedup -c 'select count(*) from lyric_lines;'
260 ```
261
262 查看几条数据:
263
264 ```bash
265 psql postgresql:///lyric_dedup -c \
266 'select id, record_id, title, artist, line_count from lyrics limit 5;'
267 ```
268
269 ## 11. 全量导入
270
271 确认小批量没问题后,导入全量:
272
273 ```bash
274 python scripts/import_library_postgres.py \
275 --dsn postgresql:///lyric_dedup \
276 --lyrics-dir data/library
277 ```
278
279 脚本会显示进度:
280
281 ```text
282 [pg-import] files: 70295
283 [pg-import] import: 500/70295
284 ...
285 ```
286
287 导入是 upsert,同一个 `record_id` 再导入会更新,不会重复插入。
288
289 如果想额外生成“高行覆盖率近重复候选”报告,但不自动删除:
290
291 ```bash
292 python scripts/import_library_postgres.py \
293 --dsn postgresql:///lyric_dedup \
294 --lyrics-dir data/library \
295 --line-duplicate-report outputs/results/postgres_line_duplicates.csv \
296 --line-coverage-threshold 0.95
297 ```
298
299 注意:行覆盖率近重复报告可能较慢,且只用于抽查。当前脚本不会自动 soft delete 这些近重复候选。
300
301 ## 12. 基础 SQL 验证
302
303 ### exact hash 重复
304
305 找规范化 hash 重复:
306
307 ```bash
308 psql postgresql:///lyric_dedup -c "
309 select exact_hash, count(*)
310 from lyrics
311 where deleted_at is null
312 group by exact_hash
313 having count(*) > 1
314 order by count(*) desc
315 limit 20;
316 "
317 ```
318
319 如果导入时没有加 `--skip-dedup-exact`,这里理论上不应该再出现 active exact 重复;已经清理的重复记录可以这样查看:
320
321 ```bash
322 psql postgresql:///lyric_dedup -c "
323 select count(*)
324 from lyrics
325 where deleted_at is not null;
326 "
327 ```
328
329 ### pg_trgm 相似查询
330
331 测试 `pg_trgm`
332
333 ```bash
334 psql postgresql:///lyric_dedup -c "
335 select id, title, similarity(primary_text, '我爱你在每个夜里') as sim
336 from lyrics
337 where primary_text % '我爱你在每个夜里'
338 order by sim desc
339 limit 10;
340 "
341 ```
342
343 ### 行级重合
344
345 找某一行出现在哪些歌:
346
347 ```bash
348 psql postgresql:///lyric_dedup -c "
349 select l.id, l.title, ll.normalized_line
350 from lyric_lines ll
351 join lyrics l on l.id = ll.lyric_id
352 where ll.normalized_line = '我爱你在每个夜里'
353 limit 20;
354 "
355 ```
356
357 ## 13. 后续查重查询应该怎么做
358
359 未来 PostgreSQL 版查重流程:
360
361 ```text
362 1. Python 读取新增歌词
363 2. normalize_lyrics
364 3. SQL exact_hash 召回
365 4. SQL pg_trgm 召回
366 5. SQL lyric_lines 行级召回
367 6. 合并候选 id
368 7. 拉候选 normalized 数据
369 8. Python 复用当前打分规则
370 9. 输出 duplicate / review / new
371 ```
372
373 示意 SQL:
374
375 ```sql
376 select id
377 from lyrics
378 where exact_hash = $1
379 and deleted_at is null;
380 ```
381
382 ```sql
383 select id, similarity(primary_text, $1) as sim
384 from lyrics
385 where deleted_at is null
386 and primary_text % $1
387 order by sim desc
388 limit 200;
389 ```
390
391 ```sql
392 select lyric_id, count(*) as matched_lines
393 from lyric_lines
394 where role = 'primary'
395 and line_hash = any($1)
396 group by lyric_id
397 order by matched_lines desc
398 limit 200;
399 ```
400
401 ## 14. 增量更新设计
402
403 新增一首歌:
404
405 ```text
406 1. normalize
407 2. 用 PostgreSQL 召回候选
408 3. Python 判定
409 4. duplicate: 拒绝或关联已有记录
410 5. review: 进入人工复核
411 6. new: 写入 lyrics 和 lyric_lines
412 ```
413
414 删除一首歌:
415
416 ```sql
417 update lyrics
418 set deleted_at = now(), updated_at = now()
419 where id = ...;
420 ```
421
422 不建议物理删除,除非确认不需要审计。
423
424 更新一首歌:
425
426 ```text
427 1. 更新 lyrics.raw_text / normalized_text / primary_text / exact_hash
428 2. 删除旧 lyric_lines
429 3. 插入新 lyric_lines
430 4. 整个过程放在一个事务里
431 ```
432
433 ## 15. PostgreSQL 版评测
434
435 评测仍然需要先生成测试集。测试集是“输入样本 + 期望标签”,PostgreSQL 版评测只负责用 PostgreSQL 数据库召回候选并计算指标。
436
437 如果还没有测试集,先生成:
438
439 ```bash
440 python -m lyric_dedup.cli generate-eval-set \
441 --library-dir data/library \
442 --lyrics-dir data/generated_eval/incoming \
443 --csv data/generated_eval/eval_5000.csv \
444 --size 5000 \
445 --positive-ratio 0.3
446 ```
447
448 然后跑 PostgreSQL 版评测:
449
450 ```bash
451 python scripts/evaluate_postgres.py \
452 --dsn postgresql:///lyric_dedup \
453 --csv data/generated_eval/eval_5000.csv \
454 --base-dir data/generated_eval \
455 --out outputs/results/postgres_eval_5000.csv
456 ```
457
458 它会:
459
460 ```text
461 1. 对 eval 样本 normalize
462 2. 用 PostgreSQL exact_hash 召回
463 3. 用 pg_trgm primary_text 召回
464 4. 用 lyric_lines.line_hash 召回
465 5. 合并候选
466 6. 用 Python DuplicateChecker 对候选重新打分
467 7. 输出 duplicate / review / new 和指标
468 ```
469
470 如果想把 `review` 也算作“抓到可疑样本”:
471
472 ```bash
473 python scripts/evaluate_postgres.py \
474 --dsn postgresql:///lyric_dedup \
475 --csv data/generated_eval/eval_50000.csv \
476 --base-dir data/generated_eval \
477 --positive-decisions duplicate,review \
478 --out outputs/results/postgres_eval_50000_review_positive.csv
479 ```
480
481 可调参数:
482
483 ```text
484 --recall-limit 每类 SQL 召回最多取多少候选,默认 100
485 --enable-trgm 打开 pg_trgm 整段文本召回;默认关闭,避免评测过慢
486 --trgm-threshold pg_trgm 的 % 匹配阈值,默认 0.3,仅 --enable-trgm 时使用
487 --max-candidates 最终输出多少候选,默认 5
488 --statement-timeout-ms 单条 SQL 超时时间,默认 5000
489 ```
490
491 注意:当前 PostgreSQL 版是原型评测脚本。默认只用 `exact_hash + lyric_lines.line_hash` 召回,速度更可控。`pg_trgm` 可以作为补充召回,但整段歌词 trigram 查询在 5 万评测集上可能很慢,建议单独开小样本验证后再用于全量。
492
493 ## 16. 迁移验证标准
494
495 迁移不是导入完就结束。需要单独验证 PostgreSQL 版查重链路:
496
497 ```text
498 1. exact duplicate 是否能查到
499 2. punctuation / timestamp / platform noise 正例是否能召回
500 3. fragment / shared chorus 负例是否不会被直接判 duplicate
501 4. PostgreSQL 召回候选数量是否合理
502 5. PostgreSQL 版 evaluate 指标是否达到业务要求
503 ```
504
505 第一阶段目标:
506
507 ```text
508 PostgreSQL 负责召回,Python 仍负责判定。
509 ```
510
511 ## 17. 常见问题
512
513 ### 提示 `Missing dependency: psycopg`
514
515 运行:
516
517 ```bash
518 python -m pip install 'psycopg[binary]'
519 ```
520
521 ### 连接失败
522
523 检查 PostgreSQL 是否启动:
524
525 ```bash
526 pg_isready
527 ```
528
529 检查数据库是否存在:
530
531 ```bash
532 psql -l | grep lyric_dedup
533 ```
534
535 ### `pg_trgm` 创建失败
536
537 确认连接用户有创建 extension 权限。本机默认用户一般可以。
538
539 手动测试:
540
541 ```bash
542 psql postgresql:///lyric_dedup -c 'create extension if not exists pg_trgm;'
543 ```
544
545 ### 想清空重新导入
546
547 谨慎执行:
548
549 ```bash
550 psql postgresql:///lyric_dedup -c 'truncate lyric_lines, lyrics restart identity cascade;'
551 ```
552
553 然后重新运行导入脚本。
554
555 ## 18. 当前建议执行顺序
556
557 你现在已经完成:
558
559 ```bash
560 createdb lyric_dedup
561 ```
562
563 接下来执行:
564
565 ```bash
566 python -m pip install 'psycopg[binary]'
567 ```
568
569 ```bash
570 python scripts/init_postgres.py \
571 --dsn postgresql:///lyric_dedup
572 ```
573
574 ```bash
575 python scripts/import_library_postgres.py \
576 --dsn postgresql:///lyric_dedup \
577 --lyrics-dir data/library \
578 --limit 1000
579 ```
580
581 确认数量:
582
583 ```bash
584 psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;'
585 ```
586
587 确认后全量导入:
588
589 ```bash
590 python scripts/import_library_postgres.py \
591 --dsn postgresql:///lyric_dedup \
592 --lyrics-dir data/library
593 ```
...@@ -85,15 +85,33 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -85,15 +85,33 @@ python -m lyric_dedup.cli generate-eval-set \
85 --positive-ratio 0.3 85 --positive-ratio 0.3
86 ``` 86 ```
87 87
88 生成器的业务口径: 88 默认 `--profile standard` 生成常规生产评估集。也可以生成更贴近业务边界的 hard 集:
89
90 ```bash
91 python -m lyric_dedup.cli generate-eval-set \
92 --profile hard \
93 --library-dir data/library \
94 --lyrics-dir data/generated_eval/hard_incoming \
95 --csv data/generated_eval/eval_hard_5000.csv \
96 --eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \
97 --size 5000 \
98 --positive-ratio 0.3
99 ```
100
101 standard 业务口径:
89 102
90 - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。 103 - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。
91 - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。 104 - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译、少量错别字/英文拼写错误
92 - `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。 105 - `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。
93 - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review` 106 - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`
94 - 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它。 107 - 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它。
95 - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。 108 - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。
96 109
110 hard 业务口径不故意制造反常输入,主要覆盖上线更容易踩边界的情况:
111
112 - `应去重`: 同曲平台版本噪声、较完整歌词缺少一段、整段中文翻译附加、较真实的录入/OCR 错别字、时间戳和平台元信息混合。
113 - `不应去重`: 真实 holdout 新歌、从 holdout 中优先挑选和曲库有行重合的近邻新歌、较长但不完整的单曲片段、多曲 medley/串烧式片段、重复副歌碰撞、仅翻译相似、短歌词边界。
114
97 先准备一个 CSV,例如 `data/eval/eval.csv` 115 先准备一个 CSV,例如 `data/eval/eval.csv`
98 116
99 ```csv 117 ```csv
......
...@@ -108,6 +108,20 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -108,6 +108,20 @@ python -m lyric_dedup.cli generate-eval-set \
108 --positive-ratio 0.3 108 --positive-ratio 0.3
109 ``` 109 ```
110 110
111 如需生成更贴近业务边界的 hard 口径测试集:
112
113 ```bash
114 python -m lyric_dedup.cli generate-eval-set \
115 --profile hard \
116 --library-dir data/library \
117 --lyrics-dir data/generated_eval/hard_incoming \
118 --csv data/generated_eval/eval_hard_5000.csv \
119 --index outputs/indexes/library_lyrics.pkl \
120 --eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \
121 --size 5000 \
122 --positive-ratio 0.3
123 ```
124
111 默认生产评估口径: 125 默认生产评估口径:
112 126
113 ```text 127 ```text
...@@ -120,7 +134,7 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -120,7 +134,7 @@ python -m lyric_dedup.cli generate-eval-set \
120 业务口径: 134 业务口径:
121 135
122 ```text 136 ```text
123 positive_* = 应去重,全曲歌词样式变化 137 positive_* = 应去重,全曲歌词样式变化,包括少量错别字/英文拼写错误扰动
124 negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除 138 negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除
125 negative_fragment = 不应去重,单曲片段 139 negative_fragment = 不应去重,单曲片段
126 negative_shared_chorus = 不应去重,重复副歌碰撞 140 negative_shared_chorus = 不应去重,重复副歌碰撞
...@@ -129,6 +143,15 @@ negative_same_theme_synthetic = 不应去重,同主题新歌词 ...@@ -129,6 +143,15 @@ negative_same_theme_synthetic = 不应去重,同主题新歌词
129 edge_short_or_placeholder = 不应去重,短歌词/占位边界样本 143 edge_short_or_placeholder = 不应去重,短歌词/占位边界样本
130 ``` 144 ```
131 145
146 hard 口径额外强调真实业务边界,而不是故意制造反常难题:
147
148 ```text
149 positive_realistic_variant = 应去重,同曲平台版本噪声、较完整缺段、整段翻译附加、真实录入/OCR 错
150 negative_near_neighbor_holdout_full_song = 不应去重,和曲库有较多行重合的真实 holdout 新歌
151 negative_long_fragment = 不应去重,较长但不完整的单曲片段
152 negative_catalog_mashup = 不应去重,多首真实歌词片段组成的串烧/混剪式输入
153 ```
154
132 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出: 155 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出:
133 156
134 ```text 157 ```text
......
...@@ -5,7 +5,7 @@ from __future__ import annotations ...@@ -5,7 +5,7 @@ from __future__ import annotations
5 import hashlib 5 import hashlib
6 import pickle 6 import pickle
7 from dataclasses import dataclass 7 from dataclasses import dataclass
8 from enum import StrEnum 8 from enum import Enum
9 from pathlib import Path 9 from pathlib import Path
10 10
11 from lyric_dedup.minhash_lsh import MinHashConfig 11 from lyric_dedup.minhash_lsh import MinHashConfig
...@@ -16,7 +16,7 @@ from lyric_dedup.normalization import lyric_tokens ...@@ -16,7 +16,7 @@ from lyric_dedup.normalization import lyric_tokens
16 from lyric_dedup.normalization import normalize_lyrics 16 from lyric_dedup.normalization import normalize_lyrics
17 17
18 18
19 class DuplicateDecision(StrEnum): 19 class DuplicateDecision(str, Enum):
20 DUPLICATE = "duplicate" 20 DUPLICATE = "duplicate"
21 REVIEW = "review" 21 REVIEW = "review"
22 NEW = "new" 22 NEW = "new"
......
...@@ -53,6 +53,12 @@ def main() -> None: ...@@ -53,6 +53,12 @@ def main() -> None:
53 generate.add_argument("--seed", type=int, default=20260602) 53 generate.add_argument("--seed", type=int, default=20260602)
54 generate.add_argument("--index", default="", help="optional source index path recorded in the manifest") 54 generate.add_argument("--index", default="", help="optional source index path recorded in the manifest")
55 generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set") 55 generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set")
56 generate.add_argument(
57 "--profile",
58 choices=("standard", "hard"),
59 default="standard",
60 help="evaluation sample profile: standard production mix or harder business-realistic edge mix",
61 )
56 62
57 args = parser.parse_args() 63 args = parser.parse_args()
58 if args.command == "build-index": 64 if args.command == "build-index":
...@@ -80,6 +86,7 @@ def main() -> None: ...@@ -80,6 +86,7 @@ def main() -> None:
80 seed=args.seed, 86 seed=args.seed,
81 index_path=Path(args.index) if args.index else None, 87 index_path=Path(args.index) if args.index else None,
82 eval_index_path=Path(args.eval_index) if args.eval_index else None, 88 eval_index_path=Path(args.eval_index) if args.eval_index else None,
89 profile=args.profile,
83 ) 90 )
84 print(json.dumps(summary, ensure_ascii=False)) 91 print(json.dumps(summary, ensure_ascii=False))
85 92
......
...@@ -21,7 +21,7 @@ from lyric_dedup.normalization import fingerprint_text ...@@ -21,7 +21,7 @@ from lyric_dedup.normalization import fingerprint_text
21 from lyric_dedup.normalization import normalize_lyrics 21 from lyric_dedup.normalization import normalize_lyrics
22 22
23 23
24 DEFAULT_SAMPLE_MIX = { 24 STANDARD_SAMPLE_MIX = {
25 "positive_full_duplicate": 0.30, 25 "positive_full_duplicate": 0.30,
26 "negative_real_holdout_full_song": 0.40, 26 "negative_real_holdout_full_song": 0.40,
27 "negative_fragment": 0.10, 27 "negative_fragment": 0.10,
...@@ -30,6 +30,18 @@ DEFAULT_SAMPLE_MIX = { ...@@ -30,6 +30,18 @@ DEFAULT_SAMPLE_MIX = {
30 "negative_same_theme_synthetic": 0.05, 30 "negative_same_theme_synthetic": 0.05,
31 "edge_short_or_placeholder": 0.05, 31 "edge_short_or_placeholder": 0.05,
32 } 32 }
33 DEFAULT_SAMPLE_MIX = STANDARD_SAMPLE_MIX
34
35 HARD_SAMPLE_MIX = {
36 "positive_realistic_variant": 0.30,
37 "negative_real_holdout_full_song": 0.20,
38 "negative_near_neighbor_holdout_full_song": 0.20,
39 "negative_long_fragment": 0.15,
40 "negative_shared_chorus": 0.05,
41 "negative_translation_only": 0.04,
42 "negative_catalog_mashup": 0.04,
43 "edge_short_or_placeholder": 0.02,
44 }
33 45
34 46
35 def _progress(message: str) -> None: 47 def _progress(message: str) -> None:
...@@ -87,6 +99,7 @@ def generate_eval_set( ...@@ -87,6 +99,7 @@ def generate_eval_set(
87 seed: int = 20260602, 99 seed: int = 20260602,
88 index_path: Path | None = None, 100 index_path: Path | None = None,
89 eval_index_path: Path | None = None, 101 eval_index_path: Path | None = None,
102 profile: str = "standard",
90 ) -> dict[str, object]: 103 ) -> dict[str, object]:
91 """Generate a stratified production evaluation set. 104 """Generate a stratified production evaluation set.
92 105
...@@ -96,7 +109,10 @@ def generate_eval_set( ...@@ -96,7 +109,10 @@ def generate_eval_set(
96 if size <= 0: 109 if size <= 0:
97 raise ValueError("size must be positive") 110 raise ValueError("size must be positive")
98 111
99 _progress(f"start generation: size={size}, positive_ratio={positive_ratio}, seed={seed}") 112 if profile not in {"standard", "hard"}:
113 raise ValueError("profile must be 'standard' or 'hard'")
114
115 _progress(f"start generation: profile={profile}, size={size}, positive_ratio={positive_ratio}, seed={seed}")
100 rng = random.Random(seed) 116 rng = random.Random(seed)
101 profiles = profile_library(library_dir) 117 profiles = profile_library(library_dir)
102 if not profiles: 118 if not profiles:
...@@ -107,9 +123,9 @@ def generate_eval_set( ...@@ -107,9 +123,9 @@ def generate_eval_set(
107 _progress(f"clean output dir: {output_dir}") 123 _progress(f"clean output dir: {output_dir}")
108 _clean_generated_output_dir(output_dir) 124 _clean_generated_output_dir(output_dir)
109 125
110 plan = _sample_plan(size, positive_ratio=positive_ratio) 126 plan = _sample_plan(size, positive_ratio=positive_ratio, profile=profile)
111 _progress(f"sample plan: {plan}") 127 _progress(f"sample plan: {plan}")
112 holdout_count = min(plan["negative_real_holdout_full_song"], max(1, len(profiles) // 2)) 128 holdout_count = min(_holdout_plan_count(plan), max(1, len(profiles) // 2))
113 holdout_profiles = _stratified_unique_sample( 129 holdout_profiles = _stratified_unique_sample(
114 profiles, 130 profiles,
115 holdout_count, 131 holdout_count,
...@@ -122,6 +138,20 @@ def generate_eval_set( ...@@ -122,6 +138,20 @@ def generate_eval_set(
122 groups = _profile_groups(indexed_profiles) 138 groups = _profile_groups(indexed_profiles)
123 samples: list[GeneratedSample] = [] 139 samples: list[GeneratedSample] = []
124 140
141 if profile == "hard":
142 samples.extend(
143 _build_hard_samples(
144 plan,
145 groups=groups,
146 holdout_profiles=holdout_profiles,
147 indexed_profiles=indexed_profiles,
148 output_dir=output_dir,
149 csv_base=csv_path.parent,
150 rng=rng,
151 start_index=len(samples) + 1,
152 )
153 )
154 else:
125 _progress("build positive_full_duplicate samples") 155 _progress("build positive_full_duplicate samples")
126 samples.extend( 156 samples.extend(
127 _build_positive_samples( 157 _build_positive_samples(
...@@ -136,7 +166,7 @@ def generate_eval_set( ...@@ -136,7 +166,7 @@ def generate_eval_set(
136 _progress("build negative_real_holdout_full_song samples") 166 _progress("build negative_real_holdout_full_song samples")
137 samples.extend( 167 samples.extend(
138 _build_holdout_full_song_samples( 168 _build_holdout_full_song_samples(
139 holdout_profiles, 169 holdout_profiles[: plan["negative_real_holdout_full_song"]],
140 output_dir, 170 output_dir,
141 csv_path.parent, 171 csv_path.parent,
142 start_index=len(samples) + 1, 172 start_index=len(samples) + 1,
...@@ -226,6 +256,7 @@ def generate_eval_set( ...@@ -226,6 +256,7 @@ def generate_eval_set(
226 index_path=index_path, 256 index_path=index_path,
227 eval_index_path=eval_index_path, 257 eval_index_path=eval_index_path,
228 holdout_count=len(holdout_profiles), 258 holdout_count=len(holdout_profiles),
259 profile=profile,
229 ) 260 )
230 _progress("generation complete") 261 _progress("generation complete")
231 return manifest 262 return manifest
...@@ -264,14 +295,16 @@ def profile_library(library_dir: Path) -> list[LyricProfile]: ...@@ -264,14 +295,16 @@ def profile_library(library_dir: Path) -> list[LyricProfile]:
264 return profiles 295 return profiles
265 296
266 297
267 def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]: 298 def _sample_plan(size: int, *, positive_ratio: float, profile: str) -> dict[str, int]:
268 positive_ratio = max(0.0, min(1.0, positive_ratio)) 299 positive_ratio = max(0.0, min(1.0, positive_ratio))
269 mix = dict(DEFAULT_SAMPLE_MIX) 300 mix = dict(HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX)
270 negative_total = sum(value for key, value in mix.items() if key != "positive_full_duplicate") 301 positive_key = "positive_realistic_variant" if profile == "hard" else "positive_full_duplicate"
271 mix["positive_full_duplicate"] = positive_ratio 302 negative_total = sum(value for key, value in mix.items() if key != positive_key)
303 mix[positive_key] = positive_ratio
272 for key in list(mix): 304 for key in list(mix):
273 if key != "positive_full_duplicate": 305 if key != positive_key:
274 mix[key] = (1.0 - positive_ratio) * (DEFAULT_SAMPLE_MIX[key] / negative_total) 306 base_mix = HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX
307 mix[key] = (1.0 - positive_ratio) * (base_mix[key] / negative_total)
275 308
276 plan = {key: int(size * value) for key, value in mix.items()} 309 plan = {key: int(size * value) for key, value in mix.items()}
277 remainder = size - sum(plan.values()) 310 remainder = size - sum(plan.values())
...@@ -283,6 +316,10 @@ def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]: ...@@ -283,6 +316,10 @@ def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]:
283 return plan 316 return plan
284 317
285 318
319 def _holdout_plan_count(plan: dict[str, int]) -> int:
320 return plan.get("negative_real_holdout_full_song", 0) + plan.get("negative_near_neighbor_holdout_full_song", 0)
321
322
286 def _profile_groups(profiles: list[LyricProfile]) -> dict[str, list[LyricProfile]]: 323 def _profile_groups(profiles: list[LyricProfile]) -> dict[str, list[LyricProfile]]:
287 normal = [profile for profile in profiles if profile.line_count >= 6] 324 normal = [profile for profile in profiles if profile.line_count >= 6]
288 edge = [profile for profile in profiles if profile.line_count <= 5] 325 edge = [profile for profile in profiles if profile.line_count <= 5]
...@@ -375,6 +412,7 @@ def _build_positive_samples( ...@@ -375,6 +412,7 @@ def _build_positive_samples(
375 ("positive_blank_line_noise", _add_blank_line_noise(lines)), 412 ("positive_blank_line_noise", _add_blank_line_noise(lines)),
376 ("positive_chorus_count_changed", _change_repeated_line_counts(lines)), 413 ("positive_chorus_count_changed", _change_repeated_line_counts(lines)),
377 ("positive_translation_added", _translation_added(lines)), 414 ("positive_translation_added", _translation_added(lines)),
415 ("positive_typo_noise", _add_typo_noise(lines, rng)),
378 ] 416 ]
379 sample_type, text = variants[offset % len(variants)] 417 sample_type, text = variants[offset % len(variants)]
380 index = start_index + offset 418 index = start_index + offset
...@@ -384,31 +422,181 @@ def _build_positive_samples( ...@@ -384,31 +422,181 @@ def _build_positive_samples(
384 return samples 422 return samples
385 423
386 424
425 def _build_hard_samples(
426 plan: dict[str, int],
427 *,
428 groups: dict[str, list[LyricProfile]],
429 holdout_profiles: list[LyricProfile],
430 indexed_profiles: list[LyricProfile],
431 output_dir: Path,
432 csv_base: Path,
433 rng: random.Random,
434 start_index: int,
435 ) -> list[GeneratedSample]:
436 samples: list[GeneratedSample] = []
437
438 _progress("build positive_realistic_variant samples")
439 samples.extend(
440 _build_realistic_positive_samples(
441 _stratified_sample(groups["normal"], plan["positive_realistic_variant"], rng),
442 output_dir,
443 csv_base,
444 rng,
445 start_index=start_index + len(samples),
446 )
447 )
448 _progress(f"built samples: {len(samples)}")
449
450 real_holdout_count = plan.get("negative_real_holdout_full_song", 0)
451 _progress("build negative_real_holdout_full_song samples")
452 samples.extend(
453 _build_holdout_full_song_samples(
454 holdout_profiles[:real_holdout_count],
455 output_dir,
456 csv_base,
457 start_index=start_index + len(samples),
458 )
459 )
460 _progress(f"built samples: {len(samples)}")
461
462 near_count = plan.get("negative_near_neighbor_holdout_full_song", 0)
463 _progress("build negative_near_neighbor_holdout_full_song samples")
464 near_holdouts = _near_neighbor_holdouts(
465 holdout_profiles[real_holdout_count:],
466 indexed_profiles,
467 near_count,
468 )
469 samples.extend(
470 _build_holdout_full_song_samples(
471 near_holdouts,
472 output_dir,
473 csv_base,
474 start_index=start_index + len(samples),
475 sample_type="negative_near_neighbor_holdout_full_song",
476 notes="full real holdout lyric selected for catalog line overlap with indexed songs",
477 )
478 )
479 _progress(f"built samples: {len(samples)}")
480
481 _progress("build negative_long_fragment samples")
482 samples.extend(
483 _build_fragment_samples(
484 _stratified_sample(groups["fragmentable"], plan.get("negative_long_fragment", 0), rng),
485 output_dir,
486 csv_base,
487 rng,
488 start_index=start_index + len(samples),
489 sample_type="negative_long_fragment",
490 long_fragment=True,
491 notes="realistic long partial lyric upload, not a full-song duplicate",
492 )
493 )
494 _progress(f"built samples: {len(samples)}")
495
496 _progress("build negative_shared_chorus samples")
497 samples.extend(
498 _build_shared_chorus_samples(
499 _stratified_sample(groups["normal"], plan.get("negative_shared_chorus", 0), rng),
500 output_dir,
501 csv_base,
502 rng,
503 start_index=start_index + len(samples),
504 )
505 )
506 _progress(f"built samples: {len(samples)}")
507
508 _progress("build negative_translation_only samples")
509 samples.extend(
510 _build_translation_only_samples(
511 _stratified_sample(groups["foreign"], plan.get("negative_translation_only", 0), rng),
512 output_dir,
513 csv_base,
514 rng,
515 start_index=start_index + len(samples),
516 )
517 )
518 _progress(f"built samples: {len(samples)}")
519
520 _progress("build negative_catalog_mashup samples")
521 samples.extend(
522 _build_catalog_mashup_samples(
523 _stratified_sample(groups["normal"], plan.get("negative_catalog_mashup", 0) * 3, rng),
524 plan.get("negative_catalog_mashup", 0),
525 output_dir,
526 csv_base,
527 rng,
528 start_index=start_index + len(samples),
529 )
530 )
531 _progress(f"built samples: {len(samples)}")
532
533 _progress("build edge_short_or_placeholder samples")
534 samples.extend(
535 _build_edge_samples(
536 _stratified_sample(groups["edge"], plan.get("edge_short_or_placeholder", 0), rng),
537 output_dir,
538 csv_base,
539 rng,
540 start_index=start_index + len(samples),
541 )
542 )
543 return samples
544
545
546 def _build_realistic_positive_samples(
547 profiles: list[LyricProfile],
548 output_dir: Path,
549 csv_base: Path,
550 rng: random.Random,
551 *,
552 start_index: int,
553 ) -> list[GeneratedSample]:
554 samples: list[GeneratedSample] = []
555 for offset, profile in enumerate(profiles):
556 content_lines = _content_lines(profile.raw_text)
557 primary_lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) or content_lines
558 variants = [
559 ("positive_platform_mixed_noise", _platform_mixed_noise(content_lines, rng)),
560 ("positive_near_full_missing_section", _near_full_missing_section(primary_lines, rng)),
561 ("positive_block_translation_added", _block_translation_added(primary_lines)),
562 ("positive_typo_and_punctuation_noise", _stronger_typo_and_punctuation_noise(content_lines, rng)),
563 ("positive_timestamped_platform_variant", _timestamped_platform_variant(content_lines)),
564 ("positive_chorus_count_changed", _change_repeated_line_counts(content_lines)),
565 ]
566 sample_type, text = variants[offset % len(variants)]
567 index = start_index + offset
568 path = _write_sample_file(output_dir, f"pos_{index:05d}_{sample_type}.txt", text)
569 samples.append(_sample_from_profile(index, path, csv_base, "应去重", sample_type, profile))
570 _progress_count("positive_realistic_variant", len(samples), len(profiles))
571 return samples
572
573
387 def _build_holdout_full_song_samples( 574 def _build_holdout_full_song_samples(
388 profiles: list[LyricProfile], 575 profiles: list[LyricProfile],
389 output_dir: Path, 576 output_dir: Path,
390 csv_base: Path, 577 csv_base: Path,
391 *, 578 *,
392 start_index: int, 579 start_index: int,
580 sample_type: str = "negative_real_holdout_full_song",
581 notes: str = "full real lyric held out from the generated eval index",
393 ) -> list[GeneratedSample]: 582 ) -> list[GeneratedSample]:
394 _progress("build negative_real_holdout_full_song samples")
395 samples: list[GeneratedSample] = [] 583 samples: list[GeneratedSample] = []
396 for offset, profile in enumerate(profiles): 584 for offset, profile in enumerate(profiles):
397 index = start_index + offset 585 index = start_index + offset
398 text = profile.raw_text 586 text = profile.raw_text
399 path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_real_holdout_full_song.txt", text) 587 path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text)
400 samples.append( 588 samples.append(
401 _sample_from_profile( 589 _sample_from_profile(
402 index, 590 index,
403 path, 591 path,
404 csv_base, 592 csv_base,
405 "不应去重", 593 "不应去重",
406 "negative_real_holdout_full_song", 594 sample_type,
407 profile, 595 profile,
408 notes="full real lyric held out from the generated eval index", 596 notes=notes,
409 ) 597 )
410 ) 598 )
411 _progress_count("negative_real_holdout_full_song", len(samples), len(profiles)) 599 _progress_count(sample_type, len(samples), len(profiles))
412 return samples 600 return samples
413 601
414 602
...@@ -446,25 +634,59 @@ def _build_fragment_samples( ...@@ -446,25 +634,59 @@ def _build_fragment_samples(
446 rng: random.Random, 634 rng: random.Random,
447 *, 635 *,
448 start_index: int, 636 start_index: int,
637 sample_type: str = "negative_fragment",
638 long_fragment: bool = False,
639 notes: str = "partial lyric fragment only",
449 ) -> list[GeneratedSample]: 640 ) -> list[GeneratedSample]:
450 samples: list[GeneratedSample] = [] 641 samples: list[GeneratedSample] = []
451 for offset, profile in enumerate(profiles): 642 for offset, profile in enumerate(profiles):
452 lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) 643 lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines)
453 text = _single_song_fragment(lines, rng) 644 text = _long_song_fragment(lines, rng) if long_fragment else _single_song_fragment(lines, rng)
454 index = start_index + offset 645 index = start_index + offset
455 path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_fragment.txt", text) 646 path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text)
456 samples.append( 647 samples.append(
457 _sample_from_profile( 648 _sample_from_profile(
458 index, 649 index,
459 path, 650 path,
460 csv_base, 651 csv_base,
461 "不应去重", 652 "不应去重",
462 "negative_fragment", 653 sample_type,
463 profile, 654 profile,
464 notes="partial lyric fragment only", 655 notes=notes,
656 )
657 )
658 _progress_count(sample_type, len(samples), len(profiles))
659 return samples
660
661
662 def _build_catalog_mashup_samples(
663 profiles: list[LyricProfile],
664 count: int,
665 output_dir: Path,
666 csv_base: Path,
667 rng: random.Random,
668 *,
669 start_index: int,
670 ) -> list[GeneratedSample]:
671 samples: list[GeneratedSample] = []
672 if count <= 0 or not profiles:
673 return samples
674 for offset in range(count):
675 index = start_index + offset
676 picked = rng.sample(profiles, k=min(3, len(profiles)))
677 text = _catalog_mashup_text(picked, rng)
678 path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_catalog_mashup.txt", text)
679 samples.append(
680 GeneratedSample(
681 sample_id=f"sample-{index:05d}",
682 file=str(path.relative_to(csv_base)),
683 expected="不应去重",
684 sample_type="negative_catalog_mashup",
685 source=" | ".join(str(profile.path) for profile in picked),
686 notes="medley-style partial lyric assembled from multiple catalog songs",
465 ) 687 )
466 ) 688 )
467 _progress_count("negative_fragment", len(samples), len(profiles)) 689 _progress_count("negative_catalog_mashup", len(samples), count)
468 return samples 690 return samples
469 691
470 692
...@@ -658,8 +880,10 @@ def _write_manifest( ...@@ -658,8 +880,10 @@ def _write_manifest(
658 index_path: Path | None, 880 index_path: Path | None,
659 eval_index_path: Path, 881 eval_index_path: Path,
660 holdout_count: int, 882 holdout_count: int,
883 profile: str,
661 ) -> dict[str, object]: 884 ) -> dict[str, object]:
662 manifest = { 885 manifest = {
886 "profile": profile,
663 "seed": seed, 887 "seed": seed,
664 "library_files": len(profiles), 888 "library_files": len(profiles),
665 "sample_size": len(samples), 889 "sample_size": len(samples),
...@@ -684,6 +908,38 @@ def _write_manifest( ...@@ -684,6 +908,38 @@ def _write_manifest(
684 return manifest 908 return manifest
685 909
686 910
911 def _near_neighbor_holdouts(
912 holdout_profiles: list[LyricProfile],
913 indexed_profiles: list[LyricProfile],
914 count: int,
915 ) -> list[LyricProfile]:
916 if count <= 0 or not holdout_profiles:
917 return []
918 if not indexed_profiles:
919 return holdout_profiles[:count]
920
921 line_to_indexed_count: Counter[str] = Counter()
922 for profile in indexed_profiles:
923 for line in set(profile.normalized.primary_lines or profile.normalized.unique_lines):
924 if len(line) >= 4:
925 line_to_indexed_count[line] += 1
926
927 scored: list[tuple[float, LyricProfile]] = []
928 for profile in holdout_profiles:
929 lines = set(profile.normalized.primary_lines or profile.normalized.unique_lines)
930 useful_lines = {line for line in lines if len(line) >= 4}
931 if not useful_lines:
932 score = 0.0
933 else:
934 shared = sum(1 for line in useful_lines if line_to_indexed_count[line] > 0)
935 common_weight = sum(min(line_to_indexed_count[line], 5) for line in useful_lines)
936 score = (shared / len(useful_lines)) + (common_weight / (len(useful_lines) * 20))
937 scored.append((score, profile))
938
939 scored.sort(key=lambda item: item[0], reverse=True)
940 return [profile for _, profile in scored[:count]]
941
942
687 def _content_lines(text: str) -> list[str]: 943 def _content_lines(text: str) -> list[str]:
688 lines = [line.strip() for line in text.splitlines() if line.strip()] 944 lines = [line.strip() for line in text.splitlines() if line.strip()]
689 return lines or [text.strip()] 945 return lines or [text.strip()]
...@@ -735,6 +991,18 @@ def _add_timestamps(lines: list[str]) -> str: ...@@ -735,6 +991,18 @@ def _add_timestamps(lines: list[str]) -> str:
735 return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1)) 991 return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1))
736 992
737 993
994 def _platform_mixed_noise(lines: list[str], rng: random.Random) -> str:
995 noisy = _add_blank_line_noise(lines).splitlines()
996 if noisy:
997 noisy = _add_punctuation_noise(noisy, rng).splitlines()
998 return "\n".join(["作词:未知", "歌词来自平台同步", *noisy, "未经著作权人许可 不得商业使用"])
999
1000
1001 def _timestamped_platform_variant(lines: list[str]) -> str:
1002 timestamped = _add_timestamps(lines).splitlines()
1003 return "\n".join(["[00:00.00]歌词贡献者:用户上传", *timestamped])
1004
1005
738 def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str: 1006 def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str:
739 marks = ["!", "?", "...", ",", "。"] 1007 marks = ["!", "?", "...", ",", "。"]
740 return "\n".join(f"{line}{rng.choice(marks)}" for line in lines) 1008 return "\n".join(f"{line}{rng.choice(marks)}" for line in lines)
...@@ -773,6 +1041,97 @@ def _translation_added(lines: list[str]) -> str: ...@@ -773,6 +1041,97 @@ def _translation_added(lines: list[str]) -> str:
773 return "\n".join(result) 1041 return "\n".join(result)
774 1042
775 1043
1044 def _block_translation_added(lines: list[str]) -> str:
1045 body = "\n".join(lines)
1046 translation_count = min(8, max(4, len(lines) // 4))
1047 translations = [_pseudo_translation(index) for index in range(1, translation_count + 1)]
1048 return "\n".join([body, "", *translations])
1049
1050
1051 def _near_full_missing_section(lines: list[str], rng: random.Random) -> str:
1052 if len(lines) <= 8:
1053 return "\n".join(lines)
1054 drop_count = max(1, min(max(1, len(lines) // 5), 8))
1055 start = rng.randrange(0, max(1, len(lines) - drop_count + 1))
1056 kept = lines[:start] + lines[start + drop_count :]
1057 return "\n".join(kept or lines)
1058
1059
1060 def _add_typo_noise(lines: list[str], rng: random.Random) -> str:
1061 if not lines:
1062 return ""
1063 result = list(lines)
1064 editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)]
1065 if not editable_indexes:
1066 return "\n".join(result)
1067 typo_count = max(1, min(4, len(editable_indexes) // 8 or 1))
1068 for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))):
1069 result[index] = _typo_line(result[index], rng)
1070 return "\n".join(result)
1071
1072
1073 def _stronger_typo_and_punctuation_noise(lines: list[str], rng: random.Random) -> str:
1074 if not lines:
1075 return ""
1076 result = _add_punctuation_noise(lines, rng).splitlines()
1077 editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)]
1078 typo_count = max(1, min(8, len(editable_indexes) // 6 or 1))
1079 for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))):
1080 result[index] = _typo_line(result[index], rng)
1081 return "\n".join(result)
1082
1083
1084 def _can_typo_line(line: str) -> bool:
1085 return bool(re.search(r"[A-Za-z]{4,}|[\u4e00-\u9fff]{4,}", line))
1086
1087
1088 def _typo_line(line: str, rng: random.Random) -> str:
1089 words = list(re.finditer(r"[A-Za-z]{4,}", line))
1090 if words and rng.random() < 0.65:
1091 match = rng.choice(words)
1092 typo = _typo_english_word(match.group(0), rng)
1093 return line[: match.start()] + typo + line[match.end() :]
1094 cjk_positions = [index for index, char in enumerate(line) if "\u4e00" <= char <= "\u9fff"]
1095 if cjk_positions:
1096 index = rng.choice(cjk_positions)
1097 return line[:index] + _typo_cjk_char(line[index]) + line[index + 1 :]
1098 return line
1099
1100
1101 def _typo_english_word(word: str, rng: random.Random) -> str:
1102 if len(word) <= 4 or rng.random() < 0.55:
1103 remove_at = rng.randrange(1, max(2, len(word) - 1))
1104 return word[:remove_at] + word[remove_at + 1 :]
1105 swap_at = rng.randrange(1, max(2, len(word) - 2))
1106 chars = list(word)
1107 chars[swap_at], chars[swap_at + 1] = chars[swap_at + 1], chars[swap_at]
1108 return "".join(chars)
1109
1110
1111 def _typo_cjk_char(char: str) -> str:
1112 replacements = {
1113 "你": "妳",
1114 "爱": "爰",
1115 "夜": "液",
1116 "里": "裏",
1117 "风": "凤",
1118 "雨": "兩",
1119 "听": "昕",
1120 "说": "説",
1121 "想": "相",
1122 "梦": "夣",
1123 "心": "芯",
1124 "光": "先",
1125 "城": "诚",
1126 "远": "迩",
1127 "回": "囬",
1128 "走": "赱",
1129 "海": "毎",
1130 "天": "夭",
1131 }
1132 return replacements.get(char, char)
1133
1134
776 def _single_song_fragment(lines: list[str], rng: random.Random) -> str: 1135 def _single_song_fragment(lines: list[str], rng: random.Random) -> str:
777 if len(lines) <= 4: 1136 if len(lines) <= 4:
778 return "\n".join(lines[: max(1, len(lines) // 2)]) 1137 return "\n".join(lines[: max(1, len(lines) // 2)])
...@@ -781,6 +1140,28 @@ def _single_song_fragment(lines: list[str], rng: random.Random) -> str: ...@@ -781,6 +1140,28 @@ def _single_song_fragment(lines: list[str], rng: random.Random) -> str:
781 return "\n".join(lines[start : start + fragment_len]) 1140 return "\n".join(lines[start : start + fragment_len])
782 1141
783 1142
1143 def _long_song_fragment(lines: list[str], rng: random.Random) -> str:
1144 if len(lines) <= 8:
1145 return _single_song_fragment(lines, rng)
1146 fragment_len = max(6, min(len(lines) - 1, int(len(lines) * rng.uniform(0.35, 0.60))))
1147 start = rng.randrange(0, max(1, len(lines) - fragment_len + 1))
1148 return "\n".join(lines[start : start + fragment_len])
1149
1150
1151 def _catalog_mashup_text(profiles: list[LyricProfile], rng: random.Random) -> str:
1152 sections: list[str] = []
1153 for profile in profiles:
1154 lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines)
1155 if not lines:
1156 continue
1157 section_len = min(max(2, len(lines) // 8), 5)
1158 start = rng.randrange(0, max(1, len(lines) - section_len + 1))
1159 sections.extend(lines[start : start + section_len])
1160 if not sections:
1161 return _same_theme_synthetic(0, rng)
1162 return "\n".join(sections)
1163
1164
784 def _short_shared_snippet(lines: list[str], rng: random.Random) -> str: 1165 def _short_shared_snippet(lines: list[str], rng: random.Random) -> str:
785 snippet = rng.sample(lines, k=min(2, len(lines))) if lines else [] 1166 snippet = rng.sample(lines, k=min(2, len(lines))) if lines else []
786 synthetic = [ 1167 synthetic = [
......
1 # Test runner
2 pytest>=8.0
3
4 # PostgreSQL storage prototype
5 psycopg[binary]>=3.2
6
7 # Existing MySQL/COS lyric download utilities
8 pymysql>=1.1
9 cos-python-sdk-v5>=1.9
10 tqdm>=4.66
1 """Evaluate lyric duplicate checking with PostgreSQL-backed candidate recall."""
2
3 from __future__ import annotations
4
5 import argparse
6 import csv
7 import hashlib
8 import json
9 import sys
10 import time
11 from pathlib import Path
12 from typing import Any
13
14
15 PROJECT_ROOT = Path(__file__).resolve().parents[1]
16 if str(PROJECT_ROOT) not in sys.path:
17 sys.path.insert(0, str(PROJECT_ROOT))
18
19 from lyric_dedup.checker import DuplicateChecker
20 from lyric_dedup.checker import LyricRecord
21 from lyric_dedup.file_import import read_lyric_file
22 from lyric_dedup.file_import import record_from_file
23 from lyric_dedup.normalization import fingerprint_text
24 from lyric_dedup.normalization import normalize_lyrics
25
26
27 def main() -> None:
28 parser = argparse.ArgumentParser(description="Evaluate duplicate checking using PostgreSQL recall.")
29 parser.add_argument("--dsn", required=True)
30 parser.add_argument("--csv", required=True)
31 parser.add_argument("--out", required=True)
32 parser.add_argument("--base-dir", default="")
33 parser.add_argument("--positive-decisions", default="duplicate")
34 parser.add_argument("--max-candidates", type=int, default=5)
35 parser.add_argument("--recall-limit", type=int, default=100)
36 parser.add_argument("--enable-trgm", action="store_true", help="Enable pg_trgm full-text recall. Slower; exact + line recall is used by default.")
37 parser.add_argument("--trgm-threshold", type=float, default=0.3)
38 parser.add_argument("--statement-timeout-ms", type=int, default=5000)
39 parser.add_argument("--profile-every", type=int, default=100)
40 args = parser.parse_args()
41
42 psycopg = _import_psycopg()
43 csv_path = Path(args.csv)
44 out_path = Path(args.out)
45 base_dir = Path(args.base_dir) if args.base_dir else None
46 positive_decisions = {item.strip() for item in args.positive_decisions.split(",") if item.strip()}
47
48 total = _csv_data_row_count(csv_path)
49 rows: list[dict[str, object]] = []
50 profile_stats = _new_profile_stats()
51 out_path.parent.mkdir(parents=True, exist_ok=True)
52 _progress(f"evaluate postgres csv: 0/{total}")
53 with psycopg.connect(args.dsn) as conn:
54 with conn.cursor() as cursor:
55 cursor.execute("select set_config('statement_timeout', %s, false)", (str(args.statement_timeout_ms),))
56 cursor.execute("select set_config('pg_trgm.similarity_threshold', %s, false)", (str(args.trgm_threshold),))
57 with csv_path.open(encoding="utf-8-sig", newline="") as in_file, out_path.open(
58 "w", encoding="utf-8", newline=""
59 ) as out_file:
60 reader = csv.DictReader(in_file)
61 if reader.fieldnames is None:
62 raise ValueError("评估 CSV 需要表头")
63 writer = csv.DictWriter(out_file, fieldnames=_fieldnames())
64 writer.writeheader()
65 for index, row in enumerate(reader, start=1):
66 row_out = _evaluate_row(
67 conn,
68 row,
69 row_number=index + 1,
70 csv_path=csv_path,
71 base_dir=base_dir,
72 positive_decisions=positive_decisions,
73 max_candidates=args.max_candidates,
74 recall_limit=args.recall_limit,
75 enable_trgm=args.enable_trgm,
76 )
77 rows.append(row_out)
78 writer.writerow(row_out)
79 _progress_count("evaluate postgres csv", index, total, step=10)
80 _update_profile_stats(profile_stats, row_out)
81 if args.profile_every > 0 and index % args.profile_every == 0:
82 _progress(_format_profile_stats(profile_stats, index))
83
84 summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path)
85 summary_path = out_path.with_suffix(out_path.suffix + ".summary.json")
86 summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
87 _progress("postgres evaluation complete")
88 print(json.dumps(summary, ensure_ascii=False))
89
90
91 def _evaluate_row(
92 conn: Any,
93 row: dict[str, str],
94 *,
95 row_number: int,
96 csv_path: Path,
97 base_dir: Path | None,
98 positive_decisions: set[str],
99 max_candidates: int,
100 recall_limit: int,
101 enable_trgm: bool,
102 ) -> dict[str, object]:
103 parse_started = time.perf_counter()
104 sample_id = row.get("id") or row.get("sample_id") or str(row_number)
105 record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
106 expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
107 parse_ms = round((time.perf_counter() - parse_started) * 1000, 2)
108 candidates, timings = _recall_candidates(
109 conn,
110 record,
111 recall_limit=recall_limit,
112 enable_trgm=enable_trgm,
113 exclude_record_ids=_exclude_record_ids_for_eval_row(row),
114 )
115 rank_started = time.perf_counter()
116 result = _check_against_candidates(record, candidates, max_candidates=max_candidates)
117 rank_ms = round((time.perf_counter() - rank_started) * 1000, 2)
118 recall_ms = round(timings["exact_ms"] + timings["trgm_ms"] + timings["line_ms"], 2)
119 predicted_duplicate = result.decision.value in positive_decisions
120 best = result.candidates[0] if result.candidates else None
121 return {
122 "id": sample_id,
123 "source": source,
124 "expected_duplicate": expected_duplicate,
125 "decision": result.decision.value,
126 "predicted_duplicate": predicted_duplicate,
127 "correct": expected_duplicate == predicted_duplicate,
128 "confidence": result.confidence,
129 "reason": result.reason,
130 "candidate_count": len(candidates),
131 "parse_ms": parse_ms,
132 "recall_ms": recall_ms,
133 "exact_ms": timings["exact_ms"],
134 "trgm_ms": timings["trgm_ms"],
135 "line_ms": timings["line_ms"],
136 "rank_ms": rank_ms,
137 "best_candidate_id": best.record_id if best else "",
138 "best_candidate_decision": best.decision.value if best else "",
139 "best_candidate_confidence": best.confidence if best else "",
140 "best_candidate_jaccard": best.jaccard if best else "",
141 "best_candidate_line_coverage": best.line_coverage if best else "",
142 "best_candidate_primary_jaccard": best.primary_jaccard if best else "",
143 "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
144 "best_candidate_translation_jaccard": best.translation_jaccard if best else "",
145 "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
146 "best_candidate_reason": best.reason if best else "",
147 "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
148 }
149
150
151 def _recall_candidates(
152 conn: Any,
153 record: LyricRecord,
154 *,
155 recall_limit: int,
156 enable_trgm: bool,
157 exclude_record_ids: list[str],
158 ) -> tuple[list[LyricRecord], dict[str, float]]:
159 query_lyrics = _pg_text(record.lyrics) or ""
160 normalized = normalize_lyrics(query_lyrics)
161 exact_text = fingerprint_text(normalized)
162 exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
163 primary_text = "\n".join(normalized.primary_lines)
164 line_hashes = [hashlib.sha256(line.encode("utf-8")).hexdigest() for line in normalized.primary_lines if line]
165 candidates: dict[str, LyricRecord] = {}
166 timings = {"exact_ms": 0.0, "trgm_ms": 0.0, "line_ms": 0.0}
167 with conn.cursor() as cursor:
168 started = time.perf_counter()
169 cursor.execute(
170 """
171 select record_id, raw_text, title, artist
172 from lyrics
173 where deleted_at is null
174 and exact_hash = %s
175 and not (record_id = any(%s))
176 limit %s
177 """,
178 (exact_hash, exclude_record_ids, recall_limit),
179 )
180 _add_rows(candidates, cursor.fetchall())
181 timings["exact_ms"] = round((time.perf_counter() - started) * 1000, 2)
182
183 if enable_trgm and primary_text:
184 started = time.perf_counter()
185 cursor.execute(
186 """
187 select record_id, raw_text, title, artist
188 from lyrics
189 where deleted_at is null
190 and not (record_id = any(%s))
191 and primary_text %% %s
192 order by similarity(primary_text, %s) desc
193 limit %s
194 """,
195 (exclude_record_ids, primary_text, primary_text, recall_limit),
196 )
197 _add_rows(candidates, cursor.fetchall())
198 timings["trgm_ms"] = round((time.perf_counter() - started) * 1000, 2)
199
200 if line_hashes:
201 started = time.perf_counter()
202 cursor.execute(
203 """
204 select l.record_id, l.raw_text, l.title, l.artist
205 from lyric_lines ll
206 join lyrics l on l.id = ll.lyric_id
207 where l.deleted_at is null
208 and not (l.record_id = any(%s))
209 and ll.role = 'primary'
210 and ll.line_hash = any(%s)
211 group by l.id
212 order by count(*) desc
213 limit %s
214 """,
215 (exclude_record_ids, line_hashes, recall_limit),
216 )
217 _add_rows(candidates, cursor.fetchall())
218 timings["line_ms"] = round((time.perf_counter() - started) * 1000, 2)
219 return list(candidates.values()), timings
220
221
222 def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]:
223 if row.get("sample_type") == "negative_real_holdout_full_song" and row.get("source_record_id"):
224 return [row["source_record_id"]]
225 return []
226
227
228 def _add_rows(candidates: dict[str, LyricRecord], rows: list[tuple[object, ...]]) -> None:
229 for record_id, raw_text, title, artist in rows:
230 candidates.setdefault(
231 str(record_id),
232 LyricRecord(
233 record_id=str(record_id),
234 lyrics=str(raw_text),
235 title=str(title) if title is not None else None,
236 artist=str(artist) if artist is not None else None,
237 ),
238 )
239
240
241 def _check_against_candidates(
242 record: LyricRecord,
243 candidates: list[LyricRecord],
244 *,
245 max_candidates: int,
246 ):
247 checker = DuplicateChecker()
248 for candidate in candidates:
249 checker.add_record(candidate)
250 return checker.check_record(record, max_candidates=max_candidates)
251
252
253 def _record_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[LyricRecord, str]:
254 lyrics = (row.get("lyrics") or "").strip()
255 if lyrics:
256 return (
257 LyricRecord(
258 record_id=row.get("id") or row.get("sample_id") or "__eval__",
259 lyrics=_pg_text(lyrics.replace("\\n", "\n")) or "",
260 title=_pg_text(row.get("title") or None),
261 artist=_pg_text(row.get("artist") or None),
262 ),
263 "inline",
264 )
265
266 file_value = (row.get("file") or row.get("path") or row.get("source") or "").strip()
267 if not file_value:
268 raise ValueError("评估 CSV 每行需要 lyrics,或 file/path/source 文件路径")
269
270 file_path = Path(file_value)
271 if not file_path.is_absolute():
272 file_path = (base_dir or csv_path.parent) / file_path
273 record = record_from_file(file_path)
274 record = LyricRecord(
275 record_id=record.record_id,
276 lyrics=_pg_text(record.lyrics) or "",
277 title=_pg_text(record.title),
278 artist=_pg_text(record.artist),
279 )
280 if row.get("title") or row.get("artist"):
281 record = LyricRecord(
282 record_id=record.record_id,
283 lyrics=record.lyrics,
284 title=_pg_text(row.get("title") or record.title),
285 artist=_pg_text(row.get("artist") or record.artist),
286 )
287 return record, str(file_path)
288
289
290 def _parse_expected(value: str | None) -> bool:
291 if value is None:
292 raise ValueError("评估 CSV 每行需要 expected/label/target 列")
293 normalized = value.strip().lower()
294 positives = {"1", "true", "yes", "y", "duplicate", "dup", "重复", "应去重", "去重", "是"}
295 negatives = {"0", "false", "no", "n", "new", "not_duplicate", "non_duplicate", "不重复", "不应去重", "新歌", "否"}
296 if normalized in positives:
297 return True
298 if normalized in negatives:
299 return False
300 raise ValueError(f"无法识别 expected 值: {value!r}")
301
302
303 def _evaluation_summary(
304 rows: list[dict[str, object]],
305 *,
306 positive_decisions: set[str],
307 out_path: Path,
308 ) -> dict[str, object]:
309 tp = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is True)
310 fp = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is True)
311 tn = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is False)
312 fn = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is False)
313 total = len(rows)
314 precision = tp / (tp + fp) if tp + fp else 0.0
315 recall = tp / (tp + fn) if tp + fn else 0.0
316 accuracy = (tp + tn) / total if total else 0.0
317 f1 = (2 * precision * recall / (precision + recall)) if precision + recall else 0.0
318 return {
319 "total": total,
320 "positive_decisions": sorted(positive_decisions),
321 "accuracy": round(accuracy, 4),
322 "precision": round(precision, 4),
323 "recall": round(recall, 4),
324 "f1": round(f1, 4),
325 "true_positive": tp,
326 "false_positive": fp,
327 "true_negative": tn,
328 "false_negative": fn,
329 "duplicate": sum(1 for row in rows if row["decision"] == "duplicate"),
330 "review": sum(1 for row in rows if row["decision"] == "review"),
331 "new": sum(1 for row in rows if row["decision"] == "new"),
332 "out": str(out_path),
333 "summary": str(out_path.with_suffix(out_path.suffix + ".summary.json")),
334 }
335
336
337 def _fieldnames() -> list[str]:
338 return [
339 "id",
340 "source",
341 "expected_duplicate",
342 "decision",
343 "predicted_duplicate",
344 "correct",
345 "confidence",
346 "reason",
347 "candidate_count",
348 "parse_ms",
349 "recall_ms",
350 "exact_ms",
351 "trgm_ms",
352 "line_ms",
353 "rank_ms",
354 "best_candidate_id",
355 "best_candidate_decision",
356 "best_candidate_confidence",
357 "best_candidate_jaccard",
358 "best_candidate_line_coverage",
359 "best_candidate_primary_jaccard",
360 "best_candidate_primary_line_coverage",
361 "best_candidate_translation_jaccard",
362 "best_candidate_translation_line_coverage",
363 "best_candidate_reason",
364 "matched_unique_lines",
365 ]
366
367
368 def _csv_data_row_count(csv_path: Path) -> int:
369 with csv_path.open(encoding="utf-8-sig", newline="") as file:
370 reader = csv.reader(file)
371 next(reader, None)
372 return sum(1 for _ in reader)
373
374
375 def _progress(message: str) -> None:
376 print(f"[pg-eval] {message}", file=sys.stderr, flush=True)
377
378
379 def _progress_count(label: str, current: int, total: int, *, step: int = 1000) -> None:
380 if total <= 0:
381 return
382 if current == 1 or current == total or current % step == 0:
383 _progress(f"{label}: {current}/{total}")
384
385
386 def _new_profile_stats() -> dict[str, float]:
387 return {
388 "parse_ms": 0.0,
389 "exact_ms": 0.0,
390 "trgm_ms": 0.0,
391 "line_ms": 0.0,
392 "rank_ms": 0.0,
393 "recall_ms": 0.0,
394 "candidate_count": 0.0,
395 }
396
397
398 def _update_profile_stats(stats: dict[str, float], row: dict[str, object]) -> None:
399 for key in stats:
400 try:
401 stats[key] += float(row.get(key) or 0)
402 except (TypeError, ValueError):
403 pass
404
405
406 def _format_profile_stats(stats: dict[str, float], count: int) -> str:
407 if count <= 0:
408 return "profile: no rows"
409 return (
410 "profile avg "
411 f"parse={stats['parse_ms'] / count:.2f}ms "
412 f"exact={stats['exact_ms'] / count:.2f}ms "
413 f"line={stats['line_ms'] / count:.2f}ms "
414 f"trgm={stats['trgm_ms'] / count:.2f}ms "
415 f"rank={stats['rank_ms'] / count:.2f}ms "
416 f"recall={stats['recall_ms'] / count:.2f}ms "
417 f"candidates={stats['candidate_count'] / count:.1f}"
418 )
419
420
421 def _pg_text(value: str | None) -> str | None:
422 if value is None:
423 return None
424 return value.replace("\x00", "")
425
426
427 def _import_psycopg():
428 try:
429 import psycopg
430
431 return psycopg
432 except ModuleNotFoundError:
433 print(
434 "Missing dependency: psycopg. Install it with:\n"
435 " python -m pip install 'psycopg[binary]'",
436 file=sys.stderr,
437 )
438 raise SystemExit(1)
439
440
441 if __name__ == "__main__":
442 main()
1 """Import normalized lyric library records into PostgreSQL."""
2
3 from __future__ import annotations
4
5 import argparse
6 import csv
7 import hashlib
8 import sys
9 from pathlib import Path
10 from typing import Any
11
12
13 PROJECT_ROOT = Path(__file__).resolve().parents[1]
14 if str(PROJECT_ROOT) not in sys.path:
15 sys.path.insert(0, str(PROJECT_ROOT))
16
17 from lyric_dedup.file_import import iter_lyric_files
18 from lyric_dedup.file_import import record_from_file
19 from lyric_dedup.normalization import fingerprint_text
20 from lyric_dedup.normalization import normalize_lyrics
21
22
23 def main() -> None:
24 parser = argparse.ArgumentParser(description="Import lyric library into PostgreSQL.")
25 parser.add_argument("--dsn", required=True)
26 parser.add_argument("--lyrics-dir", required=True)
27 parser.add_argument("--batch-size", type=int, default=500)
28 parser.add_argument("--limit", type=int, default=0)
29 parser.add_argument("--skip-dedup-exact", action="store_true", help="Skip exact-hash duplicate soft deletion after import.")
30 parser.add_argument("--duplicate-report", default="outputs/results/postgres_exact_duplicates.csv")
31 parser.add_argument("--line-duplicate-report", default="", help="Optional CSV report for high line-coverage duplicate candidates.")
32 parser.add_argument("--line-coverage-threshold", type=float, default=0.95)
33 parser.add_argument("--line-duplicate-limit", type=int, default=10000)
34 args = parser.parse_args()
35
36 psycopg = _import_psycopg()
37 lyrics_dir = Path(args.lyrics_dir)
38 paths = iter_lyric_files(lyrics_dir)
39 if args.limit > 0:
40 paths = paths[: args.limit]
41 print(f"[pg-import] files: {len(paths)}", file=sys.stderr, flush=True)
42
43 imported = 0
44 exact_deleted = 0
45 line_reported = 0
46 nul_cleaned = 0
47 with psycopg.connect(args.dsn) as conn:
48 for start in range(0, len(paths), args.batch_size):
49 batch = paths[start : start + args.batch_size]
50 with conn.transaction():
51 with conn.cursor() as cursor:
52 for path in batch:
53 lyric_id, line_rows, cleaned = _upsert_lyric(cursor, path, lyrics_dir)
54 nul_cleaned += cleaned
55 cursor.execute("delete from lyric_lines where lyric_id = %s", (lyric_id,))
56 if line_rows:
57 cursor.executemany(
58 """
59 insert into lyric_lines
60 (lyric_id, role, line_no, normalized_line, line_hash)
61 values (%s, %s, %s, %s, %s)
62 """,
63 line_rows,
64 )
65 imported += 1
66 _progress("import", imported, len(paths), step=args.batch_size)
67 if not args.skip_dedup_exact:
68 exact_deleted = _soft_delete_exact_duplicates(conn, Path(args.duplicate_report))
69 if args.line_duplicate_report:
70 line_reported = _write_line_duplicate_report(
71 conn,
72 Path(args.line_duplicate_report),
73 threshold=args.line_coverage_threshold,
74 limit=args.line_duplicate_limit,
75 )
76 print(
77 {
78 "imported": imported,
79 "records_with_nul_cleaned": nul_cleaned,
80 "exact_duplicates_soft_deleted": exact_deleted,
81 "line_duplicate_candidates_reported": line_reported,
82 }
83 )
84
85
86 def _upsert_lyric(cursor: Any, path: Path, lyrics_dir: Path) -> tuple[int, list[tuple[object, ...]], int]:
87 record = record_from_file(path, base_dir=lyrics_dir)
88 raw_text, raw_cleaned = _pg_text(record.lyrics)
89 normalized = normalize_lyrics(raw_text)
90 primary_text = _pg_text("\n".join(normalized.primary_lines))[0]
91 translation_text = _pg_text("\n".join(normalized.translation_lines))[0] or None
92 normalized_text = _pg_text(normalized.normalized_full_text)[0]
93 exact_text = fingerprint_text(normalized)
94 exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
95 cursor.execute(
96 """
97 insert into lyrics (
98 record_id, source_path, title, artist, raw_text, normalized_text,
99 primary_text, translation_text, exact_hash, split_confidence,
100 split_reason, line_count, updated_at, deleted_at
101 )
102 values (
103 %(record_id)s, %(source_path)s, %(title)s, %(artist)s, %(raw_text)s,
104 %(normalized_text)s, %(primary_text)s, %(translation_text)s,
105 %(exact_hash)s, %(split_confidence)s, %(split_reason)s,
106 %(line_count)s, now(), null
107 )
108 on conflict (record_id) do update set
109 source_path = excluded.source_path,
110 title = excluded.title,
111 artist = excluded.artist,
112 raw_text = excluded.raw_text,
113 normalized_text = excluded.normalized_text,
114 primary_text = excluded.primary_text,
115 translation_text = excluded.translation_text,
116 exact_hash = excluded.exact_hash,
117 split_confidence = excluded.split_confidence,
118 split_reason = excluded.split_reason,
119 line_count = excluded.line_count,
120 updated_at = now(),
121 deleted_at = null
122 returning id
123 """,
124 {
125 "record_id": record.record_id,
126 "source_path": str(path),
127 "title": _pg_text(record.title)[0],
128 "artist": _pg_text(record.artist)[0],
129 "raw_text": raw_text,
130 "normalized_text": normalized_text,
131 "primary_text": primary_text,
132 "translation_text": translation_text,
133 "exact_hash": exact_hash,
134 "split_confidence": _pg_text(normalized.split_confidence)[0],
135 "split_reason": _pg_text(normalized.split_reason)[0],
136 "line_count": len(normalized.primary_lines or normalized.unique_lines),
137 },
138 )
139 lyric_id = cursor.fetchone()[0]
140 line_rows: list[tuple[object, ...]] = []
141 line_rows.extend(_line_rows(lyric_id, "primary", normalized.primary_lines))
142 line_rows.extend(_line_rows(lyric_id, "translation", normalized.translation_lines))
143 line_rows.extend(_line_rows(lyric_id, "unknown", normalized.unknown_lines))
144 return lyric_id, line_rows, int(raw_cleaned)
145
146
147 def _line_rows(lyric_id: int, role: str, lines: tuple[str, ...]) -> list[tuple[object, ...]]:
148 rows: list[tuple[object, ...]] = []
149 for index, line in enumerate(lines):
150 line = _pg_text(line)[0] or ""
151 line_hash = hashlib.sha256(line.encode("utf-8")).hexdigest()
152 rows.append((lyric_id, role, index, line, line_hash))
153 return rows
154
155
156 def _pg_text(value: str | None) -> tuple[str | None, bool]:
157 if value is None:
158 return None, False
159 if "\x00" not in value:
160 return value, False
161 return value.replace("\x00", ""), True
162
163
164 def _soft_delete_exact_duplicates(conn: Any, report_path: Path) -> int:
165 print("[pg-import] deduplicate exact_hash duplicates", file=sys.stderr, flush=True)
166 with conn.transaction():
167 with conn.cursor() as cursor:
168 cursor.execute(
169 """
170 with ranked as (
171 select
172 id,
173 exact_hash,
174 first_value(id) over (
175 partition by exact_hash
176 order by
177 case when source_path like '%/None_%' then 1 else 0 end,
178 line_count desc,
179 length(primary_text) desc,
180 id
181 ) as kept_id,
182 row_number() over (
183 partition by exact_hash
184 order by
185 case when source_path like '%/None_%' then 1 else 0 end,
186 line_count desc,
187 length(primary_text) desc,
188 id
189 ) as rn
190 from lyrics
191 where deleted_at is null
192 ),
193 to_delete as (
194 select id, exact_hash, kept_id
195 from ranked
196 where rn > 1
197 ),
198 updated as (
199 update lyrics l
200 set deleted_at = now(), updated_at = now()
201 from to_delete d
202 where l.id = d.id
203 returning
204 l.id as duplicate_id,
205 l.record_id as duplicate_record_id,
206 l.source_path as duplicate_source_path,
207 d.exact_hash,
208 d.kept_id
209 )
210 select
211 u.duplicate_id,
212 u.duplicate_record_id,
213 u.duplicate_source_path,
214 k.id as kept_id,
215 k.record_id as kept_record_id,
216 k.source_path as kept_source_path,
217 u.exact_hash
218 from updated u
219 join lyrics k on k.id = u.kept_id
220 order by u.exact_hash, u.duplicate_id
221 """
222 )
223 rows = cursor.fetchall()
224 _write_rows(
225 report_path,
226 [
227 "duplicate_id",
228 "duplicate_record_id",
229 "duplicate_source_path",
230 "kept_id",
231 "kept_record_id",
232 "kept_source_path",
233 "exact_hash",
234 ],
235 rows,
236 )
237 print(f"[pg-import] exact duplicates soft-deleted: {len(rows)}", file=sys.stderr, flush=True)
238 return len(rows)
239
240
241 def _write_line_duplicate_report(conn: Any, report_path: Path, *, threshold: float, limit: int) -> int:
242 print("[pg-import] report high line-coverage duplicate candidates", file=sys.stderr, flush=True)
243 with conn.cursor() as cursor:
244 cursor.execute(
245 """
246 with pairs as (
247 select
248 a.lyric_id as left_id,
249 b.lyric_id as right_id,
250 count(*) as matched_lines
251 from lyric_lines a
252 join lyric_lines b
253 on a.line_hash = b.line_hash
254 and a.lyric_id < b.lyric_id
255 join lyrics la on la.id = a.lyric_id and la.deleted_at is null
256 join lyrics lb on lb.id = b.lyric_id and lb.deleted_at is null
257 where a.role = 'primary'
258 and b.role = 'primary'
259 group by a.lyric_id, b.lyric_id
260 )
261 select
262 p.left_id,
263 l1.record_id as left_record_id,
264 l1.source_path as left_source_path,
265 p.right_id,
266 l2.record_id as right_record_id,
267 l2.source_path as right_source_path,
268 p.matched_lines,
269 l1.line_count as left_line_count,
270 l2.line_count as right_line_count,
271 p.matched_lines::float / greatest(l1.line_count, l2.line_count) as line_coverage
272 from pairs p
273 join lyrics l1 on l1.id = p.left_id
274 join lyrics l2 on l2.id = p.right_id
275 where p.matched_lines::float / greatest(l1.line_count, l2.line_count) >= %s
276 order by line_coverage desc, matched_lines desc
277 limit %s
278 """,
279 (threshold, limit),
280 )
281 rows = cursor.fetchall()
282 _write_rows(
283 report_path,
284 [
285 "left_id",
286 "left_record_id",
287 "left_source_path",
288 "right_id",
289 "right_record_id",
290 "right_source_path",
291 "matched_lines",
292 "left_line_count",
293 "right_line_count",
294 "line_coverage",
295 ],
296 rows,
297 )
298 print(f"[pg-import] line duplicate candidates reported: {len(rows)}", file=sys.stderr, flush=True)
299 return len(rows)
300
301
302 def _write_rows(report_path: Path, fieldnames: list[str], rows: list[tuple[object, ...]]) -> None:
303 report_path.parent.mkdir(parents=True, exist_ok=True)
304 with report_path.open("w", encoding="utf-8", newline="") as file:
305 writer = csv.writer(file)
306 writer.writerow(fieldnames)
307 writer.writerows(rows)
308
309
310 def _progress(label: str, current: int, total: int, *, step: int) -> None:
311 if current == total or current % step == 0:
312 print(f"[pg-import] {label}: {current}/{total}", file=sys.stderr, flush=True)
313
314
315 def _import_psycopg():
316 try:
317 import psycopg
318
319 return psycopg
320 except ModuleNotFoundError:
321 print(
322 "Missing dependency: psycopg. Install it with:\n"
323 " python -m pip install 'psycopg[binary]'",
324 file=sys.stderr,
325 )
326 raise SystemExit(1)
327
328
329 if __name__ == "__main__":
330 main()
1 """Initialize PostgreSQL schema for lyric dedup storage."""
2
3 from __future__ import annotations
4
5 import argparse
6 import sys
7 from pathlib import Path
8
9
10 PROJECT_ROOT = Path(__file__).resolve().parents[1]
11 SCHEMA_PATH = PROJECT_ROOT / "scripts" / "postgres_schema.sql"
12
13
14 def main() -> None:
15 parser = argparse.ArgumentParser(description="Initialize PostgreSQL schema for lyric dedup.")
16 parser.add_argument("--dsn", required=True, help="PostgreSQL DSN, e.g. postgresql://user:pass@localhost:5432/lyric_dedup")
17 parser.add_argument("--schema", default=str(SCHEMA_PATH))
18 args = parser.parse_args()
19
20 psycopg = _import_psycopg()
21 schema_sql = Path(args.schema).read_text(encoding="utf-8")
22 with psycopg.connect(args.dsn) as conn:
23 with conn.cursor() as cursor:
24 cursor.execute(schema_sql)
25 conn.commit()
26 print(f"initialized schema from {args.schema}")
27
28
29 def _import_psycopg():
30 try:
31 import psycopg
32
33 return psycopg
34 except ModuleNotFoundError:
35 print(
36 "Missing dependency: psycopg. Install it with:\n"
37 " python -m pip install 'psycopg[binary]'",
38 file=sys.stderr,
39 )
40 raise SystemExit(1)
41
42
43 if __name__ == "__main__":
44 main()
1 create extension if not exists pg_trgm;
2
3 create table if not exists lyrics (
4 id bigserial primary key,
5 record_id text not null unique,
6 source_path text not null,
7 title text,
8 artist text,
9 raw_text text not null,
10 normalized_text text not null,
11 primary_text text not null,
12 translation_text text,
13 exact_hash text not null,
14 split_confidence text,
15 split_reason text,
16 line_count integer not null,
17 created_at timestamptz not null default now(),
18 updated_at timestamptz not null default now(),
19 deleted_at timestamptz
20 );
21
22 create index if not exists lyrics_exact_hash_idx
23 on lyrics (exact_hash)
24 where deleted_at is null;
25
26 create index if not exists lyrics_primary_text_trgm_idx
27 on lyrics using gin (primary_text gin_trgm_ops);
28
29 create table if not exists lyric_lines (
30 lyric_id bigint not null references lyrics(id) on delete cascade,
31 role text not null,
32 line_no integer not null,
33 normalized_line text not null,
34 line_hash text not null,
35 primary key (lyric_id, role, line_no)
36 );
37
38 create index if not exists lyric_lines_hash_idx
39 on lyric_lines (line_hash);
40
41 create index if not exists lyric_lines_lyric_id_idx
42 on lyric_lines (lyric_id);
...@@ -316,6 +316,40 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None: ...@@ -316,6 +316,40 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None:
316 assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_")) 316 assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_"))
317 317
318 318
319 def test_generated_hard_eval_set_uses_business_realistic_edge_mix(tmp_path) -> None:
320 library = tmp_path / "library"
321 incoming = tmp_path / "generated" / "incoming"
322 eval_csv = tmp_path / "generated" / "eval_hard.csv"
323 library.mkdir()
324 for idx in range(24):
325 prefix = "AY" if idx % 3 == 0 else "WHHY"
326 lyric = BASE_LYRIC.replace("我爱你", f"我想你{idx}").replace("城市", f"城市{idx}")
327 if idx % 4 == 0:
328 lyric += "\nI miss you tonight\nUnder the moonlight\nNever let me go\n"
329 (library / f"{idx}_{prefix}{idx:06d}.txt").write_text(lyric, encoding="utf-8")
330
331 generate_eval_set(
332 library_dir=library,
333 output_dir=incoming,
334 csv_path=eval_csv,
335 size=40,
336 positive_ratio=0.3,
337 profile="hard",
338 )
339
340 rows = list(csv.DictReader(eval_csv.open(encoding="utf-8")))
341 manifest = json.loads((tmp_path / "generated" / "eval_hard.csv.manifest.json").read_text(encoding="utf-8"))
342 sample_types = {row["sample_type"] for row in rows}
343
344 assert len(rows) == 40
345 assert manifest["profile"] == "hard"
346 assert "positive_realistic_variant" in manifest["plan"]
347 assert "negative_near_neighbor_holdout_full_song" in manifest["plan"]
348 assert "negative_long_fragment" in sample_types
349 assert "negative_catalog_mashup" in sample_types
350 assert any(row["sample_type"].startswith("positive_") for row in rows)
351
352
319 def test_foreign_original_with_added_chinese_translation_is_duplicate() -> None: 353 def test_foreign_original_with_added_chinese_translation_is_duplicate() -> None:
320 checker = DuplicateChecker() 354 checker = DuplicateChecker()
321 checker.add_record( 355 checker.add_record(
......