Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
adding all new configs and such
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 15, 2024
1 parent 2f2306f commit 169907a
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 135 deletions.
8 changes: 6 additions & 2 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from .model_name_or_path import ModelNameOrCheckpointPath
from .dataset_config import DatasetConfig
from .model_config import AutoModelConfig
from .quantization_config import QuantizationConfig
from .tokenizer_config import AutoTokenizerConfig
from .trainer_config import TrainerConfig

__all__ = [
"ModelNameOrCheckpointPath",
"AutoModelConfig",
"AutoTokenizerConfig",
"DatasetConfig",
"QuantizationConfig",
"TrainerConfig",
]
10 changes: 10 additions & 0 deletions src/flamingo/integrations/huggingface/dataset_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig


class DatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace dataset."""

artifact: str | WandbArtifactLink
split_size: float | None = None
seed: int | None = None
19 changes: 19 additions & 0 deletions src/flamingo/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class AutoModelConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoModel instantiation."""

artifact: str | WandbArtifactLink
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None

@validator("artifact", pre=True, always=True)
def _validate_model_name(cls, v):
if isinstance(v, str) and not is_valid_huggingface_repo_id(v):
raise ValueError(f"{v} is not a valid HuggingFace model name.")
return v
34 changes: 0 additions & 34 deletions src/flamingo/integrations/huggingface/model_name_or_path.py

This file was deleted.

19 changes: 19 additions & 0 deletions src/flamingo/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any

from flamingo.types import BaseFlamingoConfig


class AutoTokenizerConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

name: str
trust_remote_code: bool | None = None
use_fast: bool | None = None

def get_tokenizer_args(self) -> dict[str, Any]:
args = dict(
trust_remote_code=self.trust_remote_code,
use_fast=self.use_fast,
)
# Only return non-None values so we get HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
48 changes: 34 additions & 14 deletions src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from typing import Any

from flamingo.types import BaseFlamingoConfig


class TrainerConfig(BaseFlamingoConfig):
"""Configuration for a HuggingFace trainer/training arguments."""

max_seq_length: int | None = None
num_train_epochs: int = 1
batch_size: int = 16
learning_rate: float = 1e-5
weight_decay: float = 1e-3
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = False
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
evaluation_strategy: str = "epoch"
num_train_epochs: int | None = None
per_device_train_batch_size: int | None = None
per_device_eval_batch_size: int | None = None
learning_rate: float | None = None
weight_decay: float | None = None
gradient_accumulation_steps: int | None = None
gradient_checkpointing: bool | None = None
evaluation_strategy: str | None = None
eval_steps: float | None = None
logging_strategy: str = "steps"
logging_steps: float = 100
save_strategy: str = "steps"
save_steps: int = 500
logging_strategy: str | None = None
logging_steps: float | None = None
save_strategy: str | None = None
save_steps: int | None = None

def get_training_args(self) -> dict[str, Any]:
args = dict(
num_train_epochs=self.num_train_epochs,
learning_rate=self.learning_rate,
per_device_train_batch_size=self.per_device_train_batch_size,
per_device_eval_batch_size=self.per_device_eval_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
gradient_checkpointing=self.gradient_checkpointing,
weight_decay=self.weight_decay,
evaluation_strategy=self.evaluation_strategy,
eval_steps=self.eval_steps,
logging_strategy=self.logging_strategy,
logging_steps=self.logging_steps,
save_strategy=self.save_strategy,
save_steps=self.save_steps,
)
# Only return non-None values so we get HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
10 changes: 7 additions & 3 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import secrets
import warnings

import wandb
from pydantic import Extra, root_validator
from wandb.apis.public import Run

Expand Down Expand Up @@ -69,6 +69,10 @@ def from_run(cls, run: Run) -> "WandbEnvironment":
run_id=run.id,
)

def force_run_id(self) -> None:
def ensure_run_id(self, provided_run_id: str | None = None) -> None:
"""Ensure that the run_id is set in the configuration.
If None, the run_id is set to the passed value or a random 8-digit hexadecimal string.
"""
if self.run_id is None:
self.run_id = wandb.run
self.run_id = provided_run_id or secrets.token_hex(nbytes=4)
2 changes: 0 additions & 2 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from .base_config import BaseJobConfig
from .finetuning_config import FinetuningJobConfig
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
Expand Down
7 changes: 0 additions & 7 deletions src/flamingo/jobs/base_config.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/flamingo/jobs/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .finetuning import run_finetuning
from .lm_harness import run_lm_harness
from .ludwig import run_ludwig
from .simple import run_simple
from .finetuning_job import run_finetuning
from .lm_harness_job import run_lm_harness
from .ludwig_job import run_ludwig
from .simple_job import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_ludwig", "run_simple"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,38 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments
from trl import SFTTrainer

from flamingo.integrations.wandb import update_wandb_summary
from flamingo.jobs import FinetuningJobConfig


def is_wandb_enabled(config: FinetuningJobConfig):
# Only report to WandB on the rank 0 worker
# Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html
return config.wandb_env and train.get_context().get_world_rank() == 0
return config.tracking is not None and train.get_context().get_world_rank() == 0


