Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from mozilla-ai/sfriedowitz/config-layout
Browse files Browse the repository at this point in the history
Add more structured configurations for jobs
  • Loading branch information
Sean Friedowitz authored Jan 17, 2024
2 parents 7200af9 + 5368178 commit e44dae9
Show file tree
Hide file tree
Showing 31 changed files with 598 additions and 292 deletions.
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,3 @@ cython_debug/

# Ruff
.ruff_cache


# ignore local wandb cache files. Not perfect
**/wandb/*.log
**/wandb/*run*
40 changes: 40 additions & 0 deletions examples/configs/finetuning_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Tokenizer defined by only a string repo ID
tokenizer: "mistral-ai/other-repo-with-special-tokenizer"

# Model defined as an object with additional settings beyond the repo ID
model:
path: "mistral-ai/mistral-7b"
trust_remote_code: True
torch_dtype: "bfloat16"

# Dataset defined as an object with a path linking to a W&B artifact
dataset:
path:
name: "dataset-artifact"
version: "latest"
project: "research-project"
split: "train"
test_size: 0.2

trainer:
max_seq_length: 512
learning_rate: 0.1
num_train_epochs: 2

quantization:
load_in_4bit: True
bnb_4bit_quant_type: "fp4"

adapter:
r: 16
lora_alpha: 32
lora_dropout: 0.2

tracking:
name: "location-to-log-results"
project: "another-project"
entity: "another-entity"

ray:
use_gpu: True
num_workers: 4
28 changes: 28 additions & 0 deletions examples/configs/lm_harness_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Model to evaluate, specified as a W&B artifact
model:
path:
name: "training-run-model-artifact"
version: "v4"
project: "research-project"
entity: "twitter.com"
trust_remote_code: True
torch_dtype: "float16"

# Settings specific to lm_harness.evaluate
evaluator:
tasks: ["task1", "task2", "...", "taskN"]
num_fewshot: 5

quantization:
load_in_4bit: True
bnb_4bit_quant_type: "fp4"

tracking:
name: "location-to-log-results"
project: "another-project"
entity: "another-entity"

ray:
use_gpu: True
num_workers: 4
timeout: 3600
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ ludwig = ["ludwig==0.9.1"]

evaluate = ["lm-eval==0.4.0", "einops"]

test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"]
test = ["ruff==0.1.7", "pytest==7.4.3", "pytest-cov==4.1.0"]

all = ["flamingo[finetune,ludwig,evaluate,test]"]
all = ["flamingo[finetune,evaluate,test]"]

[project.scripts]
flamingo = "flamingo.cli:main"
Expand Down
8 changes: 6 additions & 2 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from .model_name_or_path import ModelNameOrCheckpointPath
from .dataset_config import DatasetConfig
from .model_config import AutoModelConfig
from .quantization_config import QuantizationConfig
from .tokenizer_config import AutoTokenizerConfig
from .trainer_config import TrainerConfig

__all__ = [
"ModelNameOrCheckpointPath",
"AutoModelConfig",
"AutoTokenizerConfig",
"DatasetConfig",
"QuantizationConfig",
"TrainerConfig",
]
16 changes: 16 additions & 0 deletions src/flamingo/integrations/huggingface/dataset_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactConfig
from flamingo.types import BaseFlamingoConfig


class DatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace dataset."""

path: str | WandbArtifactConfig
split: str | None = None
test_size: float | None = None
seed: int | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)
19 changes: 19 additions & 0 deletions src/flamingo/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactConfig
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class AutoModelConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoModel instantiation.
The model path can either be a string corresponding to a HuggingFace repo ID,
or an artifact link to a reference artifact on W&B.
"""

path: str | WandbArtifactConfig
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)
34 changes: 0 additions & 34 deletions src/flamingo/integrations/huggingface/model_name_or_path.py

This file was deleted.

15 changes: 15 additions & 0 deletions src/flamingo/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactConfig
from flamingo.types import BaseFlamingoConfig


class AutoTokenizerConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

path: str | WandbArtifactConfig
trust_remote_code: bool | None = None
use_fast: bool | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)
33 changes: 18 additions & 15 deletions src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from flamingo.types import BaseFlamingoConfig


class TrainerConfig(BaseFlamingoConfig):
"""Configuration for a HuggingFace trainer/training arguments."""
"""Configuration for a HuggingFace trainer/training arguments.
This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class,
but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes.
"""

max_seq_length: int | None = None
num_train_epochs: int = 1
batch_size: int = 16
learning_rate: float = 1e-5
weight_decay: float = 1e-3
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = False
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
evaluation_strategy: str = "epoch"
num_train_epochs: int | None = None
per_device_train_batch_size: int | None = None
per_device_eval_batch_size: int | None = None
learning_rate: float | None = None
weight_decay: float | None = None
gradient_accumulation_steps: int | None = None
gradient_checkpointing: bool | None = None
evaluation_strategy: str | None = None
eval_steps: float | None = None
logging_strategy: str = "steps"
logging_steps: float = 100
save_strategy: str = "steps"
save_steps: int = 500
logging_strategy: str | None = None
logging_steps: float | None = None
save_strategy: str | None = None
save_steps: int | None = None
10 changes: 9 additions & 1 deletion src/flamingo/integrations/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from typing import Any

from huggingface_hub.utils import HFValidationError, validate_repo_id


def is_valid_huggingface_model_name(s: str):
def repo_id_validator(x: Any):
if isinstance(x, str) and not is_valid_huggingface_repo_id(x):
raise ValueError(f"{x} is not a valid HuggingFace repo ID.")
return x


def is_valid_huggingface_repo_id(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Sadly, theirs throws an exception and has no return.
Expand Down
6 changes: 4 additions & 2 deletions src/flamingo/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .wandb_environment import WandbEnvironment # noqa: I001
from .artifact_config import WandbArtifactConfig
from .run_config import WandbRunConfig
from .utils import get_wandb_summary, update_wandb_summary

__all__ = [
"WandbEnvironment",
"WandbArtifactConfig",
"WandbRunConfig",
"get_wandb_summary",
"update_wandb_summary",
]
16 changes: 16 additions & 0 deletions src/flamingo/integrations/wandb/artifact_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from flamingo.types import BaseFlamingoConfig


class WandbArtifactConfig(BaseFlamingoConfig):
"""Configuration required to retrieve an artifact from W&B."""

name: str
version: str = "latest"
project: str | None = None
entity: str | None = None

def get_wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
path = f"{path}:{self.version}"
return path
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import os
import warnings

from pydantic import Extra, root_validator
from pydantic import root_validator
from wandb.apis.public import Run
from wandb.util import random_string

from flamingo.types import BaseFlamingoConfig


class WandbEnvironment(BaseFlamingoConfig):
"""Settings required to log to a W&B run.
class WandbRunConfig(BaseFlamingoConfig):
"""Configuration required to log to a W&B run.
The fields on this class map to the environment variables
that are used to control the W&B logging locations.
A W&B Run is uniquely identified by the combination of `entity/project/run_id`.
The W&B platform will auto-generate values for these variables if they are not provided
when you initialize a run.
The `name` and `project` are required as they are the minimum information
required to identify a run. The `name` is the human-readable name that appears in the W&B UI.
`name` is different than the `run_id` which must be unique within a project.
Although the `name` is not mandatorily unique, it is generally best practice to use a
unique and descriptive name to later identify the run.
However, based on how these attributes are passed between jobs it is often necessary
to know the run ID before initializing a run.
For this reason, the run ID field is made non-optional and auto-generated locally
if it is not provided.
"""

class Config:
extra = Extra.forbid # Error on extra kwargs

__match_args__ = ("name", "project", "run_id", "run_group", "entity")
__match_args__ = ("run_id", "name", "project", "run_group", "entity")

run_id: str
name: str | None = None
project: str | None = None
run_id: str | None = None
run_group: str | None = None
entity: str | None = None

Expand All @@ -40,22 +38,15 @@ def warn_missing_api_key(cls, values):
)
return values

@property
def env_vars(self) -> dict[str, str]:
# WandB w/ HuggingFace is weird. You can specify the run name inline,
# but the rest must be injected as environment variables
env_vars = {
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_ID": self.run_id,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}
@root_validator(pre=True)
def ensure_run_id(cls, values):
if values.get("run_id", None) is None:
# Generate an random 8-digit alphanumeric string, analogous to W&B platform
values["run_id"] = random_string(length=8)
return values

@classmethod
def from_run(cls, run: Run) -> "WandbEnvironment":
def from_run(cls, run: Run) -> "WandbRunConfig":
"""Extract environment settings from a W&B Run object.
Useful when listing runs from the W&B API and extracting their settings for a job.
Expand All @@ -67,3 +58,19 @@ def from_run(cls, run: Run) -> "WandbEnvironment":
entity=run.entity,
run_id=run.id,
)

def get_wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None)
return path

def get_env_vars(self) -> dict[str, str]:
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}
Loading

0 comments on commit e44dae9

Please sign in to comment.