diff --git a/examples/configs/lm_harness_config.yaml b/examples/configs/lm_harness_config.yaml index 64592ba4..6d91b607 100644 --- a/examples/configs/lm_harness_config.yaml +++ b/examples/configs/lm_harness_config.yaml @@ -23,6 +23,6 @@ tracking: entity: "another-entity" ray: - use_gpu: True - num_workers: 4 + num_cpus: 1 + num_gpus: 4 timeout: 3600 diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index eb12d27e..9f4b065b 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -40,6 +40,7 @@ def get_artifact_directory(artifact: wandb.Artifact) -> str: case 1: return list(dir_paths)[0] case _: + # TODO: Can this be resolved somehow else??? dir_string = ",".join(dir_paths) raise ValueError( f"Artifact {artifact.name} references multiple directories: {dir_string}. " diff --git a/src/flamingo/jobs/drivers/finetuning.py b/src/flamingo/jobs/drivers/finetuning.py index a7f9cdfd..dfe5b80b 100644 --- a/src/flamingo/jobs/drivers/finetuning.py +++ b/src/flamingo/jobs/drivers/finetuning.py @@ -144,6 +144,8 @@ def run_finetuning(config: FinetuningJobConfig): # Must resume from the just-completed training run with wandb.init(**config.tracking.wandb_init_args(), resume="must") as run: artifact_type = ArtifactType.MODEL.value + print(f"Generating {artifact_type} artifact of training results...") + artifact_name = f"{config.tracking.name or config.tracking.run_id}-{artifact_type}" artifact = wandb.Artifact(artifact_name, type=artifact_type) artifact.add_reference(f"file://{result.checkpoint.path}/checkpoint") diff --git a/src/flamingo/jobs/drivers/lm_harness.py b/src/flamingo/jobs/drivers/lm_harness.py index 4531dda1..fd23272a 100644 --- a/src/flamingo/jobs/drivers/lm_harness.py +++ b/src/flamingo/jobs/drivers/lm_harness.py @@ -1,77 +1,56 @@ -from pathlib import Path - import lm_eval import ray +import wandb from lm_eval.models.huggingface import HFLM from peft import PeftConfig -from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary -from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath - +from flamingo.integrations.wandb import WandbArtifactLoader +from flamingo.integrations.wandb.utils import resolve_artifact_path +from flamingo.jobs import LMHarnessJobConfig -def resolve_model_or_path(config: LMHarnessJobConfig) -> str: - mn_or_path = None - match config.model_name_or_path: - case None: - print("Attempting to resolve checkpoint path from existing W&B run...") - run_summary = get_wandb_summary(config.wandb_env) - cp = Path(run_summary["ray/checkpoint_path"]) - print(f"Using checkpoint path from wandb run: {cp}") - if not cp.exists(): - raise (FileNotFoundError(f"{mn_or_path} cannot be found.")) - mn_or_path = str(cp) - case ModelNameOrCheckpointPath(checkpoint=None) as x: - print("No checkpoint; will attempt to load model from HuggingFace") - mn_or_path = x.name - case ModelNameOrCheckpointPath(checkpoint=ckpt): - print(f"Checkpoint found; will attempt to load model from {ckpt}") - mn_or_path = ckpt - case _: - raise ( - ValueError( - "Something is wrong with the passed " - f"model_name_or_path: {config.model_name_or_path}" - ) - ) - return mn_or_path +def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> HFLM: + model_path = resolve_artifact_path(config.model.path, loader) -def load_harness_model(config: LMHarnessJobConfig, model_to_load: str) -> HFLM: # We don't know if the checkpoint is adapter weights or merged model weights # Try to load as an adapter and fall back to the checkpoint containing the full model try: - adapter_config = PeftConfig.from_pretrained(model_to_load) + adapter_config = PeftConfig.from_pretrained(model_path) pretrained = adapter_config.base_model_name_or_path - peft = model_to_load + peft = model_path except ValueError as e: print( f"Unable to load model as adapter: {e}. " "This is expected if the checkpoint does not contain adapter weights." ) - pretrained = model_to_load + pretrained = model_path peft = None # Return lm-harness model wrapper class - quantization_kwargs = config.quantization_config.dict() if config.quantization_config else {} + quantization_kwargs = config.quantization.dict() if config.quantization else {} return HFLM( pretrained=pretrained, tokenizer=pretrained, peft=peft, - device="cuda" if config.num_gpus > 0 else None, - trust_remote_code=config.trust_remote_code, - dtype=config.torch_dtype if config.torch_dtype else "auto", + device="cuda" if config.ray.use_gpu else None, + trust_remote_code=config.model.trust_remote_code, + dtype=config.model.torch_dtype if config.model.torch_dtype else "auto", **quantization_kwargs, ) -@ray.remote +@ray.remote(num_cpus=) def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None: print("Initializing lm-harness tasks...") lm_eval.tasks.initialize_tasks() - print("Running lm-harness evaluation inside remote function...") - llm = load_harness_model(config, model_to_load) - raw_results = lm_eval.simple_evaluate( + wandb_run = None + if config.tracking is not None: + wandb_run = wandb.init(**config.tracking.wandb_init_args(), resume="never") + artifact_loader = WandbArtifactLoader(wandb_run) + + llm = load_harness_model(config, artifact_loader) + eval_results = lm_eval.simple_evaluate( model=llm, tasks=config.tasks, num_fewshot=config.num_fewshot, @@ -79,36 +58,26 @@ def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None: limit=config.limit, log_samples=False, ) - print("Finished lm-harness evaluation inside remote function") + eval_results = eval_results["results"] + print(f"Obtained evaluation results: {eval_results}") - formatted_results = {} - for task_name, metrics in raw_results["results"].items(): - task_metrics = { - f"{task_name}/{metric.replace(',', '_')}": value for metric, value in metrics.items() - } - formatted_results.update(task_metrics) - print(f"Obtained evaluation results: {formatted_results}") - - if config.wandb_env: - print("Logging results to W&B...") - update_wandb_summary(config.wandb_env, formatted_results) + if config.tracking is not None: + print("Generating table artifact of evaluation results...") + pass def run_lm_harness(config: LMHarnessJobConfig): print(f"Received job configuration: {config}") - # Resolve path and ensure exists - model_to_load = resolve_model_or_path(config) - # Using .options() to dynamically specify resource requirements - eval_func = evaluation_task.options(num_cpus=config.num_cpus, num_gpus=config.num_gpus) - eval_future = eval_func.remote(config, model_to_load) + eval_func = evaluation_task.options(num_cpus=config.ray.num_cpus, num_gpus=config.ray.num_gpus) + eval_future = eval_func.remote(config) timeout_seconds = config.timeout.seconds if config.timeout else None try: print("Waiting on evaluation task...") ray.get(eval_future, timeout=timeout_seconds) - print("Evaluation successfully completed") + print("Evaluation successfully completed!") except TimeoutError: print( f"Evaluation task timed out after {timeout_seconds} sec. " diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index cce3459b..d61eb8c7 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -10,8 +10,8 @@ class LMHarnessRayConfig(BaseFlamingoConfig): """Misc settings for Ray compute in the LM harness job.""" - use_gpu: bool = True - num_workers: int = 1 + num_cpus: int | float = 1 + num_gpus: int | float = 1 timeout: datetime.timedelta | None = None diff --git a/tests/integrations/wandb/test_artifact_config.py b/tests/integrations/wandb/test_artifact_config.py index af4b9bb0..5d1c85cb 100644 --- a/tests/integrations/wandb/test_artifact_config.py +++ b/tests/integrations/wandb/test_artifact_config.py @@ -18,4 +18,4 @@ def test_serde_round_trip(wandb_artifact_config): def test_wandb_path(wandb_artifact_config): - assert wandb_artifact_config.get_wandb_path() == "twitter/cortex/artifact-name:latest" + assert wandb_artifact_config.wandb_path() == "twitter/cortex/artifact-name:latest" diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 3e2938ae..2214ead1 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -17,8 +17,9 @@ def lm_harness_evaluator_config(): @pytest.fixture def lm_harness_ray_config(): return LMHarnessRayConfig( - num_workers=4, - use_gpu=True, + num_cpus=2, + num_gpus=4, + timeout=3600, )