Add configured model connectivity check
Showing
2 changed files
with
171 additions
and
0 deletions
| ... | @@ -49,6 +49,7 @@ cp .env.example .env | ... | @@ -49,6 +49,7 @@ cp .env.example .env |
| 49 | 49 | ||
| 50 | ```bash | 50 | ```bash |
| 51 | python scripts/00_create_kb.py | 51 | python scripts/00_create_kb.py |
| 52 | python scripts/00_check_models.py | ||
| 52 | python scripts/01_upload_docs.py | 53 | python scripts/01_upload_docs.py |
| 53 | python scripts/02_wait_ingestion.py | 54 | python scripts/02_wait_ingestion.py |
| 54 | python scripts/03_export_chunks.py | 55 | python scripts/03_export_chunks.py | ... | ... |
scripts/00_check_models.py
0 → 100644
| 1 | from __future__ import annotations | ||
| 2 | |||
| 3 | import sys | ||
| 4 | import time | ||
| 5 | from typing import Any | ||
| 6 | |||
| 7 | import _bootstrap # noqa: F401 | ||
| 8 | |||
| 9 | import requests | ||
| 10 | from langchain_openai import ChatOpenAI, OpenAIEmbeddings | ||
| 11 | |||
| 12 | from weknora_eval.config import load_config | ||
| 13 | |||
| 14 | |||
| 15 | def main() -> int: | ||
| 16 | config = load_config() | ||
| 17 | ragas = config["ragas"] | ||
| 18 | failures: list[str] = [] | ||
| 19 | |||
| 20 | print("Checking configured Ragas model services...\n") | ||
| 21 | |||
| 22 | failures.extend( | ||
| 23 | check_chat_model( | ||
| 24 | title="Generator LLM", | ||
| 25 | base_url=require_value(ragas, "llm_base_url"), | ||
| 26 | api_key=require_value(ragas, "llm_api_key"), | ||
| 27 | model=require_value(ragas, "generator_model"), | ||
| 28 | temperature=float(ragas.get("temperature", 0)), | ||
| 29 | max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024), | ||
| 30 | ) | ||
| 31 | ) | ||
| 32 | failures.extend( | ||
| 33 | check_chat_model( | ||
| 34 | title="Judge LLM", | ||
| 35 | base_url=require_value(ragas, "llm_base_url"), | ||
| 36 | api_key=require_value(ragas, "llm_api_key"), | ||
| 37 | model=require_value(ragas, "judge_model"), | ||
| 38 | temperature=float(ragas.get("temperature", 0)), | ||
| 39 | max_tokens=min(int(ragas.get("max_tokens", 1024)), 1024), | ||
| 40 | ) | ||
| 41 | ) | ||
| 42 | failures.extend( | ||
| 43 | check_embedding_model( | ||
| 44 | base_url=require_value(ragas, "embedding_base_url"), | ||
| 45 | api_key=require_value(ragas, "embedding_api_key"), | ||
| 46 | model=require_value(ragas, "embedding_model"), | ||
| 47 | ) | ||
| 48 | ) | ||
| 49 | |||
| 50 | reranker_base_url = str(ragas.get("reranker_base_url") or "") | ||
| 51 | reranker_model = str(ragas.get("reranker_model") or "") | ||
| 52 | if reranker_base_url and reranker_model: | ||
| 53 | failures.extend( | ||
| 54 | check_reranker_model( | ||
| 55 | base_url=reranker_base_url, | ||
| 56 | api_key=str(ragas.get("reranker_api_key") or ""), | ||
| 57 | model=reranker_model, | ||
| 58 | ) | ||
| 59 | ) | ||
| 60 | else: | ||
| 61 | print("[SKIP] Reranker: RAGAS_RERANKER_BASE_URL or RAGAS_RERANKER_MODEL is empty\n") | ||
| 62 | |||
| 63 | if failures: | ||
| 64 | print("Model service check failed:") | ||
| 65 | for failure in failures: | ||
| 66 | print(f"- {failure}") | ||
| 67 | return 1 | ||
| 68 | |||
| 69 | print("All configured model services are reachable.") | ||
| 70 | return 0 | ||
| 71 | |||
| 72 | |||
| 73 | def check_chat_model( | ||
| 74 | *, | ||
| 75 | title: str, | ||
| 76 | base_url: str, | ||
| 77 | api_key: str, | ||
| 78 | model: str, | ||
| 79 | temperature: float, | ||
| 80 | max_tokens: int, | ||
| 81 | ) -> list[str]: | ||
| 82 | print(f"[CHECK] {title}: model={model} base_url={base_url}") | ||
| 83 | started = time.monotonic() | ||
| 84 | try: | ||
| 85 | llm = ChatOpenAI( | ||
| 86 | model=model, | ||
| 87 | api_key=api_key, | ||
| 88 | base_url=base_url, | ||
| 89 | temperature=temperature, | ||
| 90 | max_tokens=max_tokens, | ||
| 91 | timeout=120, | ||
| 92 | ) | ||
| 93 | response = llm.invoke("Reply with exactly: OK") | ||
| 94 | content = str(response.content or "").strip() | ||
| 95 | elapsed = time.monotonic() - started | ||
| 96 | if not content: | ||
| 97 | return [f"{title} returned an empty response"] | ||
| 98 | print(f"[OK] {title}: {elapsed:.2f}s response={content[:80]!r}\n") | ||
| 99 | return [] | ||
| 100 | except Exception as exc: # noqa: BLE001 | ||
| 101 | return [f"{title} failed: {exc}"] | ||
| 102 | |||
| 103 | |||
| 104 | def check_embedding_model(*, base_url: str, api_key: str, model: str) -> list[str]: | ||
| 105 | print(f"[CHECK] Embedding: model={model} base_url={base_url}") | ||
| 106 | started = time.monotonic() | ||
| 107 | try: | ||
| 108 | embeddings = OpenAIEmbeddings( | ||
| 109 | model=model, | ||
| 110 | api_key=api_key, | ||
| 111 | base_url=base_url, | ||
| 112 | tiktoken_enabled=False, | ||
| 113 | check_embedding_ctx_length=False, | ||
| 114 | request_timeout=120, | ||
| 115 | ) | ||
| 116 | vector = embeddings.embed_query("hello") | ||
| 117 | elapsed = time.monotonic() - started | ||
| 118 | if not vector: | ||
| 119 | return ["Embedding returned an empty vector"] | ||
| 120 | print(f"[OK] Embedding: {elapsed:.2f}s dimensions={len(vector)} first3={vector[:3]}\n") | ||
| 121 | return [] | ||
| 122 | except Exception as exc: # noqa: BLE001 | ||
| 123 | return [f"Embedding failed: {exc}"] | ||
| 124 | |||
| 125 | |||
| 126 | def check_reranker_model(*, base_url: str, api_key: str, model: str) -> list[str]: | ||
| 127 | print(f"[CHECK] Reranker: model={model} base_url={base_url}") | ||
| 128 | url = base_url.rstrip("/") + "/rerank" | ||
| 129 | headers = {"Content-Type": "application/json"} | ||
| 130 | if api_key: | ||
| 131 | headers["Authorization"] = f"Bearer {api_key}" | ||
| 132 | payload = { | ||
| 133 | "model": model, | ||
| 134 | "query": "付款期限是什么?", | ||
| 135 | "documents": [ | ||
| 136 | "买方应在收到合法有效发票后30日内完成付款。", | ||
| 137 | "本合同自双方签字盖章之日起生效。", | ||
| 138 | ], | ||
| 139 | } | ||
| 140 | started = time.monotonic() | ||
| 141 | try: | ||
| 142 | response = requests.post(url, headers=headers, json=payload, timeout=120) | ||
| 143 | elapsed = time.monotonic() - started | ||
| 144 | if response.status_code >= 400: | ||
| 145 | return [f"Reranker failed with HTTP {response.status_code}: {response.text[:500]}"] | ||
| 146 | payload = response.json() | ||
| 147 | if not _has_rerank_results(payload): | ||
| 148 | return [f"Reranker returned no recognizable results: {payload}"] | ||
| 149 | print(f"[OK] Reranker: {elapsed:.2f}s response_keys={list(payload.keys())}\n") | ||
| 150 | return [] | ||
| 151 | except Exception as exc: # noqa: BLE001 | ||
| 152 | return [f"Reranker failed: {exc}"] | ||
| 153 | |||
| 154 | |||
| 155 | def _has_rerank_results(payload: dict[str, Any]) -> bool: | ||
| 156 | for key in ("results", "data"): | ||
| 157 | if isinstance(payload.get(key), list) and payload[key]: | ||
| 158 | return True | ||
| 159 | return False | ||
| 160 | |||
| 161 | |||
| 162 | def require_value(config: dict[str, Any], key: str) -> str: | ||
| 163 | value = config.get(key) | ||
| 164 | if value in {None, ""}: | ||
| 165 | raise ValueError(f"Missing required config value: ragas.{key}") | ||
| 166 | return str(value) | ||
| 167 | |||
| 168 | |||
| 169 | if __name__ == "__main__": | ||
| 170 | sys.exit(main()) |
-
Please register or sign in to post a comment