Commit 128018ac 128018acbf2d5c9fc9a757aec3069a64628a22bf by 沈秋雨

Add configured model connectivity check

1 parent f7777e43
...@@ -49,6 +49,7 @@ cp .env.example .env ...@@ -49,6 +49,7 @@ cp .env.example .env
49 49
50 ```bash 50 ```bash
51 python scripts/00_create_kb.py 51 python scripts/00_create_kb.py
52 python scripts/00_check_models.py
52 python scripts/01_upload_docs.py 53 python scripts/01_upload_docs.py
53 python scripts/02_wait_ingestion.py 54 python scripts/02_wait_ingestion.py
54 python scripts/03_export_chunks.py 55 python scripts/03_export_chunks.py
......
1 from __future__ import annotations
2
3 import sys
4 import time
5 from typing import Any
6
7 import _bootstrap # noqa: F401
8
9 import requests
10 from langchain_openai import ChatOpenAI, OpenAIEmbeddings
11
12 from weknora_eval.config import load_config
13
14
15 def main() -> int:
16 config = load_config()
17 ragas = config["ragas"]
18 failures: list[str] = []
19
20 print("Checking configured Ragas model services...\n")
21
22 failures.extend(
23 check_chat_model(
24 title="Generator LLM",
25 base_url=require_value(ragas, "llm_base_url"),
26 api_key=require_value(ragas, "llm_api_key"),
27 model=require_value(ragas, "generator_model"),
28 temperature=float(ragas.get("temperature", 0)),
29 max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024),
30 )
31 )
32 failures.extend(
33 check_chat_model(
34 title="Judge LLM",
35 base_url=require_value(ragas, "llm_base_url"),
36 api_key=require_value(ragas, "llm_api_key"),
37 model=require_value(ragas, "judge_model"),
38 temperature=float(ragas.get("temperature", 0)),
39 max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024),
40 )
41 )
42 failures.extend(
43 check_embedding_model(
44 base_url=require_value(ragas, "embedding_base_url"),
45 api_key=require_value(ragas, "embedding_api_key"),
46 model=require_value(ragas, "embedding_model"),
47 )
48 )
49
50 reranker_base_url = str(ragas.get("reranker_base_url") or "")
51 reranker_model = str(ragas.get("reranker_model") or "")
52 if reranker_base_url and reranker_model:
53 failures.extend(
54 check_reranker_model(
55 base_url=reranker_base_url,
56 api_key=str(ragas.get("reranker_api_key") or ""),
57 model=reranker_model,
58 )
59 )
60 else:
61 print("[SKIP] Reranker: RAGAS_RERANKER_BASE_URL or RAGAS_RERANKER_MODEL is empty\n")
62
63 if failures:
64 print("Model service check failed:")
65 for failure in failures:
66 print(f"- {failure}")
67 return 1
68
69 print("All configured model services are reachable.")
70 return 0
71
72
73 def check_chat_model(
74 *,
75 title: str,
76 base_url: str,
77 api_key: str,
78 model: str,
79 temperature: float,
80 max_tokens: int,
81 ) -> list[str]:
82 print(f"[CHECK] {title}: model={model} base_url={base_url}")
83 started = time.monotonic()
84 try:
85 llm = ChatOpenAI(
86 model=model,
87 api_key=api_key,
88 base_url=base_url,
89 temperature=temperature,
90 max_tokens=max_tokens,
91 timeout=120,
92 )
93 response = llm.invoke("Reply with exactly: OK")
94 content = str(response.content or "").strip()
95 elapsed = time.monotonic() - started
96 if not content:
97 return [f"{title} returned an empty response"]
98 print(f"[OK] {title}: {elapsed:.2f}s response={content[:80]!r}\n")
99 return []
100 except Exception as exc: # noqa: BLE001
101 return [f"{title} failed: {exc}"]
102
103
104 def check_embedding_model(*, base_url: str, api_key: str, model: str) -> list[str]:
105 print(f"[CHECK] Embedding: model={model} base_url={base_url}")
106 started = time.monotonic()
107 try:
108 embeddings = OpenAIEmbeddings(
109 model=model,
110 api_key=api_key,
111 base_url=base_url,
112 tiktoken_enabled=False,
113 check_embedding_ctx_length=False,
114 request_timeout=120,
115 )
116 vector = embeddings.embed_query("hello")
117 elapsed = time.monotonic() - started
118 if not vector:
119 return ["Embedding returned an empty vector"]
120 print(f"[OK] Embedding: {elapsed:.2f}s dimensions={len(vector)} first3={vector[:3]}\n")
121 return []
122 except Exception as exc: # noqa: BLE001
123 return [f"Embedding failed: {exc}"]
124
125
126 def check_reranker_model(*, base_url: str, api_key: str, model: str) -> list[str]:
127 print(f"[CHECK] Reranker: model={model} base_url={base_url}")
128 url = base_url.rstrip("/") + "/rerank"
129 headers = {"Content-Type": "application/json"}
130 if api_key:
131 headers["Authorization"] = f"Bearer {api_key}"
132 payload = {
133 "model": model,
134 "query": "付款期限是什么?",
135 "documents": [
136 "买方应在收到合法有效发票后30日内完成付款。",
137 "本合同自双方签字盖章之日起生效。",
138 ],
139 }
140 started = time.monotonic()
141 try:
142 response = requests.post(url, headers=headers, json=payload, timeout=120)
143 elapsed = time.monotonic() - started
144 if response.status_code >= 400:
145 return [f"Reranker failed with HTTP {response.status_code}: {response.text[:500]}"]
146 payload = response.json()
147 if not _has_rerank_results(payload):
148 return [f"Reranker returned no recognizable results: {payload}"]
149 print(f"[OK] Reranker: {elapsed:.2f}s response_keys={list(payload.keys())}\n")
150 return []
151 except Exception as exc: # noqa: BLE001
152 return [f"Reranker failed: {exc}"]
153
154
155 def _has_rerank_results(payload: dict[str, Any]) -> bool:
156 for key in ("results", "data"):
157 if isinstance(payload.get(key), list) and payload[key]:
158 return True
159 return False
160
161
162 def require_value(config: dict[str, Any], key: str) -> str:
163 value = config.get(key)
164 if value in {None, ""}:
165 raise ValueError(f"Missing required config value: ragas.{key}")
166 return str(value)
167
168
169 if __name__ == "__main__":
170 sys.exit(main())