ragas_runner.py 4.55 KB
from __future__ import annotations

import os
from pathlib import Path
from typing import Any

import pandas as pd

from weknora_eval.config import require_config
from weknora_eval.loaders import read_jsonl


def run_ragas_eval(
    config: dict[str, Any],
    *,
    input_path: str = "data/runs/ragas_input.jsonl",
    output_csv_path: str = "data/reports/ragas_scores.csv",
) -> pd.DataFrame:
    from datasets import Dataset
    from langchain_openai import ChatOpenAI, OpenAIEmbeddings
    from ragas import evaluate
    from ragas.run_config import RunConfig

    ragas_config = config["ragas"]
    llm_api_key = _required_ragas_value(ragas_config, "llm_api_key")
    llm_base_url = _required_ragas_value(ragas_config, "llm_base_url")
    embedding_api_key = _required_ragas_value(ragas_config, "embedding_api_key")
    embedding_base_url = _required_ragas_value(ragas_config, "embedding_base_url")
    judge_model = str(require_config(config, "ragas.judge_model"))
    embedding_model = str(require_config(config, "ragas.embedding_model"))
    temperature = float(ragas_config.get("temperature", 0))
    max_tokens = int(ragas_config.get("max_tokens", 4096))
    timeout_seconds = int(ragas_config.get("timeout_seconds", 600))
    max_workers = int(ragas_config.get("max_workers", 1))

    os.environ["OPENAI_API_KEY"] = llm_api_key
    if llm_base_url:
        os.environ["OPENAI_BASE_URL"] = llm_base_url

    rows = read_jsonl(input_path)
    dataset = Dataset.from_list(
        [
            {
                "user_input": row["user_input"],
                "response": row["response"],
                "retrieved_contexts": row["retrieved_contexts"],
                "reference": row["reference"],
                "reference_contexts": row.get("reference_contexts") or [],
            }
            for row in rows
        ]
    )

    metric_map = _metric_map()
    selected_metrics = [
        metric_map[name]
        for name in ragas_config.get("metrics", metric_map.keys())
        if name in metric_map
    ]

    llm = ChatOpenAI(
        model=judge_model,
        api_key=llm_api_key,
        base_url=llm_base_url or None,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    embeddings = OpenAIEmbeddings(
        model=embedding_model,
        api_key=embedding_api_key,
        base_url=embedding_base_url or None,
        tiktoken_enabled=False,
        check_embedding_ctx_length=False,
    )
    ragas_llm, ragas_embeddings = _wrap_langchain_models(llm, embeddings)

    run_config = RunConfig(timeout=timeout_seconds, max_workers=max_workers)
    result = evaluate(
        dataset,
        metrics=selected_metrics,
        llm=ragas_llm,
        embeddings=ragas_embeddings,
        run_config=run_config,
    )
    scores = result.to_pandas()
    for index, row in enumerate(rows):
        scores.loc[index, "sample_id"] = row.get("sample_id")

    target = Path(output_csv_path)
    target.parent.mkdir(parents=True, exist_ok=True)
    scores.to_csv(target, index=False)
    return scores


def _metric_map() -> dict[str, Any]:
    try:
        from ragas.metrics import (
            context_precision,
            context_recall,
            faithfulness,
            factual_correctness,
            response_relevancy,
        )

        return {
            "faithfulness": faithfulness,
            "response_relevancy": response_relevancy,
            "context_precision": context_precision,
            "context_recall": context_recall,
            "factual_correctness": factual_correctness,
        }
    except ImportError:
        from ragas.metrics import (
            Faithfulness,
            FactualCorrectness,
            LLMContextPrecisionWithReference,
            LLMContextRecall,
            ResponseRelevancy,
        )

        return {
            "faithfulness": Faithfulness(),
            "response_relevancy": ResponseRelevancy(),
            "context_precision": LLMContextPrecisionWithReference(),
            "context_recall": LLMContextRecall(),
            "factual_correctness": FactualCorrectness(),
        }


def _required_ragas_value(config: dict[str, Any], key: str) -> str:
    value = config.get(key)
    if value in {None, ""}:
        raise ValueError(f"Missing required Ragas config value: ragas.{key}")
    return str(value)


def _wrap_langchain_models(llm: Any, embeddings: Any) -> tuple[Any, Any]:
    try:
        from ragas.embeddings import LangchainEmbeddingsWrapper
        from ragas.llms import LangchainLLMWrapper
    except ImportError:
        return llm, embeddings

    return LangchainLLMWrapper(llm), LangchainEmbeddingsWrapper(embeddings)