00_diagnose_ragas_llm.py 9.42 KB
from __future__ import annotations

import json
import sys
import time
from typing import Any

import _bootstrap  # noqa: F401

import requests
from langchain_core.prompt_values import StringPromptValue
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, ValidationError
from ragas.llms import LangchainLLMWrapper
from ragas.run_config import RunConfig
from ragas.testset.persona import Persona
from ragas.testset.synthesizers.single_hop.prompts import (
    GeneratedQueryAnswer,
    QueryAnswerGenerationPrompt,
    QueryCondition,
)

from weknora_eval.config import load_config
from weknora_eval.llm_options import chat_extra_body, chat_openai_kwargs


class SimpleQA(BaseModel):
    query: str
    answer: str


def main() -> int:
    config = load_config()
    ragas = config["ragas"]
    testset = config.get("testset", {})
    base_url = require_value(ragas, "llm_base_url").rstrip("/")
    api_key = require_value(ragas, "llm_api_key")
    model = require_value(ragas, "generator_model")
    max_tokens = int(testset.get("generator_max_tokens", ragas.get("max_tokens", 4096)))
    temperature = float(ragas.get("temperature", 0))
    timeout = int(ragas.get("timeout_seconds", 600))
    extra_body = chat_extra_body(ragas)

    print("Diagnosing Ragas generator LLM compatibility\n")
    print(f"model={model}")
    print(f"base_url={base_url}")
    print(f"max_tokens={max_tokens}")
    print(f"temperature={temperature}\n")

    plain = run_raw_chat(
        title="plain_text",
        base_url=base_url,
        api_key=api_key,
        model=model,
        messages=[{"role": "user", "content": "Reply with exactly: OK"}],
        max_tokens=min(max_tokens, 256),
        temperature=temperature,
        timeout=timeout,
        extra_body=extra_body,
    )

    json_prompt = (
        "Return only valid JSON, with no markdown and no extra text. "
        'The JSON schema is {"query": "string", "answer": "string"}. '
        "Use this content: 合同约定乙方享有作品的著作权,甲方应按约支付费用。"
    )
    structured = run_raw_chat(
        title="raw_json_contract",
        base_url=base_url,
        api_key=api_key,
        model=model,
        messages=[{"role": "user", "content": json_prompt}],
        max_tokens=max_tokens,
        temperature=temperature,
        timeout=timeout,
        extra_body=extra_body,
    )
    validate_json_payload(structured.get("content") or "")

    langchain_result = run_langchain_probe(
        base_url=base_url,
        api_key=api_key,
        model=model,
        prompt=json_prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        timeout=timeout,
        extra_kwargs=chat_openai_kwargs(ragas),
    )

    run_ragas_prompt_probe(
        base_url=base_url,
        api_key=api_key,
        model=model,
        max_tokens=max_tokens,
        temperature=temperature,
        timeout=timeout,
        extra_kwargs=chat_openai_kwargs(ragas),
    )

    explain_result(plain, structured, langchain_result)
    return 0


def run_raw_chat(
    *,
    title: str,
    base_url: str,
    api_key: str,
    model: str,
    messages: list[dict[str, str]],
    max_tokens: int,
    temperature: float,
    timeout: int,
    extra_body: dict[str, Any],
) -> dict[str, Any]:
    print(f"[RAW] {title}")
    started = time.monotonic()
    response = requests.post(
        base_url + "/chat/completions",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
        json={
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            **extra_body,
        },
        timeout=timeout,
    )
    elapsed = time.monotonic() - started
    print(f"status={response.status_code} elapsed={elapsed:.2f}s")
    if response.status_code >= 400:
        print(response.text[:1000])
        return {"ok": False, "status": response.status_code, "content": ""}

    payload = response.json()
    choice = (payload.get("choices") or [{}])[0]
    message = choice.get("message") or {}
    content = str(message.get("content") or "")
    finish_reason = choice.get("finish_reason")
    usage = payload.get("usage")
    print(f"finish_reason={finish_reason!r}")
    print(f"usage={usage}")
    print(f"content_len={len(content)}")
    print(f"content_preview={content[:500]!r}\n")
    return {
        "ok": True,
        "finish_reason": finish_reason,
        "usage": usage,
        "content": content,
    }


def validate_json_payload(content: str) -> None:
    print("[PARSE] raw_json_contract")
    stripped = content.strip()
    try:
        data = json.loads(stripped)
        parsed = SimpleQA.model_validate(data)
        print("json_parse=ok")
        print(f"pydantic_parse=ok query={parsed.query[:80]!r}\n")
    except (json.JSONDecodeError, ValidationError) as exc:
        print(f"json_or_pydantic_parse=failed: {exc}")
        print("This indicates structured-output incompatibility or extra non-JSON text.\n")


