Bypass Ragas scenario generation in direct mode
Showing
1 changed file
with
102 additions
and
50 deletions
| 1 | from __future__ import annotations | 1 | from __future__ import annotations |
| 2 | 2 | ||
| 3 | import asyncio | ||
| 4 | import inspect | ||
| 3 | import json | 5 | import json |
| 4 | import logging | 6 | import logging |
| 5 | import inspect | ||
| 6 | from typing import Any | 7 | from typing import Any |
| 7 | 8 | ||
| 8 | from langchain_core.documents import Document | 9 | from langchain_core.documents import Document |
| 9 | from langchain_openai import ChatOpenAI, OpenAIEmbeddings | 10 | from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
| 10 | from ragas.run_config import RunConfig | 11 | from ragas.run_config import RunConfig |
| 11 | from ragas.testset.graph import KnowledgeGraph, Node, NodeType | ||
| 12 | from ragas.testset.persona import Persona | ||
| 13 | from ragas.testset import TestsetGenerator | 12 | from ragas.testset import TestsetGenerator |
| 14 | from ragas.testset.synthesizers.single_hop.specific import SingleHopSpecificQuerySynthesizer | 13 | from ragas.testset.graph import Node, NodeType |
| 14 | from ragas.testset.persona import Persona | ||
| 15 | from ragas.testset.synthesizers.base import QueryLength, QueryStyle | ||
| 16 | from ragas.testset.synthesizers.single_hop.base import ( | ||
| 17 | SingleHopQuerySynthesizer, | ||
| 18 | SingleHopScenario, | ||
| 19 | ) | ||
| 15 | 20 | ||
| 16 | from weknora_eval.config import require_config | 21 | from weknora_eval.config import require_config |
| 17 | from weknora_eval.loaders import read_jsonl, write_jsonl | 22 | from weknora_eval.loaders import read_jsonl, write_jsonl |
| ... | @@ -96,7 +101,11 @@ def generate_ragas_testset( | ... | @@ -96,7 +101,11 @@ def generate_ragas_testset( |
| 96 | max_workers=int(ragas_config.get("max_workers", 1)), | 101 | max_workers=int(ragas_config.get("max_workers", 1)), |
| 97 | ) | 102 | ) |
| 98 | if ragas_mode == "direct": | 103 | if ragas_mode == "direct": |
| 99 | result = _generate_ragas_direct(llm, documents, size, run_config) | 104 | rows = _generate_ragas_direct_rows( |
| 105 | llm, documents, selected_source_rows, size, run_config | ||
| 106 | ) | ||
| 107 | write_jsonl(output_path, rows) | ||
| 108 | return rows | ||
| 100 | elif ragas_mode == "prechunked": | 109 | elif ragas_mode == "prechunked": |
| 101 | result = _generate_ragas_prechunked( | 110 | result = _generate_ragas_prechunked( |
| 102 | config, ragas_config, llm, documents, size, run_config | 111 | config, ragas_config, llm, documents, size, run_config |
| ... | @@ -114,57 +123,100 @@ def generate_ragas_testset( | ... | @@ -114,57 +123,100 @@ def generate_ragas_testset( |
| 114 | return rows | 123 | return rows |
| 115 | 124 | ||
| 116 | 125 | ||
| 117 | def _generate_ragas_direct( | 126 | def _generate_ragas_direct_rows( |
| 118 | llm: ChatOpenAI, | 127 | llm: ChatOpenAI, |
| 119 | documents: list[Document], | 128 | documents: list[Document], |
| 129 | source_rows: list[dict[str, Any]], | ||
| 120 | size: int, | 130 | size: int, |
| 121 | run_config: RunConfig, | 131 | run_config: RunConfig, |
| 122 | ) -> Any: | 132 | ) -> list[dict[str, Any]]: |
| 123 | ragas_llm = _wrap_langchain_llm(llm) | 133 | ragas_llm = _wrap_langchain_llm(llm) |
| 124 | kg = KnowledgeGraph( | 134 | if hasattr(ragas_llm, "set_run_config"): |
| 125 | nodes=[ | 135 | ragas_llm.set_run_config(run_config) |
| 126 | Node( | 136 | |
| 127 | type=NodeType.CHUNK, | 137 | personas = [ |
| 128 | properties={ | 138 | Persona( |
| 129 | "page_content": document.page_content, | 139 | name="合同审核人员", |
| 130 | "document_metadata": document.metadata, | 140 | role_description="关注合同条款、权利归属、授权范围和履约义务。", |
| 131 | "entities": _generation_terms(document), | 141 | ), |
| 132 | "themes": _generation_terms(document), | 142 | Persona( |
| 133 | }, | 143 | name="业务运营人员", |
| 134 | ) | 144 | role_description="关注文档中可用于业务执行和信息核验的事实。", |
| 135 | for document in documents | 145 | ), |
| 136 | if document.page_content.strip() | 146 | Persona( |
| 137 | ] | 147 | name="法务合规人员", |
| 138 | ) | 148 | role_description="关注协议、版权、授权、责任和风险表述。", |
| 139 | generator = TestsetGenerator( | 149 | ), |
| 140 | llm=ragas_llm, | 150 | ] |
| 141 | embedding_model=None, | 151 | synthesizer = SingleHopQuerySynthesizer(llm=ragas_llm) |
| 142 | knowledge_graph=kg, | 152 | rows = asyncio.run( |
| 143 | persona_list=[ | 153 | _generate_direct_samples(synthesizer, documents, source_rows, personas, size) |
| 144 | Persona( | ||
| 145 | name="合同审核人员", | ||
| 146 | role_description="关注合同条款、权利归属、授权范围和履约义务。", | ||
| 147 | ), | ||
| 148 | Persona( | ||
| 149 | name="业务运营人员", | ||
| 150 | role_description="关注文档中可用于业务执行和信息核验的事实。", | ||
| 151 | ), | ||
| 152 | Persona( | ||
| 153 | name="法务合规人员", | ||
| 154 | role_description="关注协议、版权、授权、责任和风险表述。", | ||
| 155 | ), | ||
| 156 | ], | ||
| 157 | ) | 154 | ) |
| 158 | generate_kwargs: dict[str, Any] = { | 155 | logger.info("Generated %s Ragas direct QA samples", len(rows)) |
| 159 | "testset_size": size, | 156 | return rows |
| 160 | "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)], | 157 | |
| 161 | "num_personas": 3, | 158 | |
| 162 | "run_config": run_config, | 159 | async def _generate_direct_samples( |
| 163 | "raise_exceptions": False, | 160 | synthesizer: SingleHopQuerySynthesizer, |
| 164 | } | 161 | documents: list[Document], |
| 165 | if "batch_size" in inspect.signature(generator.generate).parameters: | 162 | source_rows: list[dict[str, Any]], |
| 166 | generate_kwargs["batch_size"] = 1 | 163 | personas: list[Persona], |
| 167 | return generator.generate(**generate_kwargs) | 164 | size: int, |
| 165 | ) -> list[dict[str, Any]]: | ||
| 166 | rows: list[dict[str, Any]] = [] | ||
| 167 | styles = [QueryStyle.PERFECT_GRAMMAR, QueryStyle.WEB_SEARCH_LIKE] | ||
| 168 | lengths = [QueryLength.MEDIUM, QueryLength.SHORT] | ||
| 169 | for index, (document, source) in enumerate(zip(documents, source_rows), start=1): | ||
| 170 | if len(rows) >= size: | ||
| 171 | break | ||
| 172 | term = _generation_terms(document)[0] | ||
| 173 | node = Node( | ||
| 174 | type=NodeType.CHUNK, | ||
| 175 | properties={ | ||
| 176 | "page_content": document.page_content, | ||
| 177 | "document_metadata": document.metadata, | ||
| 178 | }, | ||
| 179 | ) | ||
| 180 | scenario = SingleHopScenario( | ||
| 181 | nodes=[node], | ||
| 182 | term=term, | ||
| 183 | persona=personas[(index - 1) % len(personas)], | ||
| 184 | style=styles[(index - 1) % len(styles)], | ||
| 185 | length=lengths[(index - 1) % len(lengths)], | ||
| 186 | ) | ||
| 187 | try: | ||
| 188 | sample = await synthesizer.generate_sample(scenario) | ||
| 189 | except Exception as exc: # noqa: BLE001 | ||
| 190 | logger.warning( | ||
| 191 | "Ragas direct QA generation failed for source_file=%s doc_id=%s: %s", | ||
| 192 | source.get("source_file"), | ||
| 193 | source.get("doc_id"), | ||
| 194 | exc, | ||
| 195 | ) | ||
| 196 | continue | ||
| 197 | |||
| 198 | chunk_id = (source.get("metadata") or {}).get("chunk_id") or source.get( | ||
| 199 | "doc_id" | ||
| 200 | ) | ||
| 201 | rows.append( | ||
| 202 | TestsetRecord( | ||
| 203 | sample_id=f"qa-{len(rows) + 1:04d}", | ||
| 204 | user_input=str(sample.user_input or "").strip(), | ||
| 205 | reference=str(sample.reference or "").strip(), | ||
| 206 | reference_contexts=[document.page_content], | ||
| 207 | source_file=source.get("source_file"), | ||
| 208 | gold_chunk_ids=[str(chunk_id)] if chunk_id else [], | ||
| 209 | question_type="ragas_single_hop_direct", | ||
| 210 | review_status="pending", | ||
| 211 | ).to_dict() | ||
| 212 | ) | ||
| 213 | return [ | ||
| 214 | row | ||
| 215 | for row in rows | ||
| 216 | if row.get("user_input") | ||
| 217 | and row.get("reference") | ||
| 218 | and row.get("reference_contexts") | ||
| 219 | ] | ||
| 168 | 220 | ||
| 169 | 221 | ||
| 170 | def _generate_ragas_prechunked( | 222 | def _generate_ragas_prechunked( | ... | ... |
-
Please register or sign in to post a comment