testset.py
3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import annotations
from typing import Any
from weknora_eval.loaders import read_jsonl, write_jsonl
from weknora_eval.schemas import TestsetRecord
def generate_rule_based_testset(
*,
documents_path: str = "data/parsed_docs/documents.jsonl",
output_path: str = "data/testsets/testset.raw.jsonl",
size: int = 50,
min_context_chars: int = 80,
) -> list[dict[str, Any]]:
documents = [
row
for row in read_jsonl(documents_path)
if len(row.get("content") or "") >= min_context_chars
]
rows: list[dict[str, Any]] = []
for index, document in enumerate(documents[:size], start=1):
context = document["content"]
source_file = document.get("source_file")
question = _default_question(document)
reference = _reference_from_context(context)
rows.append(
TestsetRecord(
sample_id=f"qa-{index:04d}",
user_input=question,
reference=reference,
reference_contexts=[context],
source_file=source_file,
question_type="single_hop",
review_status="pending",
).to_dict()
)
write_jsonl(output_path, rows)
return rows
def approve_pending_testset(
*,
input_path: str = "data/testsets/testset.raw.jsonl",
output_path: str = "data/testsets/testset.reviewed.jsonl",
) -> list[dict[str, Any]]:
rows = read_jsonl(input_path)
reviewed: list[dict[str, Any]] = []
for row in rows:
row = dict(row)
if row.get("review_status") == "rejected":
continue
row["review_status"] = "approved"
reviewed.append(row)
write_jsonl(output_path, reviewed)
return reviewed
def validate_reviewed_testset(path: str = "data/testsets/testset.reviewed.jsonl") -> list[str]:
errors: list[str] = []
for index, row in enumerate(read_jsonl(path), start=1):
prefix = f"{path}:{index}"
if row.get("review_status") != "approved":
errors.append(f"{prefix} review_status must be approved")
for key in ("sample_id", "user_input", "reference"):
if not row.get(key):
errors.append(f"{prefix} missing {key}")
if not row.get("reference_contexts"):
errors.append(f"{prefix} reference_contexts must be non-empty")
return errors
def _default_question(document: dict[str, Any]) -> str:
source = document.get("source_file") or "该文档"
if document.get("file_type") == "xlsx" and document.get("sheet"):
return f"请根据 {source} 的 {document['sheet']} 中对应记录回答:这条记录的主要内容是什么?"
if document.get("page"):
return f"请根据 {source} 第 {document['page']} 页回答:该片段的主要内容是什么?"
return f"请根据 {source} 回答:该片段的主要内容是什么?"
def _reference_from_context(context: str, *, max_chars: int = 500) -> str:
text = " ".join(context.split())
if len(text) <= max_chars:
return text
return text[:max_chars].rstrip() + "..."