ragas_runner.py
4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import pandas as pd
from weknora_eval.config import require_config
from weknora_eval.loaders import read_jsonl
def run_ragas_eval(
config: dict[str, Any],
*,
input_path: str = "data/runs/ragas_input.jsonl",
output_csv_path: str = "data/reports/ragas_scores.csv",
) -> pd.DataFrame:
from datasets import Dataset
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas import evaluate
from ragas.run_config import RunConfig
ragas_config = config["ragas"]
llm_api_key = _required_ragas_value(ragas_config, "llm_api_key")
llm_base_url = _required_ragas_value(ragas_config, "llm_base_url")
embedding_api_key = _required_ragas_value(ragas_config, "embedding_api_key")
embedding_base_url = _required_ragas_value(ragas_config, "embedding_base_url")
judge_model = str(require_config(config, "ragas.judge_model"))
embedding_model = str(require_config(config, "ragas.embedding_model"))
temperature = float(ragas_config.get("temperature", 0))
max_tokens = int(ragas_config.get("max_tokens", 4096))
timeout_seconds = int(ragas_config.get("timeout_seconds", 600))
max_workers = int(ragas_config.get("max_workers", 1))
os.environ["OPENAI_API_KEY"] = llm_api_key
if llm_base_url:
os.environ["OPENAI_BASE_URL"] = llm_base_url
rows = read_jsonl(input_path)
dataset = Dataset.from_list(
[
{
"user_input": row["user_input"],
"response": row["response"],
"retrieved_contexts": row["retrieved_contexts"],
"reference": row["reference"],
"reference_contexts": row.get("reference_contexts") or [],
}
for row in rows
]
)
metric_map = _metric_map()
selected_metrics = [
metric_map[name]
for name in ragas_config.get("metrics", metric_map.keys())
if name in metric_map
]
llm = ChatOpenAI(
model=judge_model,
api_key=llm_api_key,
base_url=llm_base_url or None,
temperature=temperature,
max_tokens=max_tokens,
)
embeddings = OpenAIEmbeddings(
model=embedding_model,
api_key=embedding_api_key,
base_url=embedding_base_url or None,
tiktoken_enabled=False,
check_embedding_ctx_length=False,
)
ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)
run_config = RunConfig(timeout=timeout_seconds, max_workers=max_workers)
result = evaluate(
dataset,
metrics=selected_metrics,
llm=ragas_llm,
embeddings=ragas_embeddings,
run_config=run_config,
)
scores = result.to_pandas()
for index, row in enumerate(rows):
scores.loc[index, "sample_id"] = row.get("sample_id")
target = Path(output_csv_path)
target.parent.mkdir(parents=True, exist_ok=True)
scores.to_csv(target, index=False)
return scores
def _metric_map() -> dict[str, Any]:
try:
from ragas.metrics import (
context_precision,
context_recall,
faithfulness,
factual_correctness,
response_relevancy,
)
return {
"faithfulness": faithfulness,
"response_relevancy": response_relevancy,
"context_precision": context_precision,
"context_recall": context_recall,
"factual_correctness": factual_correctness,
}
except ImportError:
from ragas.metrics import (
Faithfulness,
FactualCorrectness,
LLMContextPrecisionWithReference,
LLMContextRecall,
ResponseRelevancy,
)
return {
"faithfulness": Faithfulness(),
"response_relevancy": ResponseRelevancy(),
"context_precision": LLMContextPrecisionWithReference(),
"context_recall": LLMContextRecall(),
"factual_correctness": FactualCorrectness(),
}
def _required_ragas_value(config: dict[str, Any], key: str) -> str:
value = config.get(key)
if value in {None, ""}:
raise ValueError(f"Missing required Ragas config value: ragas.{key}")
return str(value)
def _wrap_langchain_models(llm: Any, embeddings: Any) -> tuple[Any, Any]:
try:
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import LangchainLLMWrapper
except ImportError:
return llm, embeddings
return LangchainLLMWrapper(llm), LangchainEmbeddingsWrapper(embeddings)