Commit 6c7e5043 6c7e5043fdf840d97768792120e6cb7a6a2a6df2 by 沈秋雨

添加llm输出格式诊断脚本

1 parent 3c79a5fd
1 from __future__ import annotations
2
3 import json
4 import sys
5 import time
6 from typing import Any
7
8 import _bootstrap # noqa: F401
9
10 import requests
11 from langchain_core.prompt_values import StringPromptValue
12 from langchain_openai import ChatOpenAI
13 from pydantic import BaseModel, ValidationError
14 from ragas.llms import LangchainLLMWrapper
15 from ragas.run_config import RunConfig
16 from ragas.testset.persona import Persona
17 from ragas.testset.synthesizers.single_hop.prompts import (
18 GeneratedQueryAnswer,
19 QueryAnswerGenerationPrompt,
20 QueryCondition,
21 )
22
23 from weknora_eval.config import load_config
24
25
26 class SimpleQA(BaseModel):
27 query: str
28 answer: str
29
30
31 def main() -> int:
32 config = load_config()
33 ragas = config["ragas"]
34 testset = config.get("testset", {})
35 base_url = require_value(ragas, "llm_base_url").rstrip("/")
36 api_key = require_value(ragas, "llm_api_key")
37 model = require_value(ragas, "generator_model")
38 max_tokens = int(testset.get("generator_max_tokens", ragas.get("max_tokens", 4096)))
39 temperature = float(ragas.get("temperature", 0))
40 timeout = int(ragas.get("timeout_seconds", 600))
41
42 print("Diagnosing Ragas generator LLM compatibility\n")
43 print(f"model={model}")
44 print(f"base_url={base_url}")
45 print(f"max_tokens={max_tokens}")
46 print(f"temperature={temperature}\n")
47
48 plain = run_raw_chat(
49 title="plain_text",
50 base_url=base_url,
51 api_key=api_key,
52 model=model,
53 messages=[{"role": "user", "content": "Reply with exactly: OK"}],
54 max_tokens=min(max_tokens, 256),
55 temperature=temperature,
56 timeout=timeout,
57 )
58
59 json_prompt = (
60 "Return only valid JSON, with no markdown and no extra text. "
61 'The JSON schema is {"query": "string", "answer": "string"}. '
62 "Use this content: 合同约定乙方享有作品的著作权,甲方应按约支付费用。"
63 )
64 structured = run_raw_chat(
65 title="raw_json_contract",
66 base_url=base_url,
67 api_key=api_key,
68 model=model,
69 messages=[{"role": "user", "content": json_prompt}],
70 max_tokens=max_tokens,
71 temperature=temperature,
72 timeout=timeout,
73 )
74 validate_json_payload(structured.get("content") or "")
75
76 langchain_result = run_langchain_probe(
77 base_url=base_url,
78 api_key=api_key,
79 model=model,
80 prompt=json_prompt,
81 max_tokens=max_tokens,
82 temperature=temperature,
83 timeout=timeout,
84 )
85
86 run_ragas_prompt_probe(
87 base_url=base_url,
88 api_key=api_key,
89 model=model,
90 max_tokens=max_tokens,
91 temperature=temperature,
92 timeout=timeout,
93 )
94
95 explain_result(plain, structured, langchain_result)
96 return 0
97
98
99 def run_raw_chat(
100 *,
101 title: str,
102 base_url: str,
103 api_key: str,
104 model: str,
105 messages: list[dict[str, str]],
106 max_tokens: int,
107 temperature: float,
108 timeout: int,
109 ) -> dict[str, Any]:
110 print(f"[RAW] {title}")
111 started = time.monotonic()
112 response = requests.post(
113 base_url + "/chat/completions",
114 headers={
115 "Authorization": f"Bearer {api_key}",
116 "Content-Type": "application/json",
117 },
118 json={
119 "model": model,
120 "messages": messages,
121 "temperature": temperature,
122 "max_tokens": max_tokens,
123 },
124 timeout=timeout,
125 )
126 elapsed = time.monotonic() - started
127 print(f"status={response.status_code} elapsed={elapsed:.2f}s")
128 if response.status_code >= 400:
129 print(response.text[:1000])
130 return {"ok": False, "status": response.status_code, "content": ""}
131
132 payload = response.json()
133 choice = (payload.get("choices") or [{}])[0]
134 message = choice.get("message") or {}
135 content = str(message.get("content") or "")
136 finish_reason = choice.get("finish_reason")
137 usage = payload.get("usage")
138 print(f"finish_reason={finish_reason!r}")
139 print(f"usage={usage}")
140 print(f"content_len={len(content)}")
141 print(f"content_preview={content[:500]!r}\n")
142 return {
143 "ok": True,
144 "finish_reason": finish_reason,
145 "usage": usage,
146 "content": content,
147 }
148
149
150 def validate_json_payload(content: str) -> None:
151 print("[PARSE] raw_json_contract")
152 stripped = content.strip()
153 try:
154 data = json.loads(stripped)
155 parsed = SimpleQA.model_validate(data)
156 print("json_parse=ok")
157 print(f"pydantic_parse=ok query={parsed.query[:80]!r}\n")
158 except (json.JSONDecodeError, ValidationError) as exc:
159 print(f"json_or_pydantic_parse=failed: {exc}")
160 print("This indicates structured-output incompatibility or extra non-JSON text.\n")
161
162
163 def run_langchain_probe(
164 *,
165 base_url: str,
166 api_key: str,
167 model: str,
168 prompt: str,
169 max_tokens: int,
170 temperature: float,
171 timeout: int,
172 ) -> dict[str, Any]:
173 print("[LANGCHAIN] generation metadata")
174 llm = ChatOpenAI(
175 model=model,
176 api_key=api_key,
177 base_url=base_url,
178 temperature=temperature,
179 max_tokens=max_tokens,
180 timeout=timeout,
181 )
182 prompt_value = StringPromptValue(text=prompt)
183 result = llm.generate_prompt([prompt_value])
184 generation = result.generations[0][0]
185 metadata = getattr(generation, "generation_info", None)
186 response_metadata = getattr(getattr(generation, "message", None), "response_metadata", None)
187 wrapper = LangchainLLMWrapper(llm)
188 is_finished = wrapper.is_finished(result)
189 print(f"generation_info={metadata}")
190 print(f"response_metadata={response_metadata}")
191 print(f"ragas_is_finished={is_finished}")
192 print(f"text_preview={generation.text[:500]!r}\n")
193 return {
194 "generation_info": metadata,
195 "response_metadata": response_metadata,
196 "ragas_is_finished": is_finished,
197 "content": generation.text,
198 }
199
200
201 def run_ragas_prompt_probe(
202 *,
203 base_url: str,
204 api_key: str,
205 model: str,
206 max_tokens: int,
207 temperature: float,
208 timeout: int,
209 ) -> None:
210 print("[RAGAS] QueryAnswerGenerationPrompt")
211 llm = ChatOpenAI(
212 model=model,
213 api_key=api_key,
214 base_url=base_url,
215 temperature=temperature,
216 max_tokens=max_tokens,
217 timeout=timeout,
218 )
219 ragas_llm = LangchainLLMWrapper(llm)
220 ragas_llm.set_run_config(RunConfig(timeout=timeout, max_workers=1))
221 prompt = QueryAnswerGenerationPrompt()
222 condition = QueryCondition(
223 persona=Persona(
224 name="合同审核人员",
225 role_description="关注合同条款、权利归属、授权范围和履约义务。",
226 ),
227 term="著作权",
228 query_style="Perfect grammar",
229 query_length="medium",
230 context="合同约定乙方享有作品的著作权,甲方应在收到发票后30日内支付授权费用。",
231 )
232 try:
233 import asyncio
234
235 result: GeneratedQueryAnswer = asyncio.run(prompt.generate(llm=ragas_llm, data=condition))
236 print(f"ragas_prompt=ok query={result.query[:120]!r}")
237 print(f"answer={result.answer[:120]!r}\n")
238 except Exception as exc: # noqa: BLE001
239 print(f"ragas_prompt=failed: {type(exc).__name__}: {exc}\n")
240
241
242 def explain_result(
243 plain: dict[str, Any],
244 structured: dict[str, Any],
245 langchain_result: dict[str, Any],
246 ) -> None:
247 print("[DIAGNOSIS]")
248 raw_finish = structured.get("finish_reason")
249 ragas_finished = langchain_result.get("ragas_is_finished")
250 content = structured.get("content") or ""
251 if raw_finish == "length":
252 print("- Raw vLLM finish_reason is 'length': output hit max_tokens.")
253 print("- Lower TESTSET_MAX_DOCUMENT_CHARS or increase TESTSET_GENERATOR_MAX_TOKENS moderately.")
254 elif raw_finish not in {"stop", "STOP", "MAX_TOKENS", "eos_token"}:
255 print("- Raw vLLM finish_reason is not one of the values RAGAS 0.4.3 accepts.")
256 print("- This is finish_reason compatibility, not necessarily context length.")
257 elif ragas_finished is False:
258 print("- Raw finish_reason looks acceptable, but RAGAS/LangChain metadata is not accepted.")
259 print("- This points to LangChain/RAGAS finish_reason adapter compatibility.")
260 elif not content.strip().startswith("{"):
261 print("- The model did not return pure JSON for a JSON-only prompt.")
262 print("- This points to structured-output incompatibility, often caused by thinking text or markdown.")
263 else:
264 print("- Basic raw and LangChain completion checks look compatible.")
265 print("- If RAGAS still fails, inspect the RAGAS prompt output preview above for invalid JSON or truncation.")
266
267 if plain.get("finish_reason") == "stop" and structured.get("finish_reason") != "stop":
268 print("- Plain text works but structured output differs; focus on JSON/structured output settings.")
269
270
271 def require_value(config: dict[str, Any], key: str) -> str:
272 value = config.get(key)
273 if value in {None, ""}:
274 raise ValueError(f"Missing required config value: ragas.{key}")
275 return str(value)
276
277
278 if __name__ == "__main__":
279 sys.exit(main())