api.py 12.3 KB
from __future__ import annotations

import logging
import time
from collections import Counter
from pathlib import Path
from typing import Any
from urllib.parse import urljoin

import requests

from weknora_eval.config import require_config
from weknora_eval.loaders import append_jsonl
from weknora_eval.sse import normalize_reference, parse_sse_events

logger = logging.getLogger(__name__)


class WeKnoraApiError(RuntimeError):
    pass


class WeKnoraClient:
    def __init__(
        self,
        *,
        base_url: str,
        api_key: str,
        knowledge_base_id: str,
        timeout_seconds: int = 300,
        request_interval_seconds: float = 0.2,
        error_log_path: str | Path = "data/runs/api_errors.jsonl",
        max_retries: int = 3,
    ) -> None:
        self.base_url = base_url.rstrip("/") + "/"
        self.api_key = api_key
        self.knowledge_base_id = knowledge_base_id
        self.timeout_seconds = timeout_seconds
        self.request_interval_seconds = request_interval_seconds
        self.error_log_path = Path(error_log_path)
        self.max_retries = max_retries
        self.session = requests.Session()
        self.session.headers.update({"X-API-Key": api_key})

    def create_knowledge_base(self, *, name: str) -> dict[str, Any]:
        return self._json_request("POST", "knowledge-bases", json={"name": name})

    def create_session(
        self,
        title: str,
        description: str = "Ragas evaluation session",
    ) -> dict[str, Any]:
        payload = {"title": title, "description": description}
        return self._json_request("POST", "sessions", json=payload)

    def upload_file(self, file_path: str | Path, *, enable_multimodel: bool = False) -> dict[str, Any]:
        self._ensure_knowledge_base_id()
        target = Path(file_path)
        with target.open("rb") as file:
            files = {"file": (target.name, file)}
            data = {"enable_multimodel": str(enable_multimodel).lower()}
            return self._json_request(
                "POST",
                f"knowledge-bases/{self.knowledge_base_id}/knowledge/file",
                files=files,
                data=data,
            )

    def list_knowledge(self, *, page_size: int = 100) -> list[dict[str, Any]]:
        self._ensure_knowledge_base_id()
        return self._paginate(
            f"knowledge-bases/{self.knowledge_base_id}/knowledge",
            page_size=page_size,
        )

    def wait_ingestion_completed(
        self,
        *,
        knowledge_ids: set[str] | None = None,
        timeout_seconds: int | None = None,
        poll_interval_seconds: float = 5.0,
    ) -> dict[str, list[dict[str, Any]]]:
        deadline = time.monotonic() + (timeout_seconds or self.timeout_seconds)
        target_ids = knowledge_ids or set()

        while time.monotonic() < deadline:
            rows = self.list_knowledge()
            if target_ids:
                rows = [row for row in rows if row.get("id") in target_ids]

            completed = [
                row
                for row in rows
                if self._is_ingestion_completed(row)
            ]
            failed = [row for row in rows if row.get("parse_status") == "failed"]

            if failed:
                return {"completed": completed, "failed": failed, "pending": []}
            if rows and len(completed) == len(rows):
                return {"completed": completed, "failed": [], "pending": []}

            pending = [row for row in rows if row not in completed]
            logger.info(
                "Waiting for ingestion: completed=%s pending=%s parse_status=%s enable_status=%s",
                len(completed),
                len(pending),
                dict(Counter(str(row.get("parse_status")) for row in rows)),
                dict(Counter(str(row.get("enable_status")) for row in rows)),
            )
            time.sleep(poll_interval_seconds)

        rows = self.list_knowledge()
        if target_ids:
            rows = [row for row in rows if row.get("id") in target_ids]
        completed = [
            row
            for row in rows
            if self._is_ingestion_completed(row)
        ]
        failed = [row for row in rows if row.get("parse_status") == "failed"]
        pending = [row for row in rows if row not in completed and row not in failed]
        return {"completed": completed, "failed": failed, "pending": pending}

    def list_chunks(self, knowledge_id: str, *, page_size: int = 100) -> list[dict[str, Any]]:
        return self._paginate(f"chunks/{knowledge_id}", page_size=page_size)

    def _is_ingestion_completed(self, row: dict[str, Any]) -> bool:
        parse_status = row.get("parse_status")
        enable_status = row.get("enable_status")
        parsed = parse_status in {"completed", "success", "done"} or parse_status in {2, "2"}
        enabled = enable_status in {"enabled", "success", "done"} or enable_status in {1, 2, "1", "2"}
        return parsed and enabled

    def knowledge_chat_sse(
        self,
        *,
        session_id: str,
        query: str,
        knowledge_ids: list[str] | None = None,
        knowledge_base_ids: list[str] | None = None,
        disable_title: bool = True,
        enable_memory: bool = False,
        channel: str = "api",
    ) -> dict[str, Any]:
        payload: dict[str, Any] = {
            "query": query,
            "disable_title": disable_title,
            "enable_memory": enable_memory,
            "channel": channel,
        }
        if knowledge_ids:
            payload["knowledge_ids"] = knowledge_ids
        else:
            self._ensure_knowledge_base_id()
            payload["knowledge_base_ids"] = knowledge_base_ids or [self.knowledge_base_id]

        url = self._url(f"knowledge-chat/{session_id}")
        response = self.session.post(
            url,
            json=payload,
            timeout=self.timeout_seconds,
            stream=True,
            headers={"Accept": "text/event-stream"},
        )
        if response.status_code >= 400:
            self._log_error("POST", url, response)
            raise WeKnoraApiError(f"POST {url} failed with HTTP {response.status_code}")

        answer_parts: list[str] = []
        references: list[dict[str, Any]] = []
        raw_events: list[dict[str, Any]] = []
        request_id: str | None = None
        seen_reference_ids: set[str] = set()

        for event in parse_sse_events(response.iter_lines(decode_unicode=True)):
            raw_events.append(event)
            data = event.get("data")
            if not isinstance(data, dict):
                continue
            request_id = request_id or data.get("id")
            response_type = data.get("response_type")
            if response_type == "references":
                for reference in data.get("knowledge_references") or []:
                    normalized = normalize_reference(reference)
                    reference_id = str(normalized.get("id") or "")
                    if reference_id and reference_id in seen_reference_ids:
                        continue
                    if reference_id:
                        seen_reference_ids.add(reference_id)
                    references.append(normalized)
            elif response_type == "answer" and not data.get("done"):
                answer_parts.append(data.get("content") or "")

        retrieved_contexts = [ref["content"] for ref in references if ref.get("content")]
        return {
            "request_id": request_id,
            "response": "".join(answer_parts).strip(),
            "retrieved_contexts": retrieved_contexts,
            "weknora_references": references,
            "raw_events": raw_events,
        }

    def load_messages(self, session_id: str, *, limit: int = 10) -> list[dict[str, Any]]:
        payload = self._json_request("GET", f"messages/{session_id}/load", params={"limit": limit})
        if isinstance(payload, list):
            return payload
        return []

    def knowledge_search(
        self,
        query: str,
        *,
        knowledge_ids: list[str] | None = None,
        knowledge_base_ids: list[str] | None = None,
    ) -> list[dict[str, Any]]:
        payload: dict[str, Any] = {"query": query}
        if knowledge_ids:
            payload["knowledge_ids"] = knowledge_ids
        else:
            self._ensure_knowledge_base_id()
            payload["knowledge_base_ids"] = knowledge_base_ids or [self.knowledge_base_id]
        data = self._json_request("POST", "knowledge-search", json=payload)
        return data if isinstance(data, list) else []

    def _paginate(self, path: str, *, page_size: int = 100) -> list[dict[str, Any]]:
        page = 1
        rows: list[dict[str, Any]] = []
        while True:
            envelope = self._request("GET", path, params={"page": page, "page_size": page_size})
            payload = self._decode_envelope(envelope)
            if not isinstance(payload, list):
                raise WeKnoraApiError(f"Expected list response for {path}, got {type(payload).__name__}")
            rows.extend(payload)

            total = int(envelope.get("total") or len(rows))
            if len(rows) >= total or not payload:
                return rows
            page += 1

    def _json_request(self, method: str, path: str, **kwargs: Any) -> Any:
        envelope = self._request(method, path, **kwargs)
        return self._decode_envelope(envelope)

    def _request(self, method: str, path: str, **kwargs: Any) -> dict[str, Any]:
        url = self._url(path)
        last_error: Exception | None = None
        for attempt in range(1, self.max_retries + 1):
            try:
                response = self.session.request(
                    method,
                    url,
                    timeout=self.timeout_seconds,
                    **kwargs,
                )
                if response.status_code in {429, 500, 502, 503, 504} and attempt < self.max_retries:
                    time.sleep(attempt)
                    continue
                if response.status_code >= 400:
                    self._log_error(method, url, response)
                    body = response.text[:1000]
                    raise WeKnoraApiError(
                        f"{method} {url} failed with HTTP {response.status_code}: {body}"
                    )
                time.sleep(self.request_interval_seconds)
                return response.json()
            except (requests.RequestException, ValueError) as exc:
                last_error = exc
                if attempt >= self.max_retries:
                    break
                time.sleep(attempt)

        raise WeKnoraApiError(f"{method} {url} failed: {last_error}") from last_error

    def _decode_envelope(self, envelope: dict[str, Any]) -> Any:
        if envelope.get("success") is False:
            raise WeKnoraApiError(str(envelope))
        return envelope.get("data", envelope)

    def _url(self, path: str) -> str:
        return urljoin(self.base_url, path.lstrip("/"))

    def _ensure_knowledge_base_id(self) -> None:
        if not self.knowledge_base_id:
            raise WeKnoraApiError("Missing knowledge_base_id. Run scripts/00_create_kb.py first.")

    def _log_error(self, method: str, url: str, response: requests.Response) -> None:
        body = response.text[:5000]
        append_jsonl(
            self.error_log_path,
            {
                "method": method,
                "url": url,
                "status_code": response.status_code,
                "response_body": body,
            },
        )


def client_from_config(config: dict[str, Any]) -> WeKnoraClient:
    weknora = config["weknora"]
    return WeKnoraClient(
        base_url=require_config(config, "weknora.base_url"),
        api_key=require_config(config, "weknora.api_key"),
        knowledge_base_id=require_config(config, "weknora.knowledge_base_id"),
        timeout_seconds=int(weknora.get("timeout_seconds", 300)),
        request_interval_seconds=float(weknora.get("request_interval_seconds", 0.2)),
    )


def bootstrap_client_from_config(config: dict[str, Any]) -> WeKnoraClient:
    weknora = config["weknora"]
    return WeKnoraClient(
        base_url=require_config(config, "weknora.base_url"),
        api_key=require_config(config, "weknora.api_key"),
        knowledge_base_id=str(weknora.get("knowledge_base_id") or ""),
        timeout_seconds=int(weknora.get("timeout_seconds", 300)),
        request_interval_seconds=float(weknora.get("request_interval_seconds", 0.2)),
    )