diff --git a/src/lm_buddy/jobs/ragas/entrypoint.py b/src/lm_buddy/jobs/ragas/entrypoint.py index 32bde96f..fa6b4b61 100644 --- a/src/lm_buddy/jobs/ragas/entrypoint.py +++ b/src/lm_buddy/jobs/ragas/entrypoint.py @@ -1,7 +1,10 @@ +import os from pathlib import Path import ray from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset +from langchain.chat_models import ChatOpenAI +from langchain_core.embeddings import Embeddings from ragas import evaluate from lm_buddy.integrations.wandb import get_wandb_summary, update_wandb_summary @@ -48,13 +51,33 @@ def _load_dataset_for_ragas_eval( def evaluation_task(config: RagasEvaluationJobConfig) -> None: + # ragas custom model args + ragas_args = {} + + # set up custom embedding model for ragas (only supports langchain embedding models right now) + if config.judge_model.embedding_model: + ragas_args["embeddings"] = Embeddings( + model=config.judge_model.embedding_model + ) # any langchain embedding instance + + # set up custom judge LLM model (called from vllm server) + if config.judge_model.language_model: + # create vLLM Langchain instance + vllm_entry = ChatOpenAI( + model=config.judge_model.language_model, + openai_api_key=config.judge_model.openai_api_key, + # get api endpoint from environment variable + openai_api_base=os.environ.get("VLLM_JUDGE_ENDPOINT"), + max_tokens=config.judge_model.max_tokens, + temperature=config.judge_model.temperature, + ) + + ragas_args["llm"] = vllm_entry + dataset = _load_dataset_for_ragas_eval(config) print("Initializing ragas eval task...") - result = evaluate( - dataset=dataset, - metrics=config.metrics, - ) + result = evaluate(dataset=dataset, metrics=config.metrics, **ragas_args) print(f"Obtained evaluation results: {result}") diff --git a/src/lm_buddy/jobs/ragas/ragas_config.py b/src/lm_buddy/jobs/ragas/ragas_config.py index 0da30fde..41c07c32 100644 --- a/src/lm_buddy/jobs/ragas/ragas_config.py +++ b/src/lm_buddy/jobs/ragas/ragas_config.py @@ -32,17 +32,25 @@ class RagasConfig(BaseLMBuddyConfig): class RagasvLLMJudgeConfig(BaseLMBuddyConfig): """ Configuration class for a vLLM hosted judge model - Requires a vLLM endpoint that the model will hit instead of the openAI default + Requires a vLLM endpoint that the model will hit instead of the openAI default, + the url for which is to be passed as env variable """ - model: AutoModelConfig - # inference_server_url: str | None = "http://localhost:8080/v1" - openai_api_key: str | None = "no-key" + language_model: AutoModelConfig + embedding_model: AutoModelConfig + openai_api_key: str | None = "nokey" max_tokens: int | None = 5 temperature: float | None = 0 - @field_validator("model", mode="before", always=True) - def validate_model_arg(cls, x): + @field_validator("language_model", mode="before", always=True) + def validate_inference_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(load_from=x) + return x + + @field_validator("embedding_model", mode="before", always=True) + def validate_embedding_model_arg(cls, x): """Allow for passing just a path string as the model argument.""" if isinstance(x, str): return AutoModelConfig(load_from=x) @@ -77,6 +85,7 @@ class RagasEvaluationJobConfig(BaseLMBuddyConfig): """ # evaluation settings for ragas + dataset: RagasEvaluationDatasetConfig evaluator: RagasConfig