Commit 67dc9bec 67dc9becc907b5a267c3bd6d5f9884abef5e60bd by 沈秋雨

对齐报告生成步骤

1 parent f73e2a81
......@@ -115,7 +115,7 @@ def generate_summary_report(
lines.append(
f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | "
f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | "
f"{_score(row.get('faithfulness'))} | {_score(row.get('factual_correctness'))} |"
f"{_score(_metric_value(row, 'faithfulness'))} | {_score(_metric_value(row, 'factual_correctness'))} |"
)
empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts"))
......@@ -149,9 +149,27 @@ def generate_summary_report(
def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]:
if scores.empty or column not in scores.columns:
metric_column = _metric_column(scores, column)
if scores.empty or metric_column is None:
return []
return scores.sort_values(column, ascending=True).head(limit).to_dict(orient="records")
return scores.sort_values(metric_column, ascending=True).head(limit).to_dict(orient="records")
def _metric_column(scores: pd.DataFrame, name: str) -> str | None:
if name in scores.columns:
return name
prefix = f"{name}("
return next((column for column in scores.columns if column.startswith(prefix)), None)
def _metric_value(row: dict[str, Any], name: str) -> Any:
if name in row:
return row[name]
prefix = f"{name}("
for key, value in row.items():
if str(key).startswith(prefix):
return value
return None
def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]:
......
......@@ -140,20 +140,7 @@ def _generate_ragas_direct_rows(
if hasattr(ragas_llm, "set_run_config"):
ragas_llm.set_run_config(run_config)
personas = [
Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
Persona(
name="业务运营人员",
role_description="关注文档中可用于业务执行和信息核验的事实。",
),
Persona(
name="法务合规人员",
role_description="关注协议、版权、授权、责任和风险表述。",
),
]
personas = _default_personas()
synthesizer = SingleHopSpecificQuerySynthesizer(llm=ragas_llm)
rows = asyncio.run(
_generate_direct_samples(synthesizer, documents, source_rows, personas, size)
......@@ -235,7 +222,11 @@ def _generate_ragas_prechunked(
) -> Any:
embeddings = _build_embeddings(config, ragas_config)
ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings)
generator = TestsetGenerator(
llm=ragas_llm,
embedding_model=ragas_embeddings,
persona_list=_default_personas(),
)
return generator.generate_with_chunks(
documents,
testset_size=size,
......@@ -256,7 +247,11 @@ def _generate_ragas_langchain_docs(
) -> Any:
embeddings = _build_embeddings(config, ragas_config)
ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings)
generator = TestsetGenerator(
llm=ragas_llm,
embedding_model=ragas_embeddings,
persona_list=_default_personas(),
)
generate_kwargs: dict[str, Any] = {
"testset_size": size,
"query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)],
......@@ -279,6 +274,23 @@ def _is_chunk_node(node: Any) -> bool:
return getattr(getattr(node, "type", None), "name", "") == "CHUNK"
def _default_personas() -> list[Persona]:
return [
Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
Persona(
name="业务运营人员",
role_description="关注文档中可用于业务执行和信息核验的事实。",
),
Persona(
name="法务合规人员",
role_description="关注协议、版权、授权、责任和风险表述。",
),
]
def _build_embeddings(
config: dict[str, Any], ragas_config: dict[str, Any]
) -> OpenAIEmbeddings:
......