Commit ed19c4ee ed19c4ee32fb1dcd5fddf847b7e6d295bc4f5d17 by 沈秋雨

修复hard数据集误判

1 parent 49008962
...@@ -450,9 +450,9 @@ python -m lyric_dedup.cli generate-eval-set \ ...@@ -450,9 +450,9 @@ python -m lyric_dedup.cli generate-eval-set \
450 ```bash 450 ```bash
451 python scripts/evaluate_postgres.py \ 451 python scripts/evaluate_postgres.py \
452 --dsn postgresql:///lyric_dedup \ 452 --dsn postgresql:///lyric_dedup \
453 --csv data/generated_eval/eval_5000.csv \ 453 --csv data/generated_eval/eval_hard_5000.csv \
454 --base-dir data/generated_eval \ 454 --base-dir data/generated_eval \
455 --out outputs/results/postgres_eval_5000.csv 455 --out outputs/results/postgres_eval_hard_5000.csv
456 ``` 456 ```
457 457
458 它会: 458 它会:
......
...@@ -220,7 +220,11 @@ def _recall_candidates( ...@@ -220,7 +220,11 @@ def _recall_candidates(
220 220
221 221
222 def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]: 222 def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]:
223 if row.get("sample_type") == "negative_real_holdout_full_song" and row.get("source_record_id"): 223 holdout_sample_types = {
224 "negative_real_holdout_full_song",
225 "negative_near_neighbor_holdout_full_song",
226 }
227 if row.get("sample_type") in holdout_sample_types and row.get("source_record_id"):
224 return [row["source_record_id"]] 228 return [row["source_record_id"]]
225 return [] 229 return []
226 230
......