Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
copy over artifact loading logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 17, 2024
1 parent 31a905d commit 8206b31
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 62 deletions.
10 changes: 0 additions & 10 deletions src/flamingo/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
Expand All @@ -15,11 +13,3 @@ class AutoTokenizerConfig(BaseFlamingoConfig):
use_fast: bool | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)

def get_tokenizer_args(self) -> dict[str, Any]:
args = dict(
trust_remote_code=self.trust_remote_code,
use_fast=self.use_fast,
)
# Only return non-None values so we get HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
18 changes: 18 additions & 0 deletions src/flamingo/integrations/wandb/artifact_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import wandb

from flamingo.types import BaseFlamingoConfig


Expand All @@ -15,3 +17,19 @@ def wandb_path(self) -> str:
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
path = f"{path}:{self.version}"
return path


class WandbArtifactLoader:
"""Helper class for loading W&B artifacts and linking them to runs."""

def __init__(self, run: wandb.run | None = None):
self._run = run

def load_artifact(self, link: WandbArtifactConfig) -> wandb.Artifact:
if self._run is not None:
# Retrieves the artifact and links it as an input to the run
return self._run.use_artifact(link.wandb_path)
else:
# Retrieves the artifact outside of the run
api = wandb.Api()
return api.artifact(link.wandb_path)
10 changes: 10 additions & 0 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Any

import wandb
Expand All @@ -23,3 +24,12 @@ def update_wandb_summary(run_config: WandbRunConfig, metrics: dict[str, Any]) ->
run = get_wandb_api_run(run_config)
run.summary.update(metrics)
run.update()


def get_reference_filesystem_path(artifact: wandb.Artifact) -> str:
for entry in artifact.manifest.entries.values():
if entry.ref.startswith("file://"):
# TODO: What if there are entries with different base paths in the artifact manifest?
entry_path = Path(entry.ref.replace("file://", ""))
return str(entry_path.parent.absolute())
raise ValueError("Artifact does not contain a filesystem reference.")
133 changes: 81 additions & 52 deletions src/flamingo/jobs/drivers/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,78 +5,95 @@
from accelerate import Accelerator
from datasets import DatasetDict
from ray import train
from ray.train import CheckpointConfig, RunConfig
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
from ray.train.torch import TorchTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments
from trl import SFTTrainer

from flamingo.integrations.wandb import update_wandb_summary
from flamingo.integrations.huggingface.utils import load_and_split_dataset
from flamingo.integrations.wandb import WandbArtifactConfig, WandbArtifactLoader
from flamingo.integrations.wandb.utils import get_reference_filesystem_path
from flamingo.jobs import FinetuningJobConfig


def is_wandb_enabled(config: FinetuningJobConfig):
def is_tracking_enabled(config: FinetuningJobConfig):
# Only report to WandB on the rank 0 worker
# Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html
return config.wandb_env and train.get_context().get_world_rank() == 0
return config.tracking is not None and train.get_context().get_world_rank() == 0


def get_training_args(config: FinetuningJobConfig) -> TrainingArguments:
def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str:
"""Resolve the actual filesystem path for a path/artifact asset.
The artifact loader internally handles linking the artifact-to-load to an in-progress run.
"""
match path:
case str():
return path
case WandbArtifactConfig() as artifact_config:
artifact = loader.load_artifact(artifact_config)
return get_reference_filesystem_path(artifact)
case _:
raise ValueError(f"Invalid artifact path: {path}")


def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
"""Get TrainingArguments appropriate for the worker rank and job config."""
provided_args = config.trainer.get_training_args()
return TrainingArguments(
output_dir="out", # Local checkpoint path on a worker
num_train_epochs=config.num_train_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=config.gradient_checkpointing,
weight_decay=config.weight_decay,
evaluation_strategy=config.evaluation_strategy,
eval_steps=config.eval_steps,
logging_strategy=config.logging_strategy,
logging_steps=config.logging_steps,
save_strategy=config.save_strategy,
save_steps=config.save_steps,
run_name=config.wandb_name,
report_to="wandb" if is_wandb_enabled(config) else "none",
no_cuda=not config.scaling_config.use_gpu,
report_to="wandb" if is_tracking_enabled(config) else "none",
no_cuda=not config.scaling.use_gpu,
push_to_hub=False,
disable_tqdm=True,
logging_dir=None,
**provided_args,
)


