Commit 7ce899a9 7ce899a9b07c4ede9971a5608503808dc9ae1c24 by 沈秋雨

Constrain Ragas testset generation budget

1 parent 0813ef9c
...@@ -22,4 +22,7 @@ RAGAS_JUDGE_MODEL=gpt-4o-mini ...@@ -22,4 +22,7 @@ RAGAS_JUDGE_MODEL=gpt-4o-mini
22 RAGAS_EMBEDDING_MODEL=text-embedding-3-small 22 RAGAS_EMBEDDING_MODEL=text-embedding-3-small
23 23
24 TESTSET_SIZE=50 24 TESTSET_SIZE=50
25 TESTSET_MAX_DOCUMENT_CHARS=2000
26 TESTSET_SOURCE_MULTIPLIER=3
27 TESTSET_GENERATOR_MAX_TOKENS=4096
25 REQUEST_INTERVAL_SECONDS=0.2 28 REQUEST_INTERVAL_SECONDS=0.2
......
...@@ -68,7 +68,7 @@ python scripts/10_report.py ...@@ -68,7 +68,7 @@ python scripts/10_report.py
68 68
69 首轮建议只使用 2 个 PDF、1 个 XLSX 和 10 条审核通过 QA,确认 `retrieved_contexts``response`、Ragas 输入字段都正常后再扩展样本量。 69 首轮建议只使用 2 个 PDF、1 个 XLSX 和 10 条审核通过 QA,确认 `retrieved_contexts``response`、Ragas 输入字段都正常后再扩展样本量。
70 70
71 默认 `04_parse_docs.py` 从 WeKnora 导出的 `data/exported/chunks.jsonl` 构造测试集来源,不再重复调用外部 PDF 解析器。`05_generate_testset.py` 默认使用 Ragas 结合评估侧 LLM 自动生成 QA;`local``mineru``rule_based` 只作为可选实验/兜底配置保留。 71 默认 `04_parse_docs.py` 从 WeKnora 导出的 `data/exported/chunks.jsonl` 构造测试集来源,不再重复调用外部 PDF 解析器。`05_generate_testset.py` 默认使用 Ragas 结合评估侧 LLM 自动生成 QA;生成阶段会用 `TESTSET_MAX_DOCUMENT_CHARS` 限制单条来源上下文长度,并用 `TESTSET_GENERATOR_MAX_TOKENS` 控制生成输出预算,避免和后续评测用的 `ragas.max_tokens` 混在一起。`local``mineru``rule_based` 只作为可选实验/兜底配置保留。
72 72
73 ## 主要产物 73 ## 主要产物
74 74
......
...@@ -290,6 +290,15 @@ max_workers: 1 ...@@ -290,6 +290,15 @@ max_workers: 1
290 max_tokens: 4096 290 max_tokens: 4096
291 ``` 291 ```
292 292
293 如果 `05_generate_testset.py` 在生成 QA 时出现 `LLMDidNotFinishException`,优先不要继续盲目调大 `ragas.max_tokens``05` 有独立的生成预算和输入长度:
294
295 ```bash
296 TESTSET_GENERATOR_MAX_TOKENS=4096
297 TESTSET_MAX_DOCUMENT_CHARS=2000
298 ```
299
300 如果 vLLM 仍然报生成未完成,先把 `TESTSET_SIZE` 降到 3,再把 `TESTSET_MAX_DOCUMENT_CHARS` 调到 1000-1500 验证链路;`ragas.max_tokens` 主要用于后续评测阶段,不应该拿来无限放大测试集生成阶段的输出长度。
301
293 ### WeKnora 问答没有 retrieved_contexts 302 ### WeKnora 问答没有 retrieved_contexts
294 303
295 检查: 304 检查:
......
...@@ -13,6 +13,9 @@ testset: ...@@ -13,6 +13,9 @@ testset:
13 include_pdf: true 13 include_pdf: true
14 include_xlsx: true 14 include_xlsx: true
15 min_context_chars: 80 15 min_context_chars: 80
16 max_document_chars: "${TESTSET_MAX_DOCUMENT_CHARS:-2000}"
17 source_multiplier: "${TESTSET_SOURCE_MULTIPLIER:-3}"
18 generator_max_tokens: "${TESTSET_GENERATOR_MAX_TOKENS:-4096}"
16 require_manual_review: true 19 require_manual_review: true
17 20
18 parsing: 21 parsing:
...@@ -69,7 +72,7 @@ ragas: ...@@ -69,7 +72,7 @@ ragas:
69 judge_model: "${RAGAS_JUDGE_MODEL}" 72 judge_model: "${RAGAS_JUDGE_MODEL}"
70 embedding_model: "${RAGAS_EMBEDDING_MODEL}" 73 embedding_model: "${RAGAS_EMBEDDING_MODEL}"
71 temperature: 0 74 temperature: 0
72 max_tokens: 8192 75 max_tokens: 4096
73 timeout_seconds: 600 76 timeout_seconds: 600
74 max_workers: 1 77 max_workers: 1
75 metrics: 78 metrics:
......
1 from __future__ import annotations 1 from __future__ import annotations
2 2
3 import json 3 import json
4 import logging
4 from typing import Any 5 from typing import Any
5 6
6 from langchain_core.documents import Document 7 from langchain_core.documents import Document
...@@ -13,6 +14,8 @@ from weknora_eval.loaders import read_jsonl, write_jsonl ...@@ -13,6 +14,8 @@ from weknora_eval.loaders import read_jsonl, write_jsonl
13 from weknora_eval.ragas_runner import _wrap_langchain_models 14 from weknora_eval.ragas_runner import _wrap_langchain_models
14 from weknora_eval.schemas import TestsetRecord 15 from weknora_eval.schemas import TestsetRecord
15 16
17 logger = logging.getLogger(__name__)
18
16 19
17 def generate_testset(config: dict[str, Any]) -> list[dict[str, Any]]: 20 def generate_testset(config: dict[str, Any]) -> list[dict[str, Any]]:
18 testset = config.get("testset", {}) 21 testset = config.get("testset", {})
...@@ -37,6 +40,11 @@ def generate_ragas_testset( ...@@ -37,6 +40,11 @@ def generate_ragas_testset(
37 ragas_config = config["ragas"] 40 ragas_config = config["ragas"]
38 size = int(testset_config.get("size", 50)) 41 size = int(testset_config.get("size", 50))
39 min_context_chars = int(testset_config.get("min_context_chars", 80)) 42 min_context_chars = int(testset_config.get("min_context_chars", 80))
43 max_document_chars = int(testset_config.get("max_document_chars", 2000))
44 source_multiplier = max(int(testset_config.get("source_multiplier", 3)), 1)
45 generator_max_tokens = int(
46 testset_config.get("generator_max_tokens", ragas_config.get("max_tokens", 4096))
47 )
40 48
41 source_rows = [ 49 source_rows = [
42 row 50 row
...@@ -47,24 +55,34 @@ def generate_ragas_testset( ...@@ -47,24 +55,34 @@ def generate_ragas_testset(
47 write_jsonl(output_path, []) 55 write_jsonl(output_path, [])
48 return [] 56 return []
49 57
58 source_limit = min(len(source_rows), max(size * source_multiplier, size, 1))
59 selected_source_rows = source_rows[:source_limit]
50 documents = [ 60 documents = [
51 Document( 61 Document(
52 page_content=row["content"], 62 page_content=_truncate_for_generation(row["content"], max_document_chars),
53 metadata={ 63 metadata={
54 "source_file": row.get("source_file"), 64 "source_file": row.get("source_file"),
55 "doc_id": row.get("doc_id"), 65 "doc_id": row.get("doc_id"),
66 "content_chars": len(row.get("content") or ""),
56 **(row.get("metadata") or {}), 67 **(row.get("metadata") or {}),
57 }, 68 },
58 ) 69 )
59 for row in source_rows 70 for row in selected_source_rows
60 ] 71 ]
72 logger.info(
73 "Generating Ragas testset: target_size=%s source_documents=%s max_document_chars=%s generator_max_tokens=%s",
74 size,
75 len(documents),
76 max_document_chars,
77 generator_max_tokens,
78 )
61 79
62 llm = ChatOpenAI( 80 llm = ChatOpenAI(
63 model=str(require_config(config, "ragas.generator_model")), 81 model=str(require_config(config, "ragas.generator_model")),
64 api_key=_required_ragas_value(ragas_config, "llm_api_key"), 82 api_key=_required_ragas_value(ragas_config, "llm_api_key"),
65 base_url=_required_ragas_value(ragas_config, "llm_base_url"), 83 base_url=_required_ragas_value(ragas_config, "llm_base_url"),
66 temperature=float(ragas_config.get("temperature", 0)), 84 temperature=float(ragas_config.get("temperature", 0)),
67 max_tokens=int(ragas_config.get("max_tokens", 4096)), 85 max_tokens=generator_max_tokens,
68 timeout=int(ragas_config.get("timeout_seconds", 600)), 86 timeout=int(ragas_config.get("timeout_seconds", 600)),
69 ) 87 )
70 embeddings = OpenAIEmbeddings( 88 embeddings = OpenAIEmbeddings(
...@@ -78,21 +96,29 @@ def generate_ragas_testset( ...@@ -78,21 +96,29 @@ def generate_ragas_testset(
78 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) 96 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
79 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) 97 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings)
80 result = generator.generate_with_langchain_docs( 98 result = generator.generate_with_langchain_docs(
81 documents[: max(size, 1)], 99 documents,
82 testset_size=size, 100 testset_size=size,
83 run_config=RunConfig( 101 run_config=RunConfig(
84 timeout=int(ragas_config.get("timeout_seconds", 600)), 102 timeout=int(ragas_config.get("timeout_seconds", 600)),
85 max_workers=int(ragas_config.get("max_workers", 1)), 103 max_workers=int(ragas_config.get("max_workers", 1)),
86 ), 104 ),
105 batch_size=1,
87 raise_exceptions=False, 106 raise_exceptions=False,
88 ) 107 )
89 108
90 ragas_rows = result.to_list() 109 ragas_rows = result.to_list()
91 rows = _normalize_ragas_rows(ragas_rows, source_rows) 110 rows = _normalize_ragas_rows(ragas_rows, selected_source_rows)
92 write_jsonl(output_path, rows) 111 write_jsonl(output_path, rows)
93 return rows 112 return rows
94 113
95 114
115 def _truncate_for_generation(content: str, max_chars: int) -> str:
116 text = " ".join((content or "").split())
117 if max_chars <= 0 or len(text) <= max_chars:
118 return text
119 return text[:max_chars].rstrip()
120
121
96 def _normalize_ragas_rows( 122 def _normalize_ragas_rows(
97 ragas_rows: list[dict[str, Any]], 123 ragas_rows: list[dict[str, Any]],
98 source_rows: list[dict[str, Any]], 124 source_rows: list[dict[str, Any]],
......