testset.py 9.45 KB
from __future__ import annotations

import json
from typing import Any

from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas.run_config import RunConfig
from ragas.testset import TestsetGenerator

from weknora_eval.config import require_config
from weknora_eval.loaders import read_jsonl, write_jsonl
from weknora_eval.ragas_runner import _wrap_langchain_models
from weknora_eval.schemas import TestsetRecord


def generate_testset(config: dict[str, Any]) -> list[dict[str, Any]]:
    testset = config.get("testset", {})
    generator = str(testset.get("generator", "ragas"))
    if generator == "ragas":
        return generate_ragas_testset(config)
    if generator == "rule_based":
        return generate_rule_based_testset(
            size=int(testset.get("size", 50)),
            min_context_chars=int(testset.get("min_context_chars", 80)),
        )
    raise ValueError(f"Unsupported testset.generator: {generator}")


def generate_ragas_testset(
    config: dict[str, Any],
    *,
    documents_path: str = "data/parsed_docs/documents.jsonl",
    output_path: str = "data/testsets/testset.raw.jsonl",
) -> list[dict[str, Any]]:
    testset_config = config.get("testset", {})
    ragas_config = config["ragas"]
    size = int(testset_config.get("size", 50))
    min_context_chars = int(testset_config.get("min_context_chars", 80))

    source_rows = [
        row
        for row in read_jsonl(documents_path)
        if len(row.get("content") or "") >= min_context_chars
    ]
    if not source_rows:
        write_jsonl(output_path, [])
        return []

    documents = [
        Document(
            page_content=row["content"],
            metadata={
                "source_file": row.get("source_file"),
                "doc_id": row.get("doc_id"),
                **(row.get("metadata") or {}),
            },
        )
        for row in source_rows
    ]

    llm = ChatOpenAI(
        model=str(require_config(config, "ragas.generator_model")),
        api_key=_required_ragas_value(ragas_config, "llm_api_key"),
        base_url=_required_ragas_value(ragas_config, "llm_base_url"),
        temperature=float(ragas_config.get("temperature", 0)),
        max_tokens=int(ragas_config.get("max_tokens", 4096)),
        timeout=int(ragas_config.get("timeout_seconds", 600)),
    )
    embeddings = OpenAIEmbeddings(
        model=str(require_config(config, "ragas.embedding_model")),
        api_key=_required_ragas_value(ragas_config, "embedding_api_key"),
        base_url=_required_ragas_value(ragas_config, "embedding_base_url"),
        tiktoken_enabled=False,
        check_embedding_ctx_length=False,
        request_timeout=int(ragas_config.get("timeout_seconds", 600)),
    )
    ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
    generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings)
    result = generator.generate_with_langchain_docs(
        documents[: max(size, 1)],
        testset_size=size,
        run_config=RunConfig(
            timeout=int(ragas_config.get("timeout_seconds", 600)),
            max_workers=int(ragas_config.get("max_workers", 1)),
        ),
        raise_exceptions=False,
    )

    ragas_rows = result.to_list()
    rows = _normalize_ragas_rows(ragas_rows, source_rows)
    write_jsonl(output_path, rows)
    return rows


def _normalize_ragas_rows(
    ragas_rows: list[dict[str, Any]],
    source_rows: list[dict[str, Any]],
) -> list[dict[str, Any]]:
    normalized: list[dict[str, Any]] = []
    source_by_doc_id = {str(row.get("doc_id")): row for row in source_rows if row.get("doc_id")}
    for index, row in enumerate(ragas_rows, start=1):
        reference_contexts = _as_string_list(row.get("reference_contexts"))
        if not reference_contexts and row.get("reference_context"):
            reference_contexts = _as_string_list(row.get("reference_context"))
        source = _match_source_row(row, source_rows, source_by_doc_id, reference_contexts)
        gold_chunk_ids = []
        if source:
            chunk_id = (source.get("metadata") or {}).get("chunk_id") or source.get("doc_id")
            if chunk_id:
                gold_chunk_ids = [str(chunk_id)]
        normalized.append(
            TestsetRecord(
                sample_id=f"qa-{index:04d}",
                user_input=str(row.get("user_input") or row.get("query") or "").strip(),
                reference=str(row.get("reference") or row.get("answer") or "").strip(),
                reference_contexts=reference_contexts or ([source["content"]] if source else []),
                source_file=source.get("source_file") if source else None,
                gold_chunk_ids=gold_chunk_ids,
                question_type=str(row.get("synthesizer_name") or "ragas"),
                review_status="pending",
            ).to_dict()
        )
    return [
        row
        for row in normalized
        if row.get("user_input") and row.get("reference") and row.get("reference_contexts")
    ]


