00_check_models.py 5.68 KB
from __future__ import annotations

import sys
import time
from typing import Any

import _bootstrap  # noqa: F401

import requests
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from weknora_eval.config import load_config
from weknora_eval.llm_options import chat_openai_kwargs


def main() -> int:
    config = load_config()
    ragas = config["ragas"]
    failures: list[str] = []

    print("Checking configured Ragas model services...\n")

    failures.extend(
        check_chat_model(
            title="Generator LLM",
            base_url=require_value(ragas, "llm_base_url"),
            api_key=require_value(ragas, "llm_api_key"),
            model=require_value(ragas, "generator_model"),
            temperature=float(ragas.get("temperature", 0)),
            max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024),
            extra_kwargs=chat_openai_kwargs(ragas),
        )
    )
    failures.extend(
        check_chat_model(
            title="Judge LLM",
            base_url=require_value(ragas, "llm_base_url"),
            api_key=require_value(ragas, "llm_api_key"),
            model=require_value(ragas, "judge_model"),
            temperature=float(ragas.get("temperature", 0)),
            max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024),
            extra_kwargs=chat_openai_kwargs(ragas),
        )
    )
    failures.extend(
        check_embedding_model(
            base_url=require_value(ragas, "embedding_base_url"),
            api_key=require_value(ragas, "embedding_api_key"),
            model=require_value(ragas, "embedding_model"),
        )
    )

    reranker_base_url = str(ragas.get("reranker_base_url") or "")
    reranker_model = str(ragas.get("reranker_model") or "")
    if reranker_base_url and reranker_model:
        failures.extend(
            check_reranker_model(
                base_url=reranker_base_url,
                api_key=str(ragas.get("reranker_api_key") or ""),
                model=reranker_model,
            )
        )
    else:
        print("[SKIP] Reranker: RAGAS_RERANKER_BASE_URL or RAGAS_RERANKER_MODEL is empty\n")

    if failures:
        print("Model service check failed:")
        for failure in failures:
            print(f"- {failure}")
        return 1

    print("All configured model services are reachable.")
    return 0


def check_chat_model(
    *,
    title: str,
    base_url: str,
    api_key: str,
    model: str,
    temperature: float,
    max_tokens: int,
    extra_kwargs: dict[str, Any],
) -> list[str]:
    print(f"[CHECK] {title}: model={model} base_url={base_url}")
    started = time.monotonic()
    try:
        llm = ChatOpenAI(
            model=model,
            api_key=api_key,
            base_url=base_url,
            temperature=temperature,
            max_tokens=max_tokens,
            timeout=120,
            **extra_kwargs,
        )
        response = llm.invoke("Reply with exactly: OK")
        content = str(response.content or "").strip()
        elapsed = time.monotonic() - started
        if not content:
            return [f"{title} returned an empty response"]
        print(f"[OK] {title}: {elapsed:.2f}s response={content[:80]!r}\n")
        return []
    except Exception as exc:  # noqa: BLE001
        return [f"{title} failed: {exc}"]


def check_embedding_model(*, base_url: str, api_key: str, model: str) -> list[str]:
    print(f"[CHECK] Embedding: model={model} base_url={base_url}")
    started = time.monotonic()
    try:
        embeddings = OpenAIEmbeddings(
            model=model,
            api_key=api_key,
            base_url=base_url,
            tiktoken_enabled=False,
            check_embedding_ctx_length=False,
            request_timeout=120,
        )
        vector = embeddings.embed_query("hello")
        elapsed = time.monotonic() - started
        if not vector:
            return ["Embedding returned an empty vector"]
        print(f"[OK] Embedding: {elapsed:.2f}s dimensions={len(vector)} first3={vector[:3]}\n")
        return []
    except Exception as exc:  # noqa: BLE001
        return [f"Embedding failed: {exc}"]


def check_reranker_model(*, base_url: str, api_key: str, model: str) -> list[str]:
    print(f"[CHECK] Reranker: model={model} base_url={base_url}")
    url = base_url.rstrip("/") + "/rerank"
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    payload = {
        "model": model,
        "query": "付款期限是什么?",
        "documents": [
            "买方应在收到合法有效发票后30日内完成付款。",
            "本合同自双方签字盖章之日起生效。",
        ],
    }
    started = time.monotonic()
    try:
        response = requests.post(url, headers=headers, json=payload, timeout=120)
        elapsed = time.monotonic() - started
        if response.status_code >= 400:
            return [f"Reranker failed with HTTP {response.status_code}: {response.text[:500]}"]
        payload = response.json()
        if not _has_rerank_results(payload):
            return [f"Reranker returned no recognizable results: {payload}"]
        print(f"[OK] Reranker: {elapsed:.2f}s response_keys={list(payload.keys())}\n")
        return []
    except Exception as exc:  # noqa: BLE001
        return [f"Reranker failed: {exc}"]


def _has_rerank_results(payload: dict[str, Any]) -> bool:
    for key in ("results", "data"):
        if isinstance(payload.get(key), list) and payload[key]:
            return True
    return False


def require_value(config: dict[str, Any], key: str) -> str:
    value = config.get(key)
    if value in {None, ""}:
        raise ValueError(f"Missing required config value: ragas.{key}")
    return str(value)


if __name__ == "__main__":
    sys.exit(main())