From a03d272d3622540901e5e1581b01143a223dfb25 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 11:47:12 -0800 Subject: [PATCH] working on w&b artifact links --- pyproject.toml | 2 +- .../integrations/huggingface/model_config.py | 10 ++--- .../huggingface/trainer_config.py | 8 +++- src/flamingo/integrations/wandb/utils.py | 45 +++++++++---------- .../integrations/wandb/wandb_artifact_link.py | 4 +- .../integrations/wandb/wandb_environment.py | 32 +++++++------ src/flamingo/jobs/__init__.py | 3 +- src/flamingo/jobs/drivers/finetuning.py | 28 +++++++----- src/flamingo/jobs/finetuning_config.py | 35 +++++++++++++-- src/flamingo/jobs/lm_harness_config.py | 3 +- tests/{configs => jobs}/__init__.py | 0 .../test_finetuning_config.py | 0 .../test_lm_harness_config.py | 0 13 files changed, 105 insertions(+), 65 deletions(-) rename tests/{configs => jobs}/__init__.py (100%) rename tests/{configs => jobs}/test_finetuning_config.py (100%) rename tests/{configs => jobs}/test_lm_harness_config.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 6b2b2d75..4b738173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ evaluate = ["lm-eval==0.4.0", "einops"] test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"] -all = ["flamingo[finetune,ludwig,evaluate,test]"] +all = ["flamingo[finetune,evaluate,test]"] [project.scripts] flamingo = "flamingo.cli:main" diff --git a/src/flamingo/integrations/huggingface/model_config.py b/src/flamingo/integrations/huggingface/model_config.py index 9978b22f..60d8e43d 100644 --- a/src/flamingo/integrations/huggingface/model_config.py +++ b/src/flamingo/integrations/huggingface/model_config.py @@ -1,19 +1,19 @@ -from peft import LoraConfig from pydantic import validator -from flamingo.integrations.huggingface import QuantizationConfig from flamingo.integrations.huggingface.utils import repo_name_validator from flamingo.integrations.wandb import WandbArtifactLink from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype class AutoModelConfig(BaseFlamingoConfig): - """Settings passed to a HuggingFace AutoModel instantiation.""" + """Settings passed to a HuggingFace AutoModel instantiation. + + The model path can either be a string corresponding to a HuggingFace repo ID, + or an artifact link to a reference artifact on W&B. + """ path: str | WandbArtifactLink trust_remote_code: bool = False torch_dtype: SerializableTorchDtype = None - quantization: QuantizationConfig | None = None - lora: LoraConfig | None = None # TODO: Create own dataclass here _path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator) diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index cc4cad6d..e3cf4b64 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -4,7 +4,11 @@ class TrainerConfig(BaseFlamingoConfig): - """Configuration for a HuggingFace trainer/training arguments.""" + """Configuration for a HuggingFace trainer/training arguments. + + This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class, + but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes. + """ max_seq_length: int | None = None num_train_epochs: int | None = None @@ -37,5 +41,5 @@ def get_training_args(self) -> dict[str, Any]: save_strategy=self.save_strategy, save_steps=self.save_steps, ) - # Only return non-None values so we get HuggingFace defaults when not specified + # Only return non-None values so we use the HuggingFace defaults when not specified return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 0c9299be..bbeeb0e5 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -1,47 +1,42 @@ +from pathlib import Path from typing import Any import wandb from wandb.apis.public import Run -from flamingo.integrations.wandb import WandbEnvironment -from flamingo.integrations.wandb.wandb_artifact_link import WandbArtifactLink +from flamingo.integrations.wandb import WandbArtifactLink, WandbEnvironment -def get_wandb_artifact(link: WandbArtifactLink): +def get_wandb_run(env: WandbEnvironment) -> Run: + """Retrieve a run from the W&B API.""" api = wandb.Api() - return api.artifact(link.artifact_path()) + return api.run(env.wandb_path) def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: """Get the summary dictionary attached to a W&B run.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) return dict(run.summary) def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: """Update a run's summary with the provided metrics.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) run.summary.update(metrics) run.update() -def _resolve_wandb_run(env: WandbEnvironment) -> Run: - """Resolve a WandB run object from the provided environment settings. - - An exception is raised if a Run cannot be found, - or if multiple runs exist in scope with the same name. - """ +def get_wandb_artifact(link: WandbArtifactLink) -> wandb.Artifact: + """Retrieve an artifact from the W&B API.""" api = wandb.Api() - base_path = "/".join(x for x in (env.entity, env.project) if x) - if env.run_id is not None: - full_path = f"{base_path}/{env.run_id}" - return api.run(full_path) - else: - match [run for run in api.runs(base_path) if run.name == env.name]: - case []: - raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") - case [Run(), _]: - raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") - case [Run()] as mr: - # we have a single one, hurray - return mr[0] + return api.artifact(link.wandb_path) + + +def get_artifact_filesystem_path(link: WandbArtifactLink) -> str: + # TODO: What if there are multiple folder paths in the artifact manifest? + artifact = get_wandb_artifact(link) + for entry in artifact.manifest.entries.values(): + if entry.ref.startswith("file://"): + entry_path = Path(entry.ref.replace("file://", "")) + return str(entry_path.parent.absolute()) + raise ValueError("Artifact does not contain reference to filesystem files.") diff --git a/src/flamingo/integrations/wandb/wandb_artifact_link.py b/src/flamingo/integrations/wandb/wandb_artifact_link.py index aa825ed9..d533dee9 100644 --- a/src/flamingo/integrations/wandb/wandb_artifact_link.py +++ b/src/flamingo/integrations/wandb/wandb_artifact_link.py @@ -9,7 +9,9 @@ class WandbArtifactLink(BaseFlamingoConfig): project: str | None = None entity: str | None = None - def artifact_path(self) -> str: + @property + def wandb_path(self) -> str: + """String identifier for retrieving the asset from the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) if self.alias: path = f"{path}:{self.alias}" diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/wandb_environment.py index 5a425692..57790336 100644 --- a/src/flamingo/integrations/wandb/wandb_environment.py +++ b/src/flamingo/integrations/wandb/wandb_environment.py @@ -37,25 +37,13 @@ def warn_missing_api_key(cls, values): ) return values - @validator("run_id", post=True, always=True) + @validator("run_id", always=True) def ensure_run_id(cls, run_id): if run_id is None: # Generates an 8-digit random hexadecimal string, analogous to W&B platform run_id = secrets.token_hex(nbytes=4) return run_id - @property - def env_vars(self) -> dict[str, str]: - env_vars = { - "WANDB_RUN_ID": self.run_id, - "WANDB_NAME": self.name, - "WANDB_PROJECT": self.project, - "WANDB_RUN_GROUP": self.run_group, - "WANDB_ENTITY": self.entity, - "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), - } - return {k: v for k, v in env_vars.items() if v is not None} - @classmethod def from_run(cls, run: Run) -> "WandbEnvironment": """Extract environment settings from a W&B Run object. @@ -69,3 +57,21 @@ def from_run(cls, run: Run) -> "WandbEnvironment": entity=run.entity, run_id=run.id, ) + + @property + def env_vars(self) -> dict[str, str]: + env_vars = { + "WANDB_RUN_ID": self.run_id, + "WANDB_NAME": self.name, + "WANDB_PROJECT": self.project, + "WANDB_RUN_GROUP": self.run_group, + "WANDB_ENTITY": self.entity, + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), + } + return {k: v for k, v in env_vars.items() if v is not None} + + @property + def wandb_path(self) -> str: + """String identifier for retrieving the asset from the W&B platform.""" + path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) + return path diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py index cb4fa167..5c003cd0 100644 --- a/src/flamingo/jobs/__init__.py +++ b/src/flamingo/jobs/__init__.py @@ -1,10 +1,9 @@ from .finetuning_config import FinetuningJobConfig -from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath +from .lm_harness_config import LMHarnessJobConfig from .simple_config import SimpleJobConfig __all__ = [ "SimpleJobConfig", "FinetuningJobConfig", "LMHarnessJobConfig", - "ModelNameOrCheckpointPath", ] diff --git a/src/flamingo/jobs/drivers/finetuning.py b/src/flamingo/jobs/drivers/finetuning.py index 617525e0..149d5529 100644 --- a/src/flamingo/jobs/drivers/finetuning.py +++ b/src/flamingo/jobs/drivers/finetuning.py @@ -5,15 +5,24 @@ from accelerate import Accelerator from datasets import DatasetDict from ray import train -from ray.train import CheckpointConfig, RunConfig +from ray.train import CheckpointConfig, RunConfig, ScalingConfig from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer from ray.train.torch import TorchTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments from trl import SFTTrainer +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.integrations.wandb.utils import get_artifact_filesystem_path from flamingo.jobs import FinetuningJobConfig +def resolve_file_path(name_or_artifact: str | WandbArtifactLink) -> str: + if isinstance(name_or_artifact, str): + return name_or_artifact + else: + return get_artifact_filesystem_path(name_or_artifact) + + def is_wandb_enabled(config: FinetuningJobConfig): # Only report to WandB on the rank 0 worker # Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html @@ -22,7 +31,7 @@ def is_wandb_enabled(config: FinetuningJobConfig): def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: """Get TrainingArguments appropriate for the worker rank and job config.""" - training_args = config.trainer.get_training_args() if config.trainer else {} + provided_args = config.trainer.get_training_args() return TrainingArguments( output_dir="out", # Local checkpoint path on a worker report_to="wandb" if is_wandb_enabled(config) else "none", @@ -30,7 +39,7 @@ def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: push_to_hub=False, disable_tqdm=True, logging_dir=None, - **training_args, + **provided_args, ) @@ -60,7 +69,7 @@ def get_model(config: FinetuningJobConfig) -> PreTrainedModel: def get_tokenizer(config: FinetuningJobConfig): - tokenizer_name = config.tokenizer.name if config.tokenizer else config.model.name + tokenizer_name = config.tokenizer.path if config.tokenizer else config.model.name tokenizer_args = config.tokenizer.get_tokenizer_args() if config.tokenizer else {} tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_args) if not tokenizer.pad_token_id: @@ -92,8 +101,8 @@ def train_func(config_data: dict): model=model, tokenizer=tokenizer, args=training_args, - peft_config=config.lora, - max_seq_length=config.trainer.max_seq_length if config.trainer else None, + peft_config=config.adapter, + max_seq_length=config.trainer.max_seq_length, train_dataset=datasets["train"], eval_dataset=datasets["test"], dataset_text_field="text", @@ -110,10 +119,7 @@ def train_func(config_data: dict): def run_finetuning(config: FinetuningJobConfig): print(f"Received job configuration: {config}") - if config.tracking: - # Ensure the run_id is set so that the W&B run can be initialized deterministically - config.tracking.ensure_run_id() - + scaling_config = ScalingConfig(**config.ray.get_scaling_args()) run_config = RunConfig( name=config.tracking.name if config.tracking else None, storage_path=config.ray.storage_path, @@ -122,7 +128,7 @@ def run_finetuning(config: FinetuningJobConfig): trainer = TorchTrainer( train_loop_per_worker=train_func, train_loop_config=json.loads(config.json()), - scaling_config=config.ray.get_scaling_config(), + scaling_config=scaling_config, run_config=run_config, ) result = trainer.fit() diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index f850dfbf..5c7fec15 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,10 +1,13 @@ -from pydantic import Field -from ray.train import ScalingConfig +from typing import Any + +from peft import LoraConfig +from pydantic import Field, validator from flamingo.integrations.huggingface import ( AutoModelConfig, AutoTokenizerConfig, DatasetConfig, + QuantizationConfig, TrainerConfig, ) from flamingo.integrations.wandb import WandbEnvironment @@ -21,8 +24,9 @@ class RayTrainConfig(BaseFlamingoConfig): num_workers: int | None = None storage_path: str | None = None - def get_scaling_config(self) -> ScalingConfig: - return ScalingConfig(use_gpu=self.use_gpu, num_workers=self.num_workers) + def get_scaling_args(self) -> dict[str, Any]: + args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers) + return {k: v for k, v in args.items() if v is not None} class FinetuningJobConfig(BaseFlamingoConfig): @@ -31,6 +35,29 @@ class FinetuningJobConfig(BaseFlamingoConfig): model: AutoModelConfig dataset: DatasetConfig tokenizer: AutoTokenizerConfig | None = None + quantization: QuantizationConfig | None = None + adapter: LoraConfig | None = None # TODO: Create own dataclass here tracking: WandbEnvironment | None = None trainer: TrainerConfig = Field(default_factory=TrainerConfig) ray: RayTrainConfig = Field(default_factory=RayTrainConfig) + + @validator("model", pre=True, always=True) + def validate_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(path=x) + return x + + @validator("dataset", pre=True, always=True) + def validate_dataset_arg(cls, x): + """Allow for passing just a path string as the dataset argument.""" + if isinstance(x, str): + return DatasetConfig(path=x) + return x + + @validator("tokenizer", pre=True, always=True) + def validate_tokenizer_arg(cls, x): + """Allow for passing just a path string as the tokenizer argument.""" + if isinstance(x, str): + return AutoTokenizerConfig(path=x) + return x diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index 7de4c9be..d3f69901 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -2,7 +2,7 @@ from pydantic import Field -from flamingo.integrations.huggingface import AutoModelConfig +from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig from flamingo.integrations.wandb import WandbEnvironment from flamingo.types import BaseFlamingoConfig @@ -29,5 +29,6 @@ class LMHarnessJobConfig(BaseFlamingoConfig): model: AutoModelConfig evaluator: LMHarnessEvaluatorSettings + quantization: QuantizationConfig | None = None tracking: WandbEnvironment | None = None ray: RayComputeSettings = Field(default_factory=RayComputeSettings) diff --git a/tests/configs/__init__.py b/tests/jobs/__init__.py similarity index 100% rename from tests/configs/__init__.py rename to tests/jobs/__init__.py diff --git a/tests/configs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py similarity index 100% rename from tests/configs/test_finetuning_config.py rename to tests/jobs/test_finetuning_config.py diff --git a/tests/configs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py similarity index 100% rename from tests/configs/test_lm_harness_config.py rename to tests/jobs/test_lm_harness_config.py