sse.py 1.83 KB
from __future__ import annotations

import json
from collections.abc import Iterable, Iterator
from typing import Any


def parse_sse_events(lines: Iterable[str | bytes]) -> Iterator[dict[str, Any]]:
    event_name = "message"
    data_lines: list[str] = []

    for raw_line in lines:
        line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
        line = line.rstrip("\r\n")

        if not line:
            if data_lines:
                yield _build_event(event_name, data_lines)
                event_name = "message"
                data_lines = []
            continue

        if line.startswith(":"):
            continue
        if line.startswith("event:"):
            event_name = line.removeprefix("event:").strip()
            continue
        if line.startswith("data:"):
            data_lines.append(line.removeprefix("data:").strip())

    if data_lines:
        yield _build_event(event_name, data_lines)


def _build_event(event_name: str, data_lines: list[str]) -> dict[str, Any]:
    raw_data = "\n".join(data_lines)
    parsed_data: Any = raw_data
    if raw_data and raw_data != "[DONE]":
        try:
            parsed_data = json.loads(raw_data)
        except json.JSONDecodeError:
            parsed_data = raw_data
    return {"event": event_name, "data": parsed_data}


def normalize_reference(reference: dict[str, Any]) -> dict[str, Any]:
    return {
        "id": reference.get("id"),
        "content": reference.get("content") or "",
        "knowledge_id": reference.get("knowledge_id"),
        "chunk_index": reference.get("chunk_index"),
        "score": reference.get("score"),
        "knowledge_filename": reference.get("knowledge_filename")
        or reference.get("knowledge_title"),
        "match_type": reference.get("match_type"),
        "chunk_type": reference.get("chunk_type"),
    }