def run_langchain_probe(
    *,
    base_url: str,
    api_key: str,
    model: str,
    prompt: str,
    max_tokens: int,
    temperature: float,
    timeout: int,
    extra_kwargs: dict[str, Any],
) -> dict[str, Any]:
    print("[LANGCHAIN] generation metadata")
    llm = ChatOpenAI(
        model=model,
        api_key=api_key,
        base_url=base_url,
        temperature=temperature,
        max_tokens=max_tokens,
        timeout=timeout,
        **extra_kwargs,
    )
    prompt_value = StringPromptValue(text=prompt)
    result = llm.generate_prompt([prompt_value])
    generation = result.generations[0][0]
    metadata = getattr(generation, "generation_info", None)
    response_metadata = getattr(getattr(generation, "message", None), "response_metadata", None)
    wrapper = LangchainLLMWrapper(llm)
    is_finished = wrapper.is_finished(result)
    print(f"generation_info={metadata}")
    print(f"response_metadata={response_metadata}")
    print(f"ragas_is_finished={is_finished}")
    print(f"text_preview={generation.text[:500]!r}\n")
    return {
        "generation_info": metadata,
        "response_metadata": response_metadata,
        "ragas_is_finished": is_finished,
        "content": generation.text,
    }


def run_ragas_prompt_probe(
    *,
    base_url: str,
    api_key: str,
    model: str,
    max_tokens: int,
    temperature: float,
    timeout: int,
    extra_kwargs: dict[str, Any],
) -> None:
    print("[RAGAS] QueryAnswerGenerationPrompt")
    llm = ChatOpenAI(
        model=model,
        api_key=api_key,
        base_url=base_url,
        temperature=temperature,
        max_tokens=max_tokens,
        timeout=timeout,
        **extra_kwargs,
    )
    ragas_llm = LangchainLLMWrapper(llm)
    ragas_llm.set_run_config(RunConfig(timeout=timeout, max_workers=1))
    prompt = QueryAnswerGenerationPrompt()
    condition = QueryCondition(
        persona=Persona(
            name="合同审核人员",
            role_description="关注合同条款、权利归属、授权范围和履约义务。",
        ),
        term="著作权",
        query_style="Perfect grammar",
        query_length="medium",
        context="合同约定乙方享有作品的著作权,甲方应在收到发票后30日内支付授权费用。",
    )
    try:
        import asyncio

        result: GeneratedQueryAnswer = asyncio.run(prompt.generate(llm=ragas_llm, data=condition))
        print(f"ragas_prompt=ok query={result.query[:120]!r}")
        print(f"answer={result.answer[:120]!r}\n")
    except Exception as exc:  # noqa: BLE001
        print(f"ragas_prompt=failed: {type(exc).__name__}: {exc}\n")


def explain_result(
    plain: dict[str, Any],
    structured: dict[str, Any],
    langchain_result: dict[str, Any],
) -> None:
    print("[DIAGNOSIS]")
    raw_finish = structured.get("finish_reason")
    ragas_finished = langchain_result.get("ragas_is_finished")
    content = structured.get("content") or ""
    if raw_finish == "length":
        print("- Raw vLLM finish_reason is 'length': output hit max_tokens.")
        print("- Lower TESTSET_MAX_DOCUMENT_CHARS or increase TESTSET_GENERATOR_MAX_TOKENS moderately.")
    elif raw_finish not in {"stop", "STOP", "MAX_TOKENS", "eos_token"}:
        print("- Raw vLLM finish_reason is not one of the values RAGAS 0.4.3 accepts.")
        print("- This is finish_reason compatibility, not necessarily context length.")
    elif ragas_finished is False:
        print("- Raw finish_reason looks acceptable, but RAGAS/LangChain metadata is not accepted.")
        print("- This points to LangChain/RAGAS finish_reason adapter compatibility.")
    elif not content.strip().startswith("{"):
        print("- The model did not return pure JSON for a JSON-only prompt.")
        print("- This points to structured-output incompatibility, often caused by thinking text or markdown.")
    else:
        print("- Basic raw and LangChain completion checks look compatible.")
        print("- If RAGAS still fails, inspect the RAGAS prompt output preview above for invalid JSON or truncation.")

    if plain.get("finish_reason") == "stop" and structured.get("finish_reason") != "stop":
        print("- Plain text works but structured output differs; focus on JSON/structured output settings.")


def require_value(config: dict[str, Any], key: str) -> str:
    value = config.get(key)
    if value in {None, ""}:
        raise ValueError(f"Missing required config value: ragas.{key}")
    return str(value)


if __name__ == "__main__":
    sys.exit(main())