新增 PostgreSQL 去重检索链路与 hard 评估集支持
- 新增 PostgreSQL 导入脚本、评估脚本和 schema 定义,支持基于 exact_hash、pg_trgm 和行级 hash 的三层召回策略 - 评估 CLI 新增 hard profile,覆盖错别字、OCR 错误、整段翻译、medley 片段等更贴近业务边界的场景 - 调整 checker.py 复核阈值与匹配理由文案,优化翻译行相似与仅副歌重复场景的判定逻辑 - 同步更新 README、TEST_WORKFLOW 和单元测试 Co-Authored-By: Claude <noreply@anthropic.com>
Showing
12 changed files
with
2017 additions
and
93 deletions
POSTGRES_MIGRATION.md
0 → 100644
| 1 | # PostgreSQL 迁移说明 | ||
| 2 | |||
| 3 | 本文档说明为什么要从 `.pkl` 索引迁移到 PostgreSQL、PostgreSQL 在这个项目里承担什么角色,以及从本地初始化到导入曲库的具体步骤。 | ||
| 4 | |||
| 5 | ## 1. 为什么迁移 | ||
| 6 | |||
| 7 | 当前 `outputs/indexes/library_lyrics.pkl` 是一个 Python `DuplicateChecker` 对象快照。它适合实验和离线评估,但不适合线上服务: | ||
| 8 | |||
| 9 | - 增量新增、删除、回滚不方便。 | ||
| 10 | - 多实例部署时很难同步同一个 `.pkl`。 | ||
| 11 | - 没有数据库事务、审计、版本管理。 | ||
| 12 | - 索引结构绑定 Python pickle,不利于运维和长期维护。 | ||
| 13 | |||
| 14 | PostgreSQL 的目标不是替代全部判重逻辑,而是替代“数据存储 + 候选召回”部分。最终判定仍保留在 Python 服务层。 | ||
| 15 | |||
| 16 | ## 2. PostgreSQL 在本项目里的角色 | ||
| 17 | |||
| 18 | 推荐分工: | ||
| 19 | |||
| 20 | ```text | ||
| 21 | PostgreSQL: | ||
| 22 | 保存原始歌词、规范化歌词、行级特征 | ||
| 23 | 用 exact_hash 做精确召回 | ||
| 24 | 用 pg_trgm 做文本近似召回 | ||
| 25 | 用 lyric_lines 做行级重合召回 | ||
| 26 | |||
| 27 | Python: | ||
| 28 | normalize_lyrics | ||
| 29 | 计算 Jaccard / line coverage | ||
| 30 | 处理翻译行、副歌碰撞、片段保护 | ||
| 31 | 最终输出 duplicate / review / new | ||
| 32 | ``` | ||
| 33 | |||
| 34 | 不要让 PostgreSQL 的 `similarity()` 直接决定 `duplicate`。数据库只负责找候选,最终判定仍应复用当前算法规则。 | ||
| 35 | |||
| 36 | ## 3. pkl 索引和 PostgreSQL 的对应关系 | ||
| 37 | |||
| 38 | 当前 `.pkl` 里主要有: | ||
| 39 | |||
| 40 | ```text | ||
| 41 | _records 每首歌词的完整索引记录 | ||
| 42 | _exact_hash_to_ids exact hash -> record ids | ||
| 43 | _line_to_ids normalized line -> record ids | ||
| 44 | _token_to_ids n-gram token -> record ids | ||
| 45 | _lsh MinHash LSH 桶 | ||
| 46 | ``` | ||
| 47 | |||
| 48 | 迁移到 PostgreSQL 后: | ||
| 49 | |||
| 50 | ```text | ||
| 51 | lyrics.exact_hash 对应 _exact_hash_to_ids | ||
| 52 | lyrics.primary_text + pg_trgm 对应近似文本召回 | ||
| 53 | lyric_lines.line_hash 对应行级倒排召回 | ||
| 54 | lyrics / lyric_lines 对应 _records 的可持久化部分 | ||
| 55 | ``` | ||
| 56 | |||
| 57 | 第一阶段不迁移 MinHash LSH。先用 `exact_hash + pg_trgm + line_hash` 验证召回效果。 | ||
| 58 | |||
| 59 | ## 4. 本地 PostgreSQL 基础 | ||
| 60 | |||
| 61 | PostgreSQL 是关系型数据库,常用概念: | ||
| 62 | |||
| 63 | ```text | ||
| 64 | database 数据库,例如 lyric_dedup | ||
| 65 | schema 命名空间,默认 public | ||
| 66 | table 表,例如 lyrics | ||
| 67 | index 索引,用来加速查询 | ||
| 68 | extension 扩展,例如 pg_trgm | ||
| 69 | DSN 连接字符串 | ||
| 70 | ``` | ||
| 71 | |||
| 72 | 本机数据库连接字符串常见写法: | ||
| 73 | |||
| 74 | ```bash | ||
| 75 | postgresql:///lyric_dedup | ||
| 76 | ``` | ||
| 77 | |||
| 78 | 含义是:使用当前系统用户名,连接本机 PostgreSQL 的 `lyric_dedup` 数据库。 | ||
| 79 | |||
| 80 | 如果有用户名密码: | ||
| 81 | |||
| 82 | ```bash | ||
| 83 | postgresql://postgres:postgres@localhost:5432/lyric_dedup | ||
| 84 | ``` | ||
| 85 | |||
| 86 | ## 5. 当前已新增脚本 | ||
| 87 | |||
| 88 | 项目里已经加入: | ||
| 89 | |||
| 90 | ```text | ||
| 91 | scripts/postgres_schema.sql | ||
| 92 | scripts/init_postgres.py | ||
| 93 | scripts/import_library_postgres.py | ||
| 94 | scripts/evaluate_postgres.py | ||
| 95 | ``` | ||
| 96 | |||
| 97 | 用途: | ||
| 98 | |||
| 99 | ```text | ||
| 100 | postgres_schema.sql 建表、建索引、启用 pg_trgm | ||
| 101 | init_postgres.py 自动执行 schema SQL | ||
| 102 | import_library_postgres.py 扫描 data/library,规范化后导入 PostgreSQL | ||
| 103 | evaluate_postgres.py 使用 PostgreSQL 召回候选并评测 CSV | ||
| 104 | ``` | ||
| 105 | |||
| 106 | ## 6. 安装依赖 | ||
| 107 | |||
| 108 | 当前 Python 环境需要安装 PostgreSQL 驱动: | ||
| 109 | |||
| 110 | ```bash | ||
| 111 | python -m pip install 'psycopg[binary]' | ||
| 112 | ``` | ||
| 113 | |||
| 114 | 如果你使用 conda 环境,确认命令运行在当前项目所用的 `(base)` 或目标环境里。 | ||
| 115 | |||
| 116 | 验证: | ||
| 117 | |||
| 118 | ```bash | ||
| 119 | python - <<'PY' | ||
| 120 | import psycopg | ||
| 121 | print(psycopg.__version__) | ||
| 122 | PY | ||
| 123 | ``` | ||
| 124 | |||
| 125 | ## 7. 创建数据库 | ||
| 126 | |||
| 127 | 你已经执行过: | ||
| 128 | |||
| 129 | ```bash | ||
| 130 | createdb lyric_dedup | ||
| 131 | ``` | ||
| 132 | |||
| 133 | 如果需要确认数据库存在: | ||
| 134 | |||
| 135 | ```bash | ||
| 136 | psql -l | grep lyric_dedup | ||
| 137 | ``` | ||
| 138 | |||
| 139 | 进入数据库: | ||
| 140 | |||
| 141 | ```bash | ||
| 142 | psql postgresql:///lyric_dedup | ||
| 143 | ``` | ||
| 144 | |||
| 145 | 退出 `psql`: | ||
| 146 | |||
| 147 | ```text | ||
| 148 | \q | ||
| 149 | ``` | ||
| 150 | |||
| 151 | ## 8. 初始化表结构 | ||
| 152 | |||
| 153 | 执行: | ||
| 154 | |||
| 155 | ```bash | ||
| 156 | python scripts/init_postgres.py \ | ||
| 157 | --dsn postgresql:///lyric_dedup | ||
| 158 | ``` | ||
| 159 | |||
| 160 | 它会执行: | ||
| 161 | |||
| 162 | ```sql | ||
| 163 | create extension if not exists pg_trgm; | ||
| 164 | create table if not exists lyrics (...); | ||
| 165 | create table if not exists lyric_lines (...); | ||
| 166 | create index if not exists ...; | ||
| 167 | ``` | ||
| 168 | |||
| 169 | 成功输出类似: | ||
| 170 | |||
| 171 | ```text | ||
| 172 | initialized schema from scripts/postgres_schema.sql | ||
| 173 | ``` | ||
| 174 | |||
| 175 | 可以检查表: | ||
| 176 | |||
| 177 | ```bash | ||
| 178 | psql postgresql:///lyric_dedup -c '\dt' | ||
| 179 | ``` | ||
| 180 | |||
| 181 | 检查扩展: | ||
| 182 | |||
| 183 | ```bash | ||
| 184 | psql postgresql:///lyric_dedup -c 'select * from pg_extension;' | ||
| 185 | ``` | ||
| 186 | |||
| 187 | ## 9. 表结构说明 | ||
| 188 | |||
| 189 | ### lyrics | ||
| 190 | |||
| 191 | 保存每首歌词的主记录: | ||
| 192 | |||
| 193 | ```text | ||
| 194 | record_id 当前文件生成的稳定 id | ||
| 195 | source_path 原始文件路径 | ||
| 196 | title / artist 从文件名解析的元数据 | ||
| 197 | raw_text 原始歌词 | ||
| 198 | normalized_text 清洗后的全文 | ||
| 199 | primary_text 原文行拼接文本,主要用于自动判重 | ||
| 200 | translation_text 翻译行拼接文本 | ||
| 201 | exact_hash 规范化原文 hash | ||
| 202 | split_confidence 翻译拆分置信度 | ||
| 203 | split_reason 翻译拆分原因 | ||
| 204 | line_count 有效歌词行数 | ||
| 205 | deleted_at 软删除字段 | ||
| 206 | ``` | ||
| 207 | |||
| 208 | ### lyric_lines | ||
| 209 | |||
| 210 | 保存行级特征: | ||
| 211 | |||
| 212 | ```text | ||
| 213 | lyric_id 对应 lyrics.id | ||
| 214 | role primary / translation / unknown | ||
| 215 | line_no 行号 | ||
| 216 | normalized_line 规范化歌词行 | ||
| 217 | line_hash 行 hash | ||
| 218 | ``` | ||
| 219 | |||
| 220 | 用途:快速找“哪些歌包含相同行”。 | ||
| 221 | |||
| 222 | ## 10. 小批量导入测试 | ||
| 223 | |||
| 224 | 先导入 1000 条,确认环境和 schema 都正常: | ||
| 225 | |||
| 226 | ```bash | ||
| 227 | python scripts/import_library_postgres.py \ | ||
| 228 | --dsn postgresql:///lyric_dedup \ | ||
| 229 | --lyrics-dir data/library \ | ||
| 230 | --limit 1000 | ||
| 231 | ``` | ||
| 232 | |||
| 233 | 导入脚本默认会在导入结束后执行一次低风险库内去重: | ||
| 234 | |||
| 235 | ```text | ||
| 236 | exact_hash 完全一致的记录只保留一条,其余记录 soft delete,即设置 lyrics.deleted_at。 | ||
| 237 | ``` | ||
| 238 | |||
| 239 | 重复清理报告默认写到: | ||
| 240 | |||
| 241 | ```text | ||
| 242 | outputs/results/postgres_exact_duplicates.csv | ||
| 243 | ``` | ||
| 244 | |||
| 245 | 如果只是想导入,不做 exact 去重: | ||
| 246 | |||
| 247 | ```bash | ||
| 248 | python scripts/import_library_postgres.py \ | ||
| 249 | --dsn postgresql:///lyric_dedup \ | ||
| 250 | --lyrics-dir data/library \ | ||
| 251 | --limit 1000 \ | ||
| 252 | --skip-dedup-exact | ||
| 253 | ``` | ||
| 254 | |||
| 255 | 查看数量: | ||
| 256 | |||
| 257 | ```bash | ||
| 258 | psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;' | ||
| 259 | psql postgresql:///lyric_dedup -c 'select count(*) from lyric_lines;' | ||
| 260 | ``` | ||
| 261 | |||
| 262 | 查看几条数据: | ||
| 263 | |||
| 264 | ```bash | ||
| 265 | psql postgresql:///lyric_dedup -c \ | ||
| 266 | 'select id, record_id, title, artist, line_count from lyrics limit 5;' | ||
| 267 | ``` | ||
| 268 | |||
| 269 | ## 11. 全量导入 | ||
| 270 | |||
| 271 | 确认小批量没问题后,导入全量: | ||
| 272 | |||
| 273 | ```bash | ||
| 274 | python scripts/import_library_postgres.py \ | ||
| 275 | --dsn postgresql:///lyric_dedup \ | ||
| 276 | --lyrics-dir data/library | ||
| 277 | ``` | ||
| 278 | |||
| 279 | 脚本会显示进度: | ||
| 280 | |||
| 281 | ```text | ||
| 282 | [pg-import] files: 70295 | ||
| 283 | [pg-import] import: 500/70295 | ||
| 284 | ... | ||
| 285 | ``` | ||
| 286 | |||
| 287 | 导入是 upsert,同一个 `record_id` 再导入会更新,不会重复插入。 | ||
| 288 | |||
| 289 | 如果想额外生成“高行覆盖率近重复候选”报告,但不自动删除: | ||
| 290 | |||
| 291 | ```bash | ||
| 292 | python scripts/import_library_postgres.py \ | ||
| 293 | --dsn postgresql:///lyric_dedup \ | ||
| 294 | --lyrics-dir data/library \ | ||
| 295 | --line-duplicate-report outputs/results/postgres_line_duplicates.csv \ | ||
| 296 | --line-coverage-threshold 0.95 | ||
| 297 | ``` | ||
| 298 | |||
| 299 | 注意:行覆盖率近重复报告可能较慢,且只用于抽查。当前脚本不会自动 soft delete 这些近重复候选。 | ||
| 300 | |||
| 301 | ## 12. 基础 SQL 验证 | ||
| 302 | |||
| 303 | ### exact hash 重复 | ||
| 304 | |||
| 305 | 找规范化 hash 重复: | ||
| 306 | |||
| 307 | ```bash | ||
| 308 | psql postgresql:///lyric_dedup -c " | ||
| 309 | select exact_hash, count(*) | ||
| 310 | from lyrics | ||
| 311 | where deleted_at is null | ||
| 312 | group by exact_hash | ||
| 313 | having count(*) > 1 | ||
| 314 | order by count(*) desc | ||
| 315 | limit 20; | ||
| 316 | " | ||
| 317 | ``` | ||
| 318 | |||
| 319 | 如果导入时没有加 `--skip-dedup-exact`,这里理论上不应该再出现 active exact 重复;已经清理的重复记录可以这样查看: | ||
| 320 | |||
| 321 | ```bash | ||
| 322 | psql postgresql:///lyric_dedup -c " | ||
| 323 | select count(*) | ||
| 324 | from lyrics | ||
| 325 | where deleted_at is not null; | ||
| 326 | " | ||
| 327 | ``` | ||
| 328 | |||
| 329 | ### pg_trgm 相似查询 | ||
| 330 | |||
| 331 | 测试 `pg_trgm`: | ||
| 332 | |||
| 333 | ```bash | ||
| 334 | psql postgresql:///lyric_dedup -c " | ||
| 335 | select id, title, similarity(primary_text, '我爱你在每个夜里') as sim | ||
| 336 | from lyrics | ||
| 337 | where primary_text % '我爱你在每个夜里' | ||
| 338 | order by sim desc | ||
| 339 | limit 10; | ||
| 340 | " | ||
| 341 | ``` | ||
| 342 | |||
| 343 | ### 行级重合 | ||
| 344 | |||
| 345 | 找某一行出现在哪些歌: | ||
| 346 | |||
| 347 | ```bash | ||
| 348 | psql postgresql:///lyric_dedup -c " | ||
| 349 | select l.id, l.title, ll.normalized_line | ||
| 350 | from lyric_lines ll | ||
| 351 | join lyrics l on l.id = ll.lyric_id | ||
| 352 | where ll.normalized_line = '我爱你在每个夜里' | ||
| 353 | limit 20; | ||
| 354 | " | ||
| 355 | ``` | ||
| 356 | |||
| 357 | ## 13. 后续查重查询应该怎么做 | ||
| 358 | |||
| 359 | 未来 PostgreSQL 版查重流程: | ||
| 360 | |||
| 361 | ```text | ||
| 362 | 1. Python 读取新增歌词 | ||
| 363 | 2. normalize_lyrics | ||
| 364 | 3. SQL exact_hash 召回 | ||
| 365 | 4. SQL pg_trgm 召回 | ||
| 366 | 5. SQL lyric_lines 行级召回 | ||
| 367 | 6. 合并候选 id | ||
| 368 | 7. 拉候选 normalized 数据 | ||
| 369 | 8. Python 复用当前打分规则 | ||
| 370 | 9. 输出 duplicate / review / new | ||
| 371 | ``` | ||
| 372 | |||
| 373 | 示意 SQL: | ||
| 374 | |||
| 375 | ```sql | ||
| 376 | select id | ||
| 377 | from lyrics | ||
| 378 | where exact_hash = $1 | ||
| 379 | and deleted_at is null; | ||
| 380 | ``` | ||
| 381 | |||
| 382 | ```sql | ||
| 383 | select id, similarity(primary_text, $1) as sim | ||
| 384 | from lyrics | ||
| 385 | where deleted_at is null | ||
| 386 | and primary_text % $1 | ||
| 387 | order by sim desc | ||
| 388 | limit 200; | ||
| 389 | ``` | ||
| 390 | |||
| 391 | ```sql | ||
| 392 | select lyric_id, count(*) as matched_lines | ||
| 393 | from lyric_lines | ||
| 394 | where role = 'primary' | ||
| 395 | and line_hash = any($1) | ||
| 396 | group by lyric_id | ||
| 397 | order by matched_lines desc | ||
| 398 | limit 200; | ||
| 399 | ``` | ||
| 400 | |||
| 401 | ## 14. 增量更新设计 | ||
| 402 | |||
| 403 | 新增一首歌: | ||
| 404 | |||
| 405 | ```text | ||
| 406 | 1. normalize | ||
| 407 | 2. 用 PostgreSQL 召回候选 | ||
| 408 | 3. Python 判定 | ||
| 409 | 4. duplicate: 拒绝或关联已有记录 | ||
| 410 | 5. review: 进入人工复核 | ||
| 411 | 6. new: 写入 lyrics 和 lyric_lines | ||
| 412 | ``` | ||
| 413 | |||
| 414 | 删除一首歌: | ||
| 415 | |||
| 416 | ```sql | ||
| 417 | update lyrics | ||
| 418 | set deleted_at = now(), updated_at = now() | ||
| 419 | where id = ...; | ||
| 420 | ``` | ||
| 421 | |||
| 422 | 不建议物理删除,除非确认不需要审计。 | ||
| 423 | |||
| 424 | 更新一首歌: | ||
| 425 | |||
| 426 | ```text | ||
| 427 | 1. 更新 lyrics.raw_text / normalized_text / primary_text / exact_hash | ||
| 428 | 2. 删除旧 lyric_lines | ||
| 429 | 3. 插入新 lyric_lines | ||
| 430 | 4. 整个过程放在一个事务里 | ||
| 431 | ``` | ||
| 432 | |||
| 433 | ## 15. PostgreSQL 版评测 | ||
| 434 | |||
| 435 | 评测仍然需要先生成测试集。测试集是“输入样本 + 期望标签”,PostgreSQL 版评测只负责用 PostgreSQL 数据库召回候选并计算指标。 | ||
| 436 | |||
| 437 | 如果还没有测试集,先生成: | ||
| 438 | |||
| 439 | ```bash | ||
| 440 | python -m lyric_dedup.cli generate-eval-set \ | ||
| 441 | --library-dir data/library \ | ||
| 442 | --lyrics-dir data/generated_eval/incoming \ | ||
| 443 | --csv data/generated_eval/eval_5000.csv \ | ||
| 444 | --size 5000 \ | ||
| 445 | --positive-ratio 0.3 | ||
| 446 | ``` | ||
| 447 | |||
| 448 | 然后跑 PostgreSQL 版评测: | ||
| 449 | |||
| 450 | ```bash | ||
| 451 | python scripts/evaluate_postgres.py \ | ||
| 452 | --dsn postgresql:///lyric_dedup \ | ||
| 453 | --csv data/generated_eval/eval_5000.csv \ | ||
| 454 | --base-dir data/generated_eval \ | ||
| 455 | --out outputs/results/postgres_eval_5000.csv | ||
| 456 | ``` | ||
| 457 | |||
| 458 | 它会: | ||
| 459 | |||
| 460 | ```text | ||
| 461 | 1. 对 eval 样本 normalize | ||
| 462 | 2. 用 PostgreSQL exact_hash 召回 | ||
| 463 | 3. 用 pg_trgm primary_text 召回 | ||
| 464 | 4. 用 lyric_lines.line_hash 召回 | ||
| 465 | 5. 合并候选 | ||
| 466 | 6. 用 Python DuplicateChecker 对候选重新打分 | ||
| 467 | 7. 输出 duplicate / review / new 和指标 | ||
| 468 | ``` | ||
| 469 | |||
| 470 | 如果想把 `review` 也算作“抓到可疑样本”: | ||
| 471 | |||
| 472 | ```bash | ||
| 473 | python scripts/evaluate_postgres.py \ | ||
| 474 | --dsn postgresql:///lyric_dedup \ | ||
| 475 | --csv data/generated_eval/eval_50000.csv \ | ||
| 476 | --base-dir data/generated_eval \ | ||
| 477 | --positive-decisions duplicate,review \ | ||
| 478 | --out outputs/results/postgres_eval_50000_review_positive.csv | ||
| 479 | ``` | ||
| 480 | |||
| 481 | 可调参数: | ||
| 482 | |||
| 483 | ```text | ||
| 484 | --recall-limit 每类 SQL 召回最多取多少候选,默认 100 | ||
| 485 | --enable-trgm 打开 pg_trgm 整段文本召回;默认关闭,避免评测过慢 | ||
| 486 | --trgm-threshold pg_trgm 的 % 匹配阈值,默认 0.3,仅 --enable-trgm 时使用 | ||
| 487 | --max-candidates 最终输出多少候选,默认 5 | ||
| 488 | --statement-timeout-ms 单条 SQL 超时时间,默认 5000 | ||
| 489 | ``` | ||
| 490 | |||
| 491 | 注意:当前 PostgreSQL 版是原型评测脚本。默认只用 `exact_hash + lyric_lines.line_hash` 召回,速度更可控。`pg_trgm` 可以作为补充召回,但整段歌词 trigram 查询在 5 万评测集上可能很慢,建议单独开小样本验证后再用于全量。 | ||
| 492 | |||
| 493 | ## 16. 迁移验证标准 | ||
| 494 | |||
| 495 | 迁移不是导入完就结束。需要单独验证 PostgreSQL 版查重链路: | ||
| 496 | |||
| 497 | ```text | ||
| 498 | 1. exact duplicate 是否能查到 | ||
| 499 | 2. punctuation / timestamp / platform noise 正例是否能召回 | ||
| 500 | 3. fragment / shared chorus 负例是否不会被直接判 duplicate | ||
| 501 | 4. PostgreSQL 召回候选数量是否合理 | ||
| 502 | 5. PostgreSQL 版 evaluate 指标是否达到业务要求 | ||
| 503 | ``` | ||
| 504 | |||
| 505 | 第一阶段目标: | ||
| 506 | |||
| 507 | ```text | ||
| 508 | PostgreSQL 负责召回,Python 仍负责判定。 | ||
| 509 | ``` | ||
| 510 | |||
| 511 | ## 17. 常见问题 | ||
| 512 | |||
| 513 | ### 提示 `Missing dependency: psycopg` | ||
| 514 | |||
| 515 | 运行: | ||
| 516 | |||
| 517 | ```bash | ||
| 518 | python -m pip install 'psycopg[binary]' | ||
| 519 | ``` | ||
| 520 | |||
| 521 | ### 连接失败 | ||
| 522 | |||
| 523 | 检查 PostgreSQL 是否启动: | ||
| 524 | |||
| 525 | ```bash | ||
| 526 | pg_isready | ||
| 527 | ``` | ||
| 528 | |||
| 529 | 检查数据库是否存在: | ||
| 530 | |||
| 531 | ```bash | ||
| 532 | psql -l | grep lyric_dedup | ||
| 533 | ``` | ||
| 534 | |||
| 535 | ### `pg_trgm` 创建失败 | ||
| 536 | |||
| 537 | 确认连接用户有创建 extension 权限。本机默认用户一般可以。 | ||
| 538 | |||
| 539 | 手动测试: | ||
| 540 | |||
| 541 | ```bash | ||
| 542 | psql postgresql:///lyric_dedup -c 'create extension if not exists pg_trgm;' | ||
| 543 | ``` | ||
| 544 | |||
| 545 | ### 想清空重新导入 | ||
| 546 | |||
| 547 | 谨慎执行: | ||
| 548 | |||
| 549 | ```bash | ||
| 550 | psql postgresql:///lyric_dedup -c 'truncate lyric_lines, lyrics restart identity cascade;' | ||
| 551 | ``` | ||
| 552 | |||
| 553 | 然后重新运行导入脚本。 | ||
| 554 | |||
| 555 | ## 18. 当前建议执行顺序 | ||
| 556 | |||
| 557 | 你现在已经完成: | ||
| 558 | |||
| 559 | ```bash | ||
| 560 | createdb lyric_dedup | ||
| 561 | ``` | ||
| 562 | |||
| 563 | 接下来执行: | ||
| 564 | |||
| 565 | ```bash | ||
| 566 | python -m pip install 'psycopg[binary]' | ||
| 567 | ``` | ||
| 568 | |||
| 569 | ```bash | ||
| 570 | python scripts/init_postgres.py \ | ||
| 571 | --dsn postgresql:///lyric_dedup | ||
| 572 | ``` | ||
| 573 | |||
| 574 | ```bash | ||
| 575 | python scripts/import_library_postgres.py \ | ||
| 576 | --dsn postgresql:///lyric_dedup \ | ||
| 577 | --lyrics-dir data/library \ | ||
| 578 | --limit 1000 | ||
| 579 | ``` | ||
| 580 | |||
| 581 | 确认数量: | ||
| 582 | |||
| 583 | ```bash | ||
| 584 | psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;' | ||
| 585 | ``` | ||
| 586 | |||
| 587 | 确认后全量导入: | ||
| 588 | |||
| 589 | ```bash | ||
| 590 | python scripts/import_library_postgres.py \ | ||
| 591 | --dsn postgresql:///lyric_dedup \ | ||
| 592 | --lyrics-dir data/library | ||
| 593 | ``` |
| ... | @@ -85,15 +85,33 @@ python -m lyric_dedup.cli generate-eval-set \ | ... | @@ -85,15 +85,33 @@ python -m lyric_dedup.cli generate-eval-set \ |
| 85 | --positive-ratio 0.3 | 85 | --positive-ratio 0.3 |
| 86 | ``` | 86 | ``` |
| 87 | 87 | ||
| 88 | 生成器的业务口径: | 88 | 默认 `--profile standard` 生成常规生产评估集。也可以生成更贴近业务边界的 hard 集: |
| 89 | |||
| 90 | ```bash | ||
| 91 | python -m lyric_dedup.cli generate-eval-set \ | ||
| 92 | --profile hard \ | ||
| 93 | --library-dir data/library \ | ||
| 94 | --lyrics-dir data/generated_eval/hard_incoming \ | ||
| 95 | --csv data/generated_eval/eval_hard_5000.csv \ | ||
| 96 | --eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \ | ||
| 97 | --size 5000 \ | ||
| 98 | --positive-ratio 0.3 | ||
| 99 | ``` | ||
| 100 | |||
| 101 | standard 业务口径: | ||
| 89 | 102 | ||
| 90 | - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。 | 103 | - 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。 |
| 91 | - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。 | 104 | - `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译、少量错别字/英文拼写错误。 |
| 92 | - `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。 | 105 | - `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。 |
| 93 | - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`。 | 106 | - 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`。 |
| 94 | - 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它。 | 107 | - 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它。 |
| 95 | - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。 | 108 | - 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。 |
| 96 | 109 | ||
| 110 | hard 业务口径不故意制造反常输入,主要覆盖上线更容易踩边界的情况: | ||
| 111 | |||
| 112 | - `应去重`: 同曲平台版本噪声、较完整歌词缺少一段、整段中文翻译附加、较真实的录入/OCR 错别字、时间戳和平台元信息混合。 | ||
| 113 | - `不应去重`: 真实 holdout 新歌、从 holdout 中优先挑选和曲库有行重合的近邻新歌、较长但不完整的单曲片段、多曲 medley/串烧式片段、重复副歌碰撞、仅翻译相似、短歌词边界。 | ||
| 114 | |||
| 97 | 先准备一个 CSV,例如 `data/eval/eval.csv`: | 115 | 先准备一个 CSV,例如 `data/eval/eval.csv`: |
| 98 | 116 | ||
| 99 | ```csv | 117 | ```csv | ... | ... |
| ... | @@ -108,6 +108,20 @@ python -m lyric_dedup.cli generate-eval-set \ | ... | @@ -108,6 +108,20 @@ python -m lyric_dedup.cli generate-eval-set \ |
| 108 | --positive-ratio 0.3 | 108 | --positive-ratio 0.3 |
| 109 | ``` | 109 | ``` |
| 110 | 110 | ||
| 111 | 如需生成更贴近业务边界的 hard 口径测试集: | ||
| 112 | |||
| 113 | ```bash | ||
| 114 | python -m lyric_dedup.cli generate-eval-set \ | ||
| 115 | --profile hard \ | ||
| 116 | --library-dir data/library \ | ||
| 117 | --lyrics-dir data/generated_eval/hard_incoming \ | ||
| 118 | --csv data/generated_eval/eval_hard_5000.csv \ | ||
| 119 | --index outputs/indexes/library_lyrics.pkl \ | ||
| 120 | --eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \ | ||
| 121 | --size 5000 \ | ||
| 122 | --positive-ratio 0.3 | ||
| 123 | ``` | ||
| 124 | |||
| 111 | 默认生产评估口径: | 125 | 默认生产评估口径: |
| 112 | 126 | ||
| 113 | ```text | 127 | ```text |
| ... | @@ -120,7 +134,7 @@ python -m lyric_dedup.cli generate-eval-set \ | ... | @@ -120,7 +134,7 @@ python -m lyric_dedup.cli generate-eval-set \ |
| 120 | 业务口径: | 134 | 业务口径: |
| 121 | 135 | ||
| 122 | ```text | 136 | ```text |
| 123 | positive_* = 应去重,全曲歌词样式变化 | 137 | positive_* = 应去重,全曲歌词样式变化,包括少量错别字/英文拼写错误扰动 |
| 124 | negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除 | 138 | negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除 |
| 125 | negative_fragment = 不应去重,单曲片段 | 139 | negative_fragment = 不应去重,单曲片段 |
| 126 | negative_shared_chorus = 不应去重,重复副歌碰撞 | 140 | negative_shared_chorus = 不应去重,重复副歌碰撞 |
| ... | @@ -129,6 +143,15 @@ negative_same_theme_synthetic = 不应去重,同主题新歌词 | ... | @@ -129,6 +143,15 @@ negative_same_theme_synthetic = 不应去重,同主题新歌词 |
| 129 | edge_short_or_placeholder = 不应去重,短歌词/占位边界样本 | 143 | edge_short_or_placeholder = 不应去重,短歌词/占位边界样本 |
| 130 | ``` | 144 | ``` |
| 131 | 145 | ||
| 146 | hard 口径额外强调真实业务边界,而不是故意制造反常难题: | ||
| 147 | |||
| 148 | ```text | ||
| 149 | positive_realistic_variant = 应去重,同曲平台版本噪声、较完整缺段、整段翻译附加、真实录入/OCR 错 | ||
| 150 | negative_near_neighbor_holdout_full_song = 不应去重,和曲库有较多行重合的真实 holdout 新歌 | ||
| 151 | negative_long_fragment = 不应去重,较长但不完整的单曲片段 | ||
| 152 | negative_catalog_mashup = 不应去重,多首真实歌词片段组成的串烧/混剪式输入 | ||
| 153 | ``` | ||
| 154 | |||
| 132 | 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出: | 155 | 生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出: |
| 133 | 156 | ||
| 134 | ```text | 157 | ```text | ... | ... |
| ... | @@ -5,7 +5,7 @@ from __future__ import annotations | ... | @@ -5,7 +5,7 @@ from __future__ import annotations |
| 5 | import hashlib | 5 | import hashlib |
| 6 | import pickle | 6 | import pickle |
| 7 | from dataclasses import dataclass | 7 | from dataclasses import dataclass |
| 8 | from enum import StrEnum | 8 | from enum import Enum |
| 9 | from pathlib import Path | 9 | from pathlib import Path |
| 10 | 10 | ||
| 11 | from lyric_dedup.minhash_lsh import MinHashConfig | 11 | from lyric_dedup.minhash_lsh import MinHashConfig |
| ... | @@ -16,7 +16,7 @@ from lyric_dedup.normalization import lyric_tokens | ... | @@ -16,7 +16,7 @@ from lyric_dedup.normalization import lyric_tokens |
| 16 | from lyric_dedup.normalization import normalize_lyrics | 16 | from lyric_dedup.normalization import normalize_lyrics |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | class DuplicateDecision(StrEnum): | 19 | class DuplicateDecision(str, Enum): |
| 20 | DUPLICATE = "duplicate" | 20 | DUPLICATE = "duplicate" |
| 21 | REVIEW = "review" | 21 | REVIEW = "review" |
| 22 | NEW = "new" | 22 | NEW = "new" | ... | ... |
| ... | @@ -53,6 +53,12 @@ def main() -> None: | ... | @@ -53,6 +53,12 @@ def main() -> None: |
| 53 | generate.add_argument("--seed", type=int, default=20260602) | 53 | generate.add_argument("--seed", type=int, default=20260602) |
| 54 | generate.add_argument("--index", default="", help="optional source index path recorded in the manifest") | 54 | generate.add_argument("--index", default="", help="optional source index path recorded in the manifest") |
| 55 | generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set") | 55 | generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set") |
| 56 | generate.add_argument( | ||
| 57 | "--profile", | ||
| 58 | choices=("standard", "hard"), | ||
| 59 | default="standard", | ||
| 60 | help="evaluation sample profile: standard production mix or harder business-realistic edge mix", | ||
| 61 | ) | ||
| 56 | 62 | ||
| 57 | args = parser.parse_args() | 63 | args = parser.parse_args() |
| 58 | if args.command == "build-index": | 64 | if args.command == "build-index": |
| ... | @@ -80,6 +86,7 @@ def main() -> None: | ... | @@ -80,6 +86,7 @@ def main() -> None: |
| 80 | seed=args.seed, | 86 | seed=args.seed, |
| 81 | index_path=Path(args.index) if args.index else None, | 87 | index_path=Path(args.index) if args.index else None, |
| 82 | eval_index_path=Path(args.eval_index) if args.eval_index else None, | 88 | eval_index_path=Path(args.eval_index) if args.eval_index else None, |
| 89 | profile=args.profile, | ||
| 83 | ) | 90 | ) |
| 84 | print(json.dumps(summary, ensure_ascii=False)) | 91 | print(json.dumps(summary, ensure_ascii=False)) |
| 85 | 92 | ... | ... |
| ... | @@ -21,7 +21,7 @@ from lyric_dedup.normalization import fingerprint_text | ... | @@ -21,7 +21,7 @@ from lyric_dedup.normalization import fingerprint_text |
| 21 | from lyric_dedup.normalization import normalize_lyrics | 21 | from lyric_dedup.normalization import normalize_lyrics |
| 22 | 22 | ||
| 23 | 23 | ||
| 24 | DEFAULT_SAMPLE_MIX = { | 24 | STANDARD_SAMPLE_MIX = { |
| 25 | "positive_full_duplicate": 0.30, | 25 | "positive_full_duplicate": 0.30, |
| 26 | "negative_real_holdout_full_song": 0.40, | 26 | "negative_real_holdout_full_song": 0.40, |
| 27 | "negative_fragment": 0.10, | 27 | "negative_fragment": 0.10, |
| ... | @@ -30,6 +30,18 @@ DEFAULT_SAMPLE_MIX = { | ... | @@ -30,6 +30,18 @@ DEFAULT_SAMPLE_MIX = { |
| 30 | "negative_same_theme_synthetic": 0.05, | 30 | "negative_same_theme_synthetic": 0.05, |
| 31 | "edge_short_or_placeholder": 0.05, | 31 | "edge_short_or_placeholder": 0.05, |
| 32 | } | 32 | } |
| 33 | DEFAULT_SAMPLE_MIX = STANDARD_SAMPLE_MIX | ||
| 34 | |||
| 35 | HARD_SAMPLE_MIX = { | ||
| 36 | "positive_realistic_variant": 0.30, | ||
| 37 | "negative_real_holdout_full_song": 0.20, | ||
| 38 | "negative_near_neighbor_holdout_full_song": 0.20, | ||
| 39 | "negative_long_fragment": 0.15, | ||
| 40 | "negative_shared_chorus": 0.05, | ||
| 41 | "negative_translation_only": 0.04, | ||
| 42 | "negative_catalog_mashup": 0.04, | ||
| 43 | "edge_short_or_placeholder": 0.02, | ||
| 44 | } | ||
| 33 | 45 | ||
| 34 | 46 | ||
| 35 | def _progress(message: str) -> None: | 47 | def _progress(message: str) -> None: |
| ... | @@ -87,6 +99,7 @@ def generate_eval_set( | ... | @@ -87,6 +99,7 @@ def generate_eval_set( |
| 87 | seed: int = 20260602, | 99 | seed: int = 20260602, |
| 88 | index_path: Path | None = None, | 100 | index_path: Path | None = None, |
| 89 | eval_index_path: Path | None = None, | 101 | eval_index_path: Path | None = None, |
| 102 | profile: str = "standard", | ||
| 90 | ) -> dict[str, object]: | 103 | ) -> dict[str, object]: |
| 91 | """Generate a stratified production evaluation set. | 104 | """Generate a stratified production evaluation set. |
| 92 | 105 | ||
| ... | @@ -96,7 +109,10 @@ def generate_eval_set( | ... | @@ -96,7 +109,10 @@ def generate_eval_set( |
| 96 | if size <= 0: | 109 | if size <= 0: |
| 97 | raise ValueError("size must be positive") | 110 | raise ValueError("size must be positive") |
| 98 | 111 | ||
| 99 | _progress(f"start generation: size={size}, positive_ratio={positive_ratio}, seed={seed}") | 112 | if profile not in {"standard", "hard"}: |
| 113 | raise ValueError("profile must be 'standard' or 'hard'") | ||
| 114 | |||
| 115 | _progress(f"start generation: profile={profile}, size={size}, positive_ratio={positive_ratio}, seed={seed}") | ||
| 100 | rng = random.Random(seed) | 116 | rng = random.Random(seed) |
| 101 | profiles = profile_library(library_dir) | 117 | profiles = profile_library(library_dir) |
| 102 | if not profiles: | 118 | if not profiles: |
| ... | @@ -107,9 +123,9 @@ def generate_eval_set( | ... | @@ -107,9 +123,9 @@ def generate_eval_set( |
| 107 | _progress(f"clean output dir: {output_dir}") | 123 | _progress(f"clean output dir: {output_dir}") |
| 108 | _clean_generated_output_dir(output_dir) | 124 | _clean_generated_output_dir(output_dir) |
| 109 | 125 | ||
| 110 | plan = _sample_plan(size, positive_ratio=positive_ratio) | 126 | plan = _sample_plan(size, positive_ratio=positive_ratio, profile=profile) |
| 111 | _progress(f"sample plan: {plan}") | 127 | _progress(f"sample plan: {plan}") |
| 112 | holdout_count = min(plan["negative_real_holdout_full_song"], max(1, len(profiles) // 2)) | 128 | holdout_count = min(_holdout_plan_count(plan), max(1, len(profiles) // 2)) |
| 113 | holdout_profiles = _stratified_unique_sample( | 129 | holdout_profiles = _stratified_unique_sample( |
| 114 | profiles, | 130 | profiles, |
| 115 | holdout_count, | 131 | holdout_count, |
| ... | @@ -122,81 +138,95 @@ def generate_eval_set( | ... | @@ -122,81 +138,95 @@ def generate_eval_set( |
| 122 | groups = _profile_groups(indexed_profiles) | 138 | groups = _profile_groups(indexed_profiles) |
| 123 | samples: list[GeneratedSample] = [] | 139 | samples: list[GeneratedSample] = [] |
| 124 | 140 | ||
| 125 | _progress("build positive_full_duplicate samples") | 141 | if profile == "hard": |
| 126 | samples.extend( | 142 | samples.extend( |
| 127 | _build_positive_samples( | 143 | _build_hard_samples( |
| 128 | _stratified_sample(groups["normal"], plan["positive_full_duplicate"], rng), | 144 | plan, |
| 129 | output_dir, | 145 | groups=groups, |
| 130 | csv_path.parent, | 146 | holdout_profiles=holdout_profiles, |
| 131 | rng, | 147 | indexed_profiles=indexed_profiles, |
| 132 | start_index=len(samples) + 1, | 148 | output_dir=output_dir, |
| 149 | csv_base=csv_path.parent, | ||
| 150 | rng=rng, | ||
| 151 | start_index=len(samples) + 1, | ||
| 152 | ) | ||
| 133 | ) | 153 | ) |
| 134 | ) | 154 | else: |
| 135 | _progress(f"built samples: {len(samples)}/{size}") | 155 | _progress("build positive_full_duplicate samples") |
| 136 | _progress("build negative_real_holdout_full_song samples") | 156 | samples.extend( |
| 137 | samples.extend( | 157 | _build_positive_samples( |
| 138 | _build_holdout_full_song_samples( | 158 | _stratified_sample(groups["normal"], plan["positive_full_duplicate"], rng), |
| 139 | holdout_profiles, | 159 | output_dir, |
| 140 | output_dir, | 160 | csv_path.parent, |
| 141 | csv_path.parent, | 161 | rng, |
| 142 | start_index=len(samples) + 1, | 162 | start_index=len(samples) + 1, |
| 163 | ) | ||
| 143 | ) | 164 | ) |
| 144 | ) | 165 | _progress(f"built samples: {len(samples)}/{size}") |
| 145 | _progress(f"built samples: {len(samples)}/{size}") | 166 | _progress("build negative_real_holdout_full_song samples") |
| 146 | _progress("build negative_fragment samples") | 167 | samples.extend( |
| 147 | samples.extend( | 168 | _build_holdout_full_song_samples( |
| 148 | _build_fragment_samples( | 169 | holdout_profiles[: plan["negative_real_holdout_full_song"]], |
| 149 | _stratified_sample(groups["fragmentable"], plan["negative_fragment"], rng), | 170 | output_dir, |
| 150 | output_dir, | 171 | csv_path.parent, |
| 151 | csv_path.parent, | 172 | start_index=len(samples) + 1, |
| 152 | rng, | 173 | ) |
| 153 | start_index=len(samples) + 1, | ||
| 154 | ) | 174 | ) |
| 155 | ) | 175 | _progress(f"built samples: {len(samples)}/{size}") |
| 156 | _progress(f"built samples: {len(samples)}/{size}") | 176 | _progress("build negative_fragment samples") |
| 157 | _progress("build negative_shared_chorus samples") | 177 | samples.extend( |
| 158 | samples.extend( | 178 | _build_fragment_samples( |
| 159 | _build_shared_chorus_samples( | 179 | _stratified_sample(groups["fragmentable"], plan["negative_fragment"], rng), |
| 160 | _stratified_sample(groups["normal"], plan["negative_shared_chorus"], rng), | 180 | output_dir, |
| 161 | output_dir, | 181 | csv_path.parent, |
| 162 | csv_path.parent, | 182 | rng, |
| 163 | rng, | 183 | start_index=len(samples) + 1, |
| 164 | start_index=len(samples) + 1, | 184 | ) |
| 165 | ) | 185 | ) |
| 166 | ) | 186 | _progress(f"built samples: {len(samples)}/{size}") |
| 167 | _progress(f"built samples: {len(samples)}/{size}") | 187 | _progress("build negative_shared_chorus samples") |
| 168 | _progress("build negative_translation_only samples") | 188 | samples.extend( |
| 169 | samples.extend( | 189 | _build_shared_chorus_samples( |
| 170 | _build_translation_only_samples( | 190 | _stratified_sample(groups["normal"], plan["negative_shared_chorus"], rng), |
| 171 | _stratified_sample(groups["foreign"], plan["negative_translation_only"], rng), | 191 | output_dir, |
| 172 | output_dir, | 192 | csv_path.parent, |
| 173 | csv_path.parent, | 193 | rng, |
| 174 | rng, | 194 | start_index=len(samples) + 1, |
| 175 | start_index=len(samples) + 1, | 195 | ) |
| 176 | ) | 196 | ) |
| 177 | ) | 197 | _progress(f"built samples: {len(samples)}/{size}") |
| 178 | _progress(f"built samples: {len(samples)}/{size}") | 198 | _progress("build negative_translation_only samples") |
| 179 | _progress("build negative_same_theme_synthetic samples") | 199 | samples.extend( |
| 180 | samples.extend( | 200 | _build_translation_only_samples( |
| 181 | _build_same_theme_synthetic_samples( | 201 | _stratified_sample(groups["foreign"], plan["negative_translation_only"], rng), |
| 182 | plan["negative_same_theme_synthetic"], | 202 | output_dir, |
| 183 | output_dir, | 203 | csv_path.parent, |
| 184 | csv_path.parent, | 204 | rng, |
| 185 | rng, | 205 | start_index=len(samples) + 1, |
| 186 | start_index=len(samples) + 1, | 206 | ) |
| 187 | ) | 207 | ) |
| 188 | ) | 208 | _progress(f"built samples: {len(samples)}/{size}") |
| 189 | _progress(f"built samples: {len(samples)}/{size}") | 209 | _progress("build negative_same_theme_synthetic samples") |
| 190 | _progress("build edge_short_or_placeholder samples") | 210 | samples.extend( |
| 191 | samples.extend( | 211 | _build_same_theme_synthetic_samples( |
| 192 | _build_edge_samples( | 212 | plan["negative_same_theme_synthetic"], |
| 193 | _stratified_sample(groups["edge"], plan["edge_short_or_placeholder"], rng), | 213 | output_dir, |
| 194 | output_dir, | 214 | csv_path.parent, |
| 195 | csv_path.parent, | 215 | rng, |
| 196 | rng, | 216 | start_index=len(samples) + 1, |
| 197 | start_index=len(samples) + 1, | 217 | ) |
| 218 | ) | ||
| 219 | _progress(f"built samples: {len(samples)}/{size}") | ||
| 220 | _progress("build edge_short_or_placeholder samples") | ||
| 221 | samples.extend( | ||
| 222 | _build_edge_samples( | ||
| 223 | _stratified_sample(groups["edge"], plan["edge_short_or_placeholder"], rng), | ||
| 224 | output_dir, | ||
| 225 | csv_path.parent, | ||
| 226 | rng, | ||
| 227 | start_index=len(samples) + 1, | ||
| 228 | ) | ||
| 198 | ) | 229 | ) |
| 199 | ) | ||
| 200 | _progress(f"built samples: {len(samples)}/{size}") | 230 | _progress(f"built samples: {len(samples)}/{size}") |
| 201 | 231 | ||
| 202 | if len(samples) < size: | 232 | if len(samples) < size: |
| ... | @@ -226,6 +256,7 @@ def generate_eval_set( | ... | @@ -226,6 +256,7 @@ def generate_eval_set( |
| 226 | index_path=index_path, | 256 | index_path=index_path, |
| 227 | eval_index_path=eval_index_path, | 257 | eval_index_path=eval_index_path, |
| 228 | holdout_count=len(holdout_profiles), | 258 | holdout_count=len(holdout_profiles), |
| 259 | profile=profile, | ||
| 229 | ) | 260 | ) |
| 230 | _progress("generation complete") | 261 | _progress("generation complete") |
| 231 | return manifest | 262 | return manifest |
| ... | @@ -264,14 +295,16 @@ def profile_library(library_dir: Path) -> list[LyricProfile]: | ... | @@ -264,14 +295,16 @@ def profile_library(library_dir: Path) -> list[LyricProfile]: |
| 264 | return profiles | 295 | return profiles |
| 265 | 296 | ||
| 266 | 297 | ||
| 267 | def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]: | 298 | def _sample_plan(size: int, *, positive_ratio: float, profile: str) -> dict[str, int]: |
| 268 | positive_ratio = max(0.0, min(1.0, positive_ratio)) | 299 | positive_ratio = max(0.0, min(1.0, positive_ratio)) |
| 269 | mix = dict(DEFAULT_SAMPLE_MIX) | 300 | mix = dict(HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX) |
| 270 | negative_total = sum(value for key, value in mix.items() if key != "positive_full_duplicate") | 301 | positive_key = "positive_realistic_variant" if profile == "hard" else "positive_full_duplicate" |
| 271 | mix["positive_full_duplicate"] = positive_ratio | 302 | negative_total = sum(value for key, value in mix.items() if key != positive_key) |
| 303 | mix[positive_key] = positive_ratio | ||
| 272 | for key in list(mix): | 304 | for key in list(mix): |
| 273 | if key != "positive_full_duplicate": | 305 | if key != positive_key: |
| 274 | mix[key] = (1.0 - positive_ratio) * (DEFAULT_SAMPLE_MIX[key] / negative_total) | 306 | base_mix = HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX |
| 307 | mix[key] = (1.0 - positive_ratio) * (base_mix[key] / negative_total) | ||
| 275 | 308 | ||
| 276 | plan = {key: int(size * value) for key, value in mix.items()} | 309 | plan = {key: int(size * value) for key, value in mix.items()} |
| 277 | remainder = size - sum(plan.values()) | 310 | remainder = size - sum(plan.values()) |
| ... | @@ -283,6 +316,10 @@ def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]: | ... | @@ -283,6 +316,10 @@ def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]: |
| 283 | return plan | 316 | return plan |
| 284 | 317 | ||
| 285 | 318 | ||
| 319 | def _holdout_plan_count(plan: dict[str, int]) -> int: | ||
| 320 | return plan.get("negative_real_holdout_full_song", 0) + plan.get("negative_near_neighbor_holdout_full_song", 0) | ||
| 321 | |||
| 322 | |||
| 286 | def _profile_groups(profiles: list[LyricProfile]) -> dict[str, list[LyricProfile]]: | 323 | def _profile_groups(profiles: list[LyricProfile]) -> dict[str, list[LyricProfile]]: |
| 287 | normal = [profile for profile in profiles if profile.line_count >= 6] | 324 | normal = [profile for profile in profiles if profile.line_count >= 6] |
| 288 | edge = [profile for profile in profiles if profile.line_count <= 5] | 325 | edge = [profile for profile in profiles if profile.line_count <= 5] |
| ... | @@ -375,6 +412,7 @@ def _build_positive_samples( | ... | @@ -375,6 +412,7 @@ def _build_positive_samples( |
| 375 | ("positive_blank_line_noise", _add_blank_line_noise(lines)), | 412 | ("positive_blank_line_noise", _add_blank_line_noise(lines)), |
| 376 | ("positive_chorus_count_changed", _change_repeated_line_counts(lines)), | 413 | ("positive_chorus_count_changed", _change_repeated_line_counts(lines)), |
| 377 | ("positive_translation_added", _translation_added(lines)), | 414 | ("positive_translation_added", _translation_added(lines)), |
| 415 | ("positive_typo_noise", _add_typo_noise(lines, rng)), | ||
| 378 | ] | 416 | ] |
| 379 | sample_type, text = variants[offset % len(variants)] | 417 | sample_type, text = variants[offset % len(variants)] |
| 380 | index = start_index + offset | 418 | index = start_index + offset |
| ... | @@ -384,31 +422,181 @@ def _build_positive_samples( | ... | @@ -384,31 +422,181 @@ def _build_positive_samples( |
| 384 | return samples | 422 | return samples |
| 385 | 423 | ||
| 386 | 424 | ||
| 425 | def _build_hard_samples( | ||
| 426 | plan: dict[str, int], | ||
| 427 | *, | ||
| 428 | groups: dict[str, list[LyricProfile]], | ||
| 429 | holdout_profiles: list[LyricProfile], | ||
| 430 | indexed_profiles: list[LyricProfile], | ||
| 431 | output_dir: Path, | ||
| 432 | csv_base: Path, | ||
| 433 | rng: random.Random, | ||
| 434 | start_index: int, | ||
| 435 | ) -> list[GeneratedSample]: | ||
| 436 | samples: list[GeneratedSample] = [] | ||
| 437 | |||
| 438 | _progress("build positive_realistic_variant samples") | ||
| 439 | samples.extend( | ||
| 440 | _build_realistic_positive_samples( | ||
| 441 | _stratified_sample(groups["normal"], plan["positive_realistic_variant"], rng), | ||
| 442 | output_dir, | ||
| 443 | csv_base, | ||
| 444 | rng, | ||
| 445 | start_index=start_index + len(samples), | ||
| 446 | ) | ||
| 447 | ) | ||
| 448 | _progress(f"built samples: {len(samples)}") | ||
| 449 | |||
| 450 | real_holdout_count = plan.get("negative_real_holdout_full_song", 0) | ||
| 451 | _progress("build negative_real_holdout_full_song samples") | ||
| 452 | samples.extend( | ||
| 453 | _build_holdout_full_song_samples( | ||
| 454 | holdout_profiles[:real_holdout_count], | ||
| 455 | output_dir, | ||
| 456 | csv_base, | ||
| 457 | start_index=start_index + len(samples), | ||
| 458 | ) | ||
| 459 | ) | ||
| 460 | _progress(f"built samples: {len(samples)}") | ||
| 461 | |||
| 462 | near_count = plan.get("negative_near_neighbor_holdout_full_song", 0) | ||
| 463 | _progress("build negative_near_neighbor_holdout_full_song samples") | ||
| 464 | near_holdouts = _near_neighbor_holdouts( | ||
| 465 | holdout_profiles[real_holdout_count:], | ||
| 466 | indexed_profiles, | ||
| 467 | near_count, | ||
| 468 | ) | ||
| 469 | samples.extend( | ||
| 470 | _build_holdout_full_song_samples( | ||
| 471 | near_holdouts, | ||
| 472 | output_dir, | ||
| 473 | csv_base, | ||
| 474 | start_index=start_index + len(samples), | ||
| 475 | sample_type="negative_near_neighbor_holdout_full_song", | ||
| 476 | notes="full real holdout lyric selected for catalog line overlap with indexed songs", | ||
| 477 | ) | ||
| 478 | ) | ||
| 479 | _progress(f"built samples: {len(samples)}") | ||
| 480 | |||
| 481 | _progress("build negative_long_fragment samples") | ||
| 482 | samples.extend( | ||
| 483 | _build_fragment_samples( | ||
| 484 | _stratified_sample(groups["fragmentable"], plan.get("negative_long_fragment", 0), rng), | ||
| 485 | output_dir, | ||
| 486 | csv_base, | ||
| 487 | rng, | ||
| 488 | start_index=start_index + len(samples), | ||
| 489 | sample_type="negative_long_fragment", | ||
| 490 | long_fragment=True, | ||
| 491 | notes="realistic long partial lyric upload, not a full-song duplicate", | ||
| 492 | ) | ||
| 493 | ) | ||
| 494 | _progress(f"built samples: {len(samples)}") | ||
| 495 | |||
| 496 | _progress("build negative_shared_chorus samples") | ||
| 497 | samples.extend( | ||
| 498 | _build_shared_chorus_samples( | ||
| 499 | _stratified_sample(groups["normal"], plan.get("negative_shared_chorus", 0), rng), | ||
| 500 | output_dir, | ||
| 501 | csv_base, | ||
| 502 | rng, | ||
| 503 | start_index=start_index + len(samples), | ||
| 504 | ) | ||
| 505 | ) | ||
| 506 | _progress(f"built samples: {len(samples)}") | ||
| 507 | |||
| 508 | _progress("build negative_translation_only samples") | ||
| 509 | samples.extend( | ||
| 510 | _build_translation_only_samples( | ||
| 511 | _stratified_sample(groups["foreign"], plan.get("negative_translation_only", 0), rng), | ||
| 512 | output_dir, | ||
| 513 | csv_base, | ||
| 514 | rng, | ||
| 515 | start_index=start_index + len(samples), | ||
| 516 | ) | ||
| 517 | ) | ||
| 518 | _progress(f"built samples: {len(samples)}") | ||
| 519 | |||
| 520 | _progress("build negative_catalog_mashup samples") | ||
| 521 | samples.extend( | ||
| 522 | _build_catalog_mashup_samples( | ||
| 523 | _stratified_sample(groups["normal"], plan.get("negative_catalog_mashup", 0) * 3, rng), | ||
| 524 | plan.get("negative_catalog_mashup", 0), | ||
| 525 | output_dir, | ||
| 526 | csv_base, | ||
| 527 | rng, | ||
| 528 | start_index=start_index + len(samples), | ||
| 529 | ) | ||
| 530 | ) | ||
| 531 | _progress(f"built samples: {len(samples)}") | ||
| 532 | |||
| 533 | _progress("build edge_short_or_placeholder samples") | ||
| 534 | samples.extend( | ||
| 535 | _build_edge_samples( | ||
| 536 | _stratified_sample(groups["edge"], plan.get("edge_short_or_placeholder", 0), rng), | ||
| 537 | output_dir, | ||
| 538 | csv_base, | ||
| 539 | rng, | ||
| 540 | start_index=start_index + len(samples), | ||
| 541 | ) | ||
| 542 | ) | ||
| 543 | return samples | ||
| 544 | |||
| 545 | |||
| 546 | def _build_realistic_positive_samples( | ||
| 547 | profiles: list[LyricProfile], | ||
| 548 | output_dir: Path, | ||
| 549 | csv_base: Path, | ||
| 550 | rng: random.Random, | ||
| 551 | *, | ||
| 552 | start_index: int, | ||
| 553 | ) -> list[GeneratedSample]: | ||
| 554 | samples: list[GeneratedSample] = [] | ||
| 555 | for offset, profile in enumerate(profiles): | ||
| 556 | content_lines = _content_lines(profile.raw_text) | ||
| 557 | primary_lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) or content_lines | ||
| 558 | variants = [ | ||
| 559 | ("positive_platform_mixed_noise", _platform_mixed_noise(content_lines, rng)), | ||
| 560 | ("positive_near_full_missing_section", _near_full_missing_section(primary_lines, rng)), | ||
| 561 | ("positive_block_translation_added", _block_translation_added(primary_lines)), | ||
| 562 | ("positive_typo_and_punctuation_noise", _stronger_typo_and_punctuation_noise(content_lines, rng)), | ||
| 563 | ("positive_timestamped_platform_variant", _timestamped_platform_variant(content_lines)), | ||
| 564 | ("positive_chorus_count_changed", _change_repeated_line_counts(content_lines)), | ||
| 565 | ] | ||
| 566 | sample_type, text = variants[offset % len(variants)] | ||
| 567 | index = start_index + offset | ||
| 568 | path = _write_sample_file(output_dir, f"pos_{index:05d}_{sample_type}.txt", text) | ||
| 569 | samples.append(_sample_from_profile(index, path, csv_base, "应去重", sample_type, profile)) | ||
| 570 | _progress_count("positive_realistic_variant", len(samples), len(profiles)) | ||
| 571 | return samples | ||
| 572 | |||
| 573 | |||
| 387 | def _build_holdout_full_song_samples( | 574 | def _build_holdout_full_song_samples( |
| 388 | profiles: list[LyricProfile], | 575 | profiles: list[LyricProfile], |
| 389 | output_dir: Path, | 576 | output_dir: Path, |
| 390 | csv_base: Path, | 577 | csv_base: Path, |
| 391 | *, | 578 | *, |
| 392 | start_index: int, | 579 | start_index: int, |
| 580 | sample_type: str = "negative_real_holdout_full_song", | ||
| 581 | notes: str = "full real lyric held out from the generated eval index", | ||
| 393 | ) -> list[GeneratedSample]: | 582 | ) -> list[GeneratedSample]: |
| 394 | _progress("build negative_real_holdout_full_song samples") | ||
| 395 | samples: list[GeneratedSample] = [] | 583 | samples: list[GeneratedSample] = [] |
| 396 | for offset, profile in enumerate(profiles): | 584 | for offset, profile in enumerate(profiles): |
| 397 | index = start_index + offset | 585 | index = start_index + offset |
| 398 | text = profile.raw_text | 586 | text = profile.raw_text |
| 399 | path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_real_holdout_full_song.txt", text) | 587 | path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text) |
| 400 | samples.append( | 588 | samples.append( |
| 401 | _sample_from_profile( | 589 | _sample_from_profile( |
| 402 | index, | 590 | index, |
| 403 | path, | 591 | path, |
| 404 | csv_base, | 592 | csv_base, |
| 405 | "不应去重", | 593 | "不应去重", |
| 406 | "negative_real_holdout_full_song", | 594 | sample_type, |
| 407 | profile, | 595 | profile, |
| 408 | notes="full real lyric held out from the generated eval index", | 596 | notes=notes, |
| 409 | ) | 597 | ) |
| 410 | ) | 598 | ) |
| 411 | _progress_count("negative_real_holdout_full_song", len(samples), len(profiles)) | 599 | _progress_count(sample_type, len(samples), len(profiles)) |
| 412 | return samples | 600 | return samples |
| 413 | 601 | ||
| 414 | 602 | ||
| ... | @@ -446,25 +634,59 @@ def _build_fragment_samples( | ... | @@ -446,25 +634,59 @@ def _build_fragment_samples( |
| 446 | rng: random.Random, | 634 | rng: random.Random, |
| 447 | *, | 635 | *, |
| 448 | start_index: int, | 636 | start_index: int, |
| 637 | sample_type: str = "negative_fragment", | ||
| 638 | long_fragment: bool = False, | ||
| 639 | notes: str = "partial lyric fragment only", | ||
| 449 | ) -> list[GeneratedSample]: | 640 | ) -> list[GeneratedSample]: |
| 450 | samples: list[GeneratedSample] = [] | 641 | samples: list[GeneratedSample] = [] |
| 451 | for offset, profile in enumerate(profiles): | 642 | for offset, profile in enumerate(profiles): |
| 452 | lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) | 643 | lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) |
| 453 | text = _single_song_fragment(lines, rng) | 644 | text = _long_song_fragment(lines, rng) if long_fragment else _single_song_fragment(lines, rng) |
| 454 | index = start_index + offset | 645 | index = start_index + offset |
| 455 | path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_fragment.txt", text) | 646 | path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text) |
| 456 | samples.append( | 647 | samples.append( |
| 457 | _sample_from_profile( | 648 | _sample_from_profile( |
| 458 | index, | 649 | index, |
| 459 | path, | 650 | path, |
| 460 | csv_base, | 651 | csv_base, |
| 461 | "不应去重", | 652 | "不应去重", |
| 462 | "negative_fragment", | 653 | sample_type, |
| 463 | profile, | 654 | profile, |
| 464 | notes="partial lyric fragment only", | 655 | notes=notes, |
| 656 | ) | ||
| 657 | ) | ||
| 658 | _progress_count(sample_type, len(samples), len(profiles)) | ||
| 659 | return samples | ||
| 660 | |||
| 661 | |||
| 662 | def _build_catalog_mashup_samples( | ||
| 663 | profiles: list[LyricProfile], | ||
| 664 | count: int, | ||
| 665 | output_dir: Path, | ||
| 666 | csv_base: Path, | ||
| 667 | rng: random.Random, | ||
| 668 | *, | ||
| 669 | start_index: int, | ||
| 670 | ) -> list[GeneratedSample]: | ||
| 671 | samples: list[GeneratedSample] = [] | ||
| 672 | if count <= 0 or not profiles: | ||
| 673 | return samples | ||
| 674 | for offset in range(count): | ||
| 675 | index = start_index + offset | ||
| 676 | picked = rng.sample(profiles, k=min(3, len(profiles))) | ||
| 677 | text = _catalog_mashup_text(picked, rng) | ||
| 678 | path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_catalog_mashup.txt", text) | ||
| 679 | samples.append( | ||
| 680 | GeneratedSample( | ||
| 681 | sample_id=f"sample-{index:05d}", | ||
| 682 | file=str(path.relative_to(csv_base)), | ||
| 683 | expected="不应去重", | ||
| 684 | sample_type="negative_catalog_mashup", | ||
| 685 | source=" | ".join(str(profile.path) for profile in picked), | ||
| 686 | notes="medley-style partial lyric assembled from multiple catalog songs", | ||
| 465 | ) | 687 | ) |
| 466 | ) | 688 | ) |
| 467 | _progress_count("negative_fragment", len(samples), len(profiles)) | 689 | _progress_count("negative_catalog_mashup", len(samples), count) |
| 468 | return samples | 690 | return samples |
| 469 | 691 | ||
| 470 | 692 | ||
| ... | @@ -658,8 +880,10 @@ def _write_manifest( | ... | @@ -658,8 +880,10 @@ def _write_manifest( |
| 658 | index_path: Path | None, | 880 | index_path: Path | None, |
| 659 | eval_index_path: Path, | 881 | eval_index_path: Path, |
| 660 | holdout_count: int, | 882 | holdout_count: int, |
| 883 | profile: str, | ||
| 661 | ) -> dict[str, object]: | 884 | ) -> dict[str, object]: |
| 662 | manifest = { | 885 | manifest = { |
| 886 | "profile": profile, | ||
| 663 | "seed": seed, | 887 | "seed": seed, |
| 664 | "library_files": len(profiles), | 888 | "library_files": len(profiles), |
| 665 | "sample_size": len(samples), | 889 | "sample_size": len(samples), |
| ... | @@ -684,6 +908,38 @@ def _write_manifest( | ... | @@ -684,6 +908,38 @@ def _write_manifest( |
| 684 | return manifest | 908 | return manifest |
| 685 | 909 | ||
| 686 | 910 | ||
| 911 | def _near_neighbor_holdouts( | ||
| 912 | holdout_profiles: list[LyricProfile], | ||
| 913 | indexed_profiles: list[LyricProfile], | ||
| 914 | count: int, | ||
| 915 | ) -> list[LyricProfile]: | ||
| 916 | if count <= 0 or not holdout_profiles: | ||
| 917 | return [] | ||
| 918 | if not indexed_profiles: | ||
| 919 | return holdout_profiles[:count] | ||
| 920 | |||
| 921 | line_to_indexed_count: Counter[str] = Counter() | ||
| 922 | for profile in indexed_profiles: | ||
| 923 | for line in set(profile.normalized.primary_lines or profile.normalized.unique_lines): | ||
| 924 | if len(line) >= 4: | ||
| 925 | line_to_indexed_count[line] += 1 | ||
| 926 | |||
| 927 | scored: list[tuple[float, LyricProfile]] = [] | ||
| 928 | for profile in holdout_profiles: | ||
| 929 | lines = set(profile.normalized.primary_lines or profile.normalized.unique_lines) | ||
| 930 | useful_lines = {line for line in lines if len(line) >= 4} | ||
| 931 | if not useful_lines: | ||
| 932 | score = 0.0 | ||
| 933 | else: | ||
| 934 | shared = sum(1 for line in useful_lines if line_to_indexed_count[line] > 0) | ||
| 935 | common_weight = sum(min(line_to_indexed_count[line], 5) for line in useful_lines) | ||
| 936 | score = (shared / len(useful_lines)) + (common_weight / (len(useful_lines) * 20)) | ||
| 937 | scored.append((score, profile)) | ||
| 938 | |||
| 939 | scored.sort(key=lambda item: item[0], reverse=True) | ||
| 940 | return [profile for _, profile in scored[:count]] | ||
| 941 | |||
| 942 | |||
| 687 | def _content_lines(text: str) -> list[str]: | 943 | def _content_lines(text: str) -> list[str]: |
| 688 | lines = [line.strip() for line in text.splitlines() if line.strip()] | 944 | lines = [line.strip() for line in text.splitlines() if line.strip()] |
| 689 | return lines or [text.strip()] | 945 | return lines or [text.strip()] |
| ... | @@ -735,6 +991,18 @@ def _add_timestamps(lines: list[str]) -> str: | ... | @@ -735,6 +991,18 @@ def _add_timestamps(lines: list[str]) -> str: |
| 735 | return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1)) | 991 | return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1)) |
| 736 | 992 | ||
| 737 | 993 | ||
| 994 | def _platform_mixed_noise(lines: list[str], rng: random.Random) -> str: | ||
| 995 | noisy = _add_blank_line_noise(lines).splitlines() | ||
| 996 | if noisy: | ||
| 997 | noisy = _add_punctuation_noise(noisy, rng).splitlines() | ||
| 998 | return "\n".join(["作词:未知", "歌词来自平台同步", *noisy, "未经著作权人许可 不得商业使用"]) | ||
| 999 | |||
| 1000 | |||
| 1001 | def _timestamped_platform_variant(lines: list[str]) -> str: | ||
| 1002 | timestamped = _add_timestamps(lines).splitlines() | ||
| 1003 | return "\n".join(["[00:00.00]歌词贡献者:用户上传", *timestamped]) | ||
| 1004 | |||
| 1005 | |||
| 738 | def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str: | 1006 | def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str: |
| 739 | marks = ["!", "?", "...", ",", "。"] | 1007 | marks = ["!", "?", "...", ",", "。"] |
| 740 | return "\n".join(f"{line}{rng.choice(marks)}" for line in lines) | 1008 | return "\n".join(f"{line}{rng.choice(marks)}" for line in lines) |
| ... | @@ -773,6 +1041,97 @@ def _translation_added(lines: list[str]) -> str: | ... | @@ -773,6 +1041,97 @@ def _translation_added(lines: list[str]) -> str: |
| 773 | return "\n".join(result) | 1041 | return "\n".join(result) |
| 774 | 1042 | ||
| 775 | 1043 | ||
| 1044 | def _block_translation_added(lines: list[str]) -> str: | ||
| 1045 | body = "\n".join(lines) | ||
| 1046 | translation_count = min(8, max(4, len(lines) // 4)) | ||
| 1047 | translations = [_pseudo_translation(index) for index in range(1, translation_count + 1)] | ||
| 1048 | return "\n".join([body, "", *translations]) | ||
| 1049 | |||
| 1050 | |||
| 1051 | def _near_full_missing_section(lines: list[str], rng: random.Random) -> str: | ||
| 1052 | if len(lines) <= 8: | ||
| 1053 | return "\n".join(lines) | ||
| 1054 | drop_count = max(1, min(max(1, len(lines) // 5), 8)) | ||
| 1055 | start = rng.randrange(0, max(1, len(lines) - drop_count + 1)) | ||
| 1056 | kept = lines[:start] + lines[start + drop_count :] | ||
| 1057 | return "\n".join(kept or lines) | ||
| 1058 | |||
| 1059 | |||
| 1060 | def _add_typo_noise(lines: list[str], rng: random.Random) -> str: | ||
| 1061 | if not lines: | ||
| 1062 | return "" | ||
| 1063 | result = list(lines) | ||
| 1064 | editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)] | ||
| 1065 | if not editable_indexes: | ||
| 1066 | return "\n".join(result) | ||
| 1067 | typo_count = max(1, min(4, len(editable_indexes) // 8 or 1)) | ||
| 1068 | for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))): | ||
| 1069 | result[index] = _typo_line(result[index], rng) | ||
| 1070 | return "\n".join(result) | ||
| 1071 | |||
| 1072 | |||
| 1073 | def _stronger_typo_and_punctuation_noise(lines: list[str], rng: random.Random) -> str: | ||
| 1074 | if not lines: | ||
| 1075 | return "" | ||
| 1076 | result = _add_punctuation_noise(lines, rng).splitlines() | ||
| 1077 | editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)] | ||
| 1078 | typo_count = max(1, min(8, len(editable_indexes) // 6 or 1)) | ||
| 1079 | for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))): | ||
| 1080 | result[index] = _typo_line(result[index], rng) | ||
| 1081 | return "\n".join(result) | ||
| 1082 | |||
| 1083 | |||
| 1084 | def _can_typo_line(line: str) -> bool: | ||
| 1085 | return bool(re.search(r"[A-Za-z]{4,}|[\u4e00-\u9fff]{4,}", line)) | ||
| 1086 | |||
| 1087 | |||
| 1088 | def _typo_line(line: str, rng: random.Random) -> str: | ||
| 1089 | words = list(re.finditer(r"[A-Za-z]{4,}", line)) | ||
| 1090 | if words and rng.random() < 0.65: | ||
| 1091 | match = rng.choice(words) | ||
| 1092 | typo = _typo_english_word(match.group(0), rng) | ||
| 1093 | return line[: match.start()] + typo + line[match.end() :] | ||
| 1094 | cjk_positions = [index for index, char in enumerate(line) if "\u4e00" <= char <= "\u9fff"] | ||
| 1095 | if cjk_positions: | ||
| 1096 | index = rng.choice(cjk_positions) | ||
| 1097 | return line[:index] + _typo_cjk_char(line[index]) + line[index + 1 :] | ||
| 1098 | return line | ||
| 1099 | |||
| 1100 | |||
| 1101 | def _typo_english_word(word: str, rng: random.Random) -> str: | ||
| 1102 | if len(word) <= 4 or rng.random() < 0.55: | ||
| 1103 | remove_at = rng.randrange(1, max(2, len(word) - 1)) | ||
| 1104 | return word[:remove_at] + word[remove_at + 1 :] | ||
| 1105 | swap_at = rng.randrange(1, max(2, len(word) - 2)) | ||
| 1106 | chars = list(word) | ||
| 1107 | chars[swap_at], chars[swap_at + 1] = chars[swap_at + 1], chars[swap_at] | ||
| 1108 | return "".join(chars) | ||
| 1109 | |||
| 1110 | |||
| 1111 | def _typo_cjk_char(char: str) -> str: | ||
| 1112 | replacements = { | ||
| 1113 | "你": "妳", | ||
| 1114 | "爱": "爰", | ||
| 1115 | "夜": "液", | ||
| 1116 | "里": "裏", | ||
| 1117 | "风": "凤", | ||
| 1118 | "雨": "兩", | ||
| 1119 | "听": "昕", | ||
| 1120 | "说": "説", | ||
| 1121 | "想": "相", | ||
| 1122 | "梦": "夣", | ||
| 1123 | "心": "芯", | ||
| 1124 | "光": "先", | ||
| 1125 | "城": "诚", | ||
| 1126 | "远": "迩", | ||
| 1127 | "回": "囬", | ||
| 1128 | "走": "赱", | ||
| 1129 | "海": "毎", | ||
| 1130 | "天": "夭", | ||
| 1131 | } | ||
| 1132 | return replacements.get(char, char) | ||
| 1133 | |||
| 1134 | |||
| 776 | def _single_song_fragment(lines: list[str], rng: random.Random) -> str: | 1135 | def _single_song_fragment(lines: list[str], rng: random.Random) -> str: |
| 777 | if len(lines) <= 4: | 1136 | if len(lines) <= 4: |
| 778 | return "\n".join(lines[: max(1, len(lines) // 2)]) | 1137 | return "\n".join(lines[: max(1, len(lines) // 2)]) |
| ... | @@ -781,6 +1140,28 @@ def _single_song_fragment(lines: list[str], rng: random.Random) -> str: | ... | @@ -781,6 +1140,28 @@ def _single_song_fragment(lines: list[str], rng: random.Random) -> str: |
| 781 | return "\n".join(lines[start : start + fragment_len]) | 1140 | return "\n".join(lines[start : start + fragment_len]) |
| 782 | 1141 | ||
| 783 | 1142 | ||
| 1143 | def _long_song_fragment(lines: list[str], rng: random.Random) -> str: | ||
| 1144 | if len(lines) <= 8: | ||
| 1145 | return _single_song_fragment(lines, rng) | ||
| 1146 | fragment_len = max(6, min(len(lines) - 1, int(len(lines) * rng.uniform(0.35, 0.60)))) | ||
| 1147 | start = rng.randrange(0, max(1, len(lines) - fragment_len + 1)) | ||
| 1148 | return "\n".join(lines[start : start + fragment_len]) | ||
| 1149 | |||
| 1150 | |||
| 1151 | def _catalog_mashup_text(profiles: list[LyricProfile], rng: random.Random) -> str: | ||
| 1152 | sections: list[str] = [] | ||
| 1153 | for profile in profiles: | ||
| 1154 | lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) | ||
| 1155 | if not lines: | ||
| 1156 | continue | ||
| 1157 | section_len = min(max(2, len(lines) // 8), 5) | ||
| 1158 | start = rng.randrange(0, max(1, len(lines) - section_len + 1)) | ||
| 1159 | sections.extend(lines[start : start + section_len]) | ||
| 1160 | if not sections: | ||
| 1161 | return _same_theme_synthetic(0, rng) | ||
| 1162 | return "\n".join(sections) | ||
| 1163 | |||
| 1164 | |||
| 784 | def _short_shared_snippet(lines: list[str], rng: random.Random) -> str: | 1165 | def _short_shared_snippet(lines: list[str], rng: random.Random) -> str: |
| 785 | snippet = rng.sample(lines, k=min(2, len(lines))) if lines else [] | 1166 | snippet = rng.sample(lines, k=min(2, len(lines))) if lines else [] |
| 786 | synthetic = [ | 1167 | synthetic = [ | ... | ... |
requirements.txt
0 → 100644
scripts/evaluate_postgres.py
0 → 100644
| 1 | """Evaluate lyric duplicate checking with PostgreSQL-backed candidate recall.""" | ||
| 2 | |||
| 3 | from __future__ import annotations | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import csv | ||
| 7 | import hashlib | ||
| 8 | import json | ||
| 9 | import sys | ||
| 10 | import time | ||
| 11 | from pathlib import Path | ||
| 12 | from typing import Any | ||
| 13 | |||
| 14 | |||
| 15 | PROJECT_ROOT = Path(__file__).resolve().parents[1] | ||
| 16 | if str(PROJECT_ROOT) not in sys.path: | ||
| 17 | sys.path.insert(0, str(PROJECT_ROOT)) | ||
| 18 | |||
| 19 | from lyric_dedup.checker import DuplicateChecker | ||
| 20 | from lyric_dedup.checker import LyricRecord | ||
| 21 | from lyric_dedup.file_import import read_lyric_file | ||
| 22 | from lyric_dedup.file_import import record_from_file | ||
| 23 | from lyric_dedup.normalization import fingerprint_text | ||
| 24 | from lyric_dedup.normalization import normalize_lyrics | ||
| 25 | |||
| 26 | |||
| 27 | def main() -> None: | ||
| 28 | parser = argparse.ArgumentParser(description="Evaluate duplicate checking using PostgreSQL recall.") | ||
| 29 | parser.add_argument("--dsn", required=True) | ||
| 30 | parser.add_argument("--csv", required=True) | ||
| 31 | parser.add_argument("--out", required=True) | ||
| 32 | parser.add_argument("--base-dir", default="") | ||
| 33 | parser.add_argument("--positive-decisions", default="duplicate") | ||
| 34 | parser.add_argument("--max-candidates", type=int, default=5) | ||
| 35 | parser.add_argument("--recall-limit", type=int, default=100) | ||
| 36 | parser.add_argument("--enable-trgm", action="store_true", help="Enable pg_trgm full-text recall. Slower; exact + line recall is used by default.") | ||
| 37 | parser.add_argument("--trgm-threshold", type=float, default=0.3) | ||
| 38 | parser.add_argument("--statement-timeout-ms", type=int, default=5000) | ||
| 39 | parser.add_argument("--profile-every", type=int, default=100) | ||
| 40 | args = parser.parse_args() | ||
| 41 | |||
| 42 | psycopg = _import_psycopg() | ||
| 43 | csv_path = Path(args.csv) | ||
| 44 | out_path = Path(args.out) | ||
| 45 | base_dir = Path(args.base_dir) if args.base_dir else None | ||
| 46 | positive_decisions = {item.strip() for item in args.positive_decisions.split(",") if item.strip()} | ||
| 47 | |||
| 48 | total = _csv_data_row_count(csv_path) | ||
| 49 | rows: list[dict[str, object]] = [] | ||
| 50 | profile_stats = _new_profile_stats() | ||
| 51 | out_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 52 | _progress(f"evaluate postgres csv: 0/{total}") | ||
| 53 | with psycopg.connect(args.dsn) as conn: | ||
| 54 | with conn.cursor() as cursor: | ||
| 55 | cursor.execute("select set_config('statement_timeout', %s, false)", (str(args.statement_timeout_ms),)) | ||
| 56 | cursor.execute("select set_config('pg_trgm.similarity_threshold', %s, false)", (str(args.trgm_threshold),)) | ||
| 57 | with csv_path.open(encoding="utf-8-sig", newline="") as in_file, out_path.open( | ||
| 58 | "w", encoding="utf-8", newline="" | ||
| 59 | ) as out_file: | ||
| 60 | reader = csv.DictReader(in_file) | ||
| 61 | if reader.fieldnames is None: | ||
| 62 | raise ValueError("评估 CSV 需要表头") | ||
| 63 | writer = csv.DictWriter(out_file, fieldnames=_fieldnames()) | ||
| 64 | writer.writeheader() | ||
| 65 | for index, row in enumerate(reader, start=1): | ||
| 66 | row_out = _evaluate_row( | ||
| 67 | conn, | ||
| 68 | row, | ||
| 69 | row_number=index + 1, | ||
| 70 | csv_path=csv_path, | ||
| 71 | base_dir=base_dir, | ||
| 72 | positive_decisions=positive_decisions, | ||
| 73 | max_candidates=args.max_candidates, | ||
| 74 | recall_limit=args.recall_limit, | ||
| 75 | enable_trgm=args.enable_trgm, | ||
| 76 | ) | ||
| 77 | rows.append(row_out) | ||
| 78 | writer.writerow(row_out) | ||
| 79 | _progress_count("evaluate postgres csv", index, total, step=10) | ||
| 80 | _update_profile_stats(profile_stats, row_out) | ||
| 81 | if args.profile_every > 0 and index % args.profile_every == 0: | ||
| 82 | _progress(_format_profile_stats(profile_stats, index)) | ||
| 83 | |||
| 84 | summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path) | ||
| 85 | summary_path = out_path.with_suffix(out_path.suffix + ".summary.json") | ||
| 86 | summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") | ||
| 87 | _progress("postgres evaluation complete") | ||
| 88 | print(json.dumps(summary, ensure_ascii=False)) | ||
| 89 | |||
| 90 | |||
| 91 | def _evaluate_row( | ||
| 92 | conn: Any, | ||
| 93 | row: dict[str, str], | ||
| 94 | *, | ||
| 95 | row_number: int, | ||
| 96 | csv_path: Path, | ||
| 97 | base_dir: Path | None, | ||
| 98 | positive_decisions: set[str], | ||
| 99 | max_candidates: int, | ||
| 100 | recall_limit: int, | ||
| 101 | enable_trgm: bool, | ||
| 102 | ) -> dict[str, object]: | ||
| 103 | parse_started = time.perf_counter() | ||
| 104 | sample_id = row.get("id") or row.get("sample_id") or str(row_number) | ||
| 105 | record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir) | ||
| 106 | expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target")) | ||
| 107 | parse_ms = round((time.perf_counter() - parse_started) * 1000, 2) | ||
| 108 | candidates, timings = _recall_candidates( | ||
| 109 | conn, | ||
| 110 | record, | ||
| 111 | recall_limit=recall_limit, | ||
| 112 | enable_trgm=enable_trgm, | ||
| 113 | exclude_record_ids=_exclude_record_ids_for_eval_row(row), | ||
| 114 | ) | ||
| 115 | rank_started = time.perf_counter() | ||
| 116 | result = _check_against_candidates(record, candidates, max_candidates=max_candidates) | ||
| 117 | rank_ms = round((time.perf_counter() - rank_started) * 1000, 2) | ||
| 118 | recall_ms = round(timings["exact_ms"] + timings["trgm_ms"] + timings["line_ms"], 2) | ||
| 119 | predicted_duplicate = result.decision.value in positive_decisions | ||
| 120 | best = result.candidates[0] if result.candidates else None | ||
| 121 | return { | ||
| 122 | "id": sample_id, | ||
| 123 | "source": source, | ||
| 124 | "expected_duplicate": expected_duplicate, | ||
| 125 | "decision": result.decision.value, | ||
| 126 | "predicted_duplicate": predicted_duplicate, | ||
| 127 | "correct": expected_duplicate == predicted_duplicate, | ||
| 128 | "confidence": result.confidence, | ||
| 129 | "reason": result.reason, | ||
| 130 | "candidate_count": len(candidates), | ||
| 131 | "parse_ms": parse_ms, | ||
| 132 | "recall_ms": recall_ms, | ||
| 133 | "exact_ms": timings["exact_ms"], | ||
| 134 | "trgm_ms": timings["trgm_ms"], | ||
| 135 | "line_ms": timings["line_ms"], | ||
| 136 | "rank_ms": rank_ms, | ||
| 137 | "best_candidate_id": best.record_id if best else "", | ||
| 138 | "best_candidate_decision": best.decision.value if best else "", | ||
| 139 | "best_candidate_confidence": best.confidence if best else "", | ||
| 140 | "best_candidate_jaccard": best.jaccard if best else "", | ||
| 141 | "best_candidate_line_coverage": best.line_coverage if best else "", | ||
| 142 | "best_candidate_primary_jaccard": best.primary_jaccard if best else "", | ||
| 143 | "best_candidate_primary_line_coverage": best.primary_line_coverage if best else "", | ||
| 144 | "best_candidate_translation_jaccard": best.translation_jaccard if best else "", | ||
| 145 | "best_candidate_translation_line_coverage": best.translation_line_coverage if best else "", | ||
| 146 | "best_candidate_reason": best.reason if best else "", | ||
| 147 | "matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "", | ||
| 148 | } | ||
| 149 | |||
| 150 | |||
| 151 | def _recall_candidates( | ||
| 152 | conn: Any, | ||
| 153 | record: LyricRecord, | ||
| 154 | *, | ||
| 155 | recall_limit: int, | ||
| 156 | enable_trgm: bool, | ||
| 157 | exclude_record_ids: list[str], | ||
| 158 | ) -> tuple[list[LyricRecord], dict[str, float]]: | ||
| 159 | query_lyrics = _pg_text(record.lyrics) or "" | ||
| 160 | normalized = normalize_lyrics(query_lyrics) | ||
| 161 | exact_text = fingerprint_text(normalized) | ||
| 162 | exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest() | ||
| 163 | primary_text = "\n".join(normalized.primary_lines) | ||
| 164 | line_hashes = [hashlib.sha256(line.encode("utf-8")).hexdigest() for line in normalized.primary_lines if line] | ||
| 165 | candidates: dict[str, LyricRecord] = {} | ||
| 166 | timings = {"exact_ms": 0.0, "trgm_ms": 0.0, "line_ms": 0.0} | ||
| 167 | with conn.cursor() as cursor: | ||
| 168 | started = time.perf_counter() | ||
| 169 | cursor.execute( | ||
| 170 | """ | ||
| 171 | select record_id, raw_text, title, artist | ||
| 172 | from lyrics | ||
| 173 | where deleted_at is null | ||
| 174 | and exact_hash = %s | ||
| 175 | and not (record_id = any(%s)) | ||
| 176 | limit %s | ||
| 177 | """, | ||
| 178 | (exact_hash, exclude_record_ids, recall_limit), | ||
| 179 | ) | ||
| 180 | _add_rows(candidates, cursor.fetchall()) | ||
| 181 | timings["exact_ms"] = round((time.perf_counter() - started) * 1000, 2) | ||
| 182 | |||
| 183 | if enable_trgm and primary_text: | ||
| 184 | started = time.perf_counter() | ||
| 185 | cursor.execute( | ||
| 186 | """ | ||
| 187 | select record_id, raw_text, title, artist | ||
| 188 | from lyrics | ||
| 189 | where deleted_at is null | ||
| 190 | and not (record_id = any(%s)) | ||
| 191 | and primary_text %% %s | ||
| 192 | order by similarity(primary_text, %s) desc | ||
| 193 | limit %s | ||
| 194 | """, | ||
| 195 | (exclude_record_ids, primary_text, primary_text, recall_limit), | ||
| 196 | ) | ||
| 197 | _add_rows(candidates, cursor.fetchall()) | ||
| 198 | timings["trgm_ms"] = round((time.perf_counter() - started) * 1000, 2) | ||
| 199 | |||
| 200 | if line_hashes: | ||
| 201 | started = time.perf_counter() | ||
| 202 | cursor.execute( | ||
| 203 | """ | ||
| 204 | select l.record_id, l.raw_text, l.title, l.artist | ||
| 205 | from lyric_lines ll | ||
| 206 | join lyrics l on l.id = ll.lyric_id | ||
| 207 | where l.deleted_at is null | ||
| 208 | and not (l.record_id = any(%s)) | ||
| 209 | and ll.role = 'primary' | ||
| 210 | and ll.line_hash = any(%s) | ||
| 211 | group by l.id | ||
| 212 | order by count(*) desc | ||
| 213 | limit %s | ||
| 214 | """, | ||
| 215 | (exclude_record_ids, line_hashes, recall_limit), | ||
| 216 | ) | ||
| 217 | _add_rows(candidates, cursor.fetchall()) | ||
| 218 | timings["line_ms"] = round((time.perf_counter() - started) * 1000, 2) | ||
| 219 | return list(candidates.values()), timings | ||
| 220 | |||
| 221 | |||
| 222 | def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]: | ||
| 223 | if row.get("sample_type") == "negative_real_holdout_full_song" and row.get("source_record_id"): | ||
| 224 | return [row["source_record_id"]] | ||
| 225 | return [] | ||
| 226 | |||
| 227 | |||
| 228 | def _add_rows(candidates: dict[str, LyricRecord], rows: list[tuple[object, ...]]) -> None: | ||
| 229 | for record_id, raw_text, title, artist in rows: | ||
| 230 | candidates.setdefault( | ||
| 231 | str(record_id), | ||
| 232 | LyricRecord( | ||
| 233 | record_id=str(record_id), | ||
| 234 | lyrics=str(raw_text), | ||
| 235 | title=str(title) if title is not None else None, | ||
| 236 | artist=str(artist) if artist is not None else None, | ||
| 237 | ), | ||
| 238 | ) | ||
| 239 | |||
| 240 | |||
| 241 | def _check_against_candidates( | ||
| 242 | record: LyricRecord, | ||
| 243 | candidates: list[LyricRecord], | ||
| 244 | *, | ||
| 245 | max_candidates: int, | ||
| 246 | ): | ||
| 247 | checker = DuplicateChecker() | ||
| 248 | for candidate in candidates: | ||
| 249 | checker.add_record(candidate) | ||
| 250 | return checker.check_record(record, max_candidates=max_candidates) | ||
| 251 | |||
| 252 | |||
| 253 | def _record_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[LyricRecord, str]: | ||
| 254 | lyrics = (row.get("lyrics") or "").strip() | ||
| 255 | if lyrics: | ||
| 256 | return ( | ||
| 257 | LyricRecord( | ||
| 258 | record_id=row.get("id") or row.get("sample_id") or "__eval__", | ||
| 259 | lyrics=_pg_text(lyrics.replace("\\n", "\n")) or "", | ||
| 260 | title=_pg_text(row.get("title") or None), | ||
| 261 | artist=_pg_text(row.get("artist") or None), | ||
| 262 | ), | ||
| 263 | "inline", | ||
| 264 | ) | ||
| 265 | |||
| 266 | file_value = (row.get("file") or row.get("path") or row.get("source") or "").strip() | ||
| 267 | if not file_value: | ||
| 268 | raise ValueError("评估 CSV 每行需要 lyrics,或 file/path/source 文件路径") | ||
| 269 | |||
| 270 | file_path = Path(file_value) | ||
| 271 | if not file_path.is_absolute(): | ||
| 272 | file_path = (base_dir or csv_path.parent) / file_path | ||
| 273 | record = record_from_file(file_path) | ||
| 274 | record = LyricRecord( | ||
| 275 | record_id=record.record_id, | ||
| 276 | lyrics=_pg_text(record.lyrics) or "", | ||
| 277 | title=_pg_text(record.title), | ||
| 278 | artist=_pg_text(record.artist), | ||
| 279 | ) | ||
| 280 | if row.get("title") or row.get("artist"): | ||
| 281 | record = LyricRecord( | ||
| 282 | record_id=record.record_id, | ||
| 283 | lyrics=record.lyrics, | ||
| 284 | title=_pg_text(row.get("title") or record.title), | ||
| 285 | artist=_pg_text(row.get("artist") or record.artist), | ||
| 286 | ) | ||
| 287 | return record, str(file_path) | ||
| 288 | |||
| 289 | |||
| 290 | def _parse_expected(value: str | None) -> bool: | ||
| 291 | if value is None: | ||
| 292 | raise ValueError("评估 CSV 每行需要 expected/label/target 列") | ||
| 293 | normalized = value.strip().lower() | ||
| 294 | positives = {"1", "true", "yes", "y", "duplicate", "dup", "重复", "应去重", "去重", "是"} | ||
| 295 | negatives = {"0", "false", "no", "n", "new", "not_duplicate", "non_duplicate", "不重复", "不应去重", "新歌", "否"} | ||
| 296 | if normalized in positives: | ||
| 297 | return True | ||
| 298 | if normalized in negatives: | ||
| 299 | return False | ||
| 300 | raise ValueError(f"无法识别 expected 值: {value!r}") | ||
| 301 | |||
| 302 | |||
| 303 | def _evaluation_summary( | ||
| 304 | rows: list[dict[str, object]], | ||
| 305 | *, | ||
| 306 | positive_decisions: set[str], | ||
| 307 | out_path: Path, | ||
| 308 | ) -> dict[str, object]: | ||
| 309 | tp = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is True) | ||
| 310 | fp = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is True) | ||
| 311 | tn = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is False) | ||
| 312 | fn = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is False) | ||
| 313 | total = len(rows) | ||
| 314 | precision = tp / (tp + fp) if tp + fp else 0.0 | ||
| 315 | recall = tp / (tp + fn) if tp + fn else 0.0 | ||
| 316 | accuracy = (tp + tn) / total if total else 0.0 | ||
| 317 | f1 = (2 * precision * recall / (precision + recall)) if precision + recall else 0.0 | ||
| 318 | return { | ||
| 319 | "total": total, | ||
| 320 | "positive_decisions": sorted(positive_decisions), | ||
| 321 | "accuracy": round(accuracy, 4), | ||
| 322 | "precision": round(precision, 4), | ||
| 323 | "recall": round(recall, 4), | ||
| 324 | "f1": round(f1, 4), | ||
| 325 | "true_positive": tp, | ||
| 326 | "false_positive": fp, | ||
| 327 | "true_negative": tn, | ||
| 328 | "false_negative": fn, | ||
| 329 | "duplicate": sum(1 for row in rows if row["decision"] == "duplicate"), | ||
| 330 | "review": sum(1 for row in rows if row["decision"] == "review"), | ||
| 331 | "new": sum(1 for row in rows if row["decision"] == "new"), | ||
| 332 | "out": str(out_path), | ||
| 333 | "summary": str(out_path.with_suffix(out_path.suffix + ".summary.json")), | ||
| 334 | } | ||
| 335 | |||
| 336 | |||
| 337 | def _fieldnames() -> list[str]: | ||
| 338 | return [ | ||
| 339 | "id", | ||
| 340 | "source", | ||
| 341 | "expected_duplicate", | ||
| 342 | "decision", | ||
| 343 | "predicted_duplicate", | ||
| 344 | "correct", | ||
| 345 | "confidence", | ||
| 346 | "reason", | ||
| 347 | "candidate_count", | ||
| 348 | "parse_ms", | ||
| 349 | "recall_ms", | ||
| 350 | "exact_ms", | ||
| 351 | "trgm_ms", | ||
| 352 | "line_ms", | ||
| 353 | "rank_ms", | ||
| 354 | "best_candidate_id", | ||
| 355 | "best_candidate_decision", | ||
| 356 | "best_candidate_confidence", | ||
| 357 | "best_candidate_jaccard", | ||
| 358 | "best_candidate_line_coverage", | ||
| 359 | "best_candidate_primary_jaccard", | ||
| 360 | "best_candidate_primary_line_coverage", | ||
| 361 | "best_candidate_translation_jaccard", | ||
| 362 | "best_candidate_translation_line_coverage", | ||
| 363 | "best_candidate_reason", | ||
| 364 | "matched_unique_lines", | ||
| 365 | ] | ||
| 366 | |||
| 367 | |||
| 368 | def _csv_data_row_count(csv_path: Path) -> int: | ||
| 369 | with csv_path.open(encoding="utf-8-sig", newline="") as file: | ||
| 370 | reader = csv.reader(file) | ||
| 371 | next(reader, None) | ||
| 372 | return sum(1 for _ in reader) | ||
| 373 | |||
| 374 | |||
| 375 | def _progress(message: str) -> None: | ||
| 376 | print(f"[pg-eval] {message}", file=sys.stderr, flush=True) | ||
| 377 | |||
| 378 | |||
| 379 | def _progress_count(label: str, current: int, total: int, *, step: int = 1000) -> None: | ||
| 380 | if total <= 0: | ||
| 381 | return | ||
| 382 | if current == 1 or current == total or current % step == 0: | ||
| 383 | _progress(f"{label}: {current}/{total}") | ||
| 384 | |||
| 385 | |||
| 386 | def _new_profile_stats() -> dict[str, float]: | ||
| 387 | return { | ||
| 388 | "parse_ms": 0.0, | ||
| 389 | "exact_ms": 0.0, | ||
| 390 | "trgm_ms": 0.0, | ||
| 391 | "line_ms": 0.0, | ||
| 392 | "rank_ms": 0.0, | ||
| 393 | "recall_ms": 0.0, | ||
| 394 | "candidate_count": 0.0, | ||
| 395 | } | ||
| 396 | |||
| 397 | |||
| 398 | def _update_profile_stats(stats: dict[str, float], row: dict[str, object]) -> None: | ||
| 399 | for key in stats: | ||
| 400 | try: | ||
| 401 | stats[key] += float(row.get(key) or 0) | ||
| 402 | except (TypeError, ValueError): | ||
| 403 | pass | ||
| 404 | |||
| 405 | |||
| 406 | def _format_profile_stats(stats: dict[str, float], count: int) -> str: | ||
| 407 | if count <= 0: | ||
| 408 | return "profile: no rows" | ||
| 409 | return ( | ||
| 410 | "profile avg " | ||
| 411 | f"parse={stats['parse_ms'] / count:.2f}ms " | ||
| 412 | f"exact={stats['exact_ms'] / count:.2f}ms " | ||
| 413 | f"line={stats['line_ms'] / count:.2f}ms " | ||
| 414 | f"trgm={stats['trgm_ms'] / count:.2f}ms " | ||
| 415 | f"rank={stats['rank_ms'] / count:.2f}ms " | ||
| 416 | f"recall={stats['recall_ms'] / count:.2f}ms " | ||
| 417 | f"candidates={stats['candidate_count'] / count:.1f}" | ||
| 418 | ) | ||
| 419 | |||
| 420 | |||
| 421 | def _pg_text(value: str | None) -> str | None: | ||
| 422 | if value is None: | ||
| 423 | return None | ||
| 424 | return value.replace("\x00", "") | ||
| 425 | |||
| 426 | |||
| 427 | def _import_psycopg(): | ||
| 428 | try: | ||
| 429 | import psycopg | ||
| 430 | |||
| 431 | return psycopg | ||
| 432 | except ModuleNotFoundError: | ||
| 433 | print( | ||
| 434 | "Missing dependency: psycopg. Install it with:\n" | ||
| 435 | " python -m pip install 'psycopg[binary]'", | ||
| 436 | file=sys.stderr, | ||
| 437 | ) | ||
| 438 | raise SystemExit(1) | ||
| 439 | |||
| 440 | |||
| 441 | if __name__ == "__main__": | ||
| 442 | main() |
scripts/import_library_postgres.py
0 → 100644
| 1 | """Import normalized lyric library records into PostgreSQL.""" | ||
| 2 | |||
| 3 | from __future__ import annotations | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import csv | ||
| 7 | import hashlib | ||
| 8 | import sys | ||
| 9 | from pathlib import Path | ||
| 10 | from typing import Any | ||
| 11 | |||
| 12 | |||
| 13 | PROJECT_ROOT = Path(__file__).resolve().parents[1] | ||
| 14 | if str(PROJECT_ROOT) not in sys.path: | ||
| 15 | sys.path.insert(0, str(PROJECT_ROOT)) | ||
| 16 | |||
| 17 | from lyric_dedup.file_import import iter_lyric_files | ||
| 18 | from lyric_dedup.file_import import record_from_file | ||
| 19 | from lyric_dedup.normalization import fingerprint_text | ||
| 20 | from lyric_dedup.normalization import normalize_lyrics | ||
| 21 | |||
| 22 | |||
| 23 | def main() -> None: | ||
| 24 | parser = argparse.ArgumentParser(description="Import lyric library into PostgreSQL.") | ||
| 25 | parser.add_argument("--dsn", required=True) | ||
| 26 | parser.add_argument("--lyrics-dir", required=True) | ||
| 27 | parser.add_argument("--batch-size", type=int, default=500) | ||
| 28 | parser.add_argument("--limit", type=int, default=0) | ||
| 29 | parser.add_argument("--skip-dedup-exact", action="store_true", help="Skip exact-hash duplicate soft deletion after import.") | ||
| 30 | parser.add_argument("--duplicate-report", default="outputs/results/postgres_exact_duplicates.csv") | ||
| 31 | parser.add_argument("--line-duplicate-report", default="", help="Optional CSV report for high line-coverage duplicate candidates.") | ||
| 32 | parser.add_argument("--line-coverage-threshold", type=float, default=0.95) | ||
| 33 | parser.add_argument("--line-duplicate-limit", type=int, default=10000) | ||
| 34 | args = parser.parse_args() | ||
| 35 | |||
| 36 | psycopg = _import_psycopg() | ||
| 37 | lyrics_dir = Path(args.lyrics_dir) | ||
| 38 | paths = iter_lyric_files(lyrics_dir) | ||
| 39 | if args.limit > 0: | ||
| 40 | paths = paths[: args.limit] | ||
| 41 | print(f"[pg-import] files: {len(paths)}", file=sys.stderr, flush=True) | ||
| 42 | |||
| 43 | imported = 0 | ||
| 44 | exact_deleted = 0 | ||
| 45 | line_reported = 0 | ||
| 46 | nul_cleaned = 0 | ||
| 47 | with psycopg.connect(args.dsn) as conn: | ||
| 48 | for start in range(0, len(paths), args.batch_size): | ||
| 49 | batch = paths[start : start + args.batch_size] | ||
| 50 | with conn.transaction(): | ||
| 51 | with conn.cursor() as cursor: | ||
| 52 | for path in batch: | ||
| 53 | lyric_id, line_rows, cleaned = _upsert_lyric(cursor, path, lyrics_dir) | ||
| 54 | nul_cleaned += cleaned | ||
| 55 | cursor.execute("delete from lyric_lines where lyric_id = %s", (lyric_id,)) | ||
| 56 | if line_rows: | ||
| 57 | cursor.executemany( | ||
| 58 | """ | ||
| 59 | insert into lyric_lines | ||
| 60 | (lyric_id, role, line_no, normalized_line, line_hash) | ||
| 61 | values (%s, %s, %s, %s, %s) | ||
| 62 | """, | ||
| 63 | line_rows, | ||
| 64 | ) | ||
| 65 | imported += 1 | ||
| 66 | _progress("import", imported, len(paths), step=args.batch_size) | ||
| 67 | if not args.skip_dedup_exact: | ||
| 68 | exact_deleted = _soft_delete_exact_duplicates(conn, Path(args.duplicate_report)) | ||
| 69 | if args.line_duplicate_report: | ||
| 70 | line_reported = _write_line_duplicate_report( | ||
| 71 | conn, | ||
| 72 | Path(args.line_duplicate_report), | ||
| 73 | threshold=args.line_coverage_threshold, | ||
| 74 | limit=args.line_duplicate_limit, | ||
| 75 | ) | ||
| 76 | print( | ||
| 77 | { | ||
| 78 | "imported": imported, | ||
| 79 | "records_with_nul_cleaned": nul_cleaned, | ||
| 80 | "exact_duplicates_soft_deleted": exact_deleted, | ||
| 81 | "line_duplicate_candidates_reported": line_reported, | ||
| 82 | } | ||
| 83 | ) | ||
| 84 | |||
| 85 | |||
| 86 | def _upsert_lyric(cursor: Any, path: Path, lyrics_dir: Path) -> tuple[int, list[tuple[object, ...]], int]: | ||
| 87 | record = record_from_file(path, base_dir=lyrics_dir) | ||
| 88 | raw_text, raw_cleaned = _pg_text(record.lyrics) | ||
| 89 | normalized = normalize_lyrics(raw_text) | ||
| 90 | primary_text = _pg_text("\n".join(normalized.primary_lines))[0] | ||
| 91 | translation_text = _pg_text("\n".join(normalized.translation_lines))[0] or None | ||
| 92 | normalized_text = _pg_text(normalized.normalized_full_text)[0] | ||
| 93 | exact_text = fingerprint_text(normalized) | ||
| 94 | exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest() | ||
| 95 | cursor.execute( | ||
| 96 | """ | ||
| 97 | insert into lyrics ( | ||
| 98 | record_id, source_path, title, artist, raw_text, normalized_text, | ||
| 99 | primary_text, translation_text, exact_hash, split_confidence, | ||
| 100 | split_reason, line_count, updated_at, deleted_at | ||
| 101 | ) | ||
| 102 | values ( | ||
| 103 | %(record_id)s, %(source_path)s, %(title)s, %(artist)s, %(raw_text)s, | ||
| 104 | %(normalized_text)s, %(primary_text)s, %(translation_text)s, | ||
| 105 | %(exact_hash)s, %(split_confidence)s, %(split_reason)s, | ||
| 106 | %(line_count)s, now(), null | ||
| 107 | ) | ||
| 108 | on conflict (record_id) do update set | ||
| 109 | source_path = excluded.source_path, | ||
| 110 | title = excluded.title, | ||
| 111 | artist = excluded.artist, | ||
| 112 | raw_text = excluded.raw_text, | ||
| 113 | normalized_text = excluded.normalized_text, | ||
| 114 | primary_text = excluded.primary_text, | ||
| 115 | translation_text = excluded.translation_text, | ||
| 116 | exact_hash = excluded.exact_hash, | ||
| 117 | split_confidence = excluded.split_confidence, | ||
| 118 | split_reason = excluded.split_reason, | ||
| 119 | line_count = excluded.line_count, | ||
| 120 | updated_at = now(), | ||
| 121 | deleted_at = null | ||
| 122 | returning id | ||
| 123 | """, | ||
| 124 | { | ||
| 125 | "record_id": record.record_id, | ||
| 126 | "source_path": str(path), | ||
| 127 | "title": _pg_text(record.title)[0], | ||
| 128 | "artist": _pg_text(record.artist)[0], | ||
| 129 | "raw_text": raw_text, | ||
| 130 | "normalized_text": normalized_text, | ||
| 131 | "primary_text": primary_text, | ||
| 132 | "translation_text": translation_text, | ||
| 133 | "exact_hash": exact_hash, | ||
| 134 | "split_confidence": _pg_text(normalized.split_confidence)[0], | ||
| 135 | "split_reason": _pg_text(normalized.split_reason)[0], | ||
| 136 | "line_count": len(normalized.primary_lines or normalized.unique_lines), | ||
| 137 | }, | ||
| 138 | ) | ||
| 139 | lyric_id = cursor.fetchone()[0] | ||
| 140 | line_rows: list[tuple[object, ...]] = [] | ||
| 141 | line_rows.extend(_line_rows(lyric_id, "primary", normalized.primary_lines)) | ||
| 142 | line_rows.extend(_line_rows(lyric_id, "translation", normalized.translation_lines)) | ||
| 143 | line_rows.extend(_line_rows(lyric_id, "unknown", normalized.unknown_lines)) | ||
| 144 | return lyric_id, line_rows, int(raw_cleaned) | ||
| 145 | |||
| 146 | |||
| 147 | def _line_rows(lyric_id: int, role: str, lines: tuple[str, ...]) -> list[tuple[object, ...]]: | ||
| 148 | rows: list[tuple[object, ...]] = [] | ||
| 149 | for index, line in enumerate(lines): | ||
| 150 | line = _pg_text(line)[0] or "" | ||
| 151 | line_hash = hashlib.sha256(line.encode("utf-8")).hexdigest() | ||
| 152 | rows.append((lyric_id, role, index, line, line_hash)) | ||
| 153 | return rows | ||
| 154 | |||
| 155 | |||
| 156 | def _pg_text(value: str | None) -> tuple[str | None, bool]: | ||
| 157 | if value is None: | ||
| 158 | return None, False | ||
| 159 | if "\x00" not in value: | ||
| 160 | return value, False | ||
| 161 | return value.replace("\x00", ""), True | ||
| 162 | |||
| 163 | |||
| 164 | def _soft_delete_exact_duplicates(conn: Any, report_path: Path) -> int: | ||
| 165 | print("[pg-import] deduplicate exact_hash duplicates", file=sys.stderr, flush=True) | ||
| 166 | with conn.transaction(): | ||
| 167 | with conn.cursor() as cursor: | ||
| 168 | cursor.execute( | ||
| 169 | """ | ||
| 170 | with ranked as ( | ||
| 171 | select | ||
| 172 | id, | ||
| 173 | exact_hash, | ||
| 174 | first_value(id) over ( | ||
| 175 | partition by exact_hash | ||
| 176 | order by | ||
| 177 | case when source_path like '%/None_%' then 1 else 0 end, | ||
| 178 | line_count desc, | ||
| 179 | length(primary_text) desc, | ||
| 180 | id | ||
| 181 | ) as kept_id, | ||
| 182 | row_number() over ( | ||
| 183 | partition by exact_hash | ||
| 184 | order by | ||
| 185 | case when source_path like '%/None_%' then 1 else 0 end, | ||
| 186 | line_count desc, | ||
| 187 | length(primary_text) desc, | ||
| 188 | id | ||
| 189 | ) as rn | ||
| 190 | from lyrics | ||
| 191 | where deleted_at is null | ||
| 192 | ), | ||
| 193 | to_delete as ( | ||
| 194 | select id, exact_hash, kept_id | ||
| 195 | from ranked | ||
| 196 | where rn > 1 | ||
| 197 | ), | ||
| 198 | updated as ( | ||
| 199 | update lyrics l | ||
| 200 | set deleted_at = now(), updated_at = now() | ||
| 201 | from to_delete d | ||
| 202 | where l.id = d.id | ||
| 203 | returning | ||
| 204 | l.id as duplicate_id, | ||
| 205 | l.record_id as duplicate_record_id, | ||
| 206 | l.source_path as duplicate_source_path, | ||
| 207 | d.exact_hash, | ||
| 208 | d.kept_id | ||
| 209 | ) | ||
| 210 | select | ||
| 211 | u.duplicate_id, | ||
| 212 | u.duplicate_record_id, | ||
| 213 | u.duplicate_source_path, | ||
| 214 | k.id as kept_id, | ||
| 215 | k.record_id as kept_record_id, | ||
| 216 | k.source_path as kept_source_path, | ||
| 217 | u.exact_hash | ||
| 218 | from updated u | ||
| 219 | join lyrics k on k.id = u.kept_id | ||
| 220 | order by u.exact_hash, u.duplicate_id | ||
| 221 | """ | ||
| 222 | ) | ||
| 223 | rows = cursor.fetchall() | ||
| 224 | _write_rows( | ||
| 225 | report_path, | ||
| 226 | [ | ||
| 227 | "duplicate_id", | ||
| 228 | "duplicate_record_id", | ||
| 229 | "duplicate_source_path", | ||
| 230 | "kept_id", | ||
| 231 | "kept_record_id", | ||
| 232 | "kept_source_path", | ||
| 233 | "exact_hash", | ||
| 234 | ], | ||
| 235 | rows, | ||
| 236 | ) | ||
| 237 | print(f"[pg-import] exact duplicates soft-deleted: {len(rows)}", file=sys.stderr, flush=True) | ||
| 238 | return len(rows) | ||
| 239 | |||
| 240 | |||
| 241 | def _write_line_duplicate_report(conn: Any, report_path: Path, *, threshold: float, limit: int) -> int: | ||
| 242 | print("[pg-import] report high line-coverage duplicate candidates", file=sys.stderr, flush=True) | ||
| 243 | with conn.cursor() as cursor: | ||
| 244 | cursor.execute( | ||
| 245 | """ | ||
| 246 | with pairs as ( | ||
| 247 | select | ||
| 248 | a.lyric_id as left_id, | ||
| 249 | b.lyric_id as right_id, | ||
| 250 | count(*) as matched_lines | ||
| 251 | from lyric_lines a | ||
| 252 | join lyric_lines b | ||
| 253 | on a.line_hash = b.line_hash | ||
| 254 | and a.lyric_id < b.lyric_id | ||
| 255 | join lyrics la on la.id = a.lyric_id and la.deleted_at is null | ||
| 256 | join lyrics lb on lb.id = b.lyric_id and lb.deleted_at is null | ||
| 257 | where a.role = 'primary' | ||
| 258 | and b.role = 'primary' | ||
| 259 | group by a.lyric_id, b.lyric_id | ||
| 260 | ) | ||
| 261 | select | ||
| 262 | p.left_id, | ||
| 263 | l1.record_id as left_record_id, | ||
| 264 | l1.source_path as left_source_path, | ||
| 265 | p.right_id, | ||
| 266 | l2.record_id as right_record_id, | ||
| 267 | l2.source_path as right_source_path, | ||
| 268 | p.matched_lines, | ||
| 269 | l1.line_count as left_line_count, | ||
| 270 | l2.line_count as right_line_count, | ||
| 271 | p.matched_lines::float / greatest(l1.line_count, l2.line_count) as line_coverage | ||
| 272 | from pairs p | ||
| 273 | join lyrics l1 on l1.id = p.left_id | ||
| 274 | join lyrics l2 on l2.id = p.right_id | ||
| 275 | where p.matched_lines::float / greatest(l1.line_count, l2.line_count) >= %s | ||
| 276 | order by line_coverage desc, matched_lines desc | ||
| 277 | limit %s | ||
| 278 | """, | ||
| 279 | (threshold, limit), | ||
| 280 | ) | ||
| 281 | rows = cursor.fetchall() | ||
| 282 | _write_rows( | ||
| 283 | report_path, | ||
| 284 | [ | ||
| 285 | "left_id", | ||
| 286 | "left_record_id", | ||
| 287 | "left_source_path", | ||
| 288 | "right_id", | ||
| 289 | "right_record_id", | ||
| 290 | "right_source_path", | ||
| 291 | "matched_lines", | ||
| 292 | "left_line_count", | ||
| 293 | "right_line_count", | ||
| 294 | "line_coverage", | ||
| 295 | ], | ||
| 296 | rows, | ||
| 297 | ) | ||
| 298 | print(f"[pg-import] line duplicate candidates reported: {len(rows)}", file=sys.stderr, flush=True) | ||
| 299 | return len(rows) | ||
| 300 | |||
| 301 | |||
| 302 | def _write_rows(report_path: Path, fieldnames: list[str], rows: list[tuple[object, ...]]) -> None: | ||
| 303 | report_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 304 | with report_path.open("w", encoding="utf-8", newline="") as file: | ||
| 305 | writer = csv.writer(file) | ||
| 306 | writer.writerow(fieldnames) | ||
| 307 | writer.writerows(rows) | ||
| 308 | |||
| 309 | |||
| 310 | def _progress(label: str, current: int, total: int, *, step: int) -> None: | ||
| 311 | if current == total or current % step == 0: | ||
| 312 | print(f"[pg-import] {label}: {current}/{total}", file=sys.stderr, flush=True) | ||
| 313 | |||
| 314 | |||
| 315 | def _import_psycopg(): | ||
| 316 | try: | ||
| 317 | import psycopg | ||
| 318 | |||
| 319 | return psycopg | ||
| 320 | except ModuleNotFoundError: | ||
| 321 | print( | ||
| 322 | "Missing dependency: psycopg. Install it with:\n" | ||
| 323 | " python -m pip install 'psycopg[binary]'", | ||
| 324 | file=sys.stderr, | ||
| 325 | ) | ||
| 326 | raise SystemExit(1) | ||
| 327 | |||
| 328 | |||
| 329 | if __name__ == "__main__": | ||
| 330 | main() |
scripts/init_postgres.py
0 → 100644
| 1 | """Initialize PostgreSQL schema for lyric dedup storage.""" | ||
| 2 | |||
| 3 | from __future__ import annotations | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import sys | ||
| 7 | from pathlib import Path | ||
| 8 | |||
| 9 | |||
| 10 | PROJECT_ROOT = Path(__file__).resolve().parents[1] | ||
| 11 | SCHEMA_PATH = PROJECT_ROOT / "scripts" / "postgres_schema.sql" | ||
| 12 | |||
| 13 | |||
| 14 | def main() -> None: | ||
| 15 | parser = argparse.ArgumentParser(description="Initialize PostgreSQL schema for lyric dedup.") | ||
| 16 | parser.add_argument("--dsn", required=True, help="PostgreSQL DSN, e.g. postgresql://user:pass@localhost:5432/lyric_dedup") | ||
| 17 | parser.add_argument("--schema", default=str(SCHEMA_PATH)) | ||
| 18 | args = parser.parse_args() | ||
| 19 | |||
| 20 | psycopg = _import_psycopg() | ||
| 21 | schema_sql = Path(args.schema).read_text(encoding="utf-8") | ||
| 22 | with psycopg.connect(args.dsn) as conn: | ||
| 23 | with conn.cursor() as cursor: | ||
| 24 | cursor.execute(schema_sql) | ||
| 25 | conn.commit() | ||
| 26 | print(f"initialized schema from {args.schema}") | ||
| 27 | |||
| 28 | |||
| 29 | def _import_psycopg(): | ||
| 30 | try: | ||
| 31 | import psycopg | ||
| 32 | |||
| 33 | return psycopg | ||
| 34 | except ModuleNotFoundError: | ||
| 35 | print( | ||
| 36 | "Missing dependency: psycopg. Install it with:\n" | ||
| 37 | " python -m pip install 'psycopg[binary]'", | ||
| 38 | file=sys.stderr, | ||
| 39 | ) | ||
| 40 | raise SystemExit(1) | ||
| 41 | |||
| 42 | |||
| 43 | if __name__ == "__main__": | ||
| 44 | main() |
scripts/postgres_schema.sql
0 → 100644
| 1 | create extension if not exists pg_trgm; | ||
| 2 | |||
| 3 | create table if not exists lyrics ( | ||
| 4 | id bigserial primary key, | ||
| 5 | record_id text not null unique, | ||
| 6 | source_path text not null, | ||
| 7 | title text, | ||
| 8 | artist text, | ||
| 9 | raw_text text not null, | ||
| 10 | normalized_text text not null, | ||
| 11 | primary_text text not null, | ||
| 12 | translation_text text, | ||
| 13 | exact_hash text not null, | ||
| 14 | split_confidence text, | ||
| 15 | split_reason text, | ||
| 16 | line_count integer not null, | ||
| 17 | created_at timestamptz not null default now(), | ||
| 18 | updated_at timestamptz not null default now(), | ||
| 19 | deleted_at timestamptz | ||
| 20 | ); | ||
| 21 | |||
| 22 | create index if not exists lyrics_exact_hash_idx | ||
| 23 | on lyrics (exact_hash) | ||
| 24 | where deleted_at is null; | ||
| 25 | |||
| 26 | create index if not exists lyrics_primary_text_trgm_idx | ||
| 27 | on lyrics using gin (primary_text gin_trgm_ops); | ||
| 28 | |||
| 29 | create table if not exists lyric_lines ( | ||
| 30 | lyric_id bigint not null references lyrics(id) on delete cascade, | ||
| 31 | role text not null, | ||
| 32 | line_no integer not null, | ||
| 33 | normalized_line text not null, | ||
| 34 | line_hash text not null, | ||
| 35 | primary key (lyric_id, role, line_no) | ||
| 36 | ); | ||
| 37 | |||
| 38 | create index if not exists lyric_lines_hash_idx | ||
| 39 | on lyric_lines (line_hash); | ||
| 40 | |||
| 41 | create index if not exists lyric_lines_lyric_id_idx | ||
| 42 | on lyric_lines (lyric_id); |
| ... | @@ -316,6 +316,40 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None: | ... | @@ -316,6 +316,40 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None: |
| 316 | assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_")) | 316 | assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_")) |
| 317 | 317 | ||
| 318 | 318 | ||
| 319 | def test_generated_hard_eval_set_uses_business_realistic_edge_mix(tmp_path) -> None: | ||
| 320 | library = tmp_path / "library" | ||
| 321 | incoming = tmp_path / "generated" / "incoming" | ||
| 322 | eval_csv = tmp_path / "generated" / "eval_hard.csv" | ||
| 323 | library.mkdir() | ||
| 324 | for idx in range(24): | ||
| 325 | prefix = "AY" if idx % 3 == 0 else "WHHY" | ||
| 326 | lyric = BASE_LYRIC.replace("我爱你", f"我想你{idx}").replace("城市", f"城市{idx}") | ||
| 327 | if idx % 4 == 0: | ||
| 328 | lyric += "\nI miss you tonight\nUnder the moonlight\nNever let me go\n" | ||
| 329 | (library / f"{idx}_{prefix}{idx:06d}.txt").write_text(lyric, encoding="utf-8") | ||
| 330 | |||
| 331 | generate_eval_set( | ||
| 332 | library_dir=library, | ||
| 333 | output_dir=incoming, | ||
| 334 | csv_path=eval_csv, | ||
| 335 | size=40, | ||
| 336 | positive_ratio=0.3, | ||
| 337 | profile="hard", | ||
| 338 | ) | ||
| 339 | |||
| 340 | rows = list(csv.DictReader(eval_csv.open(encoding="utf-8"))) | ||
| 341 | manifest = json.loads((tmp_path / "generated" / "eval_hard.csv.manifest.json").read_text(encoding="utf-8")) | ||
| 342 | sample_types = {row["sample_type"] for row in rows} | ||
| 343 | |||
| 344 | assert len(rows) == 40 | ||
| 345 | assert manifest["profile"] == "hard" | ||
| 346 | assert "positive_realistic_variant" in manifest["plan"] | ||
| 347 | assert "negative_near_neighbor_holdout_full_song" in manifest["plan"] | ||
| 348 | assert "negative_long_fragment" in sample_types | ||
| 349 | assert "negative_catalog_mashup" in sample_types | ||
| 350 | assert any(row["sample_type"].startswith("positive_") for row in rows) | ||
| 351 | |||
| 352 | |||
| 319 | def test_foreign_original_with_added_chinese_translation_is_duplicate() -> None: | 353 | def test_foreign_original_with_added_chinese_translation_is_duplicate() -> None: |
| 320 | checker = DuplicateChecker() | 354 | checker = DuplicateChecker() |
| 321 | checker.add_record( | 355 | checker.add_record( | ... | ... |
-
Please register or sign in to post a comment