evaluate_composition.py 11.7 KB
"""曲去重评估脚本。

对 queries.csv 中每条查询音频调用 CompositionDedupService.query(),
按最终接口语义用 top1 分数阈值输出 predicted_duplicate true/false。
expected_song_id 的 top-k/top1 命中只作为诊断字段。
输出 precision/recall/F1。

用法:
python scripts/evaluate_composition.py \
    --dsn "postgresql:///lyric_dedup" \
    --queries composition_dedup/composition_testset4/queries.csv \
    --out composition_dedup/composition_eval/composition_eval_result_v3.csv
"""

import argparse
import csv
import json
import logging
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from composition_dedup.service import CompositionConfig, CompositionDedupService

logger = logging.getLogger(__name__)


def _parse_csv_filter(value: str | None) -> set[str] | None:
    if value is None:
        return None
    items = {item.strip() for item in value.split(",") if item.strip()}
    return items or None


def _song_id_from_audio_path(audio_path: str) -> str:
    """从音频文件名开头提取 song_id。"""
    return Path(audio_path).stem.split("_", 1)[0]


def main() -> None:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

    parser = argparse.ArgumentParser()
    parser.add_argument("--dsn", required=True)
    parser.add_argument("--queries", required=True, help="queries.csv 路径")
    parser.add_argument("--out", required=True, help="逐条结果输出 CSV")
    parser.add_argument("--top-k", type=int, default=10)
    parser.add_argument("--duplicate-threshold", type=float, help="覆盖 COMPOSITION_DUPLICATE_THRESHOLD")
    parser.add_argument("--variants", help="只评测指定 variant,逗号分隔,如 pitch_up1,pitch_down1")
    parser.add_argument("--sample-classes", help="只评测指定 sample_class,逗号分隔,如 dsp,negative")
    parser.add_argument("--expected", choices=["duplicate", "not_duplicate"], help="只评测指定 expected 类型")
    args = parser.parse_args()

    config = CompositionConfig(dsn=args.dsn)
    if args.duplicate_threshold is not None:
        config.duplicate_threshold = args.duplicate_threshold
    service = CompositionDedupService(config=config)

    with open(args.queries, newline="", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))

    variant_filter = _parse_csv_filter(args.variants)
    sample_class_filter = _parse_csv_filter(args.sample_classes)
    original_count = len(rows)
    if variant_filter is not None:
        rows = [r for r in rows if (r.get("variant") or "") in variant_filter]
    if sample_class_filter is not None:
        rows = [r for r in rows if (r.get("sample_class") or "") in sample_class_filter]
    if args.expected is not None:
        rows = [r for r in rows if r["expected"].strip().lower() == args.expected]

    logger.info(
        "评测样本过滤: 原始 %d 条,保留 %d 条 (variants=%s, sample_classes=%s, expected=%s)",
        original_count,
        len(rows),
        ",".join(sorted(variant_filter)) if variant_filter else "ALL",
        ",".join(sorted(sample_class_filter)) if sample_class_filter else "ALL",
        args.expected or "ALL",
    )

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    result_rows = []
    for i, row in enumerate(rows, 1):
        audio_path = row["audio_path"]
        query_song_id = row.get("song_id") or _song_id_from_audio_path(audio_path)
        audio_song_id = _song_id_from_audio_path(audio_path)
        expected_song_id = str(row["expected_song_id"])
        expected_dup = row["expected"].strip().lower() == "duplicate"
        invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id

        try:
            candidates = service.query(audio_path, top_k=args.top_k)
        except Exception as e:
            logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
            result_rows.append({
                "query_song_id": query_song_id,
                "audio_song_id": audio_song_id,
                "audio_path": audio_path,
                "variant": row.get("variant", ""),
                "sample_class": row.get("sample_class", ""),
                "expected_song_id": expected_song_id,
                "expected": row["expected"],
                "top1_song_id": "",
                "top1_similarity": "",
                "top1_source": "",
                "top1_hit": False,
                "topk_hit": False,
                "expected_rank": "",
                "expected_similarity": "",
                "invalid_negative_pair": invalid_negative_pair,
                "invalid_boolean_sample": False,
                "expected_duplicate": expected_dup,
                "predicted_duplicate": False,
                "correct": not expected_dup,  # 查询失败视为 not_duplicate
                "error": str(e),
            })
            continue

        top1 = candidates[0] if candidates else None
        top1_song_id = str(top1.song_id) if top1 else ""
        top1_sim = round(top1.similarity, 4) if top1 else ""
        top1_source = top1.source if top1 else ""

        # 诊断召回:expected_song_id 是否进入 top1/top-k。
        top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
        topk_hit = bool(expected_song_id) and any(str(c.song_id) == expected_song_id for c in candidates)
        expected_rank = ""
        expected_similarity = ""
        if expected_song_id:
            for rank, candidate in enumerate(candidates, 1):
                if str(candidate.song_id) == expected_song_id:
                    expected_rank = rank
                    expected_similarity = round(candidate.similarity, 4)
                    break

        # 最终接口语义:只返回 duplicate true/false。
        predicted_dup = service.candidates_indicate_duplicate(candidates)
        correct = expected_dup == predicted_dup
        invalid_boolean_sample = (
            (not expected_dup)
            and bool(top1)
            and top1_song_id == audio_song_id
            and predicted_dup
        )

        result_rows.append({
            "query_song_id": query_song_id,
            "audio_song_id": audio_song_id,
            "audio_path": audio_path,
            "variant": row.get("variant", ""),
            "sample_class": row.get("sample_class", ""),
            "expected_song_id": expected_song_id,
            "expected": row["expected"],
            "top1_song_id": top1_song_id,
            "top1_similarity": top1_sim,
            "top1_source": top1_source,
            "top1_hit": top1_hit,
            "topk_hit": topk_hit,
            "expected_rank": expected_rank,
            "expected_similarity": expected_similarity,
            "invalid_negative_pair": invalid_negative_pair,
            "invalid_boolean_sample": invalid_boolean_sample,
            "expected_duplicate": expected_dup,
            "predicted_duplicate": predicted_dup,
            "correct": correct,
            "error": "",
        })

        logger.info(
            "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s",
            i,
            len(rows),
            row.get("variant", ""),
            top1_source or "-",
            row["expected"],
            predicted_dup,
            service.config.duplicate_threshold,
            expected_song_id,
            top1_song_id or "-",
            top1_sim if top1_sim != "" else "-",
            top1_hit,
            topk_hit,
            expected_rank if expected_rank != "" else "-",
            expected_similarity if expected_similarity != "" else "-",
            correct,
        )

        if i % 10 == 0 or i == len(rows):
            logger.info("[%d/%d]", i, len(rows))

    # 写逐条结果
    fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
                  "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
                  "top1_hit", "topk_hit", "expected_rank", "expected_similarity",
                  "invalid_negative_pair", "invalid_boolean_sample",
                  "expected_duplicate", "predicted_duplicate", "correct", "error"]
    with out_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(result_rows)

    # 汇总指标
    def _metrics(rows: list[dict]) -> dict:
        tp = sum(1 for r in rows if r["expected_duplicate"] and r["predicted_duplicate"])
        fp = sum(1 for r in rows if not r["expected_duplicate"] and r["predicted_duplicate"])
        tn = sum(1 for r in rows if not r["expected_duplicate"] and not r["predicted_duplicate"])
        fn = sum(1 for r in rows if r["expected_duplicate"] and not r["predicted_duplicate"])
        precision = tp / (tp + fp) if tp + fp else 0.0
        recall = tp / (tp + fn) if tp + fn else 0.0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
        accuracy = (tp + tn) / len(rows) if rows else 0.0
        return {
            "total": len(rows),
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4),
            "tp": tp,
            "fp": fp,
            "tn": tn,
            "fn": fn,
        }

    metrics = _metrics(result_rows)
    valid_rows = [
        r for r in result_rows
        if not r["invalid_negative_pair"] and not r["invalid_boolean_sample"]
    ]
    valid_metrics = _metrics(valid_rows)

    summary = {
        "total": len(result_rows),
        "filters": {
            "variants": sorted(variant_filter) if variant_filter else None,
            "sample_classes": sorted(sample_class_filter) if sample_class_filter else None,
            "expected": args.expected,
            "original_total": original_count,
        },
        "duplicate_threshold": service.config.duplicate_threshold,
        "invalid_negative_pairs": sum(1 for r in result_rows if r["invalid_negative_pair"]),
        "invalid_boolean_samples": sum(1 for r in result_rows if r["invalid_boolean_sample"]),
        "accuracy": metrics["accuracy"],
        "precision": metrics["precision"],
        "recall": metrics["recall"],
        "f1": metrics["f1"],
        "tp": metrics["tp"], "fp": metrics["fp"], "tn": metrics["tn"], "fn": metrics["fn"],
        "valid_only": valid_metrics,
        "out": str(out_path),
    }

    # 按 variant 分组,方便看各种变换的通过率
    from collections import defaultdict
    by_variant: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
    for r in result_rows:
        v = r["variant"] or "unknown"
        by_variant[v]["total"] += 1
        if r["correct"]:
            by_variant[v]["correct"] += 1
    summary["by_variant"] = {
        v: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
        for v, d in sorted(by_variant.items())
    }

    # 按 sample_class 分组
    by_class: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
    for r in result_rows:
        sc = r.get("sample_class") or "unknown"
        by_class[sc]["total"] += 1
        if r["correct"]:
            by_class[sc]["correct"] += 1
    summary["by_sample_class"] = {
        sc: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
        for sc, d in sorted(by_class.items())
    }

    summary_path = out_path.with_suffix(".summary.json")
    summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(json.dumps(summary, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()