diff --git a/pyproject.toml b/pyproject.toml index e14aa8b4..a3777a99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,12 +11,12 @@ requires-python = ">=3.10,<3.11" dependencies = [ "click==8.1.7", - "ray[default]==2.7.0", "torch==2.1.0", "scipy==1.10.1", "wandb==0.16.1", "pydantic-yaml==1.2.0", "pydantic==1.10.8", + "ray[default]==2.7.0", ] [project.optional-dependencies] diff --git a/src/flamingo/cli.py b/src/flamingo/cli.py index 80743421..275ae2ea 100644 --- a/src/flamingo/cli.py +++ b/src/flamingo/cli.py @@ -1,14 +1,40 @@ import click +from flamingo.jobs import run_finetuning, run_lm_harness, run_simple +from flamingo.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig + @click.group() -def main(): +def cli(): pass +@click.group("simple") +@click.option("--config", type=str) +def run_simple_cli(config: str) -> None: + config = SimpleJobConfig.from_yaml_file(config) + run_simple.main(config) + + +@click.group("finetune") +@click.option("--config", type=str) +def run_finetuning_cli(config: str) -> None: + config = FinetuningJobConfig.from_yaml_file(config) + run_finetuning.main(config) + + +@click.group("finetune") +@click.option("--config", type=str) +def run_finetuning_cli(config: str) -> None: + config = FinetuningJobConfig.from_yaml_file(config) + run_finetuning.main(config) + + # need to add the group / command function itself, not the module -main.add_command(simple.driver) +cli.add_command(run_simple.main) +cli.add_command(run_finetuning.main) +cli.add_command(run_lm_harness.main) if __name__ == "__main__": - main() + cli() diff --git a/src/flamingo/integrations/wandb/wandb_mixin.py b/src/flamingo/integrations/wandb/wandb_mixin.py deleted file mode 100644 index 56266b81..00000000 --- a/src/flamingo/integrations/wandb/wandb_mixin.py +++ /dev/null @@ -1,22 +0,0 @@ -from flamingo.integrations.wandb import WandbEnvironment -from flamingo.types import BaseFlamingoConfig - - -class WandbEnvironmentMixin(BaseFlamingoConfig): - """Mixin for a config that contains W&B environment settings.""" - - wandb_env: WandbEnvironment | None = None - - @property - def env_vars(self) -> dict[str, str]: - return self.wandb_env.env_vars if self.wandb_env else {} - - @property - def wandb_name(self) -> str | None: - """Return the W&B run name, if it exists.""" - return self.wandb_env.name if self.wandb_env else None - - @property - def wandb_project(self) -> str | None: - """Return the W&B project name, if it exists.""" - return self.wandb_env.project if self.wandb_env else None diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py index 427b7422..a4e72c0d 100644 --- a/src/flamingo/jobs/__init__.py +++ b/src/flamingo/jobs/__init__.py @@ -1,11 +1,3 @@ -from .base_config import BaseJobConfig -from .evaluation_config import EvaluationJobConfig -from .finetuning_config import FinetuningJobConfig -from .simple_config import SimpleJobConfig - -__all__ = [ - "BaseJobConfig", - "SimpleJobConfig", - "FinetuningJobConfig", - "EvaluationJobConfig", -] +# ruff: noqa +from .configs import * +from .entrypoints import * diff --git a/src/flamingo/jobs/configs/__init__.py b/src/flamingo/jobs/configs/__init__.py new file mode 100644 index 00000000..dc5c0bfc --- /dev/null +++ b/src/flamingo/jobs/configs/__init__.py @@ -0,0 +1,11 @@ +from .base_config import BaseJobConfig +from .finetuning_config import FinetuningJobConfig +from .lm_harness_config import LMHarnessJobConfig +from .simple_config import SimpleJobConfig + +__all__ = [ + "BaseJobConfig", + "SimpleJobConfig", + "FinetuningJobConfig", + "LMHarnessJobConfig", +] diff --git a/src/flamingo/jobs/base_config.py b/src/flamingo/jobs/configs/base_config.py similarity index 100% rename from src/flamingo/jobs/base_config.py rename to src/flamingo/jobs/configs/base_config.py diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/configs/finetuning_config.py similarity index 85% rename from src/flamingo/jobs/finetuning_config.py rename to src/flamingo/jobs/configs/finetuning_config.py index 51a733e4..2ffcb6c9 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/configs/finetuning_config.py @@ -4,12 +4,12 @@ from flamingo.integrations.huggingface import QuantizationConfig from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name -from flamingo.integrations.wandb import WandbEnvironmentMixin +from flamingo.integrations.wandb import WandbEnvironment from flamingo.jobs import BaseJobConfig from flamingo.types import SerializableTorchDtype -class FinetuningJobConfig(WandbEnvironmentMixin, BaseJobConfig): +class FinetuningJobConfig(BaseJobConfig): """Configuration to submit an LLM finetuning job.""" model: str @@ -32,6 +32,7 @@ class FinetuningJobConfig(WandbEnvironmentMixin, BaseJobConfig): logging_steps: float = 100 save_strategy: str = "steps" save_steps: int = 500 + wandb_env: WandbEnvironment | None = None # Lora/quantization lora_config: LoraConfig | None = None # TODO: Create our own config type quantization_config: QuantizationConfig | None = None @@ -45,7 +46,3 @@ def _validate_modelname(cls, v): # noqa: N805 return v else: raise (ValueError(f"`{v}` is not a valid HuggingFace model name.")) - - @property - def entrypoint_command(self) -> str: - return f"python run_finetuning.py --config_json '{self.json()}'" diff --git a/src/flamingo/jobs/evaluation_config.py b/src/flamingo/jobs/configs/lm_harness_config.py similarity index 93% rename from src/flamingo/jobs/evaluation_config.py rename to src/flamingo/jobs/configs/lm_harness_config.py index a2d8c85b..a7f6de38 100644 --- a/src/flamingo/jobs/evaluation_config.py +++ b/src/flamingo/jobs/configs/lm_harness_config.py @@ -8,7 +8,7 @@ from flamingo.integrations.huggingface import QuantizationConfig from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name -from flamingo.integrations.wandb import WandbEnvironment, WandbEnvironmentMixin +from flamingo.integrations.wandb import WandbEnvironment from flamingo.jobs import BaseJobConfig from flamingo.types import SerializableTorchDtype @@ -42,7 +42,7 @@ def __post_init__(self, checkpoint): raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name")) -class EvaluationJobConfig(WandbEnvironmentMixin, BaseJobConfig): +class LMHarnessJobConfig(BaseJobConfig): """Configuration to run an lm-evaluation-harness evaluation job. This job loads an existing checkpoint path @@ -66,6 +66,7 @@ class Config: trust_remote_code: bool = False torch_dtype: SerializableTorchDtype = None quantization_config: QuantizationConfig | None = None + wandb_env: WandbEnvironment | None = None num_cpus: int = 1 num_gpus: int = 1 timeout: datetime.timedelta | None = None @@ -112,7 +113,3 @@ def _validate_modelname_or_checkpoint(cls, values) -> Any: raise (ValueError(f"{mnp} is not a valid HuggingFaceModel or checkpoint path.")) return values - - @property - def entrypoint_command(self) -> str: - return f"python run_evaluation.py --config_json '{self.json()}'" diff --git a/src/flamingo/jobs/configs/simple_config.py b/src/flamingo/jobs/configs/simple_config.py new file mode 100644 index 00000000..373cb514 --- /dev/null +++ b/src/flamingo/jobs/configs/simple_config.py @@ -0,0 +1,7 @@ +from flamingo.jobs import BaseJobConfig + + +class SimpleJobConfig(BaseJobConfig): + """A simple job to demonstrate the submission interface.""" + + magic_number: int diff --git a/src/flamingo/jobs/entrypoints/__init__.py b/src/flamingo/jobs/entrypoints/__init__.py new file mode 100644 index 00000000..8fa9df64 --- /dev/null +++ b/src/flamingo/jobs/entrypoints/__init__.py @@ -0,0 +1,5 @@ +import run_finetuning +import run_lm_harness +import run_simple + +__all__ = ["run_finetuning", "run_lm_harness", "run_simple"] diff --git a/src/flamingo/jobs/entrypoints/run_finetuning.py b/src/flamingo/jobs/entrypoints/run_finetuning.py new file mode 100644 index 00000000..8b2b90e6 --- /dev/null +++ b/src/flamingo/jobs/entrypoints/run_finetuning.py @@ -0,0 +1,136 @@ +import json + +import torch +import wandb +from accelerate import Accelerator +from datasets import DatasetDict +from ray import train +from ray.train import CheckpointConfig, RunConfig +from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer +from ray.train.torch import TorchTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments +from trl import SFTTrainer + +from flamingo.integrations.wandb import update_wandb_summary +from flamingo.jobs import FinetuningJobConfig + + +def is_wandb_enabled(config: FinetuningJobConfig): + # Only report to WandB on the rank 0 worker + # Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html + return config.wandb_env and train.get_context().get_world_rank() == 0 + + +def get_training_args(config: FinetuningJobConfig) -> TrainingArguments: + """Get TrainingArguments appropriate for the worker rank and job config.""" + return TrainingArguments( + output_dir="out", # Local checkpoint path on a worker + num_train_epochs=config.num_train_epochs, + learning_rate=config.learning_rate, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=config.batch_size, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=config.gradient_checkpointing, + weight_decay=config.weight_decay, + evaluation_strategy=config.evaluation_strategy, + eval_steps=config.eval_steps, + logging_strategy=config.logging_strategy, + logging_steps=config.logging_steps, + save_strategy=config.save_strategy, + save_steps=config.save_steps, + run_name=config.wandb_name, + report_to="wandb" if is_wandb_enabled(config) else "none", + no_cuda=not config.scaling_config.use_gpu, + push_to_hub=False, + disable_tqdm=True, + logging_dir=None, + ) + + +def get_datasets(config: FinetuningJobConfig) -> DatasetDict: + # TODO: Refactor me somehow + ... + + +def get_model(config: FinetuningJobConfig) -> PreTrainedModel: + device_map, bnb_config = None, None + if config.quantization_config: + bnb_config = config.quantization_config.as_huggingface() + # When quantization is enabled, model must all be on same GPU to work with DDP + # If a device_map is not specified we will get accelerate errors downstream + # Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994 + current_device = Accelerator().local_process_index if torch.cuda.is_available() else "cpu" + device_map = {"": current_device} + print(f"Setting model device_map = {device_map} to enable quantization") + + return AutoModelForCausalLM.from_pretrained( + config.model, + trust_remote_code=config.trust_remote_code, + torch_dtype=config.torch_dtype, + quantization_config=bnb_config, + device_map=device_map, + ) + + +def get_tokenizer(config: FinetuningJobConfig): + tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer or config.model, + trust_remote_code=config.trust_remote_code, + use_fast=True, + ) + if not tokenizer.pad_token_id: + # Pad token required for generating consistent batch sizes + tokenizer.pad_token_id = tokenizer.eos_token_id + return tokenizer + + +def train_func(config_data: dict): + config = FinetuningJobConfig(**config_data) + model = get_model(config) + tokenizer = get_tokenizer(config) + + datasets = get_datasets(config) + training_args = get_training_args(config) + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + peft_config=config.lora_config, + max_seq_length=config.max_seq_length, + train_dataset=datasets["train"], + eval_dataset=datasets["test"], + dataset_text_field="text", + ) + trainer.add_callback(RayTrainReportCallback()) + trainer = prepare_trainer(trainer) + trainer.train() + + # Force WandB finish on rank 0 worker + if is_wandb_enabled(config): + wandb.finish() + + +def main(config: FinetuningJobConfig): + print(f"Received job configuration: {config}") + + run_config = RunConfig( + name=config.wandb_name, + storage_path=config.storage_path, + checkpoint_config=CheckpointConfig(num_to_keep=1), + ) + trainer = TorchTrainer( + train_loop_per_worker=train_func, + train_loop_config=json.loads(config.json()), + scaling_config=config.scaling_config, + run_config=run_config, + ) + result = trainer.fit() + print(f"Training result: {result}") + + # Log additional training metrics to completed WandB run + if config.wandb_env: + result_paths = {"ray/result_path": result.path} + if result.checkpoint: + result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint" + update_wandb_summary(config.wandb_env, result_paths) diff --git a/src/flamingo/jobs/entrypoints/run_lm_harness.py b/src/flamingo/jobs/entrypoints/run_lm_harness.py new file mode 100644 index 00000000..cc2286b6 --- /dev/null +++ b/src/flamingo/jobs/entrypoints/run_lm_harness.py @@ -0,0 +1,118 @@ +from pathlib import Path + +import lm_eval +import ray +from lm_eval.models.huggingface import HFLM +from peft import PeftConfig +from tuner.integrations.wandb import get_wandb_summary, update_wandb_summary + +from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath + + +def resolve_model_or_path(config: LMHarnessJobConfig) -> str: + mn_or_path = None + match config.model_name_or_path: + case None: + print("Attempting to resolve checkpoint path from existing W&B run...") + run_summary = get_wandb_summary(config.wandb_env) + cp = Path(run_summary["ray/checkpoint_path"]) + print(f"Using checkpoint path from wandb run: {cp}") + if not cp.exists(): + raise (FileNotFoundError(f"{mn_or_path} cannot be found.")) + mn_or_path = str(cp) + case ModelNameOrCheckpointPath(checkpoint=None) as x: + print("No checkpoint; will attempt to load model from HuggingFace") + mn_or_path = x.name + case ModelNameOrCheckpointPath(checkpoint=ckpt): + print(f"Checkpoint found; will attempt to load model from {ckpt}") + mn_or_path = ckpt + case _: + raise ( + ValueError( + "Something is wrong with the passed " + f"model_name_or_path: {config.model_name_or_path}" + ) + ) + return mn_or_path + + +def load_harness_model(config: LMHarnessJobConfig, model_to_load: str) -> HFLM: + # We don't know if the checkpoint is adapter weights or merged model weights + # Try to load as an adapter and fall back to the checkpoint containing the full model + try: + adapter_config = PeftConfig.from_pretrained(model_to_load) + pretrained = adapter_config.base_model_name_or_path + peft = model_to_load + except ValueError as e: + print( + f"Unable to load model as adapter: {e}. " + "This is expected if the checkpoint does not contain adapter weights." + ) + pretrained = model_to_load + peft = None + + # Return lm-harness model wrapper class + quantization_kwargs = config.quantization_config.dict() if config.quantization_config else {} + return HFLM( + pretrained=pretrained, + tokenizer=pretrained, + peft=peft, + device="cuda" if config.num_gpus > 0 else None, + trust_remote_code=config.trust_remote_code, + dtype=config.torch_dtype if config.torch_dtype else "auto", + **quantization_kwargs, + ) + + +@ray.remote +def run_evaluation(config: LMHarnessJobConfig, model_to_load: str) -> None: + print("Initializing lm-harness tasks...") + lm_eval.tasks.initialize_tasks() + + print("Running lm-harness evaluation inside remote function...") + llm = load_harness_model(config, model_to_load) + raw_results = lm_eval.simple_evaluate( + model=llm, + tasks=config.tasks, + num_fewshot=config.num_fewshot, + batch_size=config.batch_size, + limit=config.limit, + log_samples=False, + ) + print("Finished lm-harness evaluation inside remote function") + + formatted_results = {} + for task_name, metrics in raw_results["results"].items(): + task_metrics = { + f"{task_name}/{metric.replace(',', '_')}": value for metric, value in metrics.items() + } + formatted_results.update(task_metrics) + print(f"Obtained evaluation results: {formatted_results}") + + if config.wandb_env: + print("Logging results to W&B...") + update_wandb_summary(config.wandb_env, formatted_results) + + +def main(config: LMHarnessJobConfig): + print(f"Received job configuration: {config}") + + # Resolve path and ensure exists + model_to_load = resolve_model_or_path(config) + + # Using .options() to dynamically specify resource requirements + eval_func = run_evaluation.options(num_cpus=config.num_cpus, num_gpus=config.num_gpus) + eval_future = eval_func.remote(config, model_to_load) + + timeout_seconds = config.timeout.seconds if config.timeout else None + try: + print("Waiting on evaluation task...") + ray.get(eval_future, timeout=timeout_seconds) + print("Evaluation successfully completed") + except TimeoutError: + print( + f"Evaluation task timed out after {timeout_seconds} sec. " + "If the evaluation runner finished but the task failed to shut down, " + "please check if your results were still generated and persisted." + ) + raise diff --git a/src/flamingo/jobs/entrypoints/run_simple.py b/src/flamingo/jobs/entrypoints/run_simple.py new file mode 100644 index 00000000..897002ca --- /dev/null +++ b/src/flamingo/jobs/entrypoints/run_simple.py @@ -0,0 +1,5 @@ +from flamingo.jobs.configs import SimpleJobConfig + + +def main(config: SimpleJobConfig): + print(f"The magic number is {config.magic_number}") diff --git a/src/flamingo/jobs/simple_config.py b/src/flamingo/jobs/simple_config.py deleted file mode 100644 index 04cb270b..00000000 --- a/src/flamingo/jobs/simple_config.py +++ /dev/null @@ -1,15 +0,0 @@ -from flamingo.jobs import BaseJobConfig - - -class SimpleJobConfig(BaseJobConfig): - """A simple job to demonstrate the submission interface.""" - - magic_number: int - - @property - def env_vars(self) -> dict[str, str]: - return {} - - @property - def entrypoint_command(self) -> str: - return f"python simple.py --magic_number '{self.magic_number}'"