Commit 67dc9bec 67dc9becc907b5a267c3bd6d5f9884abef5e60bd by 沈秋雨

对齐报告生成步骤

1 parent f73e2a81
...@@ -115,7 +115,7 @@ def generate_summary_report( ...@@ -115,7 +115,7 @@ def generate_summary_report(
115 lines.append( 115 lines.append(
116 f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | " 116 f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | "
117 f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | " 117 f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | "
118 f"{_score(row.get('faithfulness'))} | {_score(row.get('factual_correctness'))} |" 118 f"{_score(_metric_value(row, 'faithfulness'))} | {_score(_metric_value(row, 'factual_correctness'))} |"
119 ) 119 )
120 120
121 empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts")) 121 empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts"))
...@@ -149,9 +149,27 @@ def generate_summary_report( ...@@ -149,9 +149,27 @@ def generate_summary_report(
149 149
150 150
151 def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]: 151 def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]:
152 if scores.empty or column not in scores.columns: 152 metric_column = _metric_column(scores, column)
153 if scores.empty or metric_column is None:
153 return [] 154 return []
154 return scores.sort_values(column, ascending=True).head(limit).to_dict(orient="records") 155 return scores.sort_values(metric_column, ascending=True).head(limit).to_dict(orient="records")
156
157
158 def _metric_column(scores: pd.DataFrame, name: str) -> str | None:
159 if name in scores.columns:
160 return name
161 prefix = f"{name}("
162 return next((column for column in scores.columns if column.startswith(prefix)), None)
163
164
165 def _metric_value(row: dict[str, Any], name: str) -> Any:
166 if name in row:
167 return row[name]
168 prefix = f"{name}("
169 for key, value in row.items():
170 if str(key).startswith(prefix):
171 return value
172 return None
155 173
156 174
157 def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]: 175 def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]:
......
...@@ -140,20 +140,7 @@ def _generate_ragas_direct_rows( ...@@ -140,20 +140,7 @@ def _generate_ragas_direct_rows(
140 if hasattr(ragas_llm, "set_run_config"): 140 if hasattr(ragas_llm, "set_run_config"):
141 ragas_llm.set_run_config(run_config) 141 ragas_llm.set_run_config(run_config)
142 142
143 personas = [ 143 personas = _default_personas()
144 Persona(
145 name="合同审核人员",
146 role_description="关注合同条款、权利归属、授权范围和履约义务。",
147 ),
148 Persona(
149 name="业务运营人员",
150 role_description="关注文档中可用于业务执行和信息核验的事实。",
151 ),
152 Persona(
153 name="法务合规人员",
154 role_description="关注协议、版权、授权、责任和风险表述。",
155 ),
156 ]
157 synthesizer = SingleHopSpecificQuerySynthesizer(llm=ragas_llm) 144 synthesizer = SingleHopSpecificQuerySynthesizer(llm=ragas_llm)
158 rows = asyncio.run( 145 rows = asyncio.run(
159 _generate_direct_samples(synthesizer, documents, source_rows, personas, size) 146 _generate_direct_samples(synthesizer, documents, source_rows, personas, size)
...@@ -235,7 +222,11 @@ def _generate_ragas_prechunked( ...@@ -235,7 +222,11 @@ def _generate_ragas_prechunked(
235 ) -> Any: 222 ) -> Any:
236 embeddings = _build_embeddings(config, ragas_config) 223 embeddings = _build_embeddings(config, ragas_config)
237 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) 224 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
238 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) 225 generator = TestsetGenerator(
226 llm=ragas_llm,
227 embedding_model=ragas_embeddings,
228 persona_list=_default_personas(),
229 )
239 return generator.generate_with_chunks( 230 return generator.generate_with_chunks(
240 documents, 231 documents,
241 testset_size=size, 232 testset_size=size,
...@@ -256,7 +247,11 @@ def _generate_ragas_langchain_docs( ...@@ -256,7 +247,11 @@ def _generate_ragas_langchain_docs(
256 ) -> Any: 247 ) -> Any:
257 embeddings = _build_embeddings(config, ragas_config) 248 embeddings = _build_embeddings(config, ragas_config)
258 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) 249 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
259 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) 250 generator = TestsetGenerator(
251 llm=ragas_llm,
252 embedding_model=ragas_embeddings,
253 persona_list=_default_personas(),
254 )
260 generate_kwargs: dict[str, Any] = { 255 generate_kwargs: dict[str, Any] = {
261 "testset_size": size, 256 "testset_size": size,
262 "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)], 257 "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)],
...@@ -279,6 +274,23 @@ def _is_chunk_node(node: Any) -> bool: ...@@ -279,6 +274,23 @@ def _is_chunk_node(node: Any) -> bool:
279 return getattr(getattr(node, "type", None), "name", "") == "CHUNK" 274 return getattr(getattr(node, "type", None), "name", "") == "CHUNK"
280 275
281 276
277 def _default_personas() -> list[Persona]:
278 return [
279 Persona(
280 name="合同审核人员",
281 role_description="关注合同条款、权利归属、授权范围和履约义务。",
282 ),
283 Persona(
284 name="业务运营人员",
285 role_description="关注文档中可用于业务执行和信息核验的事实。",
286 ),
287 Persona(
288 name="法务合规人员",
289 role_description="关注协议、版权、授权、责任和风险表述。",
290 ),
291 ]
292
293
282 def _build_embeddings( 294 def _build_embeddings(
283 config: dict[str, Any], ragas_config: dict[str, Any] 295 config: dict[str, Any], ragas_config: dict[str, Any]
284 ) -> OpenAIEmbeddings: 296 ) -> OpenAIEmbeddings:
......