对齐报告生成步骤
Showing
2 changed files
with
49 additions
and
19 deletions
| ... | @@ -115,7 +115,7 @@ def generate_summary_report( | ... | @@ -115,7 +115,7 @@ def generate_summary_report( |
| 115 | lines.append( | 115 | lines.append( |
| 116 | f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | " | 116 | f"| {row.get('sample_id', '')} | {_cell(sample.get('user_input'))} | " |
| 117 | f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | " | 117 | f"{_cell(sample.get('response'))} | {_cell(sample.get('reference'))} | " |
| 118 | f"{_score(row.get('faithfulness'))} | {_score(row.get('factual_correctness'))} |" | 118 | f"{_score(_metric_value(row, 'faithfulness'))} | {_score(_metric_value(row, 'factual_correctness'))} |" |
| 119 | ) | 119 | ) |
| 120 | 120 | ||
| 121 | empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts")) | 121 | empty_retrievals = sum(1 for row in ragas_rows if not row.get("retrieved_contexts")) |
| ... | @@ -149,9 +149,27 @@ def generate_summary_report( | ... | @@ -149,9 +149,27 @@ def generate_summary_report( |
| 149 | 149 | ||
| 150 | 150 | ||
| 151 | def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]: | 151 | def _worst_rows(scores: pd.DataFrame, column: str, *, limit: int = 10) -> list[dict[str, Any]]: |
| 152 | if scores.empty or column not in scores.columns: | 152 | metric_column = _metric_column(scores, column) |
| 153 | if scores.empty or metric_column is None: | ||
| 153 | return [] | 154 | return [] |
| 154 | return scores.sort_values(column, ascending=True).head(limit).to_dict(orient="records") | 155 | return scores.sort_values(metric_column, ascending=True).head(limit).to_dict(orient="records") |
| 156 | |||
| 157 | |||
| 158 | def _metric_column(scores: pd.DataFrame, name: str) -> str | None: | ||
| 159 | if name in scores.columns: | ||
| 160 | return name | ||
| 161 | prefix = f"{name}(" | ||
| 162 | return next((column for column in scores.columns if column.startswith(prefix)), None) | ||
| 163 | |||
| 164 | |||
| 165 | def _metric_value(row: dict[str, Any], name: str) -> Any: | ||
| 166 | if name in row: | ||
| 167 | return row[name] | ||
| 168 | prefix = f"{name}(" | ||
| 169 | for key, value in row.items(): | ||
| 170 | if str(key).startswith(prefix): | ||
| 171 | return value | ||
| 172 | return None | ||
| 155 | 173 | ||
| 156 | 174 | ||
| 157 | def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]: | 175 | def _sample_by_id(rows: list[dict[str, Any]], sample_id: Any) -> dict[str, Any]: | ... | ... |
| ... | @@ -140,20 +140,7 @@ def _generate_ragas_direct_rows( | ... | @@ -140,20 +140,7 @@ def _generate_ragas_direct_rows( |
| 140 | if hasattr(ragas_llm, "set_run_config"): | 140 | if hasattr(ragas_llm, "set_run_config"): |
| 141 | ragas_llm.set_run_config(run_config) | 141 | ragas_llm.set_run_config(run_config) |
| 142 | 142 | ||
| 143 | personas = [ | 143 | personas = _default_personas() |
| 144 | Persona( | ||
| 145 | name="合同审核人员", | ||
| 146 | role_description="关注合同条款、权利归属、授权范围和履约义务。", | ||
| 147 | ), | ||
| 148 | Persona( | ||
| 149 | name="业务运营人员", | ||
| 150 | role_description="关注文档中可用于业务执行和信息核验的事实。", | ||
| 151 | ), | ||
| 152 | Persona( | ||
| 153 | name="法务合规人员", | ||
| 154 | role_description="关注协议、版权、授权、责任和风险表述。", | ||
| 155 | ), | ||
| 156 | ] | ||
| 157 | synthesizer = SingleHopSpecificQuerySynthesizer(llm=ragas_llm) | 144 | synthesizer = SingleHopSpecificQuerySynthesizer(llm=ragas_llm) |
| 158 | rows = asyncio.run( | 145 | rows = asyncio.run( |
| 159 | _generate_direct_samples(synthesizer, documents, source_rows, personas, size) | 146 | _generate_direct_samples(synthesizer, documents, source_rows, personas, size) |
| ... | @@ -235,7 +222,11 @@ def _generate_ragas_prechunked( | ... | @@ -235,7 +222,11 @@ def _generate_ragas_prechunked( |
| 235 | ) -> Any: | 222 | ) -> Any: |
| 236 | embeddings = _build_embeddings(config, ragas_config) | 223 | embeddings = _build_embeddings(config, ragas_config) |
| 237 | ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) | 224 | ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) |
| 238 | generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) | 225 | generator = TestsetGenerator( |
| 226 | llm=ragas_llm, | ||
| 227 | embedding_model=ragas_embeddings, | ||
| 228 | persona_list=_default_personas(), | ||
| 229 | ) | ||
| 239 | return generator.generate_with_chunks( | 230 | return generator.generate_with_chunks( |
| 240 | documents, | 231 | documents, |
| 241 | testset_size=size, | 232 | testset_size=size, |
| ... | @@ -256,7 +247,11 @@ def _generate_ragas_langchain_docs( | ... | @@ -256,7 +247,11 @@ def _generate_ragas_langchain_docs( |
| 256 | ) -> Any: | 247 | ) -> Any: |
| 257 | embeddings = _build_embeddings(config, ragas_config) | 248 | embeddings = _build_embeddings(config, ragas_config) |
| 258 | ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) | 249 | ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) |
| 259 | generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) | 250 | generator = TestsetGenerator( |
| 251 | llm=ragas_llm, | ||
| 252 | embedding_model=ragas_embeddings, | ||
| 253 | persona_list=_default_personas(), | ||
| 254 | ) | ||
| 260 | generate_kwargs: dict[str, Any] = { | 255 | generate_kwargs: dict[str, Any] = { |
| 261 | "testset_size": size, | 256 | "testset_size": size, |
| 262 | "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)], | 257 | "query_distribution": [(SingleHopSpecificQuerySynthesizer(llm=ragas_llm), 1.0)], |
| ... | @@ -279,6 +274,23 @@ def _is_chunk_node(node: Any) -> bool: | ... | @@ -279,6 +274,23 @@ def _is_chunk_node(node: Any) -> bool: |
| 279 | return getattr(getattr(node, "type", None), "name", "") == "CHUNK" | 274 | return getattr(getattr(node, "type", None), "name", "") == "CHUNK" |
| 280 | 275 | ||
| 281 | 276 | ||
| 277 | def _default_personas() -> list[Persona]: | ||
| 278 | return [ | ||
| 279 | Persona( | ||
| 280 | name="合同审核人员", | ||
| 281 | role_description="关注合同条款、权利归属、授权范围和履约义务。", | ||
| 282 | ), | ||
| 283 | Persona( | ||
| 284 | name="业务运营人员", | ||
| 285 | role_description="关注文档中可用于业务执行和信息核验的事实。", | ||
| 286 | ), | ||
| 287 | Persona( | ||
| 288 | name="法务合规人员", | ||
| 289 | role_description="关注协议、版权、授权、责任和风险表述。", | ||
| 290 | ), | ||
| 291 | ] | ||
| 292 | |||
| 293 | |||
| 282 | def _build_embeddings( | 294 | def _build_embeddings( |
| 283 | config: dict[str, Any], ragas_config: dict[str, Any] | 295 | config: dict[str, Any], ragas_config: dict[str, Any] |
| 284 | ) -> OpenAIEmbeddings: | 296 | ) -> OpenAIEmbeddings: | ... | ... |
-
Please register or sign in to post a comment