report.py 7.22 KB
from __future__ import annotations

import math
from pathlib import Path
from typing import Any

import pandas as pd

from weknora_eval.loaders import read_jsonl


def retrieval_metrics(
    ragas_rows: list[dict[str, Any]],
    *,
    ks: tuple[int, ...] = (1, 3, 5),
) -> dict[str, float]:
    samples = [row for row in ragas_rows if row.get("gold_chunk_ids")]
    if not samples:
        return {}

    totals: dict[str, float] = {f"hit@{k}": 0.0 for k in ks}
    totals.update({f"recall@{k}": 0.0 for k in ks})
    totals["mrr"] = 0.0
    totals["ndcg@5"] = 0.0

    for row in samples:
        gold = set(row.get("gold_chunk_ids") or [])
        refs = row.get("weknora_references") or []
        predicted = [str(ref.get("id")) for ref in refs if ref.get("id")]
        for k in ks:
            top_k = predicted[:k]
            hits = len(gold.intersection(top_k))
            totals[f"hit@{k}"] += 1.0 if hits else 0.0
            totals[f"recall@{k}"] += hits / len(gold)

        first_rank = next((idx for idx, chunk_id in enumerate(predicted, start=1) if chunk_id in gold), None)
        if first_rank:
            totals["mrr"] += 1 / first_rank

        dcg = 0.0
        for idx, chunk_id in enumerate(predicted[:5], start=1):
            if chunk_id in gold:
                dcg += 1 / math.log2(idx + 1)
        ideal_hits = min(len(gold), 5)
        idcg = sum(1 / math.log2(idx + 1) for idx in range(1, ideal_hits + 1))
        totals["ndcg@5"] += dcg / idcg if idcg else 0.0

    return {key: round(value / len(samples), 4) for key, value in totals.items()}


def generate_summary_report(
    config: dict[str, Any],
    *,
    scores_csv_path: str = "data/reports/ragas_scores.csv",
    ragas_input_path: str = "data/runs/ragas_input.jsonl",
    answers_path: str = "data/runs/weknora_answers.jsonl",
    output_path: str = "data/reports/summary.md",
) -> str:
    ragas_rows = read_jsonl(ragas_input_path, missing_ok=True)
    answer_rows = read_jsonl(answers_path, missing_ok=True)
    scores = pd.read_csv(scores_csv_path) if Path(scores_csv_path).exists() else pd.DataFrame()

    lines = [
        "# Ragas 评估报告",
        "",
        "## 运行信息",
        f"- WeKnora Base URL: {config.get('weknora', {}).get('base_url', '')}",
        f"- 知识库 ID: {config.get('weknora', {}).get('knowledge_base_id', '')}",
        f"- 测试集规模: {len(ragas_rows)}",
        f"- 审核通过样本数: {len(ragas_rows)}",
        f"- 失败样本数: {sum(1 for row in answer_rows if row.get('error'))}",
        f"- Judge 模型: {config.get('ragas', {}).get('judge_model', '')}",
        "",
        "## 聚合指标",
        "| 指标 | 平均值 | P50 | 失败阈值 |",
        "| --- | --- | --- | --- |",
    ]

    metric_columns = [
        column
        for column in scores.columns
        if column not in {"sample_id", "user_input", "response", "reference"}
        and pd.api.types.is_numeric_dtype(scores[column])
    ]
    for column in metric_columns:
        lines.append(
            f"| {column} | {scores[column].mean():.4f} | {scores[column].median():.4f} | 0.50 |"
        )

    chunk_metrics = retrieval_metrics(ragas_rows)
    if chunk_metrics:
        lines.extend(["", "## Chunk ID 检索指标", "| 指标 | 平均值 |", "| --- | --- |"])
        for key, value in chunk_metrics.items():
            lines.append(f"| {key} | {value:.4f} |")

    lines.extend(["", "## 检索失败样本", "| sample_id | 问题 | 预期文件 | 实际召回文件 | context_recall | 备注 |", "| --- | --- | --- | --- | --- | --- |"])
    for row in _worst_rows(scores, "context_recall"):
        sample = _sample_by_id(ragas_rows, row.get("sample_id"))
        actual_files = sorted(
            {
                ref.get("knowledge_filename") or ""
                for ref in sample.get("weknora_references", [])
                if ref.get("knowledge_filename")
            }
        )
        lines.append(
            f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | "
            f"{_cell(sample.get('source_file'))} | {_cell(', '.join(actual_files))} | "
            f"{_score(row.get('context_recall'))} | |"
        )

    lines.extend(["", "## 生成失败样本", "| sample_id | 问题 | 模型答案 | 标准答案 | faithfulness | factual_correctness |", "| --- | --- | --- | --- | --- | --- |"])
    for row in _worst_rows(scores, "faithfulness"):
        sample = _sample_by_id(ragas_rows, row.get("sample_id"))
        lines.append(
            f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | "
            f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | "
            f"{_score(_metric_value(row, 'faithfulness'))} | {_score(_metric_value(row, 'factual_correctness'))} |"
        )

    empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts"))
    fallback_answers = sum(1 for row in answer_rows if row.get("is_fallback"))
    source_counts: dict[str, int] = {}
    for row in ragas_rows:
        source = row.get("source_file") or "unknown"
        source_counts[source] = source_counts.get(source, 0) + 1

    lines.extend(
        [
            "",
            "## 数据质量",
            f"- 空检索数量: {empty_retrievals}",
            f"- fallback 答案数量: {fallback_answers}",
            f"- 来源文件分布: {source_counts}",
            "",
            "## 改进建议",
            "- 优先检查 context_recall 低且 retrieved_contexts 为空的样本。",
            "- 对低 faithfulness 且 context_recall 正常的样本,重点检查生成模型和提示词。",
            "- 对 Chunk ID 指标低但 Ragas context 指标正常的样本,检查 chunk 切分或 gold_chunk_ids 标注。",
            "",
        ]
    )

    content = "\n".join(lines)
    target = Path(output_path)
    target.parent.mkdir(parents=True, exist_ok=True)
    target.write_text(content, encoding="utf-8")
    return content


def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]:
    metric_column = _metric_column(scores, column)
    if scores.empty or metric_column is None:
        return []
    return scores.sort_values(metric_column, ascending=True).head(limit).to_dict(orient="records")


def _metric_column(scores: pd.DataFrame, name: str) -> str | None:
    if name in scores.columns:
        return name
    prefix = f"{name}("
    return next((column for column in scores.columns if column.startswith(prefix)), None)


def _metric_value(row: dict[str, Any], name: str) -> Any:
    if name in row:
        return row[name]
    prefix = f"{name}("
    for key, value in row.items():
        if str(key).startswith(prefix):
            return value
    return None


def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]:
    return next((row for row in rows if row.get("sample_id") == sample_id), {})


def _cell(value: Any, *, max_len: int = 120) -> str:
    text = "" if value is None else " ".join(str(value).split())
    text = text.replace("|", "\\|")
    if len(text) <= max_len:
        return text
    return text[:max_len].rstrip() + "..."


def _score(value: Any) -> str:
    try:
        if pd.isna(value):
            return ""
        return f"{float(value):.4f}"
    except (TypeError, ValueError):
        return ""