添加llm输出格式诊断脚本
Showing
1 changed file
with
279 additions
and
0 deletions
scripts/00_diagnose_ragas_llm.py
0 → 100644
| 1 | from __future__ import annotations | ||
| 2 | |||
| 3 | import json | ||
| 4 | import sys | ||
| 5 | import time | ||
| 6 | from typing import Any | ||
| 7 | |||
| 8 | import _bootstrap # noqa: F401 | ||
| 9 | |||
| 10 | import requests | ||
| 11 | from langchain_core.prompt_values import StringPromptValue | ||
| 12 | from langchain_openai import ChatOpenAI | ||
| 13 | from pydantic import BaseModel, ValidationError | ||
| 14 | from ragas.llms import LangchainLLMWrapper | ||
| 15 | from ragas.run_config import RunConfig | ||
| 16 | from ragas.testset.persona import Persona | ||
| 17 | from ragas.testset.synthesizers.single_hop.prompts import ( | ||
| 18 | GeneratedQueryAnswer, | ||
| 19 | QueryAnswerGenerationPrompt, | ||
| 20 | QueryCondition, | ||
| 21 | ) | ||
| 22 | |||
| 23 | from weknora_eval.config import load_config | ||
| 24 | |||
| 25 | |||
| 26 | class SimpleQA(BaseModel): | ||
| 27 | query: str | ||
| 28 | answer: str | ||
| 29 | |||
| 30 | |||
| 31 | def main() -> int: | ||
| 32 | config = load_config() | ||
| 33 | ragas = config["ragas"] | ||
| 34 | testset = config.get("testset", {}) | ||
| 35 | base_url = require_value(ragas, "llm_base_url").rstrip("/") | ||
| 36 | api_key = require_value(ragas, "llm_api_key") | ||
| 37 | model = require_value(ragas, "generator_model") | ||
| 38 | max_tokens = int(testset.get("generator_max_tokens", ragas.get("max_tokens", 4096))) | ||
| 39 | temperature = float(ragas.get("temperature", 0)) | ||
| 40 | timeout = int(ragas.get("timeout_seconds", 600)) | ||
| 41 | |||
| 42 | print("Diagnosing Ragas generator LLM compatibility\n") | ||
| 43 | print(f"model={model}") | ||
| 44 | print(f"base_url={base_url}") | ||
| 45 | print(f"max_tokens={max_tokens}") | ||
| 46 | print(f"temperature={temperature}\n") | ||
| 47 | |||
| 48 | plain = run_raw_chat( | ||
| 49 | title="plain_text", | ||
| 50 | base_url=base_url, | ||
| 51 | api_key=api_key, | ||
| 52 | model=model, | ||
| 53 | messages=[{"role": "user", "content": "Reply with exactly: OK"}], | ||
| 54 | max_tokens=min(max_tokens, 256), | ||
| 55 | temperature=temperature, | ||
| 56 | timeout=timeout, | ||
| 57 | ) | ||
| 58 | |||
| 59 | json_prompt = ( | ||
| 60 | "Return only valid JSON, with no markdown and no extra text. " | ||
| 61 | 'The JSON schema is {"query": "string", "answer": "string"}. ' | ||
| 62 | "Use this content: 合同约定乙方享有作品的著作权,甲方应按约支付费用。" | ||
| 63 | ) | ||
| 64 | structured = run_raw_chat( | ||
| 65 | title="raw_json_contract", | ||
| 66 | base_url=base_url, | ||
| 67 | api_key=api_key, | ||
| 68 | model=model, | ||
| 69 | messages=[{"role": "user", "content": json_prompt}], | ||
| 70 | max_tokens=max_tokens, | ||
| 71 | temperature=temperature, | ||
| 72 | timeout=timeout, | ||
| 73 | ) | ||
| 74 | validate_json_payload(structured.get("content") or "") | ||
| 75 | |||
| 76 | langchain_result = run_langchain_probe( | ||
| 77 | base_url=base_url, | ||
| 78 | api_key=api_key, | ||
| 79 | model=model, | ||
| 80 | prompt=json_prompt, | ||
| 81 | max_tokens=max_tokens, | ||
| 82 | temperature=temperature, | ||
| 83 | timeout=timeout, | ||
| 84 | ) | ||
| 85 | |||
| 86 | run_ragas_prompt_probe( | ||
| 87 | base_url=base_url, | ||
| 88 | api_key=api_key, | ||
| 89 | model=model, | ||
| 90 | max_tokens=max_tokens, | ||
| 91 | temperature=temperature, | ||
| 92 | timeout=timeout, | ||
| 93 | ) | ||
| 94 | |||
| 95 | explain_result(plain, structured, langchain_result) | ||
| 96 | return 0 | ||
| 97 | |||
| 98 | |||
| 99 | def run_raw_chat( | ||
| 100 | *, | ||
| 101 | title: str, | ||
| 102 | base_url: str, | ||
| 103 | api_key: str, | ||
| 104 | model: str, | ||
| 105 | messages: list[dict[str, str]], | ||
| 106 | max_tokens: int, | ||
| 107 | temperature: float, | ||
| 108 | timeout: int, | ||
| 109 | ) -> dict[str, Any]: | ||
| 110 | print(f"[RAW] {title}") | ||
| 111 | started = time.monotonic() | ||
| 112 | response = requests.post( | ||
| 113 | base_url + "/chat/completions", | ||
| 114 | headers={ | ||
| 115 | "Authorization": f"Bearer {api_key}", | ||
| 116 | "Content-Type": "application/json", | ||
| 117 | }, | ||
| 118 | json={ | ||
| 119 | "model": model, | ||
| 120 | "messages": messages, | ||
| 121 | "temperature": temperature, | ||
| 122 | "max_tokens": max_tokens, | ||
| 123 | }, | ||
| 124 | timeout=timeout, | ||
| 125 | ) | ||
| 126 | elapsed = time.monotonic() - started | ||
| 127 | print(f"status={response.status_code} elapsed={elapsed:.2f}s") | ||
| 128 | if response.status_code >= 400: | ||
| 129 | print(response.text[:1000]) | ||
| 130 | return {"ok": False, "status": response.status_code, "content": ""} | ||
| 131 | |||
| 132 | payload = response.json() | ||
| 133 | choice = (payload.get("choices") or [{}])[0] | ||
| 134 | message = choice.get("message") or {} | ||
| 135 | content = str(message.get("content") or "") | ||
| 136 | finish_reason = choice.get("finish_reason") | ||
| 137 | usage = payload.get("usage") | ||
| 138 | print(f"finish_reason={finish_reason!r}") | ||
| 139 | print(f"usage={usage}") | ||
| 140 | print(f"content_len={len(content)}") | ||
| 141 | print(f"content_preview={content[:500]!r}\n") | ||
| 142 | return { | ||
| 143 | "ok": True, | ||
| 144 | "finish_reason": finish_reason, | ||
| 145 | "usage": usage, | ||
| 146 | "content": content, | ||
| 147 | } | ||
| 148 | |||
| 149 | |||
| 150 | def validate_json_payload(content: str) -> None: | ||
| 151 | print("[PARSE] raw_json_contract") | ||
| 152 | stripped = content.strip() | ||
| 153 | try: | ||
| 154 | data = json.loads(stripped) | ||
| 155 | parsed = SimpleQA.model_validate(data) | ||
| 156 | print("json_parse=ok") | ||
| 157 | print(f"pydantic_parse=ok query={parsed.query[:80]!r}\n") | ||
| 158 | except (json.JSONDecodeError, ValidationError) as exc: | ||
| 159 | print(f"json_or_pydantic_parse=failed: {exc}") | ||
| 160 | print("This indicates structured-output incompatibility or extra non-JSON text.\n") | ||
| 161 | |||
| 162 | |||
| 163 | def run_langchain_probe( | ||
| 164 | *, | ||
| 165 | base_url: str, | ||
| 166 | api_key: str, | ||
| 167 | model: str, | ||
| 168 | prompt: str, | ||
| 169 | max_tokens: int, | ||
| 170 | temperature: float, | ||
| 171 | timeout: int, | ||
| 172 | ) -> dict[str, Any]: | ||
| 173 | print("[LANGCHAIN] generation metadata") | ||
| 174 | llm = ChatOpenAI( | ||
| 175 | model=model, | ||
| 176 | api_key=api_key, | ||
| 177 | base_url=base_url, | ||
| 178 | temperature=temperature, | ||
| 179 | max_tokens=max_tokens, | ||
| 180 | timeout=timeout, | ||
| 181 | ) | ||
| 182 | prompt_value = StringPromptValue(text=prompt) | ||
| 183 | result = llm.generate_prompt([prompt_value]) | ||
| 184 | generation = result.generations[0][0] | ||
| 185 | metadata = getattr(generation, "generation_info", None) | ||
| 186 | response_metadata = getattr(getattr(generation, "message", None), "response_metadata", None) | ||
| 187 | wrapper = LangchainLLMWrapper(llm) | ||
| 188 | is_finished = wrapper.is_finished(result) | ||
| 189 | print(f"generation_info={metadata}") | ||
| 190 | print(f"response_metadata={response_metadata}") | ||
| 191 | print(f"ragas_is_finished={is_finished}") | ||
| 192 | print(f"text_preview={generation.text[:500]!r}\n") | ||
| 193 | return { | ||
| 194 | "generation_info": metadata, | ||
| 195 | "response_metadata": response_metadata, | ||
| 196 | "ragas_is_finished": is_finished, | ||
| 197 | "content": generation.text, | ||
| 198 | } | ||
| 199 | |||
| 200 | |||
| 201 | def run_ragas_prompt_probe( | ||
| 202 | *, | ||
| 203 | base_url: str, | ||
| 204 | api_key: str, | ||
| 205 | model: str, | ||
| 206 | max_tokens: int, | ||
| 207 | temperature: float, | ||
| 208 | timeout: int, | ||
| 209 | ) -> None: | ||
| 210 | print("[RAGAS] QueryAnswerGenerationPrompt") | ||
| 211 | llm = ChatOpenAI( | ||
| 212 | model=model, | ||
| 213 | api_key=api_key, | ||
| 214 | base_url=base_url, | ||
| 215 | temperature=temperature, | ||
| 216 | max_tokens=max_tokens, | ||
| 217 | timeout=timeout, | ||
| 218 | ) | ||
| 219 | ragas_llm = LangchainLLMWrapper(llm) | ||
| 220 | ragas_llm.set_run_config(RunConfig(timeout=timeout, max_workers=1)) | ||
| 221 | prompt = QueryAnswerGenerationPrompt() | ||
| 222 | condition = QueryCondition( | ||
| 223 | persona=Persona( | ||
| 224 | name="合同审核人员", | ||
| 225 | role_description="关注合同条款、权利归属、授权范围和履约义务。", | ||
| 226 | ), | ||
| 227 | term="著作权", | ||
| 228 | query_style="Perfect grammar", | ||
| 229 | query_length="medium", | ||
| 230 | context="合同约定乙方享有作品的著作权,甲方应在收到发票后30日内支付授权费用。", | ||
| 231 | ) | ||
| 232 | try: | ||
| 233 | import asyncio | ||
| 234 | |||
| 235 | result: GeneratedQueryAnswer = asyncio.run(prompt.generate(llm=ragas_llm, data=condition)) | ||
| 236 | print(f"ragas_prompt=ok query={result.query[:120]!r}") | ||
| 237 | print(f"answer={result.answer[:120]!r}\n") | ||
| 238 | except Exception as exc: # noqa: BLE001 | ||
| 239 | print(f"ragas_prompt=failed: {type(exc).__name__}: {exc}\n") | ||
| 240 | |||
| 241 | |||
| 242 | def explain_result( | ||
| 243 | plain: dict[str, Any], | ||
| 244 | structured: dict[str, Any], | ||
| 245 | langchain_result: dict[str, Any], | ||
| 246 | ) -> None: | ||
| 247 | print("[DIAGNOSIS]") | ||
| 248 | raw_finish = structured.get("finish_reason") | ||
| 249 | ragas_finished = langchain_result.get("ragas_is_finished") | ||
| 250 | content = structured.get("content") or "" | ||
| 251 | if raw_finish == "length": | ||
| 252 | print("- Raw vLLM finish_reason is 'length': output hit max_tokens.") | ||
| 253 | print("- Lower TESTSET_MAX_DOCUMENT_CHARS or increase TESTSET_GENERATOR_MAX_TOKENS moderately.") | ||
| 254 | elif raw_finish not in {"stop", "STOP", "MAX_TOKENS", "eos_token"}: | ||
| 255 | print("- Raw vLLM finish_reason is not one of the values RAGAS 0.4.3 accepts.") | ||
| 256 | print("- This is finish_reason compatibility, not necessarily context length.") | ||
| 257 | elif ragas_finished is False: | ||
| 258 | print("- Raw finish_reason looks acceptable, but RAGAS/LangChain metadata is not accepted.") | ||
| 259 | print("- This points to LangChain/RAGAS finish_reason adapter compatibility.") | ||
| 260 | elif not content.strip().startswith("{"): | ||
| 261 | print("- The model did not return pure JSON for a JSON-only prompt.") | ||
| 262 | print("- This points to structured-output incompatibility, often caused by thinking text or markdown.") | ||
| 263 | else: | ||
| 264 | print("- Basic raw and LangChain completion checks look compatible.") | ||
| 265 | print("- If RAGAS still fails, inspect the RAGAS prompt output preview above for invalid JSON or truncation.") | ||
| 266 | |||
| 267 | if plain.get("finish_reason") == "stop" and structured.get("finish_reason") != "stop": | ||
| 268 | print("- Plain text works but structured output differs; focus on JSON/structured output settings.") | ||
| 269 | |||
| 270 | |||
| 271 | def require_value(config: dict[str, Any], key: str) -> str: | ||
| 272 | value = config.get(key) | ||
| 273 | if value in {None, ""}: | ||
| 274 | raise ValueError(f"Missing required config value: ragas.{key}") | ||
| 275 | return str(value) | ||
| 276 | |||
| 277 | |||
| 278 | if __name__ == "__main__": | ||
| 279 | sys.exit(main()) |
-
Please register or sign in to post a comment