Commit 3c79a5fd 3c79a5fdc197c183caf744426fd290018310a85f by 沈秋雨

Bypass Ragas scenario generation in direct mode

1 parent abad6fce
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import inspect
from typing import Any
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas.run_config import RunConfig
from ragas.testset.graph import KnowledgeGraph, Node, NodeType
from ragas.testset.persona import Persona
from ragas.testset import TestsetGenerator
from ragas.testset.synthesizers.single_hop.specific import SingleHopSpecificQuerySynthesizer
from ragas.testset.graph import Node, NodeType
from ragas.testset.persona import Persona
from ragas.testset.synthesizers.base import QueryLength, QueryStyle
from ragas.testset.synthesizers.single_hop.base import (
SingleHopQuerySynthesizer,
SingleHopScenario,
)
from weknora_eval.config import require_config
from weknora_eval.loaders import read_jsonl, write_jsonl
......@@ -96,7 +101,11 @@ def generate_ragas_testset(
max_workers=int(ragas_config.get("max_workers", 1)),
)
if ragas_mode == "direct":
result = _generate_ragas_direct(llm, documents, size, run_config)
rows = _generate_ragas_direct_rows(
llm, documents, selected_source_rows, size, run_config
)
write_jsonl(output_path, rows)
return rows
elif ragas_mode == "prechunked":
result = _generate_ragas_prechunked(
config, ragas_config, llm, documents, size, run_config
......@@ -114,57 +123,100 @@ def generate_ragas_testset(
return rows
def _generate_ragas_direct(
def _generate_ragas_direct_rows(
llm: ChatOpenAI,
documents: list[Document],
source_rows: list[dict[str, Any]],
size: int,
run_config: RunConfig,
) -> Any:
) -> list[dict[str, Any]]:
ragas_llm = _wrap_langchain_llm(llm)
kg = KnowledgeGraph(
nodes=[
Node(
type=NodeType.CHUNK,
properties={
"page_content": document.page_content,
"document_metadata": document.metadata,
"entities": _generation_terms(document),
"themes": _generation_terms(document),
},
)
for document in documents
if document.page_content.strip()
]
)
generator = TestsetGenerator(
llm=ragas_llm,
embedding_model=None,
knowledge_graph=kg,
persona_list=[
Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
Persona(
name="业务运营人员",
role_description="关注文档中可用于业务执行和信息核验的事实。",
),
Persona(
name="法务合规人员",
role_description="关注协议、版权、授权、责任和风险表述。",
),
],
if hasattr(ragas_llm, "set_run_config"):
ragas_llm.set_run_config(run_config)
personas = [
Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
Persona(
name="业务运营人员",
role_description="关注文档中可用于业务执行和信息核验的事实。",
),
Persona(
name="法务合规人员",
role_description="关注协议、版权、授权、责任和风险表述。",
),
]
synthesizer = SingleHopQuerySynthesizer(llm=ragas_llm)
rows = asyncio.run(
_generate_direct_samples(synthesizer, documents, source_rows, personas, size)
)
generate_kwargs: dict[str, Any] = {
"testset_size": size,
"query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)],
"num_personas": 3,
"run_config": run_config,
"raise_exceptions": False,
}
if "batch_size" in inspect.signature(generator.generate).parameters:
generate_kwargs["batch_size"] = 1
return generator.generate(**generate_kwargs)
logger.info("Generated %s Ragas direct QA samples", len(rows))
return rows
async def _generate_direct_samples(
synthesizer: SingleHopQuerySynthesizer,
documents: list[Document],
source_rows: list[dict[str, Any]],
personas: list[Persona],
size: int,
) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
styles = [QueryStyle.PERFECT_GRAMMAR, QueryStyle.WEB_SEARCH_LIKE]
lengths = [QueryLength.MEDIUM, QueryLength.SHORT]
for index, (document, source) in enumerate(zip(documents, source_rows), start=1):
if len(rows) >= size:
break
term = _generation_terms(document)[0]
node = Node(
type=NodeType.CHUNK,
properties={
"page_content": document.page_content,
"document_metadata": document.metadata,
},
)
scenario = SingleHopScenario(
nodes=[node],
term=term,
persona=personas[(index - 1) % len(personas)],
style=styles[(index - 1) % len(styles)],
length=lengths[(index - 1) % len(lengths)],
)
try:
sample = await synthesizer.generate_sample(scenario)
except Exception as exc: # noqa: BLE001
logger.warning(
"Ragas direct QA generation failed for source_file=%s doc_id=%s: %s",
source.get("source_file"),
source.get("doc_id"),
exc,
)
continue
chunk_id = (source.get("metadata") or {}).get("chunk_id") or source.get(
"doc_id"
)
rows.append(
TestsetRecord(
sample_id=f"qa-{len(rows) + 1:04d}",
user_input=str(sample.user_input or "").strip(),
reference=str(sample.reference or "").strip(),
reference_contexts=[document.page_content],
source_file=source.get("source_file"),
gold_chunk_ids=[str(chunk_id)] if chunk_id else [],
question_type="ragas_single_hop_direct",
review_status="pending",
).to_dict()
)
return [
row
for row in rows
if row.get("user_input")
and row.get("reference")
and row.get("reference_contexts")
]
def _generate_ragas_prechunked(
......