00_diagnose_ragas_llm.py
8.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from __future__ import annotations
import json
import sys
import time
from typing import Any
import _bootstrap # noqa: F401
import requests
from langchain_core.prompt_values import StringPromptValue
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, ValidationError
from ragas.llms import LangchainLLMWrapper
from ragas.run_config import RunConfig
from ragas.testset.persona import Persona
from ragas.testset.synthesizers.single_hop.prompts import (
GeneratedQueryAnswer,
QueryAnswerGenerationPrompt,
QueryCondition,
)
from weknora_eval.config import load_config
class SimpleQA(BaseModel):
query: str
answer: str
def main() -> int:
config = load_config()
ragas = config["ragas"]
testset = config.get("testset", {})
base_url = require_value(ragas, "llm_base_url").rstrip("/")
api_key = require_value(ragas, "llm_api_key")
model = require_value(ragas, "generator_model")
max_tokens = int(testset.get("generator_max_tokens", ragas.get("max_tokens", 4096)))
temperature = float(ragas.get("temperature", 0))
timeout = int(ragas.get("timeout_seconds", 600))
print("Diagnosing Ragas generator LLM compatibility\n")
print(f"model={model}")
print(f"base_url={base_url}")
print(f"max_tokens={max_tokens}")
print(f"temperature={temperature}\n")
plain = run_raw_chat(
title="plain_text",
base_url=base_url,
api_key=api_key,
model=model,
messages=[{"role": "user", "content": "Reply with exactly: OK"}],
max_tokens=min(max_tokens, 256),
temperature=temperature,
timeout=timeout,
)
json_prompt = (
"Return only valid JSON, with no markdown and no extra text. "
'The JSON schema is {"query": "string", "answer": "string"}. '
"Use this content: 合同约定乙方享有作品的著作权,甲方应按约支付费用。"
)
structured = run_raw_chat(
title="raw_json_contract",
base_url=base_url,
api_key=api_key,
model=model,
messages=[{"role": "user", "content": json_prompt}],
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
validate_json_payload(structured.get("content") or "")
langchain_result = run_langchain_probe(
base_url=base_url,
api_key=api_key,
model=model,
prompt=json_prompt,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
run_ragas_prompt_probe(
base_url=base_url,
api_key=api_key,
model=model,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
explain_result(plain, structured, langchain_result)
return 0
def run_raw_chat(
*,
title: str,
base_url: str,
api_key: str,
model: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
timeout: int,
) -> dict[str, Any]:
print(f"[RAW] {title}")
started = time.monotonic()
response = requests.post(
base_url + "/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
},
timeout=timeout,
)
elapsed = time.monotonic() - started
print(f"status={response.status_code} elapsed={elapsed:.2f}s")
if response.status_code >= 400:
print(response.text[:1000])
return {"ok": False, "status": response.status_code, "content": ""}
payload = response.json()
choice = (payload.get("choices") or [{}])[0]
message = choice.get("message") or {}
content = str(message.get("content") or "")
finish_reason = choice.get("finish_reason")
usage = payload.get("usage")
print(f"finish_reason={finish_reason!r}")
print(f"usage={usage}")
print(f"content_len={len(content)}")
print(f"content_preview={content[:500]!r}\n")
return {
"ok": True,
"finish_reason": finish_reason,
"usage": usage,
"content": content,
}
def validate_json_payload(content: str) -> None:
print("[PARSE] raw_json_contract")
stripped = content.strip()
try:
data = json.loads(stripped)
parsed = SimpleQA.model_validate(data)
print("json_parse=ok")
print(f"pydantic_parse=ok query={parsed.query[:80]!r}\n")
except (json.JSONDecodeError, ValidationError) as exc:
print(f"json_or_pydantic_parse=failed: {exc}")
print("This indicates structured-output incompatibility or extra non-JSON text.\n")
def run_langchain_probe(
*,
base_url: str,
api_key: str,
model: str,
prompt: str,
max_tokens: int,
temperature: float,
timeout: int,
) -> dict[str, Any]:
print("[LANGCHAIN] generation metadata")
llm = ChatOpenAI(
model=model,
api_key=api_key,
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
)
prompt_value = StringPromptValue(text=prompt)
result = llm.generate_prompt([prompt_value])
generation = result.generations[0][0]
metadata = getattr(generation, "generation_info", None)
response_metadata = getattr(getattr(generation, "message", None), "response_metadata", None)
wrapper = LangchainLLMWrapper(llm)
is_finished = wrapper.is_finished(result)
print(f"generation_info={metadata}")
print(f"response_metadata={response_metadata}")
print(f"ragas_is_finished={is_finished}")
print(f"text_preview={generation.text[:500]!r}\n")
return {
"generation_info": metadata,
"response_metadata": response_metadata,
"ragas_is_finished": is_finished,
"content": generation.text,
}
def run_ragas_prompt_probe(
*,
base_url: str,
api_key: str,
model: str,
max_tokens: int,
temperature: float,
timeout: int,
) -> None:
print("[RAGAS] QueryAnswerGenerationPrompt")
llm = ChatOpenAI(
model=model,
api_key=api_key,
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
)
ragas_llm = LangchainLLMWrapper(llm)
ragas_llm.set_run_config(RunConfig(timeout=timeout, max_workers=1))
prompt = QueryAnswerGenerationPrompt()
condition = QueryCondition(
persona=Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
term="著作权",
query_style="Perfect grammar",
query_length="medium",
context="合同约定乙方享有作品的著作权,甲方应在收到发票后30日内支付授权费用。",
)
try:
import asyncio
result: GeneratedQueryAnswer = asyncio.run(prompt.generate(llm=ragas_llm, data=condition))
print(f"ragas_prompt=ok query={result.query[:120]!r}")
print(f"answer={result.answer[:120]!r}\n")
except Exception as exc: # noqa: BLE001
print(f"ragas_prompt=failed: {type(exc).__name__}: {exc}\n")
def explain_result(
plain: dict[str, Any],
structured: dict[str, Any],
langchain_result: dict[str, Any],
) -> None:
print("[DIAGNOSIS]")
raw_finish = structured.get("finish_reason")
ragas_finished = langchain_result.get("ragas_is_finished")
content = structured.get("content") or ""
if raw_finish == "length":
print("- Raw vLLM finish_reason is 'length': output hit max_tokens.")
print("- Lower TESTSET_MAX_DOCUMENT_CHARS or increase TESTSET_GENERATOR_MAX_TOKENS moderately.")
elif raw_finish not in {"stop", "STOP", "MAX_TOKENS", "eos_token"}:
print("- Raw vLLM finish_reason is not one of the values RAGAS 0.4.3 accepts.")
print("- This is finish_reason compatibility, not necessarily context length.")
elif ragas_finished is False:
print("- Raw finish_reason looks acceptable, but RAGAS/LangChain metadata is not accepted.")
print("- This points to LangChain/RAGAS finish_reason adapter compatibility.")
elif not content.strip().startswith("{"):
print("- The model did not return pure JSON for a JSON-only prompt.")
print("- This points to structured-output incompatibility, often caused by thinking text or markdown.")
else:
print("- Basic raw and LangChain completion checks look compatible.")
print("- If RAGAS still fails, inspect the RAGAS prompt output preview above for invalid JSON or truncation.")
if plain.get("finish_reason") == "stop" and structured.get("finish_reason") != "stop":
print("- Plain text works but structured output differs; focus on JSON/structured output settings.")
def require_value(config: dict[str, Any], key: str) -> str:
value = config.get(key)
if value in {None, ""}:
raise ValueError(f"Missing required config value: ragas.{key}")
return str(value)
if __name__ == "__main__":
sys.exit(main())