testset.py 3.06 KB
from __future__ import annotations

from typing import Any

from weknora_eval.loaders import read_jsonl, write_jsonl
from weknora_eval.schemas import TestsetRecord


def generate_rule_based_testset(
    *,
    documents_path: str = "data/parsed_docs/documents.jsonl",
    output_path: str = "data/testsets/testset.raw.jsonl",
    size: int = 50,
    min_context_chars: int = 80,
) -> list[dict[str, Any]]:
    documents = [
        row
        for row in read_jsonl(documents_path)
        if len(row.get("content") or "") >= min_context_chars
    ]
    rows: list[dict[str, Any]] = []
    for index, document in enumerate(documents[:size], start=1):
        context = document["content"]
        source_file = document.get("source_file")
        question = _default_question(document)
        reference = _reference_from_context(context)
        rows.append(
            TestsetRecord(
                sample_id=f"qa-{index:04d}",
                user_input=question,
                reference=reference,
                reference_contexts=[context],
                source_file=source_file,
                question_type="single_hop",
                review_status="pending",
            ).to_dict()
        )
    write_jsonl(output_path, rows)
    return rows


def approve_pending_testset(
    *,
    input_path: str = "data/testsets/testset.raw.jsonl",
    output_path: str = "data/testsets/testset.reviewed.jsonl",
) -> list[dict[str, Any]]:
    rows = read_jsonl(input_path)
    reviewed: list[dict[str, Any]] = []
    for row in rows:
        row = dict(row)
        if row.get("review_status") == "rejected":
            continue
        row["review_status"] = "approved"
        reviewed.append(row)
    write_jsonl(output_path, reviewed)
    return reviewed


def validate_reviewed_testset(path: str = "data/testsets/testset.reviewed.jsonl") -> list[str]:
    errors: list[str] = []
    for index, row in enumerate(read_jsonl(path), start=1):
        prefix = f"{path}:{index}"
        if row.get("review_status") != "approved":
            errors.append(f"{prefix} review_status must be approved")
        for key in ("sample_id", "user_input", "reference"):
            if not row.get(key):
                errors.append(f"{prefix} missing {key}")
        if not row.get("reference_contexts"):
            errors.append(f"{prefix} reference_contexts must be non-empty")
    return errors


def _default_question(document: dict[str, Any]) -> str:
    source = document.get("source_file") or "该文档"
    if document.get("file_type") == "xlsx" and document.get("sheet"):
        return f"请根据 {source} 的 {document['sheet']} 中对应记录回答:这条记录的主要内容是什么?"
    if document.get("page"):
        return f"请根据 {source} 第 {document['page']} 页回答:该片段的主要内容是什么?"
    return f"请根据 {source} 回答:该片段的主要内容是什么?"


def _reference_from_context(context: str, *, max_chars: int = 500) -> str:
    text = " ".join(context.split())
    if len(text) <= max_chars:
        return text
    return text[:max_chars].rstrip() + "..."