def get_training_args(config: FinetuningJobConfig) -> TrainingArguments:
def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
"""Get TrainingArguments appropriate for the worker rank and job config."""
training_args = config.trainer.get_training_args() if config.trainer else {}
return TrainingArguments(
output_dir="out", # Local checkpoint path on a worker
num_train_epochs=config.num_train_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=config.gradient_checkpointing,
weight_decay=config.weight_decay,
evaluation_strategy=config.evaluation_strategy,
eval_steps=config.eval_steps,
logging_strategy=config.logging_strategy,
logging_steps=config.logging_steps,
save_strategy=config.save_strategy,
save_steps=config.save_steps,
run_name=config.wandb_name,
report_to="wandb" if is_wandb_enabled(config) else "none",
no_cuda=not config.scaling_config.use_gpu,
no_cuda=not config.scaling.use_gpu,
push_to_hub=False,
disable_tqdm=True,
logging_dir=None,
**training_args,
)


def get_datasets(config: FinetuningJobConfig) -> DatasetDict:
# TODO: Refactor me somehow
...
# TODO: Implement me
return DatasetDict()


def get_model(config: FinetuningJobConfig) -> PreTrainedModel:
device_map, bnb_config = None, None
if config.quantization_config:
bnb_config = config.quantization_config.as_huggingface()
if config.quantization is not None:
bnb_config = config.quantization.as_huggingface()
# When quantization is enabled, model must all be on same GPU to work with DDP
# If a device_map is not specified we will get accelerate errors downstream
# Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994
Expand All @@ -64,20 +51,18 @@ def get_model(config: FinetuningJobConfig) -> PreTrainedModel:
print(f"Setting model device_map = {device_map} to enable quantization")

return AutoModelForCausalLM.from_pretrained(
config.model,
trust_remote_code=config.trust_remote_code,
torch_dtype=config.torch_dtype,
pretrained_model_name_or_path=config.model.name,
trust_remote_code=config.model.trust_remote_code,
torch_dtype=config.model.torch_dtype,
quantization_config=bnb_config,
device_map=device_map,
)


def get_tokenizer(config: FinetuningJobConfig):
tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer or config.model,
trust_remote_code=config.trust_remote_code,
use_fast=True,
)
tokenizer_name = config.tokenizer.name if config.tokenizer else config.model.name
tokenizer_args = config.tokenizer.get_tokenizer_args() if config.tokenizer else {}
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_args)
if not tokenizer.pad_token_id:
# Pad token required for generating consistent batch sizes
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand All @@ -90,14 +75,25 @@ def train_func(config_data: dict):
tokenizer = get_tokenizer(config)

datasets = get_datasets(config)
training_args = get_training_args(config)
training_args = get_training_arguments(config)

# Manually initialize run in order to control run ID
if is_wandb_enabled(config):
env = config.tracking
wandb.init(
id=env.run_id,
name=env.name,
project=env.project,
entity=env.entity,
group=env.run_group,
)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=config.lora_config,
max_seq_length=config.max_seq_length,
peft_config=config.lora,
max_seq_length=config.trainer.max_seq_length if config.trainer else None,
train_dataset=datasets["train"],
eval_dataset=datasets["test"],
dataset_text_field="text",
Expand All @@ -114,23 +110,24 @@ def train_func(config_data: dict):
def run_finetuning(config: FinetuningJobConfig):
print(f"Received job configuration: {config}")

if config.tracking:
# Ensure the run_id is set so that the W&B run can be initialized deterministically
config.tracking.ensure_run_id()

run_config = RunConfig(
name=config.wandb_name,
storage_path=config.storage_path,
name=config.tracking.name if config.tracking else None,
storage_path=config.ray.storage_path,
checkpoint_config=CheckpointConfig(num_to_keep=1),
)
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config=json.loads(config.json()),
scaling_config=config.scaling_config,
scaling_config=config.ray.get_scaling_config(),
run_config=run_config,
)
result = trainer.fit()
print(f"Training result: {result}")

# Log additional training metrics to completed WandB run
if config.wandb_env:
result_paths = {"ray/result_path": result.path}
if result.checkpoint:
result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint"
update_wandb_summary(config.wandb_env, result_paths)
if config.tracking:
# TODO: Add ref artifact here
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from lm_eval.models.huggingface import HFLM
from peft import PeftConfig

from flamingo.configs import LMHarnessJobConfig, ModelNameOrCheckpointPath
from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary
from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath


def resolve_model_or_path(config: LMHarnessJobConfig) -> str:
Expand Down Expand Up @@ -97,6 +97,10 @@ def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None:
def run_lm_harness(config: LMHarnessJobConfig):
print(f"Received job configuration: {config}")

if config.tracking:
# Ensure the run_id is set so that the W&B run can be initialized deterministically
config.tracking.ensure_run_id()

# Resolve path and ensure exists
model_to_load = resolve_model_or_path(config)

Expand Down
Loading

0 comments on commit 169907a

Please sign in to comment.