def get_datasets(config: FinetuningJobConfig) -> DatasetDict:
# TODO: Refactor me somehow
...
def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> DatasetDict:
dataset_path = resolve_artifact_path(config.dataset.path, loader)
# We need to specify a fixed seed to load the datasets on each worker
# Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker
# If the datasets are not seeded here, the ordering will be inconsistent between workers
# TODO: Get rid of this logic once data loading occurs once outside of the workers
split_seed = config.dataset.seed or 0
return load_and_split_dataset(
dataset_path,
split=config.dataset.split,
test_size=config.dataset.test_size,
seed=split_seed,
)


def get_model(config: FinetuningJobConfig) -> PreTrainedModel:
def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreTrainedModel:
device_map, bnb_config = None, None
if config.quantization_config:
bnb_config = config.quantization_config.as_huggingface()
if config.quantization is not None:
bnb_config = config.quantization.as_huggingface()
# When quantization is enabled, model must all be on same GPU to work with DDP
# If a device_map is not specified we will get accelerate errors downstream
# Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994
current_device = Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
device_map = {"": current_device}
print(f"Setting model device_map = {device_map} to enable quantization")

model_path = resolve_artifact_path(config.model.path, loader)
return AutoModelForCausalLM.from_pretrained(
config.model,
trust_remote_code=config.trust_remote_code,
torch_dtype=config.torch_dtype,
pretrained_model_name_or_path=model_path,
trust_remote_code=config.model.trust_remote_code,
torch_dtype=config.model.torch_dtype,
quantization_config=bnb_config,
device_map=device_map,
)


def get_tokenizer(config: FinetuningJobConfig):
def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader):
tokenizer_path = resolve_artifact_path(config.tokenizer.path, loader)
tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer or config.model,
trust_remote_code=config.trust_remote_code,
use_fast=True,
pretrained_model_name_or_path=tokenizer_path,
trust_remote_code=config.tokenizer.trust_remote_code,
use_fast=config.tokenizer.use_fast,
)
if not tokenizer.pad_token_id:
# Pad token required for generating consistent batch sizes
Expand All @@ -86,51 +103,63 @@ def get_tokenizer(config: FinetuningJobConfig):

def train_func(config_data: dict):
config = FinetuningJobConfig(**config_data)
model = get_model(config)
tokenizer = get_tokenizer(config)

datasets = get_datasets(config)
training_args = get_training_args(config)
training_args = get_training_arguments(config)

# Manually initialize run in order to set the run ID and link artifacts
wandb_run = None
if is_tracking_enabled(config):
env = config.tracking
wandb_run = wandb.init(
id=env.run_id,
name=env.name,
project=env.project,
entity=env.entity,
group=env.run_group,
)

# Load the input artifacts, potentially linking them to the active W&B run
artifact_loader = WandbArtifactLoader(wandb_run)
datasets = load_datasets(config, artifact_loader)
model = load_model(config, artifact_loader)
tokenizer = load_tokenizer(config, artifact_loader)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=config.lora_config,
max_seq_length=config.max_seq_length,
peft_config=config.adapter,
max_seq_length=config.trainer.max_seq_length,
train_dataset=datasets["train"],
eval_dataset=datasets["test"],
eval_dataset=datasets.get("test"),
dataset_text_field="text",
)
trainer.add_callback(RayTrainReportCallback())
trainer = prepare_trainer(trainer)
trainer.train()

# Force WandB finish on rank 0 worker
if is_wandb_enabled(config):
if is_tracking_enabled(config):
wandb.finish()


def run_finetuning(config: FinetuningJobConfig):
print(f"Received job configuration: {config}")

scaling_config = ScalingConfig(**config.ray.get_scaling_args())
run_config = RunConfig(
name=config.wandb_name,
storage_path=config.storage_path,
name=config.tracking.name if config.tracking else None,
storage_path=config.ray.storage_path,
checkpoint_config=CheckpointConfig(num_to_keep=1),
)
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config=json.loads(config.json()),
scaling_config=config.scaling_config,
scaling_config=scaling_config,
run_config=run_config,
)
result = trainer.fit()
print(f"Training result: {result}")

# Log additional training metrics to completed WandB run
if config.wandb_env:
result_paths = {"ray/result_path": result.path}
if result.checkpoint:
result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint"
update_wandb_summary(config.wandb_env, result_paths)
if config.tracking:
# TODO: Add ref artifact here
pass

0 comments on commit 8206b31

Please sign in to comment.