def _match_source_row(
    ragas_row: dict[str, Any],
    source_rows: list[dict[str, Any]],
    source_by_doc_id: dict[str, dict[str, Any]],
    reference_contexts: list[str],
) -> dict[str, Any] | None:
    for key in ("reference_context_ids", "retrieved_context_ids"):
        for doc_id in _as_string_list(ragas_row.get(key)):
            if doc_id in source_by_doc_id:
                return source_by_doc_id[doc_id]
    for context in reference_contexts:
        for source in source_rows:
            content = source.get("content") or ""
            if context and (context in content or content in context):
                return source
    return source_rows[0] if source_rows else None


def _as_string_list(value: Any) -> list[str]:
    if value is None:
        return []
    if isinstance(value, str):
        try:
            parsed = json.loads(value)
            if parsed != value:
                return _as_string_list(parsed)
        except json.JSONDecodeError:
            pass
        return [value.strip()] if value.strip() else []
    if isinstance(value, list):
        result: list[str] = []
        for item in value:
            result.extend(_as_string_list(item))
        return result
    if isinstance(value, dict):
        for key in ("content", "text", "page_content"):
            if key in value:
                return _as_string_list(value[key])
        return []
    return [str(value)]


def _required_ragas_value(config: dict[str, Any], key: str) -> str:
    value = config.get(key)
    if value in {None, ""}:
        raise ValueError(f"Missing required Ragas config value: ragas.{key}")
    return str(value)


def generate_rule_based_testset(
    *,
    documents_path: str = "data/parsed_docs/documents.jsonl",
    output_path: str = "data/testsets/testset.raw.jsonl",
    size: int = 50,
    min_context_chars: int = 80,
) -> list[dict[str, Any]]:
    documents = [
        row
        for row in read_jsonl(documents_path)
        if len(row.get("content") or "") >= min_context_chars
    ]
    rows: list[dict[str, Any]] = []
    for index, document in enumerate(documents[:size], start=1):
        context = document["content"]
        source_file = document.get("source_file")
        question = _default_question(document)
        reference = _reference_from_context(context)
        rows.append(
            TestsetRecord(
                sample_id=f"qa-{index:04d}",
                user_input=question,
                reference=reference,
                reference_contexts=[context],
                source_file=source_file,
                question_type="single_hop",
                review_status="pending",
            ).to_dict()
        )
    write_jsonl(output_path, rows)
    return rows


def approve_pending_testset(
    *,
    input_path: str = "data/testsets/testset.raw.jsonl",
    output_path: str = "data/testsets/testset.reviewed.jsonl",
) -> list[dict[str, Any]]:
    rows = read_jsonl(input_path)
    reviewed: list[dict[str, Any]] = []
    for row in rows:
        row = dict(row)
        if row.get("review_status") == "rejected":
            continue
        row["review_status"] = "approved"
        reviewed.append(row)
    write_jsonl(output_path, reviewed)
    return reviewed


def validate_reviewed_testset(path: str = "data/testsets/testset.reviewed.jsonl") -> list[str]:
    errors: list[str] = []
    for index, row in enumerate(read_jsonl(path), start=1):
        prefix = f"{path}:{index}"
        if row.get("review_status") != "approved":
            errors.append(f"{prefix} review_status must be approved")
        for key in ("sample_id", "user_input", "reference"):
            if not row.get(key):
                errors.append(f"{prefix} missing {key}")
        if not row.get("reference_contexts"):
            errors.append(f"{prefix} reference_contexts must be non-empty")
    return errors


def _default_question(document: dict[str, Any]) -> str:
    source = document.get("source_file") or "该文档"
    if document.get("file_type") == "xlsx" and document.get("sheet"):
        return f"请根据 {source} 的 {document['sheet']} 中对应记录回答:这条记录的主要内容是什么?"
    if document.get("page"):
        return f"请根据 {source} 第 {document['page']} 页回答:该片段的主要内容是什么?"
    return f"请根据 {source} 回答:该片段的主要内容是什么?"


def _reference_from_context(context: str, *, max_chars: int = 500) -> str:
    text = " ".join(context.split())
    if len(text) <= max_chars:
        return text
    return text[:max_chars].rstrip() + "..."