Commit 5a879284 5a8792846d6cf6e590f6148fee68cd7cd89b69a1 by 沈秋雨

改为均衡抽样生成测试集

1 parent 67dc9bec
...@@ -71,7 +71,7 @@ python scripts/10_report.py ...@@ -71,7 +71,7 @@ python scripts/10_report.py
71 71
72 首轮建议只使用 2 个 PDF、1 个 XLSX 和 10 条审核通过 QA,确认 `retrieved_contexts``response`、Ragas 输入字段都正常后再扩展样本量。 72 首轮建议只使用 2 个 PDF、1 个 XLSX 和 10 条审核通过 QA,确认 `retrieved_contexts``response`、Ragas 输入字段都正常后再扩展样本量。
73 73
74 默认 `04_parse_docs.py` 从 WeKnora 导出的 `data/exported/chunks.jsonl` 构造测试集来源,不再重复调用外部 PDF 解析器。`05_generate_testset.py` 默认使用 Ragas 结合评估侧 LLM 自动生成 QA;生成阶段使用 `TESTSET_RAGAS_MODE=direct`,直接把 WeKnora chunks 组装成 Ragas KnowledgeGraph 并生成单跳 QA,避免 Ragas 默认文档预处理链路重新抽标题、摘要和实体。生成阶段还会用 `TESTSET_MAX_DOCUMENT_CHARS` 限制单条来源上下文长度,并用 `TESTSET_GENERATOR_MAX_TOKENS` 控制生成输出预算,避免和后续评测用的 `ragas.max_tokens` 混在一起`local``mineru``rule_based` 只作为可选实验/兜底配置保留。 74 默认 `04_parse_docs.py` 从 WeKnora 导出的 `data/exported/chunks.jsonl` 构造测试集来源,不再重复调用外部 PDF 解析器。`05_generate_testset.py` 默认使用 Ragas 结合评估侧 LLM 自动生成 QA;生成阶段使用 `TESTSET_RAGAS_MODE=direct`,直接把 WeKnora chunks 组装成 Ragas KnowledgeGraph 并生成单跳 QA,避免 Ragas 默认文档预处理链路重新抽标题、摘要和实体。生成阶段还会用 `TESTSET_MAX_DOCUMENT_CHARS` 限制单条来源上下文长度,`TESTSET_GENERATOR_MAX_TOKENS` 控制生成输出预算,并按来源文件轮询抽样,避免测试集集中在单个文件`local``mineru``rule_based` 只作为可选实验/兜底配置保留。
75 75
76 ## 主要产物 76 ## 主要产物
77 77
......
...@@ -298,7 +298,7 @@ TESTSET_GENERATOR_MAX_TOKENS=4096 ...@@ -298,7 +298,7 @@ TESTSET_GENERATOR_MAX_TOKENS=4096
298 TESTSET_MAX_DOCUMENT_CHARS=2000 298 TESTSET_MAX_DOCUMENT_CHARS=2000
299 RAGAS_ENABLE_THINKING=false 299 RAGAS_ENABLE_THINKING=false
300 RAGAS_HTTP_KEEPALIVE=false 300 RAGAS_HTTP_KEEPALIVE=false
301 RAGAS_TESTSET_TRANSFORMS=default 301 RAGAS_TESTSET_TRANSFORMS=single_hop_entities
302 ``` 302 ```
303 303
304 `direct` 模式会跳过 Ragas 默认的 `HeadlinesExtractor``SummaryExtractor``NERExtractor` 文档预处理链路,直接把 WeKnora chunks 组装成 Ragas KnowledgeGraph 并生成单跳 QA。`prechunked``langchain_docs` 仅用于对比实验,遇到本地 vLLM 结构化输出不稳定时不建议使用。 304 `direct` 模式会跳过 Ragas 默认的 `HeadlinesExtractor``SummaryExtractor``NERExtractor` 文档预处理链路,直接把 WeKnora chunks 组装成 Ragas KnowledgeGraph 并生成单跳 QA。`prechunked``langchain_docs` 仅用于对比实验,遇到本地 vLLM 结构化输出不稳定时不建议使用。
......
1 from __future__ import annotations 1 from __future__ import annotations
2 2
3 import asyncio 3 import asyncio
4 from collections import defaultdict
4 import inspect 5 import inspect
5 import json 6 import json
6 import logging 7 import logging
...@@ -71,7 +72,7 @@ def generate_ragas_testset( ...@@ -71,7 +72,7 @@ def generate_ragas_testset(
71 return [] 72 return []
72 73
73 source_limit = min(len(source_rows), max(size * source_multiplier, size, 1)) 74 source_limit = min(len(source_rows), max(size * source_multiplier, size, 1))
74 selected_source_rows = source_rows[:source_limit] 75 selected_source_rows = _select_source_rows(source_rows, source_limit)
75 documents = [ 76 documents = [
76 Document( 77 Document(
77 page_content=_truncate_for_generation(row["content"], max_document_chars), 78 page_content=_truncate_for_generation(row["content"], max_document_chars),
...@@ -85,9 +86,10 @@ def generate_ragas_testset( ...@@ -85,9 +86,10 @@ def generate_ragas_testset(
85 for row in selected_source_rows 86 for row in selected_source_rows
86 ] 87 ]
87 logger.info( 88 logger.info(
88 "Generating Ragas testset: target_size=%s source_documents=%s max_document_chars=%s generator_max_tokens=%s ragas_mode=%s", 89 "Generating Ragas testset: target_size=%s source_documents=%s source_files=%s max_document_chars=%s generator_max_tokens=%s ragas_mode=%s",
89 size, 90 size,
90 len(documents), 91 len(documents),
92 len({row.get("source_file") or "unknown" for row in selected_source_rows}),
91 max_document_chars, 93 max_document_chars,
92 generator_max_tokens, 94 generator_max_tokens,
93 ragas_mode, 95 ragas_mode,
...@@ -129,6 +131,35 @@ def generate_ragas_testset( ...@@ -129,6 +131,35 @@ def generate_ragas_testset(
129 return rows 131 return rows
130 132
131 133
134 def _select_source_rows(
135 source_rows: list[dict[str, Any]],
136 limit: int,
137 ) -> list[dict[str, Any]]:
138 grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
139 source_order: list[str] = []
140 for row in source_rows:
141 source = str(row.get("source_file") or "unknown")
142 if source not in grouped:
143 source_order.append(source)
144 grouped[source].append(row)
145
146 selected: list[dict[str, Any]] = []
147 cursor = 0
148 while len(selected) < limit and source_order:
149 progressed = False
150 for source in source_order:
151 rows = grouped[source]
152 if cursor < len(rows):
153 selected.append(rows[cursor])
154 progressed = True
155 if len(selected) >= limit:
156 break
157 if not progressed:
158 break
159 cursor += 1
160 return selected
161
162
132 def _generate_ragas_direct_rows( 163 def _generate_ragas_direct_rows(
133 llm: ChatOpenAI, 164 llm: ChatOpenAI,
134 documents: list[Document], 165 documents: list[Document],
......