Commit 463ef51b 463ef51b76d78ca7c99fd46eb1530d33ad9369a0 by 沈秋雨

Handle Ragas testset batch size compatibility

1 parent 7ce899a9
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
2 2
3 import json 3 import json
4 import logging 4 import logging
5 import inspect
5 from typing import Any 6 from typing import Any
6 7
7 from langchain_core.documents import Document 8 from langchain_core.documents import Document
...@@ -95,16 +96,17 @@ def generate_ragas_testset( ...@@ -95,16 +96,17 @@ def generate_ragas_testset(
95 ) 96 )
96 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings) 97 ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
97 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings) 98 generator = TestsetGenerator(llm=ragas_llm, embedding_model=ragas_embeddings)
98 result = generator.generate_with_langchain_docs( 99 generate_kwargs: dict[str, Any] = {
99 documents, 100 "testset_size": size,
100 testset_size=size, 101 "run_config": RunConfig(
101 run_config=RunConfig(
102 timeout=int(ragas_config.get("timeout_seconds", 600)), 102 timeout=int(ragas_config.get("timeout_seconds", 600)),
103 max_workers=int(ragas_config.get("max_workers", 1)), 103 max_workers=int(ragas_config.get("max_workers", 1)),
104 ), 104 ),
105 batch_size=1, 105 "raise_exceptions": False,
106 raise_exceptions=False, 106 }
107 ) 107 if "batch_size" in inspect.signature(generator.generate_with_langchain_docs).parameters:
108 generate_kwargs["batch_size"] = 1
109 result = generator.generate_with_langchain_docs(documents, **generate_kwargs)
108 110
109 ragas_rows = result.to_list() 111 ragas_rows = result.to_list()
110 rows = _normalize_ragas_rows(ragas_rows, selected_source_rows) 112 rows = _normalize_ragas_rows(ragas_rows, selected_source_rows)
......