Commit 49008962 4900896283e6b52190437fd467f52ab75caf2530 by 沈秋雨

新增 PostgreSQL 去重检索链路与 hard 评估集支持

- 新增 PostgreSQL 导入脚本、评估脚本和 schema 定义,支持基于 exact_hash、pg_trgm 和行级 hash 的三层召回策略
- 评估 CLI 新增 hard profile,覆盖错别字、OCR 错误、整段翻译、medley 片段等更贴近业务边界的场景
- 调整 checker.py 复核阈值与匹配理由文案,优化翻译行相似与仅副歌重复场景的判定逻辑
- 同步更新 README、TEST_WORKFLOW 和单元测试

Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ba39ce6a
# PostgreSQL 迁移说明
本文档说明为什么要从 `.pkl` 索引迁移到 PostgreSQL、PostgreSQL 在这个项目里承担什么角色,以及从本地初始化到导入曲库的具体步骤。
## 1. 为什么迁移
当前 `outputs/indexes/library_lyrics.pkl` 是一个 Python `DuplicateChecker` 对象快照。它适合实验和离线评估,但不适合线上服务:
- 增量新增、删除、回滚不方便。
- 多实例部署时很难同步同一个 `.pkl`
- 没有数据库事务、审计、版本管理。
- 索引结构绑定 Python pickle,不利于运维和长期维护。
PostgreSQL 的目标不是替代全部判重逻辑,而是替代“数据存储 + 候选召回”部分。最终判定仍保留在 Python 服务层。
## 2. PostgreSQL 在本项目里的角色
推荐分工:
```text
PostgreSQL:
保存原始歌词、规范化歌词、行级特征
用 exact_hash 做精确召回
用 pg_trgm 做文本近似召回
用 lyric_lines 做行级重合召回
Python:
normalize_lyrics
计算 Jaccard / line coverage
处理翻译行、副歌碰撞、片段保护
最终输出 duplicate / review / new
```
不要让 PostgreSQL 的 `similarity()` 直接决定 `duplicate`。数据库只负责找候选,最终判定仍应复用当前算法规则。
## 3. pkl 索引和 PostgreSQL 的对应关系
当前 `.pkl` 里主要有:
```text
_records 每首歌词的完整索引记录
_exact_hash_to_ids exact hash -> record ids
_line_to_ids normalized line -> record ids
_token_to_ids n-gram token -> record ids
_lsh MinHash LSH 桶
```
迁移到 PostgreSQL 后:
```text
lyrics.exact_hash 对应 _exact_hash_to_ids
lyrics.primary_text + pg_trgm 对应近似文本召回
lyric_lines.line_hash 对应行级倒排召回
lyrics / lyric_lines 对应 _records 的可持久化部分
```
第一阶段不迁移 MinHash LSH。先用 `exact_hash + pg_trgm + line_hash` 验证召回效果。
## 4. 本地 PostgreSQL 基础
PostgreSQL 是关系型数据库,常用概念:
```text
database 数据库,例如 lyric_dedup
schema 命名空间,默认 public
table 表,例如 lyrics
index 索引,用来加速查询
extension 扩展,例如 pg_trgm
DSN 连接字符串
```
本机数据库连接字符串常见写法:
```bash
postgresql:///lyric_dedup
```
含义是:使用当前系统用户名,连接本机 PostgreSQL 的 `lyric_dedup` 数据库。
如果有用户名密码:
```bash
postgresql://postgres:postgres@localhost:5432/lyric_dedup
```
## 5. 当前已新增脚本
项目里已经加入:
```text
scripts/postgres_schema.sql
scripts/init_postgres.py
scripts/import_library_postgres.py
scripts/evaluate_postgres.py
```
用途:
```text
postgres_schema.sql 建表、建索引、启用 pg_trgm
init_postgres.py 自动执行 schema SQL
import_library_postgres.py 扫描 data/library,规范化后导入 PostgreSQL
evaluate_postgres.py 使用 PostgreSQL 召回候选并评测 CSV
```
## 6. 安装依赖
当前 Python 环境需要安装 PostgreSQL 驱动:
```bash
python -m pip install 'psycopg[binary]'
```
如果你使用 conda 环境,确认命令运行在当前项目所用的 `(base)` 或目标环境里。
验证:
```bash
python - <<'PY'
import psycopg
print(psycopg.__version__)
PY
```
## 7. 创建数据库
你已经执行过:
```bash
createdb lyric_dedup
```
如果需要确认数据库存在:
```bash
psql -l | grep lyric_dedup
```
进入数据库:
```bash
psql postgresql:///lyric_dedup
```
退出 `psql`
```text
\q
```
## 8. 初始化表结构
执行:
```bash
python scripts/init_postgres.py \
--dsn postgresql:///lyric_dedup
```
它会执行:
```sql
create extension if not exists pg_trgm;
create table if not exists lyrics (...);
create table if not exists lyric_lines (...);
create index if not exists ...;
```
成功输出类似:
```text
initialized schema from scripts/postgres_schema.sql
```
可以检查表:
```bash
psql postgresql:///lyric_dedup -c '\dt'
```
检查扩展:
```bash
psql postgresql:///lyric_dedup -c 'select * from pg_extension;'
```
## 9. 表结构说明
### lyrics
保存每首歌词的主记录:
```text
record_id 当前文件生成的稳定 id
source_path 原始文件路径
title / artist 从文件名解析的元数据
raw_text 原始歌词
normalized_text 清洗后的全文
primary_text 原文行拼接文本,主要用于自动判重
translation_text 翻译行拼接文本
exact_hash 规范化原文 hash
split_confidence 翻译拆分置信度
split_reason 翻译拆分原因
line_count 有效歌词行数
deleted_at 软删除字段
```
### lyric_lines
保存行级特征:
```text
lyric_id 对应 lyrics.id
role primary / translation / unknown
line_no 行号
normalized_line 规范化歌词行
line_hash 行 hash
```
用途:快速找“哪些歌包含相同行”。
## 10. 小批量导入测试
先导入 1000 条,确认环境和 schema 都正常:
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library \
--limit 1000
```
导入脚本默认会在导入结束后执行一次低风险库内去重:
```text
exact_hash 完全一致的记录只保留一条,其余记录 soft delete,即设置 lyrics.deleted_at。
```
重复清理报告默认写到:
```text
outputs/results/postgres_exact_duplicates.csv
```
如果只是想导入,不做 exact 去重:
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library \
--limit 1000 \
--skip-dedup-exact
```
查看数量:
```bash
psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;'
psql postgresql:///lyric_dedup -c 'select count(*) from lyric_lines;'
```
查看几条数据:
```bash
psql postgresql:///lyric_dedup -c \
'select id, record_id, title, artist, line_count from lyrics limit 5;'
```
## 11. 全量导入
确认小批量没问题后,导入全量:
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library
```
脚本会显示进度:
```text
[pg-import] files: 70295
[pg-import] import: 500/70295
...
```
导入是 upsert,同一个 `record_id` 再导入会更新,不会重复插入。
如果想额外生成“高行覆盖率近重复候选”报告,但不自动删除:
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library \
--line-duplicate-report outputs/results/postgres_line_duplicates.csv \
--line-coverage-threshold 0.95
```
注意:行覆盖率近重复报告可能较慢,且只用于抽查。当前脚本不会自动 soft delete 这些近重复候选。
## 12. 基础 SQL 验证
### exact hash 重复
找规范化 hash 重复:
```bash
psql postgresql:///lyric_dedup -c "
select exact_hash, count(*)
from lyrics
where deleted_at is null
group by exact_hash
having count(*) > 1
order by count(*) desc
limit 20;
"
```
如果导入时没有加 `--skip-dedup-exact`,这里理论上不应该再出现 active exact 重复;已经清理的重复记录可以这样查看:
```bash
psql postgresql:///lyric_dedup -c "
select count(*)
from lyrics
where deleted_at is not null;
"
```
### pg_trgm 相似查询
测试 `pg_trgm`
```bash
psql postgresql:///lyric_dedup -c "
select id, title, similarity(primary_text, '我爱你在每个夜里') as sim
from lyrics
where primary_text % '我爱你在每个夜里'
order by sim desc
limit 10;
"
```
### 行级重合
找某一行出现在哪些歌:
```bash
psql postgresql:///lyric_dedup -c "
select l.id, l.title, ll.normalized_line
from lyric_lines ll
join lyrics l on l.id = ll.lyric_id
where ll.normalized_line = '我爱你在每个夜里'
limit 20;
"
```
## 13. 后续查重查询应该怎么做
未来 PostgreSQL 版查重流程:
```text
1. Python 读取新增歌词
2. normalize_lyrics
3. SQL exact_hash 召回
4. SQL pg_trgm 召回
5. SQL lyric_lines 行级召回
6. 合并候选 id
7. 拉候选 normalized 数据
8. Python 复用当前打分规则
9. 输出 duplicate / review / new
```
示意 SQL:
```sql
select id
from lyrics
where exact_hash = $1
and deleted_at is null;
```
```sql
select id, similarity(primary_text, $1) as sim
from lyrics
where deleted_at is null
and primary_text % $1
order by sim desc
limit 200;
```
```sql
select lyric_id, count(*) as matched_lines
from lyric_lines
where role = 'primary'
and line_hash = any($1)
group by lyric_id
order by matched_lines desc
limit 200;
```
## 14. 增量更新设计
新增一首歌:
```text
1. normalize
2. 用 PostgreSQL 召回候选
3. Python 判定
4. duplicate: 拒绝或关联已有记录
5. review: 进入人工复核
6. new: 写入 lyrics 和 lyric_lines
```
删除一首歌:
```sql
update lyrics
set deleted_at = now(), updated_at = now()
where id = ...;
```
不建议物理删除,除非确认不需要审计。
更新一首歌:
```text
1. 更新 lyrics.raw_text / normalized_text / primary_text / exact_hash
2. 删除旧 lyric_lines
3. 插入新 lyric_lines
4. 整个过程放在一个事务里
```
## 15. PostgreSQL 版评测
评测仍然需要先生成测试集。测试集是“输入样本 + 期望标签”,PostgreSQL 版评测只负责用 PostgreSQL 数据库召回候选并计算指标。
如果还没有测试集,先生成:
```bash
python -m lyric_dedup.cli generate-eval-set \
--library-dir data/library \
--lyrics-dir data/generated_eval/incoming \
--csv data/generated_eval/eval_5000.csv \
--size 5000 \
--positive-ratio 0.3
```
然后跑 PostgreSQL 版评测:
```bash
python scripts/evaluate_postgres.py \
--dsn postgresql:///lyric_dedup \
--csv data/generated_eval/eval_5000.csv \
--base-dir data/generated_eval \
--out outputs/results/postgres_eval_5000.csv
```
它会:
```text
1. 对 eval 样本 normalize
2. 用 PostgreSQL exact_hash 召回
3. 用 pg_trgm primary_text 召回
4. 用 lyric_lines.line_hash 召回
5. 合并候选
6. 用 Python DuplicateChecker 对候选重新打分
7. 输出 duplicate / review / new 和指标
```
如果想把 `review` 也算作“抓到可疑样本”:
```bash
python scripts/evaluate_postgres.py \
--dsn postgresql:///lyric_dedup \
--csv data/generated_eval/eval_50000.csv \
--base-dir data/generated_eval \
--positive-decisions duplicate,review \
--out outputs/results/postgres_eval_50000_review_positive.csv
```
可调参数:
```text
--recall-limit 每类 SQL 召回最多取多少候选,默认 100
--enable-trgm 打开 pg_trgm 整段文本召回;默认关闭,避免评测过慢
--trgm-threshold pg_trgm 的 % 匹配阈值,默认 0.3,仅 --enable-trgm 时使用
--max-candidates 最终输出多少候选,默认 5
--statement-timeout-ms 单条 SQL 超时时间,默认 5000
```
注意:当前 PostgreSQL 版是原型评测脚本。默认只用 `exact_hash + lyric_lines.line_hash` 召回,速度更可控。`pg_trgm` 可以作为补充召回,但整段歌词 trigram 查询在 5 万评测集上可能很慢,建议单独开小样本验证后再用于全量。
## 16. 迁移验证标准
迁移不是导入完就结束。需要单独验证 PostgreSQL 版查重链路:
```text
1. exact duplicate 是否能查到
2. punctuation / timestamp / platform noise 正例是否能召回
3. fragment / shared chorus 负例是否不会被直接判 duplicate
4. PostgreSQL 召回候选数量是否合理
5. PostgreSQL 版 evaluate 指标是否达到业务要求
```
第一阶段目标:
```text
PostgreSQL 负责召回,Python 仍负责判定。
```
## 17. 常见问题
### 提示 `Missing dependency: psycopg`
运行:
```bash
python -m pip install 'psycopg[binary]'
```
### 连接失败
检查 PostgreSQL 是否启动:
```bash
pg_isready
```
检查数据库是否存在:
```bash
psql -l | grep lyric_dedup
```
### `pg_trgm` 创建失败
确认连接用户有创建 extension 权限。本机默认用户一般可以。
手动测试:
```bash
psql postgresql:///lyric_dedup -c 'create extension if not exists pg_trgm;'
```
### 想清空重新导入
谨慎执行:
```bash
psql postgresql:///lyric_dedup -c 'truncate lyric_lines, lyrics restart identity cascade;'
```
然后重新运行导入脚本。
## 18. 当前建议执行顺序
你现在已经完成:
```bash
createdb lyric_dedup
```
接下来执行:
```bash
python -m pip install 'psycopg[binary]'
```
```bash
python scripts/init_postgres.py \
--dsn postgresql:///lyric_dedup
```
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library \
--limit 1000
```
确认数量:
```bash
psql postgresql:///lyric_dedup -c 'select count(*) from lyrics;'
```
确认后全量导入:
```bash
python scripts/import_library_postgres.py \
--dsn postgresql:///lyric_dedup \
--lyrics-dir data/library
```
......@@ -85,15 +85,33 @@ python -m lyric_dedup.cli generate-eval-set \
--positive-ratio 0.3
```
生成器的业务口径:
默认 `--profile standard` 生成常规生产评估集。也可以生成更贴近业务边界的 hard 集:
```bash
python -m lyric_dedup.cli generate-eval-set \
--profile hard \
--library-dir data/library \
--lyrics-dir data/generated_eval/hard_incoming \
--csv data/generated_eval/eval_hard_5000.csv \
--eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \
--size 5000 \
--positive-ratio 0.3
```
standard 业务口径:
- 先扫描整个曲库,按有效歌词行数、语言类型、文件来源前缀做分层采样,不再按排序前缀取样。
- `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译。
- `应去重` 样本只生成全曲歌词的样式变化,例如时间戳、标点、平台噪声、空行、重复副歌次数变化、附加中文翻译、少量错别字/英文拼写错误
- `不应去重` 样本以真实 holdout 完整歌词为主,也包含片段歌词、重复副歌碰撞、仅翻译相似、同主题新歌词、短歌词/占位边界样本。
- 片段歌词即使命中已有歌曲的一部分,也不应该输出 `duplicate`;最多进入 `review`
- 生成器会额外写出 `--eval-index`,这个索引排除了 holdout 歌,评估生成 CSV 时应使用它。
- 同时会生成 `*.manifest.json`,记录 seed、曲库规模、holdout 数、样本类型分布、语言/来源分桶和样本来源覆盖数。
hard 业务口径不故意制造反常输入,主要覆盖上线更容易踩边界的情况:
- `应去重`: 同曲平台版本噪声、较完整歌词缺少一段、整段中文翻译附加、较真实的录入/OCR 错别字、时间戳和平台元信息混合。
- `不应去重`: 真实 holdout 新歌、从 holdout 中优先挑选和曲库有行重合的近邻新歌、较长但不完整的单曲片段、多曲 medley/串烧式片段、重复副歌碰撞、仅翻译相似、短歌词边界。
先准备一个 CSV,例如 `data/eval/eval.csv`
```csv
......
......@@ -108,6 +108,20 @@ python -m lyric_dedup.cli generate-eval-set \
--positive-ratio 0.3
```
如需生成更贴近业务边界的 hard 口径测试集:
```bash
python -m lyric_dedup.cli generate-eval-set \
--profile hard \
--library-dir data/library \
--lyrics-dir data/generated_eval/hard_incoming \
--csv data/generated_eval/eval_hard_5000.csv \
--index outputs/indexes/library_lyrics.pkl \
--eval-index data/generated_eval/eval_hard_5000.csv.index.pkl \
--size 5000 \
--positive-ratio 0.3
```
默认生产评估口径:
```text
......@@ -120,7 +134,7 @@ python -m lyric_dedup.cli generate-eval-set \
业务口径:
```text
positive_* = 应去重,全曲歌词样式变化
positive_* = 应去重,全曲歌词样式变化,包括少量错别字/英文拼写错误扰动
negative_real_holdout_full_song = 不应去重,完整真实歌词,已从评估索引中排除
negative_fragment = 不应去重,单曲片段
negative_shared_chorus = 不应去重,重复副歌碰撞
......@@ -129,6 +143,15 @@ negative_same_theme_synthetic = 不应去重,同主题新歌词
edge_short_or_placeholder = 不应去重,短歌词/占位边界样本
```
hard 口径额外强调真实业务边界,而不是故意制造反常难题:
```text
positive_realistic_variant = 应去重,同曲平台版本噪声、较完整缺段、整段翻译附加、真实录入/OCR 错
negative_near_neighbor_holdout_full_song = 不应去重,和曲库有较多行重合的真实 holdout 新歌
negative_long_fragment = 不应去重,较长但不完整的单曲片段
negative_catalog_mashup = 不应去重,多首真实歌词片段组成的串烧/混剪式输入
```
生成器会扫描整个曲库并按有效歌词行数、语言类型、文件来源前缀分层采样。它会分出一批 holdout 完整歌词作为真实新歌负样本,并生成一个排除 holdout 的评估索引。每次还会输出:
```text
......
......@@ -5,7 +5,7 @@ from __future__ import annotations
import hashlib
import pickle
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from pathlib import Path
from lyric_dedup.minhash_lsh import MinHashConfig
......@@ -16,7 +16,7 @@ from lyric_dedup.normalization import lyric_tokens
from lyric_dedup.normalization import normalize_lyrics
class DuplicateDecision(StrEnum):
class DuplicateDecision(str, Enum):
DUPLICATE = "duplicate"
REVIEW = "review"
NEW = "new"
......
......@@ -53,6 +53,12 @@ def main() -> None:
generate.add_argument("--seed", type=int, default=20260602)
generate.add_argument("--index", default="", help="optional source index path recorded in the manifest")
generate.add_argument("--eval-index", default="", help="output index built from non-holdout records for this eval set")
generate.add_argument(
"--profile",
choices=("standard", "hard"),
default="standard",
help="evaluation sample profile: standard production mix or harder business-realistic edge mix",
)
args = parser.parse_args()
if args.command == "build-index":
......@@ -80,6 +86,7 @@ def main() -> None:
seed=args.seed,
index_path=Path(args.index) if args.index else None,
eval_index_path=Path(args.eval_index) if args.eval_index else None,
profile=args.profile,
)
print(json.dumps(summary, ensure_ascii=False))
......
......@@ -21,7 +21,7 @@ from lyric_dedup.normalization import fingerprint_text
from lyric_dedup.normalization import normalize_lyrics
DEFAULT_SAMPLE_MIX = {
STANDARD_SAMPLE_MIX = {
"positive_full_duplicate": 0.30,
"negative_real_holdout_full_song": 0.40,
"negative_fragment": 0.10,
......@@ -30,6 +30,18 @@ DEFAULT_SAMPLE_MIX = {
"negative_same_theme_synthetic": 0.05,
"edge_short_or_placeholder": 0.05,
}
DEFAULT_SAMPLE_MIX = STANDARD_SAMPLE_MIX
HARD_SAMPLE_MIX = {
"positive_realistic_variant": 0.30,
"negative_real_holdout_full_song": 0.20,
"negative_near_neighbor_holdout_full_song": 0.20,
"negative_long_fragment": 0.15,
"negative_shared_chorus": 0.05,
"negative_translation_only": 0.04,
"negative_catalog_mashup": 0.04,
"edge_short_or_placeholder": 0.02,
}
def _progress(message: str) -> None:
......@@ -87,6 +99,7 @@ def generate_eval_set(
seed: int = 20260602,
index_path: Path | None = None,
eval_index_path: Path | None = None,
profile: str = "standard",
) -> dict[str, object]:
"""Generate a stratified production evaluation set.
......@@ -96,7 +109,10 @@ def generate_eval_set(
if size <= 0:
raise ValueError("size must be positive")
_progress(f"start generation: size={size}, positive_ratio={positive_ratio}, seed={seed}")
if profile not in {"standard", "hard"}:
raise ValueError("profile must be 'standard' or 'hard'")
_progress(f"start generation: profile={profile}, size={size}, positive_ratio={positive_ratio}, seed={seed}")
rng = random.Random(seed)
profiles = profile_library(library_dir)
if not profiles:
......@@ -107,9 +123,9 @@ def generate_eval_set(
_progress(f"clean output dir: {output_dir}")
_clean_generated_output_dir(output_dir)
plan = _sample_plan(size, positive_ratio=positive_ratio)
plan = _sample_plan(size, positive_ratio=positive_ratio, profile=profile)
_progress(f"sample plan: {plan}")
holdout_count = min(plan["negative_real_holdout_full_song"], max(1, len(profiles) // 2))
holdout_count = min(_holdout_plan_count(plan), max(1, len(profiles) // 2))
holdout_profiles = _stratified_unique_sample(
profiles,
holdout_count,
......@@ -122,81 +138,95 @@ def generate_eval_set(
groups = _profile_groups(indexed_profiles)
samples: list[GeneratedSample] = []
_progress("build positive_full_duplicate samples")
samples.extend(
_build_positive_samples(
_stratified_sample(groups["normal"], plan["positive_full_duplicate"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
if profile == "hard":
samples.extend(
_build_hard_samples(
plan,
groups=groups,
holdout_profiles=holdout_profiles,
indexed_profiles=indexed_profiles,
output_dir=output_dir,
csv_base=csv_path.parent,
rng=rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_real_holdout_full_song samples")
samples.extend(
_build_holdout_full_song_samples(
holdout_profiles,
output_dir,
csv_path.parent,
start_index=len(samples) + 1,
else:
_progress("build positive_full_duplicate samples")
samples.extend(
_build_positive_samples(
_stratified_sample(groups["normal"], plan["positive_full_duplicate"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_fragment samples")
samples.extend(
_build_fragment_samples(
_stratified_sample(groups["fragmentable"], plan["negative_fragment"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_real_holdout_full_song samples")
samples.extend(
_build_holdout_full_song_samples(
holdout_profiles[: plan["negative_real_holdout_full_song"]],
output_dir,
csv_path.parent,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_shared_chorus samples")
samples.extend(
_build_shared_chorus_samples(
_stratified_sample(groups["normal"], plan["negative_shared_chorus"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_fragment samples")
samples.extend(
_build_fragment_samples(
_stratified_sample(groups["fragmentable"], plan["negative_fragment"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_translation_only samples")
samples.extend(
_build_translation_only_samples(
_stratified_sample(groups["foreign"], plan["negative_translation_only"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_shared_chorus samples")
samples.extend(
_build_shared_chorus_samples(
_stratified_sample(groups["normal"], plan["negative_shared_chorus"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_same_theme_synthetic samples")
samples.extend(
_build_same_theme_synthetic_samples(
plan["negative_same_theme_synthetic"],
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_translation_only samples")
samples.extend(
_build_translation_only_samples(
_stratified_sample(groups["foreign"], plan["negative_translation_only"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build edge_short_or_placeholder samples")
samples.extend(
_build_edge_samples(
_stratified_sample(groups["edge"], plan["edge_short_or_placeholder"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
_progress(f"built samples: {len(samples)}/{size}")
_progress("build negative_same_theme_synthetic samples")
samples.extend(
_build_same_theme_synthetic_samples(
plan["negative_same_theme_synthetic"],
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
_progress(f"built samples: {len(samples)}/{size}")
_progress("build edge_short_or_placeholder samples")
samples.extend(
_build_edge_samples(
_stratified_sample(groups["edge"], plan["edge_short_or_placeholder"], rng),
output_dir,
csv_path.parent,
rng,
start_index=len(samples) + 1,
)
)
)
_progress(f"built samples: {len(samples)}/{size}")
if len(samples) < size:
......@@ -226,6 +256,7 @@ def generate_eval_set(
index_path=index_path,
eval_index_path=eval_index_path,
holdout_count=len(holdout_profiles),
profile=profile,
)
_progress("generation complete")
return manifest
......@@ -264,14 +295,16 @@ def profile_library(library_dir: Path) -> list[LyricProfile]:
return profiles
def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]:
def _sample_plan(size: int, *, positive_ratio: float, profile: str) -> dict[str, int]:
positive_ratio = max(0.0, min(1.0, positive_ratio))
mix = dict(DEFAULT_SAMPLE_MIX)
negative_total = sum(value for key, value in mix.items() if key != "positive_full_duplicate")
mix["positive_full_duplicate"] = positive_ratio
mix = dict(HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX)
positive_key = "positive_realistic_variant" if profile == "hard" else "positive_full_duplicate"
negative_total = sum(value for key, value in mix.items() if key != positive_key)
mix[positive_key] = positive_ratio
for key in list(mix):
if key != "positive_full_duplicate":
mix[key] = (1.0 - positive_ratio) * (DEFAULT_SAMPLE_MIX[key] / negative_total)
if key != positive_key:
base_mix = HARD_SAMPLE_MIX if profile == "hard" else STANDARD_SAMPLE_MIX
mix[key] = (1.0 - positive_ratio) * (base_mix[key] / negative_total)
plan = {key: int(size * value) for key, value in mix.items()}
remainder = size - sum(plan.values())
......@@ -283,6 +316,10 @@ def _sample_plan(size: int, *, positive_ratio: float) -> dict[str, int]:
return plan
def _holdout_plan_count(plan: dict[str, int]) -> int:
return plan.get("negative_real_holdout_full_song", 0) + plan.get("negative_near_neighbor_holdout_full_song", 0)
def _profile_groups(profiles: list[LyricProfile]) -> dict[str, list[LyricProfile]]:
normal = [profile for profile in profiles if profile.line_count >= 6]
edge = [profile for profile in profiles if profile.line_count <= 5]
......@@ -375,6 +412,7 @@ def _build_positive_samples(
("positive_blank_line_noise", _add_blank_line_noise(lines)),
("positive_chorus_count_changed", _change_repeated_line_counts(lines)),
("positive_translation_added", _translation_added(lines)),
("positive_typo_noise", _add_typo_noise(lines, rng)),
]
sample_type, text = variants[offset % len(variants)]
index = start_index + offset
......@@ -384,31 +422,181 @@ def _build_positive_samples(
return samples
def _build_hard_samples(
plan: dict[str, int],
*,
groups: dict[str, list[LyricProfile]],
holdout_profiles: list[LyricProfile],
indexed_profiles: list[LyricProfile],
output_dir: Path,
csv_base: Path,
rng: random.Random,
start_index: int,
) -> list[GeneratedSample]:
samples: list[GeneratedSample] = []
_progress("build positive_realistic_variant samples")
samples.extend(
_build_realistic_positive_samples(
_stratified_sample(groups["normal"], plan["positive_realistic_variant"], rng),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
)
)
_progress(f"built samples: {len(samples)}")
real_holdout_count = plan.get("negative_real_holdout_full_song", 0)
_progress("build negative_real_holdout_full_song samples")
samples.extend(
_build_holdout_full_song_samples(
holdout_profiles[:real_holdout_count],
output_dir,
csv_base,
start_index=start_index + len(samples),
)
)
_progress(f"built samples: {len(samples)}")
near_count = plan.get("negative_near_neighbor_holdout_full_song", 0)
_progress("build negative_near_neighbor_holdout_full_song samples")
near_holdouts = _near_neighbor_holdouts(
holdout_profiles[real_holdout_count:],
indexed_profiles,
near_count,
)
samples.extend(
_build_holdout_full_song_samples(
near_holdouts,
output_dir,
csv_base,
start_index=start_index + len(samples),
sample_type="negative_near_neighbor_holdout_full_song",
notes="full real holdout lyric selected for catalog line overlap with indexed songs",
)
)
_progress(f"built samples: {len(samples)}")
_progress("build negative_long_fragment samples")
samples.extend(
_build_fragment_samples(
_stratified_sample(groups["fragmentable"], plan.get("negative_long_fragment", 0), rng),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
sample_type="negative_long_fragment",
long_fragment=True,
notes="realistic long partial lyric upload, not a full-song duplicate",
)
)
_progress(f"built samples: {len(samples)}")
_progress("build negative_shared_chorus samples")
samples.extend(
_build_shared_chorus_samples(
_stratified_sample(groups["normal"], plan.get("negative_shared_chorus", 0), rng),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
)
)
_progress(f"built samples: {len(samples)}")
_progress("build negative_translation_only samples")
samples.extend(
_build_translation_only_samples(
_stratified_sample(groups["foreign"], plan.get("negative_translation_only", 0), rng),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
)
)
_progress(f"built samples: {len(samples)}")
_progress("build negative_catalog_mashup samples")
samples.extend(
_build_catalog_mashup_samples(
_stratified_sample(groups["normal"], plan.get("negative_catalog_mashup", 0) * 3, rng),
plan.get("negative_catalog_mashup", 0),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
)
)
_progress(f"built samples: {len(samples)}")
_progress("build edge_short_or_placeholder samples")
samples.extend(
_build_edge_samples(
_stratified_sample(groups["edge"], plan.get("edge_short_or_placeholder", 0), rng),
output_dir,
csv_base,
rng,
start_index=start_index + len(samples),
)
)
return samples
def _build_realistic_positive_samples(
profiles: list[LyricProfile],
output_dir: Path,
csv_base: Path,
rng: random.Random,
*,
start_index: int,
) -> list[GeneratedSample]:
samples: list[GeneratedSample] = []
for offset, profile in enumerate(profiles):
content_lines = _content_lines(profile.raw_text)
primary_lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines) or content_lines
variants = [
("positive_platform_mixed_noise", _platform_mixed_noise(content_lines, rng)),
("positive_near_full_missing_section", _near_full_missing_section(primary_lines, rng)),
("positive_block_translation_added", _block_translation_added(primary_lines)),
("positive_typo_and_punctuation_noise", _stronger_typo_and_punctuation_noise(content_lines, rng)),
("positive_timestamped_platform_variant", _timestamped_platform_variant(content_lines)),
("positive_chorus_count_changed", _change_repeated_line_counts(content_lines)),
]
sample_type, text = variants[offset % len(variants)]
index = start_index + offset
path = _write_sample_file(output_dir, f"pos_{index:05d}_{sample_type}.txt", text)
samples.append(_sample_from_profile(index, path, csv_base, "应去重", sample_type, profile))
_progress_count("positive_realistic_variant", len(samples), len(profiles))
return samples
def _build_holdout_full_song_samples(
profiles: list[LyricProfile],
output_dir: Path,
csv_base: Path,
*,
start_index: int,
sample_type: str = "negative_real_holdout_full_song",
notes: str = "full real lyric held out from the generated eval index",
) -> list[GeneratedSample]:
_progress("build negative_real_holdout_full_song samples")
samples: list[GeneratedSample] = []
for offset, profile in enumerate(profiles):
index = start_index + offset
text = profile.raw_text
path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_real_holdout_full_song.txt", text)
path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text)
samples.append(
_sample_from_profile(
index,
path,
csv_base,
"不应去重",
"negative_real_holdout_full_song",
sample_type,
profile,
notes="full real lyric held out from the generated eval index",
notes=notes,
)
)
_progress_count("negative_real_holdout_full_song", len(samples), len(profiles))
_progress_count(sample_type, len(samples), len(profiles))
return samples
......@@ -446,25 +634,59 @@ def _build_fragment_samples(
rng: random.Random,
*,
start_index: int,
sample_type: str = "negative_fragment",
long_fragment: bool = False,
notes: str = "partial lyric fragment only",
) -> list[GeneratedSample]:
samples: list[GeneratedSample] = []
for offset, profile in enumerate(profiles):
lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines)
text = _single_song_fragment(lines, rng)
text = _long_song_fragment(lines, rng) if long_fragment else _single_song_fragment(lines, rng)
index = start_index + offset
path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_fragment.txt", text)
path = _write_sample_file(output_dir, f"neg_{index:05d}_{sample_type}.txt", text)
samples.append(
_sample_from_profile(
index,
path,
csv_base,
"不应去重",
"negative_fragment",
sample_type,
profile,
notes="partial lyric fragment only",
notes=notes,
)
)
_progress_count(sample_type, len(samples), len(profiles))
return samples
def _build_catalog_mashup_samples(
profiles: list[LyricProfile],
count: int,
output_dir: Path,
csv_base: Path,
rng: random.Random,
*,
start_index: int,
) -> list[GeneratedSample]:
samples: list[GeneratedSample] = []
if count <= 0 or not profiles:
return samples
for offset in range(count):
index = start_index + offset
picked = rng.sample(profiles, k=min(3, len(profiles)))
text = _catalog_mashup_text(picked, rng)
path = _write_sample_file(output_dir, f"neg_{index:05d}_negative_catalog_mashup.txt", text)
samples.append(
GeneratedSample(
sample_id=f"sample-{index:05d}",
file=str(path.relative_to(csv_base)),
expected="不应去重",
sample_type="negative_catalog_mashup",
source=" | ".join(str(profile.path) for profile in picked),
notes="medley-style partial lyric assembled from multiple catalog songs",
)
)
_progress_count("negative_fragment", len(samples), len(profiles))
_progress_count("negative_catalog_mashup", len(samples), count)
return samples
......@@ -658,8 +880,10 @@ def _write_manifest(
index_path: Path | None,
eval_index_path: Path,
holdout_count: int,
profile: str,
) -> dict[str, object]:
manifest = {
"profile": profile,
"seed": seed,
"library_files": len(profiles),
"sample_size": len(samples),
......@@ -684,6 +908,38 @@ def _write_manifest(
return manifest
def _near_neighbor_holdouts(
holdout_profiles: list[LyricProfile],
indexed_profiles: list[LyricProfile],
count: int,
) -> list[LyricProfile]:
if count <= 0 or not holdout_profiles:
return []
if not indexed_profiles:
return holdout_profiles[:count]
line_to_indexed_count: Counter[str] = Counter()
for profile in indexed_profiles:
for line in set(profile.normalized.primary_lines or profile.normalized.unique_lines):
if len(line) >= 4:
line_to_indexed_count[line] += 1
scored: list[tuple[float, LyricProfile]] = []
for profile in holdout_profiles:
lines = set(profile.normalized.primary_lines or profile.normalized.unique_lines)
useful_lines = {line for line in lines if len(line) >= 4}
if not useful_lines:
score = 0.0
else:
shared = sum(1 for line in useful_lines if line_to_indexed_count[line] > 0)
common_weight = sum(min(line_to_indexed_count[line], 5) for line in useful_lines)
score = (shared / len(useful_lines)) + (common_weight / (len(useful_lines) * 20))
scored.append((score, profile))
scored.sort(key=lambda item: item[0], reverse=True)
return [profile for _, profile in scored[:count]]
def _content_lines(text: str) -> list[str]:
lines = [line.strip() for line in text.splitlines() if line.strip()]
return lines or [text.strip()]
......@@ -735,6 +991,18 @@ def _add_timestamps(lines: list[str]) -> str:
return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1))
def _platform_mixed_noise(lines: list[str], rng: random.Random) -> str:
noisy = _add_blank_line_noise(lines).splitlines()
if noisy:
noisy = _add_punctuation_noise(noisy, rng).splitlines()
return "\n".join(["作词:未知", "歌词来自平台同步", *noisy, "未经著作权人许可 不得商业使用"])
def _timestamped_platform_variant(lines: list[str]) -> str:
timestamped = _add_timestamps(lines).splitlines()
return "\n".join(["[00:00.00]歌词贡献者:用户上传", *timestamped])
def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str:
marks = ["!", "?", "...", ",", "。"]
return "\n".join(f"{line}{rng.choice(marks)}" for line in lines)
......@@ -773,6 +1041,97 @@ def _translation_added(lines: list[str]) -> str:
return "\n".join(result)
def _block_translation_added(lines: list[str]) -> str:
body = "\n".join(lines)
translation_count = min(8, max(4, len(lines) // 4))
translations = [_pseudo_translation(index) for index in range(1, translation_count + 1)]
return "\n".join([body, "", *translations])
def _near_full_missing_section(lines: list[str], rng: random.Random) -> str:
if len(lines) <= 8:
return "\n".join(lines)
drop_count = max(1, min(max(1, len(lines) // 5), 8))
start = rng.randrange(0, max(1, len(lines) - drop_count + 1))
kept = lines[:start] + lines[start + drop_count :]
return "\n".join(kept or lines)
def _add_typo_noise(lines: list[str], rng: random.Random) -> str:
if not lines:
return ""
result = list(lines)
editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)]
if not editable_indexes:
return "\n".join(result)
typo_count = max(1, min(4, len(editable_indexes) // 8 or 1))
for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))):
result[index] = _typo_line(result[index], rng)
return "\n".join(result)
def _stronger_typo_and_punctuation_noise(lines: list[str], rng: random.Random) -> str:
if not lines:
return ""
result = _add_punctuation_noise(lines, rng).splitlines()
editable_indexes = [index for index, line in enumerate(result) if _can_typo_line(line)]
typo_count = max(1, min(8, len(editable_indexes) // 6 or 1))
for index in rng.sample(editable_indexes, k=min(typo_count, len(editable_indexes))):
result[index] = _typo_line(result[index], rng)
return "\n".join(result)
def _can_typo_line(line: str) -> bool:
return bool(re.search(r"[A-Za-z]{4,}|[\u4e00-\u9fff]{4,}", line))
def _typo_line(line: str, rng: random.Random) -> str:
words = list(re.finditer(r"[A-Za-z]{4,}", line))
if words and rng.random() < 0.65:
match = rng.choice(words)
typo = _typo_english_word(match.group(0), rng)
return line[: match.start()] + typo + line[match.end() :]
cjk_positions = [index for index, char in enumerate(line) if "\u4e00" <= char <= "\u9fff"]
if cjk_positions:
index = rng.choice(cjk_positions)
return line[:index] + _typo_cjk_char(line[index]) + line[index + 1 :]
return line
def _typo_english_word(word: str, rng: random.Random) -> str:
if len(word) <= 4 or rng.random() < 0.55:
remove_at = rng.randrange(1, max(2, len(word) - 1))
return word[:remove_at] + word[remove_at + 1 :]
swap_at = rng.randrange(1, max(2, len(word) - 2))
chars = list(word)
chars[swap_at], chars[swap_at + 1] = chars[swap_at + 1], chars[swap_at]
return "".join(chars)
def _typo_cjk_char(char: str) -> str:
replacements = {
"你": "妳",
"爱": "爰",
"夜": "液",
"里": "裏",
"风": "凤",
"雨": "兩",
"听": "昕",
"说": "説",
"想": "相",
"梦": "夣",
"心": "芯",
"光": "先",
"城": "诚",
"远": "迩",
"回": "囬",
"走": "赱",
"海": "毎",
"天": "夭",
}
return replacements.get(char, char)
def _single_song_fragment(lines: list[str], rng: random.Random) -> str:
if len(lines) <= 4:
return "\n".join(lines[: max(1, len(lines) // 2)])
......@@ -781,6 +1140,28 @@ def _single_song_fragment(lines: list[str], rng: random.Random) -> str:
return "\n".join(lines[start : start + fragment_len])
def _long_song_fragment(lines: list[str], rng: random.Random) -> str:
if len(lines) <= 8:
return _single_song_fragment(lines, rng)
fragment_len = max(6, min(len(lines) - 1, int(len(lines) * rng.uniform(0.35, 0.60))))
start = rng.randrange(0, max(1, len(lines) - fragment_len + 1))
return "\n".join(lines[start : start + fragment_len])
def _catalog_mashup_text(profiles: list[LyricProfile], rng: random.Random) -> str:
sections: list[str] = []
for profile in profiles:
lines = list(profile.normalized.primary_lines or profile.normalized.unique_lines)
if not lines:
continue
section_len = min(max(2, len(lines) // 8), 5)
start = rng.randrange(0, max(1, len(lines) - section_len + 1))
sections.extend(lines[start : start + section_len])
if not sections:
return _same_theme_synthetic(0, rng)
return "\n".join(sections)
def _short_shared_snippet(lines: list[str], rng: random.Random) -> str:
snippet = rng.sample(lines, k=min(2, len(lines))) if lines else []
synthetic = [
......
# Test runner
pytest>=8.0
# PostgreSQL storage prototype
psycopg[binary]>=3.2
# Existing MySQL/COS lyric download utilities
pymysql>=1.1
cos-python-sdk-v5>=1.9
tqdm>=4.66
"""Evaluate lyric duplicate checking with PostgreSQL-backed candidate recall."""
from __future__ import annotations
import argparse
import csv
import hashlib
import json
import sys
import time
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from lyric_dedup.checker import DuplicateChecker
from lyric_dedup.checker import LyricRecord
from lyric_dedup.file_import import read_lyric_file
from lyric_dedup.file_import import record_from_file
from lyric_dedup.normalization import fingerprint_text
from lyric_dedup.normalization import normalize_lyrics
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate duplicate checking using PostgreSQL recall.")
parser.add_argument("--dsn", required=True)
parser.add_argument("--csv", required=True)
parser.add_argument("--out", required=True)
parser.add_argument("--base-dir", default="")
parser.add_argument("--positive-decisions", default="duplicate")
parser.add_argument("--max-candidates", type=int, default=5)
parser.add_argument("--recall-limit", type=int, default=100)
parser.add_argument("--enable-trgm", action="store_true", help="Enable pg_trgm full-text recall. Slower; exact + line recall is used by default.")
parser.add_argument("--trgm-threshold", type=float, default=0.3)
parser.add_argument("--statement-timeout-ms", type=int, default=5000)
parser.add_argument("--profile-every", type=int, default=100)
args = parser.parse_args()
psycopg = _import_psycopg()
csv_path = Path(args.csv)
out_path = Path(args.out)
base_dir = Path(args.base_dir) if args.base_dir else None
positive_decisions = {item.strip() for item in args.positive_decisions.split(",") if item.strip()}
total = _csv_data_row_count(csv_path)
rows: list[dict[str, object]] = []
profile_stats = _new_profile_stats()
out_path.parent.mkdir(parents=True, exist_ok=True)
_progress(f"evaluate postgres csv: 0/{total}")
with psycopg.connect(args.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute("select set_config('statement_timeout', %s, false)", (str(args.statement_timeout_ms),))
cursor.execute("select set_config('pg_trgm.similarity_threshold', %s, false)", (str(args.trgm_threshold),))
with csv_path.open(encoding="utf-8-sig", newline="") as in_file, out_path.open(
"w", encoding="utf-8", newline=""
) as out_file:
reader = csv.DictReader(in_file)
if reader.fieldnames is None:
raise ValueError("评估 CSV 需要表头")
writer = csv.DictWriter(out_file, fieldnames=_fieldnames())
writer.writeheader()
for index, row in enumerate(reader, start=1):
row_out = _evaluate_row(
conn,
row,
row_number=index + 1,
csv_path=csv_path,
base_dir=base_dir,
positive_decisions=positive_decisions,
max_candidates=args.max_candidates,
recall_limit=args.recall_limit,
enable_trgm=args.enable_trgm,
)
rows.append(row_out)
writer.writerow(row_out)
_progress_count("evaluate postgres csv", index, total, step=10)
_update_profile_stats(profile_stats, row_out)
if args.profile_every > 0 and index % args.profile_every == 0:
_progress(_format_profile_stats(profile_stats, index))
summary = _evaluation_summary(rows, positive_decisions=positive_decisions, out_path=out_path)
summary_path = out_path.with_suffix(out_path.suffix + ".summary.json")
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
_progress("postgres evaluation complete")
print(json.dumps(summary, ensure_ascii=False))
def _evaluate_row(
conn: Any,
row: dict[str, str],
*,
row_number: int,
csv_path: Path,
base_dir: Path | None,
positive_decisions: set[str],
max_candidates: int,
recall_limit: int,
enable_trgm: bool,
) -> dict[str, object]:
parse_started = time.perf_counter()
sample_id = row.get("id") or row.get("sample_id") or str(row_number)
record, source = _record_from_eval_row(row, csv_path=csv_path, base_dir=base_dir)
expected_duplicate = _parse_expected(row.get("expected") or row.get("label") or row.get("target"))
parse_ms = round((time.perf_counter() - parse_started) * 1000, 2)
candidates, timings = _recall_candidates(
conn,
record,
recall_limit=recall_limit,
enable_trgm=enable_trgm,
exclude_record_ids=_exclude_record_ids_for_eval_row(row),
)
rank_started = time.perf_counter()
result = _check_against_candidates(record, candidates, max_candidates=max_candidates)
rank_ms = round((time.perf_counter() - rank_started) * 1000, 2)
recall_ms = round(timings["exact_ms"] + timings["trgm_ms"] + timings["line_ms"], 2)
predicted_duplicate = result.decision.value in positive_decisions
best = result.candidates[0] if result.candidates else None
return {
"id": sample_id,
"source": source,
"expected_duplicate": expected_duplicate,
"decision": result.decision.value,
"predicted_duplicate": predicted_duplicate,
"correct": expected_duplicate == predicted_duplicate,
"confidence": result.confidence,
"reason": result.reason,
"candidate_count": len(candidates),
"parse_ms": parse_ms,
"recall_ms": recall_ms,
"exact_ms": timings["exact_ms"],
"trgm_ms": timings["trgm_ms"],
"line_ms": timings["line_ms"],
"rank_ms": rank_ms,
"best_candidate_id": best.record_id if best else "",
"best_candidate_decision": best.decision.value if best else "",
"best_candidate_confidence": best.confidence if best else "",
"best_candidate_jaccard": best.jaccard if best else "",
"best_candidate_line_coverage": best.line_coverage if best else "",
"best_candidate_primary_jaccard": best.primary_jaccard if best else "",
"best_candidate_primary_line_coverage": best.primary_line_coverage if best else "",
"best_candidate_translation_jaccard": best.translation_jaccard if best else "",
"best_candidate_translation_line_coverage": best.translation_line_coverage if best else "",
"best_candidate_reason": best.reason if best else "",
"matched_unique_lines": " | ".join(best.matched_unique_lines) if best else "",
}
def _recall_candidates(
conn: Any,
record: LyricRecord,
*,
recall_limit: int,
enable_trgm: bool,
exclude_record_ids: list[str],
) -> tuple[list[LyricRecord], dict[str, float]]:
query_lyrics = _pg_text(record.lyrics) or ""
normalized = normalize_lyrics(query_lyrics)
exact_text = fingerprint_text(normalized)
exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
primary_text = "\n".join(normalized.primary_lines)
line_hashes = [hashlib.sha256(line.encode("utf-8")).hexdigest() for line in normalized.primary_lines if line]
candidates: dict[str, LyricRecord] = {}
timings = {"exact_ms": 0.0, "trgm_ms": 0.0, "line_ms": 0.0}
with conn.cursor() as cursor:
started = time.perf_counter()
cursor.execute(
"""
select record_id, raw_text, title, artist
from lyrics
where deleted_at is null
and exact_hash = %s
and not (record_id = any(%s))
limit %s
""",
(exact_hash, exclude_record_ids, recall_limit),
)
_add_rows(candidates, cursor.fetchall())
timings["exact_ms"] = round((time.perf_counter() - started) * 1000, 2)
if enable_trgm and primary_text:
started = time.perf_counter()
cursor.execute(
"""
select record_id, raw_text, title, artist
from lyrics
where deleted_at is null
and not (record_id = any(%s))
and primary_text %% %s
order by similarity(primary_text, %s) desc
limit %s
""",
(exclude_record_ids, primary_text, primary_text, recall_limit),
)
_add_rows(candidates, cursor.fetchall())
timings["trgm_ms"] = round((time.perf_counter() - started) * 1000, 2)
if line_hashes:
started = time.perf_counter()
cursor.execute(
"""
select l.record_id, l.raw_text, l.title, l.artist
from lyric_lines ll
join lyrics l on l.id = ll.lyric_id
where l.deleted_at is null
and not (l.record_id = any(%s))
and ll.role = 'primary'
and ll.line_hash = any(%s)
group by l.id
order by count(*) desc
limit %s
""",
(exclude_record_ids, line_hashes, recall_limit),
)
_add_rows(candidates, cursor.fetchall())
timings["line_ms"] = round((time.perf_counter() - started) * 1000, 2)
return list(candidates.values()), timings
def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]:
if row.get("sample_type") == "negative_real_holdout_full_song" and row.get("source_record_id"):
return [row["source_record_id"]]
return []
def _add_rows(candidates: dict[str, LyricRecord], rows: list[tuple[object, ...]]) -> None:
for record_id, raw_text, title, artist in rows:
candidates.setdefault(
str(record_id),
LyricRecord(
record_id=str(record_id),
lyrics=str(raw_text),
title=str(title) if title is not None else None,
artist=str(artist) if artist is not None else None,
),
)
def _check_against_candidates(
record: LyricRecord,
candidates: list[LyricRecord],
*,
max_candidates: int,
):
checker = DuplicateChecker()
for candidate in candidates:
checker.add_record(candidate)
return checker.check_record(record, max_candidates=max_candidates)
def _record_from_eval_row(row: dict[str, str], *, csv_path: Path, base_dir: Path | None) -> tuple[LyricRecord, str]:
lyrics = (row.get("lyrics") or "").strip()
if lyrics:
return (
LyricRecord(
record_id=row.get("id") or row.get("sample_id") or "__eval__",
lyrics=_pg_text(lyrics.replace("\\n", "\n")) or "",
title=_pg_text(row.get("title") or None),
artist=_pg_text(row.get("artist") or None),
),
"inline",
)
file_value = (row.get("file") or row.get("path") or row.get("source") or "").strip()
if not file_value:
raise ValueError("评估 CSV 每行需要 lyrics,或 file/path/source 文件路径")
file_path = Path(file_value)
if not file_path.is_absolute():
file_path = (base_dir or csv_path.parent) / file_path
record = record_from_file(file_path)
record = LyricRecord(
record_id=record.record_id,
lyrics=_pg_text(record.lyrics) or "",
title=_pg_text(record.title),
artist=_pg_text(record.artist),
)
if row.get("title") or row.get("artist"):
record = LyricRecord(
record_id=record.record_id,
lyrics=record.lyrics,
title=_pg_text(row.get("title") or record.title),
artist=_pg_text(row.get("artist") or record.artist),
)
return record, str(file_path)
def _parse_expected(value: str | None) -> bool:
if value is None:
raise ValueError("评估 CSV 每行需要 expected/label/target 列")
normalized = value.strip().lower()
positives = {"1", "true", "yes", "y", "duplicate", "dup", "重复", "应去重", "去重", "是"}
negatives = {"0", "false", "no", "n", "new", "not_duplicate", "non_duplicate", "不重复", "不应去重", "新歌", "否"}
if normalized in positives:
return True
if normalized in negatives:
return False
raise ValueError(f"无法识别 expected 值: {value!r}")
def _evaluation_summary(
rows: list[dict[str, object]],
*,
positive_decisions: set[str],
out_path: Path,
) -> dict[str, object]:
tp = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is True)
fp = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is True)
tn = sum(1 for row in rows if row["expected_duplicate"] is False and row["predicted_duplicate"] is False)
fn = sum(1 for row in rows if row["expected_duplicate"] is True and row["predicted_duplicate"] is False)
total = len(rows)
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
accuracy = (tp + tn) / total if total else 0.0
f1 = (2 * precision * recall / (precision + recall)) if precision + recall else 0.0
return {
"total": total,
"positive_decisions": sorted(positive_decisions),
"accuracy": round(accuracy, 4),
"precision": round(precision, 4),
"recall": round(recall, 4),
"f1": round(f1, 4),
"true_positive": tp,
"false_positive": fp,
"true_negative": tn,
"false_negative": fn,
"duplicate": sum(1 for row in rows if row["decision"] == "duplicate"),
"review": sum(1 for row in rows if row["decision"] == "review"),
"new": sum(1 for row in rows if row["decision"] == "new"),
"out": str(out_path),
"summary": str(out_path.with_suffix(out_path.suffix + ".summary.json")),
}
def _fieldnames() -> list[str]:
return [
"id",
"source",
"expected_duplicate",
"decision",
"predicted_duplicate",
"correct",
"confidence",
"reason",
"candidate_count",
"parse_ms",
"recall_ms",
"exact_ms",
"trgm_ms",
"line_ms",
"rank_ms",
"best_candidate_id",
"best_candidate_decision",
"best_candidate_confidence",
"best_candidate_jaccard",
"best_candidate_line_coverage",
"best_candidate_primary_jaccard",
"best_candidate_primary_line_coverage",
"best_candidate_translation_jaccard",
"best_candidate_translation_line_coverage",
"best_candidate_reason",
"matched_unique_lines",
]
def _csv_data_row_count(csv_path: Path) -> int:
with csv_path.open(encoding="utf-8-sig", newline="") as file:
reader = csv.reader(file)
next(reader, None)
return sum(1 for _ in reader)
def _progress(message: str) -> None:
print(f"[pg-eval] {message}", file=sys.stderr, flush=True)
def _progress_count(label: str, current: int, total: int, *, step: int = 1000) -> None:
if total <= 0:
return
if current == 1 or current == total or current % step == 0:
_progress(f"{label}: {current}/{total}")
def _new_profile_stats() -> dict[str, float]:
return {
"parse_ms": 0.0,
"exact_ms": 0.0,
"trgm_ms": 0.0,
"line_ms": 0.0,
"rank_ms": 0.0,
"recall_ms": 0.0,
"candidate_count": 0.0,
}
def _update_profile_stats(stats: dict[str, float], row: dict[str, object]) -> None:
for key in stats:
try:
stats[key] += float(row.get(key) or 0)
except (TypeError, ValueError):
pass
def _format_profile_stats(stats: dict[str, float], count: int) -> str:
if count <= 0:
return "profile: no rows"
return (
"profile avg "
f"parse={stats['parse_ms'] / count:.2f}ms "
f"exact={stats['exact_ms'] / count:.2f}ms "
f"line={stats['line_ms'] / count:.2f}ms "
f"trgm={stats['trgm_ms'] / count:.2f}ms "
f"rank={stats['rank_ms'] / count:.2f}ms "
f"recall={stats['recall_ms'] / count:.2f}ms "
f"candidates={stats['candidate_count'] / count:.1f}"
)
def _pg_text(value: str | None) -> str | None:
if value is None:
return None
return value.replace("\x00", "")
def _import_psycopg():
try:
import psycopg
return psycopg
except ModuleNotFoundError:
print(
"Missing dependency: psycopg. Install it with:\n"
" python -m pip install 'psycopg[binary]'",
file=sys.stderr,
)
raise SystemExit(1)
if __name__ == "__main__":
main()
"""Import normalized lyric library records into PostgreSQL."""
from __future__ import annotations
import argparse
import csv
import hashlib
import sys
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from lyric_dedup.file_import import iter_lyric_files
from lyric_dedup.file_import import record_from_file
from lyric_dedup.normalization import fingerprint_text
from lyric_dedup.normalization import normalize_lyrics
def main() -> None:
parser = argparse.ArgumentParser(description="Import lyric library into PostgreSQL.")
parser.add_argument("--dsn", required=True)
parser.add_argument("--lyrics-dir", required=True)
parser.add_argument("--batch-size", type=int, default=500)
parser.add_argument("--limit", type=int, default=0)
parser.add_argument("--skip-dedup-exact", action="store_true", help="Skip exact-hash duplicate soft deletion after import.")
parser.add_argument("--duplicate-report", default="outputs/results/postgres_exact_duplicates.csv")
parser.add_argument("--line-duplicate-report", default="", help="Optional CSV report for high line-coverage duplicate candidates.")
parser.add_argument("--line-coverage-threshold", type=float, default=0.95)
parser.add_argument("--line-duplicate-limit", type=int, default=10000)
args = parser.parse_args()
psycopg = _import_psycopg()
lyrics_dir = Path(args.lyrics_dir)
paths = iter_lyric_files(lyrics_dir)
if args.limit > 0:
paths = paths[: args.limit]
print(f"[pg-import] files: {len(paths)}", file=sys.stderr, flush=True)
imported = 0
exact_deleted = 0
line_reported = 0
nul_cleaned = 0
with psycopg.connect(args.dsn) as conn:
for start in range(0, len(paths), args.batch_size):
batch = paths[start : start + args.batch_size]
with conn.transaction():
with conn.cursor() as cursor:
for path in batch:
lyric_id, line_rows, cleaned = _upsert_lyric(cursor, path, lyrics_dir)
nul_cleaned += cleaned
cursor.execute("delete from lyric_lines where lyric_id = %s", (lyric_id,))
if line_rows:
cursor.executemany(
"""
insert into lyric_lines
(lyric_id, role, line_no, normalized_line, line_hash)
values (%s, %s, %s, %s, %s)
""",
line_rows,
)
imported += 1
_progress("import", imported, len(paths), step=args.batch_size)
if not args.skip_dedup_exact:
exact_deleted = _soft_delete_exact_duplicates(conn, Path(args.duplicate_report))
if args.line_duplicate_report:
line_reported = _write_line_duplicate_report(
conn,
Path(args.line_duplicate_report),
threshold=args.line_coverage_threshold,
limit=args.line_duplicate_limit,
)
print(
{
"imported": imported,
"records_with_nul_cleaned": nul_cleaned,
"exact_duplicates_soft_deleted": exact_deleted,
"line_duplicate_candidates_reported": line_reported,
}
)
def _upsert_lyric(cursor: Any, path: Path, lyrics_dir: Path) -> tuple[int, list[tuple[object, ...]], int]:
record = record_from_file(path, base_dir=lyrics_dir)
raw_text, raw_cleaned = _pg_text(record.lyrics)
normalized = normalize_lyrics(raw_text)
primary_text = _pg_text("\n".join(normalized.primary_lines))[0]
translation_text = _pg_text("\n".join(normalized.translation_lines))[0] or None
normalized_text = _pg_text(normalized.normalized_full_text)[0]
exact_text = fingerprint_text(normalized)
exact_hash = hashlib.sha256(exact_text.encode("utf-8")).hexdigest()
cursor.execute(
"""
insert into lyrics (
record_id, source_path, title, artist, raw_text, normalized_text,
primary_text, translation_text, exact_hash, split_confidence,
split_reason, line_count, updated_at, deleted_at
)
values (
%(record_id)s, %(source_path)s, %(title)s, %(artist)s, %(raw_text)s,
%(normalized_text)s, %(primary_text)s, %(translation_text)s,
%(exact_hash)s, %(split_confidence)s, %(split_reason)s,
%(line_count)s, now(), null
)
on conflict (record_id) do update set
source_path = excluded.source_path,
title = excluded.title,
artist = excluded.artist,
raw_text = excluded.raw_text,
normalized_text = excluded.normalized_text,
primary_text = excluded.primary_text,
translation_text = excluded.translation_text,
exact_hash = excluded.exact_hash,
split_confidence = excluded.split_confidence,
split_reason = excluded.split_reason,
line_count = excluded.line_count,
updated_at = now(),
deleted_at = null
returning id
""",
{
"record_id": record.record_id,
"source_path": str(path),
"title": _pg_text(record.title)[0],
"artist": _pg_text(record.artist)[0],
"raw_text": raw_text,
"normalized_text": normalized_text,
"primary_text": primary_text,
"translation_text": translation_text,
"exact_hash": exact_hash,
"split_confidence": _pg_text(normalized.split_confidence)[0],
"split_reason": _pg_text(normalized.split_reason)[0],
"line_count": len(normalized.primary_lines or normalized.unique_lines),
},
)
lyric_id = cursor.fetchone()[0]
line_rows: list[tuple[object, ...]] = []
line_rows.extend(_line_rows(lyric_id, "primary", normalized.primary_lines))
line_rows.extend(_line_rows(lyric_id, "translation", normalized.translation_lines))
line_rows.extend(_line_rows(lyric_id, "unknown", normalized.unknown_lines))
return lyric_id, line_rows, int(raw_cleaned)
def _line_rows(lyric_id: int, role: str, lines: tuple[str, ...]) -> list[tuple[object, ...]]:
rows: list[tuple[object, ...]] = []
for index, line in enumerate(lines):
line = _pg_text(line)[0] or ""
line_hash = hashlib.sha256(line.encode("utf-8")).hexdigest()
rows.append((lyric_id, role, index, line, line_hash))
return rows
def _pg_text(value: str | None) -> tuple[str | None, bool]:
if value is None:
return None, False
if "\x00" not in value:
return value, False
return value.replace("\x00", ""), True
def _soft_delete_exact_duplicates(conn: Any, report_path: Path) -> int:
print("[pg-import] deduplicate exact_hash duplicates", file=sys.stderr, flush=True)
with conn.transaction():
with conn.cursor() as cursor:
cursor.execute(
"""
with ranked as (
select
id,
exact_hash,
first_value(id) over (
partition by exact_hash
order by
case when source_path like '%/None_%' then 1 else 0 end,
line_count desc,
length(primary_text) desc,
id
) as kept_id,
row_number() over (
partition by exact_hash
order by
case when source_path like '%/None_%' then 1 else 0 end,
line_count desc,
length(primary_text) desc,
id
) as rn
from lyrics
where deleted_at is null
),
to_delete as (
select id, exact_hash, kept_id
from ranked
where rn > 1
),
updated as (
update lyrics l
set deleted_at = now(), updated_at = now()
from to_delete d
where l.id = d.id
returning
l.id as duplicate_id,
l.record_id as duplicate_record_id,
l.source_path as duplicate_source_path,
d.exact_hash,
d.kept_id
)
select
u.duplicate_id,
u.duplicate_record_id,
u.duplicate_source_path,
k.id as kept_id,
k.record_id as kept_record_id,
k.source_path as kept_source_path,
u.exact_hash
from updated u
join lyrics k on k.id = u.kept_id
order by u.exact_hash, u.duplicate_id
"""
)
rows = cursor.fetchall()
_write_rows(
report_path,
[
"duplicate_id",
"duplicate_record_id",
"duplicate_source_path",
"kept_id",
"kept_record_id",
"kept_source_path",
"exact_hash",
],
rows,
)
print(f"[pg-import] exact duplicates soft-deleted: {len(rows)}", file=sys.stderr, flush=True)
return len(rows)
def _write_line_duplicate_report(conn: Any, report_path: Path, *, threshold: float, limit: int) -> int:
print("[pg-import] report high line-coverage duplicate candidates", file=sys.stderr, flush=True)
with conn.cursor() as cursor:
cursor.execute(
"""
with pairs as (
select
a.lyric_id as left_id,
b.lyric_id as right_id,
count(*) as matched_lines
from lyric_lines a
join lyric_lines b
on a.line_hash = b.line_hash
and a.lyric_id < b.lyric_id
join lyrics la on la.id = a.lyric_id and la.deleted_at is null
join lyrics lb on lb.id = b.lyric_id and lb.deleted_at is null
where a.role = 'primary'
and b.role = 'primary'
group by a.lyric_id, b.lyric_id
)
select
p.left_id,
l1.record_id as left_record_id,
l1.source_path as left_source_path,
p.right_id,
l2.record_id as right_record_id,
l2.source_path as right_source_path,
p.matched_lines,
l1.line_count as left_line_count,
l2.line_count as right_line_count,
p.matched_lines::float / greatest(l1.line_count, l2.line_count) as line_coverage
from pairs p
join lyrics l1 on l1.id = p.left_id
join lyrics l2 on l2.id = p.right_id
where p.matched_lines::float / greatest(l1.line_count, l2.line_count) >= %s
order by line_coverage desc, matched_lines desc
limit %s
""",
(threshold, limit),
)
rows = cursor.fetchall()
_write_rows(
report_path,
[
"left_id",
"left_record_id",
"left_source_path",
"right_id",
"right_record_id",
"right_source_path",
"matched_lines",
"left_line_count",
"right_line_count",
"line_coverage",
],
rows,
)
print(f"[pg-import] line duplicate candidates reported: {len(rows)}", file=sys.stderr, flush=True)
return len(rows)
def _write_rows(report_path: Path, fieldnames: list[str], rows: list[tuple[object, ...]]) -> None:
report_path.parent.mkdir(parents=True, exist_ok=True)
with report_path.open("w", encoding="utf-8", newline="") as file:
writer = csv.writer(file)
writer.writerow(fieldnames)
writer.writerows(rows)
def _progress(label: str, current: int, total: int, *, step: int) -> None:
if current == total or current % step == 0:
print(f"[pg-import] {label}: {current}/{total}", file=sys.stderr, flush=True)
def _import_psycopg():
try:
import psycopg
return psycopg
except ModuleNotFoundError:
print(
"Missing dependency: psycopg. Install it with:\n"
" python -m pip install 'psycopg[binary]'",
file=sys.stderr,
)
raise SystemExit(1)
if __name__ == "__main__":
main()
"""Initialize PostgreSQL schema for lyric dedup storage."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SCHEMA_PATH = PROJECT_ROOT / "scripts" / "postgres_schema.sql"
def main() -> None:
parser = argparse.ArgumentParser(description="Initialize PostgreSQL schema for lyric dedup.")
parser.add_argument("--dsn", required=True, help="PostgreSQL DSN, e.g. postgresql://user:pass@localhost:5432/lyric_dedup")
parser.add_argument("--schema", default=str(SCHEMA_PATH))
args = parser.parse_args()
psycopg = _import_psycopg()
schema_sql = Path(args.schema).read_text(encoding="utf-8")
with psycopg.connect(args.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(schema_sql)
conn.commit()
print(f"initialized schema from {args.schema}")
def _import_psycopg():
try:
import psycopg
return psycopg
except ModuleNotFoundError:
print(
"Missing dependency: psycopg. Install it with:\n"
" python -m pip install 'psycopg[binary]'",
file=sys.stderr,
)
raise SystemExit(1)
if __name__ == "__main__":
main()
create extension if not exists pg_trgm;
create table if not exists lyrics (
id bigserial primary key,
record_id text not null unique,
source_path text not null,
title text,
artist text,
raw_text text not null,
normalized_text text not null,
primary_text text not null,
translation_text text,
exact_hash text not null,
split_confidence text,
split_reason text,
line_count integer not null,
created_at timestamptz not null default now(),
updated_at timestamptz not null default now(),
deleted_at timestamptz
);
create index if not exists lyrics_exact_hash_idx
on lyrics (exact_hash)
where deleted_at is null;
create index if not exists lyrics_primary_text_trgm_idx
on lyrics using gin (primary_text gin_trgm_ops);
create table if not exists lyric_lines (
lyric_id bigint not null references lyrics(id) on delete cascade,
role text not null,
line_no integer not null,
normalized_line text not null,
line_hash text not null,
primary key (lyric_id, role, line_no)
);
create index if not exists lyric_lines_hash_idx
on lyric_lines (line_hash);
create index if not exists lyric_lines_lyric_id_idx
on lyric_lines (lyric_id);
......@@ -316,6 +316,40 @@ def test_generated_eval_set_uses_stratified_production_mix(tmp_path) -> None:
assert all(row["expected"] == "不应去重" for row in rows if row["sample_type"].startswith("negative_"))
def test_generated_hard_eval_set_uses_business_realistic_edge_mix(tmp_path) -> None:
library = tmp_path / "library"
incoming = tmp_path / "generated" / "incoming"
eval_csv = tmp_path / "generated" / "eval_hard.csv"
library.mkdir()
for idx in range(24):
prefix = "AY" if idx % 3 == 0 else "WHHY"
lyric = BASE_LYRIC.replace("我爱你", f"我想你{idx}").replace("城市", f"城市{idx}")
if idx % 4 == 0:
lyric += "\nI miss you tonight\nUnder the moonlight\nNever let me go\n"
(library / f"{idx}_{prefix}{idx:06d}.txt").write_text(lyric, encoding="utf-8")
generate_eval_set(
library_dir=library,
output_dir=incoming,
csv_path=eval_csv,
size=40,
positive_ratio=0.3,
profile="hard",
)
rows = list(csv.DictReader(eval_csv.open(encoding="utf-8")))
manifest = json.loads((tmp_path / "generated" / "eval_hard.csv.manifest.json").read_text(encoding="utf-8"))
sample_types = {row["sample_type"] for row in rows}
assert len(rows) == 40
assert manifest["profile"] == "hard"
assert "positive_realistic_variant" in manifest["plan"]
assert "negative_near_neighbor_holdout_full_song" in manifest["plan"]
assert "negative_long_fragment" in sample_types
assert "negative_catalog_mashup" in sample_types
assert any(row["sample_type"].startswith("positive_") for row in rows)
def test_foreign_original_with_added_chinese_translation_is_duplicate() -> None:
checker = DuplicateChecker()
checker.add_record(
......