diff --git a/src/flamingo/integrations/huggingface/__init__.py b/src/flamingo/integrations/huggingface/__init__.py index 7380bd37..2c97958d 100644 --- a/src/flamingo/integrations/huggingface/__init__.py +++ b/src/flamingo/integrations/huggingface/__init__.py @@ -1,9 +1,13 @@ -from .model_name_or_path import ModelNameOrCheckpointPath +from .dataset_config import DatasetConfig +from .model_config import AutoModelConfig from .quantization_config import QuantizationConfig +from .tokenizer_config import AutoTokenizerConfig from .trainer_config import TrainerConfig __all__ = [ - "ModelNameOrCheckpointPath", + "AutoModelConfig", + "AutoTokenizerConfig", + "DatasetConfig", "QuantizationConfig", "TrainerConfig", ] diff --git a/src/flamingo/integrations/huggingface/dataset_config.py b/src/flamingo/integrations/huggingface/dataset_config.py new file mode 100644 index 00000000..7db253b9 --- /dev/null +++ b/src/flamingo/integrations/huggingface/dataset_config.py @@ -0,0 +1,10 @@ +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig + + +class DatasetConfig(BaseFlamingoConfig): + """Settings passed to load a HuggingFace dataset.""" + + artifact: str | WandbArtifactLink + split_size: float | None = None + seed: int | None = None diff --git a/src/flamingo/integrations/huggingface/model_config.py b/src/flamingo/integrations/huggingface/model_config.py new file mode 100644 index 00000000..69561b86 --- /dev/null +++ b/src/flamingo/integrations/huggingface/model_config.py @@ -0,0 +1,19 @@ +from pydantic import validator + +from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype + + +class AutoModelConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoModel instantiation.""" + + artifact: str | WandbArtifactLink + trust_remote_code: bool = False + torch_dtype: SerializableTorchDtype = None + + @validator("artifact", pre=True, always=True) + def _validate_model_name(cls, v): + if isinstance(v, str) and not is_valid_huggingface_repo_id(v): + raise ValueError(f"{v} is not a valid HuggingFace model name.") + return v diff --git a/src/flamingo/integrations/huggingface/model_name_or_path.py b/src/flamingo/integrations/huggingface/model_name_or_path.py deleted file mode 100644 index bff347ff..00000000 --- a/src/flamingo/integrations/huggingface/model_name_or_path.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path - -from pydantic.dataclasses import dataclass - -from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id - - -@dataclass -class ModelNameOrCheckpointPath: - """ - This class is explicitly used to validate if a string is - a valid HuggingFace model or can be used as a checkpoint. - - Checkpoint will be automatically assigned if it's a valid checkpoint; - it will be None if it's not valid. - """ - - # explictly needed for matching - __match_args__ = ("name", "checkpoint") - - name: str - checkpoint: str | None = None - - def __post_init__(self): - if isinstance(self.name, Path): - self.name = str(self.name) - - if Path(self.name).is_absolute(): - self.checkpoint = self.name - else: - self.checkpoint = None - - if self.checkpoint is None and not is_valid_huggingface_repo_id(self.name): - raise ValueError(f"{self.name} is not a valid checkpoint path or HF model name.") diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py new file mode 100644 index 00000000..889d9f2c --- /dev/null +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -0,0 +1,19 @@ +from typing import Any + +from flamingo.types import BaseFlamingoConfig + + +class AutoTokenizerConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoTokenizer instantiation.""" + + name: str + trust_remote_code: bool | None = None + use_fast: bool | None = None + + def get_tokenizer_args(self) -> dict[str, Any]: + args = dict( + trust_remote_code=self.trust_remote_code, + use_fast=self.use_fast, + ) + # Only return non-None values so we get HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index e78e1391..cc4cad6d 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -1,21 +1,41 @@ -from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype +from typing import Any + +from flamingo.types import BaseFlamingoConfig class TrainerConfig(BaseFlamingoConfig): """Configuration for a HuggingFace trainer/training arguments.""" max_seq_length: int | None = None - num_train_epochs: int = 1 - batch_size: int = 16 - learning_rate: float = 1e-5 - weight_decay: float = 1e-3 - gradient_accumulation_steps: int = 1 - gradient_checkpointing: bool = False - trust_remote_code: bool = False - torch_dtype: SerializableTorchDtype = None - evaluation_strategy: str = "epoch" + num_train_epochs: int | None = None + per_device_train_batch_size: int | None = None + per_device_eval_batch_size: int | None = None + learning_rate: float | None = None + weight_decay: float | None = None + gradient_accumulation_steps: int | None = None + gradient_checkpointing: bool | None = None + evaluation_strategy: str | None = None eval_steps: float | None = None - logging_strategy: str = "steps" - logging_steps: float = 100 - save_strategy: str = "steps" - save_steps: int = 500 + logging_strategy: str | None = None + logging_steps: float | None = None + save_strategy: str | None = None + save_steps: int | None = None + + def get_training_args(self) -> dict[str, Any]: + args = dict( + num_train_epochs=self.num_train_epochs, + learning_rate=self.learning_rate, + per_device_train_batch_size=self.per_device_train_batch_size, + per_device_eval_batch_size=self.per_device_eval_batch_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + gradient_checkpointing=self.gradient_checkpointing, + weight_decay=self.weight_decay, + evaluation_strategy=self.evaluation_strategy, + eval_steps=self.eval_steps, + logging_strategy=self.logging_strategy, + logging_steps=self.logging_steps, + save_strategy=self.save_strategy, + save_steps=self.save_steps, + ) + # Only return non-None values so we get HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/wandb_environment.py index a9f04afb..95af9202 100644 --- a/src/flamingo/integrations/wandb/wandb_environment.py +++ b/src/flamingo/integrations/wandb/wandb_environment.py @@ -1,7 +1,7 @@ import os +import secrets import warnings -import wandb from pydantic import Extra, root_validator from wandb.apis.public import Run @@ -69,6 +69,10 @@ def from_run(cls, run: Run) -> "WandbEnvironment": run_id=run.id, ) - def force_run_id(self) -> None: + def ensure_run_id(self, provided_run_id: str | None = None) -> None: + """Ensure that the run_id is set in the configuration. + + If None, the run_id is set to the passed value or a random 8-digit hexadecimal string. + """ if self.run_id is None: - self.run_id = wandb.run + self.run_id = provided_run_id or secrets.token_hex(nbytes=4) diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py index 4a616d75..cb4fa167 100644 --- a/src/flamingo/jobs/__init__.py +++ b/src/flamingo/jobs/__init__.py @@ -1,10 +1,8 @@ -from .base_config import BaseJobConfig from .finetuning_config import FinetuningJobConfig from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath from .simple_config import SimpleJobConfig __all__ = [ - "BaseJobConfig", "SimpleJobConfig", "FinetuningJobConfig", "LMHarnessJobConfig", diff --git a/src/flamingo/jobs/base_config.py b/src/flamingo/jobs/base_config.py deleted file mode 100644 index 3fec2edc..00000000 --- a/src/flamingo/jobs/base_config.py +++ /dev/null @@ -1,7 +0,0 @@ -from flamingo.types import BaseFlamingoConfig - - -class BaseJobConfig(BaseFlamingoConfig): - """Configuration defining a job to submit to the Ray cluster.""" - - pass diff --git a/src/flamingo/jobs/drivers/__init__.py b/src/flamingo/jobs/drivers/__init__.py index e6f97d79..bc3b2c47 100644 --- a/src/flamingo/jobs/drivers/__init__.py +++ b/src/flamingo/jobs/drivers/__init__.py @@ -1,6 +1,6 @@ -from .finetuning import run_finetuning -from .lm_harness import run_lm_harness -from .ludwig import run_ludwig -from .simple import run_simple +from .finetuning_job import run_finetuning +from .lm_harness_job import run_lm_harness +from .ludwig_job import run_ludwig +from .simple_job import run_simple __all__ = ["run_finetuning", "run_lm_harness", "run_ludwig", "run_simple"] diff --git a/src/flamingo/jobs/drivers/finetuning.py b/src/flamingo/jobs/drivers/finetuning_job.py similarity index 61% rename from src/flamingo/jobs/drivers/finetuning.py rename to src/flamingo/jobs/drivers/finetuning_job.py index 7475804e..617525e0 100644 --- a/src/flamingo/jobs/drivers/finetuning.py +++ b/src/flamingo/jobs/drivers/finetuning_job.py @@ -11,51 +11,38 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments from trl import SFTTrainer -from flamingo.integrations.wandb import update_wandb_summary from flamingo.jobs import FinetuningJobConfig def is_wandb_enabled(config: FinetuningJobConfig): # Only report to WandB on the rank 0 worker # Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html - return config.wandb_env and train.get_context().get_world_rank() == 0 + return config.tracking is not None and train.get_context().get_world_rank() == 0 -def get_training_args(config: FinetuningJobConfig) -> TrainingArguments: +def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: """Get TrainingArguments appropriate for the worker rank and job config.""" + training_args = config.trainer.get_training_args() if config.trainer else {} return TrainingArguments( output_dir="out", # Local checkpoint path on a worker - num_train_epochs=config.num_train_epochs, - learning_rate=config.learning_rate, - per_device_train_batch_size=config.batch_size, - per_device_eval_batch_size=config.batch_size, - gradient_accumulation_steps=config.gradient_accumulation_steps, - gradient_checkpointing=config.gradient_checkpointing, - weight_decay=config.weight_decay, - evaluation_strategy=config.evaluation_strategy, - eval_steps=config.eval_steps, - logging_strategy=config.logging_strategy, - logging_steps=config.logging_steps, - save_strategy=config.save_strategy, - save_steps=config.save_steps, - run_name=config.wandb_name, report_to="wandb" if is_wandb_enabled(config) else "none", - no_cuda=not config.scaling_config.use_gpu, + no_cuda=not config.scaling.use_gpu, push_to_hub=False, disable_tqdm=True, logging_dir=None, + **training_args, ) def get_datasets(config: FinetuningJobConfig) -> DatasetDict: - # TODO: Refactor me somehow - ... + # TODO: Implement me + return DatasetDict() def get_model(config: FinetuningJobConfig) -> PreTrainedModel: device_map, bnb_config = None, None - if config.quantization_config: - bnb_config = config.quantization_config.as_huggingface() + if config.quantization is not None: + bnb_config = config.quantization.as_huggingface() # When quantization is enabled, model must all be on same GPU to work with DDP # If a device_map is not specified we will get accelerate errors downstream # Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994 @@ -64,20 +51,18 @@ def get_model(config: FinetuningJobConfig) -> PreTrainedModel: print(f"Setting model device_map = {device_map} to enable quantization") return AutoModelForCausalLM.from_pretrained( - config.model, - trust_remote_code=config.trust_remote_code, - torch_dtype=config.torch_dtype, + pretrained_model_name_or_path=config.model.name, + trust_remote_code=config.model.trust_remote_code, + torch_dtype=config.model.torch_dtype, quantization_config=bnb_config, device_map=device_map, ) def get_tokenizer(config: FinetuningJobConfig): - tokenizer = AutoTokenizer.from_pretrained( - config.tokenizer or config.model, - trust_remote_code=config.trust_remote_code, - use_fast=True, - ) + tokenizer_name = config.tokenizer.name if config.tokenizer else config.model.name + tokenizer_args = config.tokenizer.get_tokenizer_args() if config.tokenizer else {} + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_args) if not tokenizer.pad_token_id: # Pad token required for generating consistent batch sizes tokenizer.pad_token_id = tokenizer.eos_token_id @@ -90,14 +75,25 @@ def train_func(config_data: dict): tokenizer = get_tokenizer(config) datasets = get_datasets(config) - training_args = get_training_args(config) + training_args = get_training_arguments(config) + + # Manually initialize run in order to control run ID + if is_wandb_enabled(config): + env = config.tracking + wandb.init( + id=env.run_id, + name=env.name, + project=env.project, + entity=env.entity, + group=env.run_group, + ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, - peft_config=config.lora_config, - max_seq_length=config.max_seq_length, + peft_config=config.lora, + max_seq_length=config.trainer.max_seq_length if config.trainer else None, train_dataset=datasets["train"], eval_dataset=datasets["test"], dataset_text_field="text", @@ -114,23 +110,24 @@ def train_func(config_data: dict): def run_finetuning(config: FinetuningJobConfig): print(f"Received job configuration: {config}") + if config.tracking: + # Ensure the run_id is set so that the W&B run can be initialized deterministically + config.tracking.ensure_run_id() + run_config = RunConfig( - name=config.wandb_name, - storage_path=config.storage_path, + name=config.tracking.name if config.tracking else None, + storage_path=config.ray.storage_path, checkpoint_config=CheckpointConfig(num_to_keep=1), ) trainer = TorchTrainer( train_loop_per_worker=train_func, train_loop_config=json.loads(config.json()), - scaling_config=config.scaling_config, + scaling_config=config.ray.get_scaling_config(), run_config=run_config, ) result = trainer.fit() print(f"Training result: {result}") - # Log additional training metrics to completed WandB run - if config.wandb_env: - result_paths = {"ray/result_path": result.path} - if result.checkpoint: - result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint" - update_wandb_summary(config.wandb_env, result_paths) + if config.tracking: + # TODO: Add ref artifact here + pass diff --git a/src/flamingo/jobs/drivers/lm_harness.py b/src/flamingo/jobs/drivers/lm_harness_job.py similarity index 95% rename from src/flamingo/jobs/drivers/lm_harness.py rename to src/flamingo/jobs/drivers/lm_harness_job.py index 03b3066e..4ac263a0 100644 --- a/src/flamingo/jobs/drivers/lm_harness.py +++ b/src/flamingo/jobs/drivers/lm_harness_job.py @@ -5,8 +5,8 @@ from lm_eval.models.huggingface import HFLM from peft import PeftConfig -from flamingo.configs import LMHarnessJobConfig, ModelNameOrCheckpointPath from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary +from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath def resolve_model_or_path(config: LMHarnessJobConfig) -> str: @@ -97,6 +97,10 @@ def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None: def run_lm_harness(config: LMHarnessJobConfig): print(f"Received job configuration: {config}") + if config.tracking: + # Ensure the run_id is set so that the W&B run can be initialized deterministically + config.tracking.ensure_run_id() + # Resolve path and ensure exists model_to_load = resolve_model_or_path(config) diff --git a/src/flamingo/jobs/drivers/ludwig.py b/src/flamingo/jobs/drivers/ludwig_job.py similarity index 100% rename from src/flamingo/jobs/drivers/ludwig.py rename to src/flamingo/jobs/drivers/ludwig_job.py diff --git a/src/flamingo/jobs/drivers/simple.py b/src/flamingo/jobs/drivers/simple_job.py similarity index 100% rename from src/flamingo/jobs/drivers/simple.py rename to src/flamingo/jobs/drivers/simple_job.py diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 39a578f3..bbb68a74 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,25 +1,39 @@ from peft import LoraConfig -from pydantic import validator +from pydantic import Field, validator from ray.train import ScalingConfig -from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig +from flamingo.integrations.huggingface import ( + AutoModelConfig, + AutoTokenizerConfig, + DatasetConfig, + QuantizationConfig, + TrainerConfig, +) from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id from flamingo.integrations.wandb import WandbEnvironment -from flamingo.jobs import BaseJobConfig +from flamingo.types import BaseFlamingoConfig -class FinetuningJobConfig(BaseJobConfig): +class RayTrainConfig(BaseFlamingoConfig): + use_gpu: bool = True + num_workers: int | None = None + storage_path: str | None = None + + def get_scaling_config(self) -> ScalingConfig: + return ScalingConfig(use_gpu=self.use_gpu, num_workers=self.num_workers) + + +class FinetuningJobConfig(BaseFlamingoConfig): """Configuration to submit an LLM finetuning job.""" - model: str - # TODO: Add back dataset config with new data implementation - tokenizer: str | None = None + model: AutoModelConfig + dataset: DatasetConfig + tokenizer: AutoTokenizerConfig | None = None trainer: TrainerConfig | None = None lora: LoraConfig | None = None # TODO: Create our own config type quantization: QuantizationConfig | None = None tracking: WandbEnvironment | None = None - scaling: ScalingConfig | None = None # TODO: Create our own config type - storage_path: str | None = None + ray: RayTrainConfig = Field(default_factory=RayTrainConfig) @validator("model") def _validate_model_name(cls, v): diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index 2fbc5377..fd8cc090 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -1,15 +1,11 @@ import datetime -from pydantic import validator +from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig +from flamingo.integrations.wandb import WandbEnvironment +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype -from flamingo.integrations.huggingface import QuantizationConfig -from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id -from flamingo.integrations.wandb import WandbArtifactLink, WandbEnvironment -from flamingo.jobs import BaseJobConfig -from flamingo.types import SerializableTorchDtype - -class LMHarnessJobConfig(BaseJobConfig): +class LMHarnessJobConfig(BaseFlamingoConfig): """Configuration to run an lm-evaluation-harness evaluation job. This job loads an existing checkpoint path from Ray storage to run evaluation against, @@ -19,7 +15,7 @@ class LMHarnessJobConfig(BaseJobConfig): which will take prescedence over the W&B checkpoint path. """ - model: str | WandbArtifactLink + model: AutoModelConfig tasks: list[str] batch_size: int | None = None num_fewshot: int | None = None @@ -31,9 +27,3 @@ class LMHarnessJobConfig(BaseJobConfig): num_cpus: int = 1 num_gpus: int = 1 timeout: datetime.timedelta | None = None - - @validator("model", pre=True, always=True) - def _validate_model_name(cls, v): - if isinstance(v, str) and not is_valid_huggingface_repo_id(v): - raise ValueError(f"{v} is not a valid HuggingFace model name.") - return v diff --git a/src/flamingo/jobs/simple_config.py b/src/flamingo/jobs/simple_config.py index bc18fa96..24fa60a8 100644 --- a/src/flamingo/jobs/simple_config.py +++ b/src/flamingo/jobs/simple_config.py @@ -1,5 +1,5 @@ -from flamingo.jobs import BaseJobConfig +from flamingo.types import BaseFlamingoConfig -class SimpleJobConfig(BaseJobConfig): +class SimpleJobConfig(BaseFlamingoConfig): magic_number: int