Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
ragas endpoint saving to summary metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
imtihan committed Jan 18, 2024
1 parent a1056f1 commit 85790ca
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
78 changes: 78 additions & 0 deletions src/flamingo/jobs/drivers/ragas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from pathlib import Path

import ray
from datasets import load_dataset
from ragas import evaluate

from flamingo.configs import RagasEvaluationConfig
from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary


def resolve_data_path(config: RagasEvaluationConfig) -> str:
data_path = None
if config.data_path:
print("Attempting to resolve data path from existing W&B data processing run...")
run_summary = get_wandb_summary(config.wandb_env)
path = Path(run_summary["dataset_path"])
print(f"Using data path from wandb run: {path}")
if not path.exists():
raise (FileNotFoundError(f"{path} cannot be found."))
data_path = str(path)
else:
data_path = str(config.data_path)
print(f"Dataset directory path: {data_path}")
return data_path


@ray.remote
def evaluation_task(config: RagasEvaluationConfig, evaluation_dataset_to_load: str) -> None:
print(f"Loading dataset from {evaluation_dataset_to_load}")
dataset = (
load_dataset(evaluation_dataset_to_load)
if config.is_hf_dataset
else load_dataset("parquet", data_files=evaluation_dataset_to_load)
)

print(
f"Remapping data columns to be ragas compatible with \
the following {config.data_column_names}"
)
dataset.rename_column("question", config.data_column_names["question"])
dataset.rename_column("answer", config.data_column_names["answer"])
dataset.rename_column("contexts", config.data_column_names["contexts"])

print("Initializing ragas eval task...")
result = evaluate(
dataset=dataset,
metrics=config.metrics,
)

print(f"Obtained evaluation results: {result}")

if config.wandb_env:
print("Logging results to W&B...")
update_wandb_summary(config.wandb_env, result)


def run_ragas_evaluation(config: RagasEvaluationConfig):
print(f"Received job configuration: {config}")

# Resolve path and ensure exists
evaluation_dataset_to_load = resolve_data_path(config)

# Using .options() to dynamically specify resource requirements
eval_func = evaluation_task.options(num_cpus=config.num_cpus, num_gpus=config.num_gpus)
eval_future = eval_func.remote(config, evaluation_dataset_to_load)

timeout_seconds = config.timeout.seconds if config.timeout else None
try:
print("Waiting on evaluation task...")
ray.get(eval_future, timeout=timeout_seconds)
print("Evaluation successfully completed")
except TimeoutError:
print(
f"Evaluation task timed out after {timeout_seconds} sec. "
"If the evaluation runner finished but the task failed to shut down, "
"please check if your results were still generated and persisted."
)
raise
9 changes: 4 additions & 5 deletions src/flamingo/jobs/ragas_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ragas.metrics import (
answer_relevancy,
context_precision,
context_recall,
faithfulness,
)
from ragas.metrics.base import MetricWithLLM
Expand All @@ -21,20 +20,20 @@ class RagasEvaluationConfig(BaseJobConfig):
the contexts (retrieved), and optionally a ground truth field.
"""

is_hf_dataset: bool | None = False
data_path: str | Path | None = None

# remap columns so data table does not need to be edited
data_column_names: dict = {
"question": "question",
"answer": "answer",
"context": "context",
"ground_truth": "ground_truth",
"contexts": "contexts",
}

metrics: list[MetricWithLLM] = [
faithfulness,
answer_relevancy,
context_recall,
# context_recall,
context_precision,
]

Expand All @@ -47,7 +46,7 @@ class RagasEvaluationConfig(BaseJobConfig):

@validator("data_column_names")
def _validate_data_column_names(cls, v):
required_keys = {"question", "answer", "context", "ground_truth"}
required_keys = {"question", "answer", "contexts"}
provided_keys = set(v.keys())
missing_keys = required_keys.difference(provided_keys)
if not missing_keys:
Expand Down

0 comments on commit 85790ca

Please sign in to comment.