Commit 6c7e5043 6c7e5043fdf840d97768792120e6cb7a6a2a6df2 by 沈秋雨

添加llm输出格式诊断脚本

1 parent 3c79a5fd
from __future__ import annotations
import json
import sys
import time
from typing import Any
import _bootstrap # noqa: F401
import requests
from langchain_core.prompt_values import StringPromptValue
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, ValidationError
from ragas.llms import LangchainLLMWrapper
from ragas.run_config import RunConfig
from ragas.testset.persona import Persona
from ragas.testset.synthesizers.single_hop.prompts import (
GeneratedQueryAnswer,
QueryAnswerGenerationPrompt,
QueryCondition,
)
from weknora_eval.config import load_config
class SimpleQA(BaseModel):
query: str
answer: str
def main() -> int:
config = load_config()
ragas = config["ragas"]
testset = config.get("testset", {})
base_url = require_value(ragas, "llm_base_url").rstrip("/")
api_key = require_value(ragas, "llm_api_key")
model = require_value(ragas, "generator_model")
max_tokens = int(testset.get("generator_max_tokens", ragas.get("max_tokens", 4096)))
temperature = float(ragas.get("temperature", 0))
timeout = int(ragas.get("timeout_seconds", 600))
print("Diagnosing Ragas generator LLM compatibility\n")
print(f"model={model}")
print(f"base_url={base_url}")
print(f"max_tokens={max_tokens}")
print(f"temperature={temperature}\n")
plain = run_raw_chat(
title="plain_text",
base_url=base_url,
api_key=api_key,
model=model,
messages=[{"role": "user", "content": "Reply with exactly: OK"}],
max_tokens=min(max_tokens, 256),
temperature=temperature,
timeout=timeout,
)
json_prompt = (
"Return only valid JSON, with no markdown and no extra text. "
'The JSON schema is {"query": "string", "answer": "string"}. '
"Use this content: 合同约定乙方享有作品的著作权,甲方应按约支付费用。"
)
structured = run_raw_chat(
title="raw_json_contract",
base_url=base_url,
api_key=api_key,
model=model,
messages=[{"role": "user", "content": json_prompt}],
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
validate_json_payload(structured.get("content") or "")
langchain_result = run_langchain_probe(
base_url=base_url,
api_key=api_key,
model=model,
prompt=json_prompt,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
run_ragas_prompt_probe(
base_url=base_url,
api_key=api_key,
model=model,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
explain_result(plain, structured, langchain_result)
return 0
def run_raw_chat(
*,
title: str,
base_url: str,
api_key: str,
model: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
timeout: int,
) -> dict[str, Any]:
print(f"[RAW] {title}")
started = time.monotonic()
response = requests.post(
base_url + "/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
},
timeout=timeout,
)
elapsed = time.monotonic() - started
print(f"status={response.status_code} elapsed={elapsed:.2f}s")
if response.status_code >= 400:
print(response.text[:1000])
return {"ok": False, "status": response.status_code, "content": ""}
payload = response.json()
choice = (payload.get("choices") or [{}])[0]
message = choice.get("message") or {}
content = str(message.get("content") or "")
finish_reason = choice.get("finish_reason")
usage = payload.get("usage")
print(f"finish_reason={finish_reason!r}")
print(f"usage={usage}")
print(f"content_len={len(content)}")
print(f"content_preview={content[:500]!r}\n")
return {
"ok": True,
"finish_reason": finish_reason,
"usage": usage,
"content": content,
}
def validate_json_payload(content: str) -> None:
print("[PARSE] raw_json_contract")
stripped = content.strip()
try:
data = json.loads(stripped)
parsed = SimpleQA.model_validate(data)
print("json_parse=ok")
print(f"pydantic_parse=ok query={parsed.query[:80]!r}\n")
except (json.JSONDecodeError, ValidationError) as exc:
print(f"json_or_pydantic_parse=failed: {exc}")
print("This indicates structured-output incompatibility or extra non-JSON text.\n")
def run_langchain_probe(
*,
base_url: str,
api_key: str,
model: str,
prompt: str,
max_tokens: int,
temperature: float,
timeout: int,
) -> dict[str, Any]:
print("[LANGCHAIN] generation metadata")
llm = ChatOpenAI(
model=model,
api_key=api_key,
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
)
prompt_value = StringPromptValue(text=prompt)
result = llm.generate_prompt([prompt_value])
generation = result.generations[0][0]
metadata = getattr(generation, "generation_info", None)
response_metadata = getattr(getattr(generation, "message", None), "response_metadata", None)
wrapper = LangchainLLMWrapper(llm)
is_finished = wrapper.is_finished(result)
print(f"generation_info={metadata}")
print(f"response_metadata={response_metadata}")
print(f"ragas_is_finished={is_finished}")
print(f"text_preview={generation.text[:500]!r}\n")
return {
"generation_info": metadata,
"response_metadata": response_metadata,
"ragas_is_finished": is_finished,
"content": generation.text,
}
def run_ragas_prompt_probe(
*,
base_url: str,
api_key: str,
model: str,
max_tokens: int,
temperature: float,
timeout: int,
) -> None:
print("[RAGAS] QueryAnswerGenerationPrompt")
llm = ChatOpenAI(
model=model,
api_key=api_key,
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
)
ragas_llm = LangchainLLMWrapper(llm)
ragas_llm.set_run_config(RunConfig(timeout=timeout, max_workers=1))
prompt = QueryAnswerGenerationPrompt()
condition = QueryCondition(
persona=Persona(
name="合同审核人员",
role_description="关注合同条款、权利归属、授权范围和履约义务。",
),
term="著作权",
query_style="Perfect grammar",
query_length="medium",
context="合同约定乙方享有作品的著作权,甲方应在收到发票后30日内支付授权费用。",
)
try:
import asyncio
result: GeneratedQueryAnswer = asyncio.run(prompt.generate(llm=ragas_llm, data=condition))
print(f"ragas_prompt=ok query={result.query[:120]!r}")
print(f"answer={result.answer[:120]!r}\n")
except Exception as exc: # noqa: BLE001
print(f"ragas_prompt=failed: {type(exc).__name__}: {exc}\n")
def explain_result(
plain: dict[str, Any],
structured: dict[str, Any],
langchain_result: dict[str, Any],
) -> None:
print("[DIAGNOSIS]")
raw_finish = structured.get("finish_reason")
ragas_finished = langchain_result.get("ragas_is_finished")
content = structured.get("content") or ""
if raw_finish == "length":
print("- Raw vLLM finish_reason is 'length': output hit max_tokens.")
print("- Lower TESTSET_MAX_DOCUMENT_CHARS or increase TESTSET_GENERATOR_MAX_TOKENS moderately.")
elif raw_finish not in {"stop", "STOP", "MAX_TOKENS", "eos_token"}:
print("- Raw vLLM finish_reason is not one of the values RAGAS 0.4.3 accepts.")
print("- This is finish_reason compatibility, not necessarily context length.")
elif ragas_finished is False:
print("- Raw finish_reason looks acceptable, but RAGAS/LangChain metadata is not accepted.")
print("- This points to LangChain/RAGAS finish_reason adapter compatibility.")
elif not content.strip().startswith("{"):
print("- The model did not return pure JSON for a JSON-only prompt.")
print("- This points to structured-output incompatibility, often caused by thinking text or markdown.")
else:
print("- Basic raw and LangChain completion checks look compatible.")
print("- If RAGAS still fails, inspect the RAGAS prompt output preview above for invalid JSON or truncation.")
if plain.get("finish_reason") == "stop" and structured.get("finish_reason") != "stop":
print("- Plain text works but structured output differs; focus on JSON/structured output settings.")
def require_value(config: dict[str, Any], key: str) -> str:
value = config.get(key)
if value in {None, ""}:
raise ValueError(f"Missing required config value: ragas.{key}")
return str(value)
if __name__ == "__main__":
sys.exit(main())