From 8206b31fefcec6df008f3fd0b5fabc9169874b88 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 13:00:35 -0800 Subject: [PATCH] copy over artifact loading logic --- .../huggingface/tokenizer_config.py | 10 -- .../integrations/wandb/artifact_config.py | 18 +++ src/flamingo/integrations/wandb/utils.py | 10 ++ src/flamingo/jobs/drivers/finetuning.py | 133 +++++++++++------- 4 files changed, 109 insertions(+), 62 deletions(-) diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py index 750a4d94..2f8fe70c 100644 --- a/src/flamingo/integrations/huggingface/tokenizer_config.py +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -1,5 +1,3 @@ -from typing import Any - from pydantic import validator from flamingo.integrations.huggingface.utils import repo_id_validator @@ -15,11 +13,3 @@ class AutoTokenizerConfig(BaseFlamingoConfig): use_fast: bool | None = None _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) - - def get_tokenizer_args(self) -> dict[str, Any]: - args = dict( - trust_remote_code=self.trust_remote_code, - use_fast=self.use_fast, - ) - # Only return non-None values so we get HuggingFace defaults when not specified - return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/wandb/artifact_config.py b/src/flamingo/integrations/wandb/artifact_config.py index 4f445330..889870db 100644 --- a/src/flamingo/integrations/wandb/artifact_config.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -1,3 +1,5 @@ +import wandb + from flamingo.types import BaseFlamingoConfig @@ -15,3 +17,19 @@ def wandb_path(self) -> str: path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) path = f"{path}:{self.version}" return path + + +class WandbArtifactLoader: + """Helper class for loading W&B artifacts and linking them to runs.""" + + def __init__(self, run: wandb.run | None = None): + self._run = run + + def load_artifact(self, link: WandbArtifactConfig) -> wandb.Artifact: + if self._run is not None: + # Retrieves the artifact and links it as an input to the run + return self._run.use_artifact(link.wandb_path) + else: + # Retrieves the artifact outside of the run + api = wandb.Api() + return api.artifact(link.wandb_path) diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index f4303486..cb4520f3 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Any import wandb @@ -23,3 +24,12 @@ def update_wandb_summary(run_config: WandbRunConfig, metrics: dict[str, Any]) -> run = get_wandb_api_run(run_config) run.summary.update(metrics) run.update() + + +def get_reference_filesystem_path(artifact: wandb.Artifact) -> str: + for entry in artifact.manifest.entries.values(): + if entry.ref.startswith("file://"): + # TODO: What if there are entries with different base paths in the artifact manifest? + entry_path = Path(entry.ref.replace("file://", "")) + return str(entry_path.parent.absolute()) + raise ValueError("Artifact does not contain a filesystem reference.") diff --git a/src/flamingo/jobs/drivers/finetuning.py b/src/flamingo/jobs/drivers/finetuning.py index 7475804e..300a02ae 100644 --- a/src/flamingo/jobs/drivers/finetuning.py +++ b/src/flamingo/jobs/drivers/finetuning.py @@ -5,57 +5,72 @@ from accelerate import Accelerator from datasets import DatasetDict from ray import train -from ray.train import CheckpointConfig, RunConfig +from ray.train import CheckpointConfig, RunConfig, ScalingConfig from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer from ray.train.torch import TorchTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments from trl import SFTTrainer -from flamingo.integrations.wandb import update_wandb_summary +from flamingo.integrations.huggingface.utils import load_and_split_dataset +from flamingo.integrations.wandb import WandbArtifactConfig, WandbArtifactLoader +from flamingo.integrations.wandb.utils import get_reference_filesystem_path from flamingo.jobs import FinetuningJobConfig -def is_wandb_enabled(config: FinetuningJobConfig): +def is_tracking_enabled(config: FinetuningJobConfig): # Only report to WandB on the rank 0 worker # Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html - return config.wandb_env and train.get_context().get_world_rank() == 0 + return config.tracking is not None and train.get_context().get_world_rank() == 0 -def get_training_args(config: FinetuningJobConfig) -> TrainingArguments: +def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str: + """Resolve the actual filesystem path for a path/artifact asset. + + The artifact loader internally handles linking the artifact-to-load to an in-progress run. + """ + match path: + case str(): + return path + case WandbArtifactConfig() as artifact_config: + artifact = loader.load_artifact(artifact_config) + return get_reference_filesystem_path(artifact) + case _: + raise ValueError(f"Invalid artifact path: {path}") + + +def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: """Get TrainingArguments appropriate for the worker rank and job config.""" + provided_args = config.trainer.get_training_args() return TrainingArguments( output_dir="out", # Local checkpoint path on a worker - num_train_epochs=config.num_train_epochs, - learning_rate=config.learning_rate, - per_device_train_batch_size=config.batch_size, - per_device_eval_batch_size=config.batch_size, - gradient_accumulation_steps=config.gradient_accumulation_steps, - gradient_checkpointing=config.gradient_checkpointing, - weight_decay=config.weight_decay, - evaluation_strategy=config.evaluation_strategy, - eval_steps=config.eval_steps, - logging_strategy=config.logging_strategy, - logging_steps=config.logging_steps, - save_strategy=config.save_strategy, - save_steps=config.save_steps, - run_name=config.wandb_name, - report_to="wandb" if is_wandb_enabled(config) else "none", - no_cuda=not config.scaling_config.use_gpu, + report_to="wandb" if is_tracking_enabled(config) else "none", + no_cuda=not config.scaling.use_gpu, push_to_hub=False, disable_tqdm=True, logging_dir=None, + **provided_args, ) -def get_datasets(config: FinetuningJobConfig) -> DatasetDict: - # TODO: Refactor me somehow - ... +def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> DatasetDict: + dataset_path = resolve_artifact_path(config.dataset.path, loader) + # We need to specify a fixed seed to load the datasets on each worker + # Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker + # If the datasets are not seeded here, the ordering will be inconsistent between workers + # TODO: Get rid of this logic once data loading occurs once outside of the workers + split_seed = config.dataset.seed or 0 + return load_and_split_dataset( + dataset_path, + split=config.dataset.split, + test_size=config.dataset.test_size, + seed=split_seed, + ) -def get_model(config: FinetuningJobConfig) -> PreTrainedModel: +def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreTrainedModel: device_map, bnb_config = None, None - if config.quantization_config: - bnb_config = config.quantization_config.as_huggingface() + if config.quantization is not None: + bnb_config = config.quantization.as_huggingface() # When quantization is enabled, model must all be on same GPU to work with DDP # If a device_map is not specified we will get accelerate errors downstream # Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994 @@ -63,20 +78,22 @@ def get_model(config: FinetuningJobConfig) -> PreTrainedModel: device_map = {"": current_device} print(f"Setting model device_map = {device_map} to enable quantization") + model_path = resolve_artifact_path(config.model.path, loader) return AutoModelForCausalLM.from_pretrained( - config.model, - trust_remote_code=config.trust_remote_code, - torch_dtype=config.torch_dtype, + pretrained_model_name_or_path=model_path, + trust_remote_code=config.model.trust_remote_code, + torch_dtype=config.model.torch_dtype, quantization_config=bnb_config, device_map=device_map, ) -def get_tokenizer(config: FinetuningJobConfig): +def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader): + tokenizer_path = resolve_artifact_path(config.tokenizer.path, loader) tokenizer = AutoTokenizer.from_pretrained( - config.tokenizer or config.model, - trust_remote_code=config.trust_remote_code, - use_fast=True, + pretrained_model_name_or_path=tokenizer_path, + trust_remote_code=config.tokenizer.trust_remote_code, + use_fast=config.tokenizer.use_fast, ) if not tokenizer.pad_token_id: # Pad token required for generating consistent batch sizes @@ -86,20 +103,34 @@ def get_tokenizer(config: FinetuningJobConfig): def train_func(config_data: dict): config = FinetuningJobConfig(**config_data) - model = get_model(config) - tokenizer = get_tokenizer(config) - - datasets = get_datasets(config) - training_args = get_training_args(config) + training_args = get_training_arguments(config) + + # Manually initialize run in order to set the run ID and link artifacts + wandb_run = None + if is_tracking_enabled(config): + env = config.tracking + wandb_run = wandb.init( + id=env.run_id, + name=env.name, + project=env.project, + entity=env.entity, + group=env.run_group, + ) + + # Load the input artifacts, potentially linking them to the active W&B run + artifact_loader = WandbArtifactLoader(wandb_run) + datasets = load_datasets(config, artifact_loader) + model = load_model(config, artifact_loader) + tokenizer = load_tokenizer(config, artifact_loader) trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, - peft_config=config.lora_config, - max_seq_length=config.max_seq_length, + peft_config=config.adapter, + max_seq_length=config.trainer.max_seq_length, train_dataset=datasets["train"], - eval_dataset=datasets["test"], + eval_dataset=datasets.get("test"), dataset_text_field="text", ) trainer.add_callback(RayTrainReportCallback()) @@ -107,30 +138,28 @@ def train_func(config_data: dict): trainer.train() # Force WandB finish on rank 0 worker - if is_wandb_enabled(config): + if is_tracking_enabled(config): wandb.finish() def run_finetuning(config: FinetuningJobConfig): print(f"Received job configuration: {config}") + scaling_config = ScalingConfig(**config.ray.get_scaling_args()) run_config = RunConfig( - name=config.wandb_name, - storage_path=config.storage_path, + name=config.tracking.name if config.tracking else None, + storage_path=config.ray.storage_path, checkpoint_config=CheckpointConfig(num_to_keep=1), ) trainer = TorchTrainer( train_loop_per_worker=train_func, train_loop_config=json.loads(config.json()), - scaling_config=config.scaling_config, + scaling_config=scaling_config, run_config=run_config, ) result = trainer.fit() print(f"Training result: {result}") - # Log additional training metrics to completed WandB run - if config.wandb_env: - result_paths = {"ray/result_path": result.path} - if result.checkpoint: - result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint" - update_wandb_summary(config.wandb_env, result_paths) + if config.tracking: + # TODO: Add ref artifact here + pass