diff --git a/src/flamingo/integrations/wandb/artifact_type.py b/src/flamingo/integrations/wandb/artifact_type.py index f477e9e9..62227a01 100644 --- a/src/flamingo/integrations/wandb/artifact_type.py +++ b/src/flamingo/integrations/wandb/artifact_type.py @@ -2,8 +2,9 @@ class ArtifactType(str, Enum): - """Enumeration of artifact types generated by the Flamingo.""" + """Enumeration of artifact types used by the Flamingo.""" DATASET = "dataset" MODEL = "model" TOKENIZER = "tokenizer" + EVALUATION = "evaluation" diff --git a/src/flamingo/jobs/finetuning/entrypoint.py b/src/flamingo/jobs/finetuning/entrypoint.py index 9b6011a3..2de8db6e 100644 --- a/src/flamingo/jobs/finetuning/entrypoint.py +++ b/src/flamingo/jobs/finetuning/entrypoint.py @@ -5,7 +5,7 @@ from accelerate import Accelerator from datasets import DatasetDict from ray import train -from ray.train import CheckpointConfig, RunConfig, ScalingConfig +from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer from ray.train.torch import TorchTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments @@ -17,6 +17,14 @@ from flamingo.jobs.finetuning import FinetuningJobConfig +def build_model_artifact(run_name: str, checkpoint: Checkpoint) -> wandb.Artifact: + print("Building artifact for model checkpoint...") + artifact_name = default_artifact_name(run_name, ArtifactType.MODEL) + artifact = wandb.Artifact(artifact_name, type=ArtifactType.MODEL.value) + artifact.add_reference(f"file://{checkpoint.path}/checkpoint") + return artifact + + def is_tracking_enabled(config: FinetuningJobConfig): # Only report to WandB on the rank 0 worker # Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html @@ -143,8 +151,5 @@ def run_finetuning(config: FinetuningJobConfig): if config.tracking and result.checkpoint: # Must resume from the just-completed training run with wandb.init(**config.tracking.wandb_init_args(), resume="must") as run: - print("Generating artifact for training results...") - artifact_name = default_artifact_name(run.name, ArtifactType.MODEL) - artifact = wandb.Artifact(artifact_name, type=ArtifactType.MODEL.value) - artifact.add_reference(f"file://{result.checkpoint.path}/checkpoint") + artifact = build_model_artifact(run.name, result.checkpoint) run.log_artifact(artifact) diff --git a/src/flamingo/jobs/lm_harness/entrypoint.py b/src/flamingo/jobs/lm_harness/entrypoint.py index cedccb17..30f40e2f 100644 --- a/src/flamingo/jobs/lm_harness/entrypoint.py +++ b/src/flamingo/jobs/lm_harness/entrypoint.py @@ -1,18 +1,25 @@ from typing import Any import lm_eval -import pandas as pd import ray import wandb from lm_eval.models.huggingface import HFLM from peft import PeftConfig -from flamingo.integrations.wandb import WandbArtifactLoader +from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader +from flamingo.integrations.wandb.utils import default_artifact_name from flamingo.jobs.lm_harness import LMHarnessJobConfig -def log_results_artifacts(results: dict[str, dict[str, Any]]) -> list[pd.DataFrame]: - pass +def build_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) -> wandb.Artifact: + print("Building artifact for evaluation results...") + artifact_name = default_artifact_name(run_name, ArtifactType.EVALUATION) + artifact = wandb.Artifact(artifact_name, type=ArtifactType.EVALUATION.value) + for task_name, task_results in results.values(): + task_data = list(task_results.items()) + task_table = wandb.Table(data=task_data, columns=["metric", "value"]) + artifact.add(task_table, name=f"task-{task_name}") + return artifact def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> HFLM: @@ -45,19 +52,10 @@ def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) ) -@ray.remote -def evaluation_task(config: LMHarnessJobConfig) -> None: +def evaluate_with_loader(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> dict[str, Any]: print("Initializing lm-harness tasks...") lm_eval.tasks.initialize_tasks() - - wandb_run = None - if config.tracking is not None: - wandb_run = wandb.init(**config.tracking.wandb_init_args(), resume="never") - - # Load the model to evaluate via the lm-harness interface - artifact_loader = WandbArtifactLoader(wandb_run) - llm = load_harness_model(config, artifact_loader) - + llm = load_harness_model(config, loader) eval_results = lm_eval.simple_evaluate( model=llm, tasks=config.evaluator.tasks, @@ -68,10 +66,20 @@ def evaluation_task(config: LMHarnessJobConfig) -> None: ) eval_results = eval_results["results"] print(f"Obtained evaluation results: {eval_results}") + return eval_results + +@ray.remote +def evaluation_task(config: LMHarnessJobConfig) -> None: if config.tracking is not None: - print("Generating artifacts for evaluation results...") - log_results_artifacts(eval_results, wandb_run) + with wandb.init(**config.tracking.wandb_init_args(), resume="never") as run: + artifact_loader = WandbArtifactLoader(run=run) + eval_results = evaluate_with_loader(config, artifact_loader) + artifact = build_evaluation_artifact(run.name, eval_results) + run.log_artifact(artifact) + else: + artifact_loader = WandbArtifactLoader(run=None) + evaluate_with_loader(config, artifact_loader) def run_lm_harness(config: LMHarnessJobConfig):