Commit 3c79a5fd 3c79a5fdc197c183caf744426fd290018310a85f by 沈秋雨

Bypass Ragas scenario generation in direct mode

1 parent abad6fce
1 from __future__ import annotations 1 from __future__ import annotations
2 2
3 import asyncio
4 import inspect
3 import json 5 import json
4 import logging 6 import logging
5 import inspect
6 from typing import Any 7 from typing import Any
7 8
8 from langchain_core.documents import Document 9 from langchain_core.documents import Document
9 from langchain_openai import ChatOpenAI, OpenAIEmbeddings 10 from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10 from ragas.run_config import RunConfig 11 from ragas.run_config import RunConfig
11 from ragas.testset.graph import KnowledgeGraph, Node, NodeType
12 from ragas.testset.persona import Persona
13 from ragas.testset import TestsetGenerator 12 from ragas.testset import TestsetGenerator
14 from ragas.testset.synthesizers.single_hop.specific import SingleHopSpecificQuerySynthesizer 13 from ragas.testset.graph import Node, NodeType
14 from ragas.testset.persona import Persona
15 from ragas.testset.synthesizers.base import QueryLength, QueryStyle
16 from ragas.testset.synthesizers.single_hop.base import (
17 SingleHopQuerySynthesizer,
18 SingleHopScenario,
19 )
15 20
16 from weknora_eval.config import require_config 21 from weknora_eval.config import require_config
17 from weknora_eval.loaders import read_jsonl, write_jsonl 22 from weknora_eval.loaders import read_jsonl, write_jsonl
...@@ -96,7 +101,11 @@ def generate_ragas_testset( ...@@ -96,7 +101,11 @@ def generate_ragas_testset(
96 max_workers=int(ragas_config.get("max_workers", 1)), 101 max_workers=int(ragas_config.get("max_workers", 1)),
97 ) 102 )
98 if ragas_mode == "direct": 103 if ragas_mode == "direct":
99 result = _generate_ragas_direct(llm, documents, size, run_config) 104 rows = _generate_ragas_direct_rows(
105 llm, documents, selected_source_rows, size, run_config
106 )
107 write_jsonl(output_path, rows)
108 return rows
100 elif ragas_mode == "prechunked": 109 elif ragas_mode == "prechunked":
101 result = _generate_ragas_prechunked( 110 result = _generate_ragas_prechunked(
102 config, ragas_config, llm, documents, size, run_config 111 config, ragas_config, llm, documents, size, run_config
...@@ -114,57 +123,100 @@ def generate_ragas_testset( ...@@ -114,57 +123,100 @@ def generate_ragas_testset(
114 return rows 123 return rows
115 124
116 125
117 def _generate_ragas_direct( 126 def _generate_ragas_direct_rows(
118 llm: ChatOpenAI, 127 llm: ChatOpenAI,
119 documents: list[Document], 128 documents: list[Document],
129 source_rows: list[dict[str, Any]],
120 size: int, 130 size: int,
121 run_config: RunConfig, 131 run_config: RunConfig,
122 ) -> Any: 132 ) -> list[dict[str, Any]]:
123 ragas_llm = _wrap_langchain_llm(llm) 133 ragas_llm = _wrap_langchain_llm(llm)
124 kg = KnowledgeGraph( 134 if hasattr(ragas_llm, "set_run_config"):
125 nodes=[ 135 ragas_llm.set_run_config(run_config)
126 Node( 136
127 type=NodeType.CHUNK, 137 personas = [
128 properties={ 138 Persona(
129 "page_content": document.page_content, 139 name="合同审核人员",
130 "document_metadata": document.metadata, 140 role_description="关注合同条款、权利归属、授权范围和履约义务。",
131 "entities": _generation_terms(document), 141 ),
132 "themes": _generation_terms(document), 142 Persona(
133 }, 143 name="业务运营人员",
134 ) 144 role_description="关注文档中可用于业务执行和信息核验的事实。",
135 for document in documents 145 ),
136 if document.page_content.strip() 146 Persona(
137 ] 147 name="法务合规人员",
138 ) 148 role_description="关注协议、版权、授权、责任和风险表述。",
139 generator = TestsetGenerator( 149 ),
140 llm=ragas_llm, 150 ]
141 embedding_model=None, 151 synthesizer = SingleHopQuerySynthesizer(llm=ragas_llm)
142 knowledge_graph=kg, 152 rows = asyncio.run(
143 persona_list=[ 153 _generate_direct_samples(synthesizer, documents, source_rows, personas, size)
144 Persona(
145 name="合同审核人员",
146 role_description="关注合同条款、权利归属、授权范围和履约义务。",
147 ),
148 Persona(
149 name="业务运营人员",
150 role_description="关注文档中可用于业务执行和信息核验的事实。",
151 ),
152 Persona(
153 name="法务合规人员",
154 role_description="关注协议、版权、授权、责任和风险表述。",
155 ),
156 ],
157 ) 154 )
158 generate_kwargs: dict[str, Any] = { 155 logger.info("Generated %s Ragas direct QA samples", len(rows))
159 "testset_size": size, 156 return rows
160 "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)], 157
161 "num_personas": 3, 158
162 "run_config": run_config, 159 async def _generate_direct_samples(
163 "raise_exceptions": False, 160 synthesizer: SingleHopQuerySynthesizer,
164 } 161 documents: list[Document],
165 if "batch_size" in inspect.signature(generator.generate).parameters: 162 source_rows: list[dict[str, Any]],
166 generate_kwargs["batch_size"] = 1 163 personas: list[Persona],
167 return generator.generate(**generate_kwargs) 164 size: int,
165 ) -> list[dict[str, Any]]:
166 rows: list[dict[str, Any]] = []
167 styles = [QueryStyle.PERFECT_GRAMMAR, QueryStyle.WEB_SEARCH_LIKE]
168 lengths = [QueryLength.MEDIUM, QueryLength.SHORT]
169 for index, (document, source) in enumerate(zip(documents, source_rows), start=1):
170 if len(rows) >= size:
171 break
172 term = _generation_terms(document)[0]
173 node = Node(
174 type=NodeType.CHUNK,
175 properties={
176 "page_content": document.page_content,
177 "document_metadata": document.metadata,
178 },
179 )
180 scenario = SingleHopScenario(
181 nodes=[node],
182 term=term,
183 persona=personas[(index - 1) % len(personas)],
184 style=styles[(index - 1) % len(styles)],
185 length=lengths[(index - 1) % len(lengths)],
186 )
187 try:
188 sample = await synthesizer.generate_sample(scenario)
189 except Exception as exc: # noqa: BLE001
190 logger.warning(
191 "Ragas direct QA generation failed for source_file=%s doc_id=%s: %s",
192 source.get("source_file"),
193 source.get("doc_id"),
194 exc,
195 )
196 continue
197
198 chunk_id = (source.get("metadata") or {}).get("chunk_id") or source.get(
199 "doc_id"
200 )
201 rows.append(
202 TestsetRecord(
203 sample_id=f"qa-{len(rows) + 1:04d}",
204 user_input=str(sample.user_input or "").strip(),
205 reference=str(sample.reference or "").strip(),
206 reference_contexts=[document.page_content],
207 source_file=source.get("source_file"),
208 gold_chunk_ids=[str(chunk_id)] if chunk_id else [],
209 question_type="ragas_single_hop_direct",
210 review_status="pending",
211 ).to_dict()
212 )
213 return [
214 row
215 for row in rows
216 if row.get("user_input")
217 and row.get("reference")
218 and row.get("reference_contexts")
219 ]
168 220
169 221
170 def _generate_ragas_prechunked( 222 def _generate_ragas_prechunked(
......