修复hard数据集误判
Showing
2 changed files
with
7 additions
and
3 deletions
| ... | @@ -450,9 +450,9 @@ python -m lyric_dedup.cli generate-eval-set \ | ... | @@ -450,9 +450,9 @@ python -m lyric_dedup.cli generate-eval-set \ |
| 450 | ```bash | 450 | ```bash |
| 451 | python scripts/evaluate_postgres.py \ | 451 | python scripts/evaluate_postgres.py \ |
| 452 | --dsn postgresql:///lyric_dedup \ | 452 | --dsn postgresql:///lyric_dedup \ |
| 453 | --csv data/generated_eval/eval_5000.csv \ | 453 | --csv data/generated_eval/eval_hard_5000.csv \ |
| 454 | --base-dir data/generated_eval \ | 454 | --base-dir data/generated_eval \ |
| 455 | --out outputs/results/postgres_eval_5000.csv | 455 | --out outputs/results/postgres_eval_hard_5000.csv |
| 456 | ``` | 456 | ``` |
| 457 | 457 | ||
| 458 | 它会: | 458 | 它会: | ... | ... |
| ... | @@ -220,7 +220,11 @@ def _recall_candidates( | ... | @@ -220,7 +220,11 @@ def _recall_candidates( |
| 220 | 220 | ||
| 221 | 221 | ||
| 222 | def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]: | 222 | def _exclude_record_ids_for_eval_row(row: dict[str, str]) -> list[str]: |
| 223 | if row.get("sample_type") == "negative_real_holdout_full_song" and row.get("source_record_id"): | 223 | holdout_sample_types = { |
| 224 | "negative_real_holdout_full_song", | ||
| 225 | "negative_near_neighbor_holdout_full_song", | ||
| 226 | } | ||
| 227 | if row.get("sample_type") in holdout_sample_types and row.get("source_record_id"): | ||
| 224 | return [row["source_record_id"]] | 228 | return [row["source_record_id"]] |
| 225 | return [] | 229 | return [] |
| 226 | 230 | ... | ... |
-
Please register or sign in to post a comment