Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
copy over new config classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 17, 2024
1 parent 7200af9 commit d726bfc
Show file tree
Hide file tree
Showing 19 changed files with 275 additions and 172 deletions.
16 changes: 16 additions & 0 deletions src/flamingo/integrations/huggingface/dataset_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig


class DatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace dataset."""

path: str | WandbArtifactLink
split: str | None = None
test_size: float | None = None
seed: int | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)
19 changes: 19 additions & 0 deletions src/flamingo/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class AutoModelConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoModel instantiation.
The model path can either be a string corresponding to a HuggingFace repo ID,
or an artifact link to a reference artifact on W&B.
"""

path: str | WandbArtifactLink
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)
34 changes: 0 additions & 34 deletions src/flamingo/integrations/huggingface/model_name_or_path.py

This file was deleted.

25 changes: 25 additions & 0 deletions src/flamingo/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_id_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig


class AutoTokenizerConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

path: str | WandbArtifactLink
trust_remote_code: bool | None = None
use_fast: bool | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator)

def get_tokenizer_args(self) -> dict[str, Any]:
args = dict(
trust_remote_code=self.trust_remote_code,
use_fast=self.use_fast,
)
# Only return non-None values so we get HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
54 changes: 39 additions & 15 deletions src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from typing import Any

from flamingo.types import BaseFlamingoConfig


class TrainerConfig(BaseFlamingoConfig):
"""Configuration for a HuggingFace trainer/training arguments."""
"""Configuration for a HuggingFace trainer/training arguments.
This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class,
but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes.
"""

max_seq_length: int | None = None
num_train_epochs: int = 1
batch_size: int = 16
learning_rate: float = 1e-5
weight_decay: float = 1e-3
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = False
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
evaluation_strategy: str = "epoch"
num_train_epochs: int | None = None
per_device_train_batch_size: int | None = None
per_device_eval_batch_size: int | None = None
learning_rate: float | None = None
weight_decay: float | None = None
gradient_accumulation_steps: int | None = None
gradient_checkpointing: bool | None = None
evaluation_strategy: str | None = None
eval_steps: float | None = None
logging_strategy: str = "steps"
logging_steps: float = 100
save_strategy: str = "steps"
save_steps: int = 500
logging_strategy: str | None = None
logging_steps: float | None = None
save_strategy: str | None = None
save_steps: int | None = None

def get_training_args(self) -> dict[str, Any]:
args = dict(
num_train_epochs=self.num_train_epochs,
learning_rate=self.learning_rate,
per_device_train_batch_size=self.per_device_train_batch_size,
per_device_eval_batch_size=self.per_device_eval_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
gradient_checkpointing=self.gradient_checkpointing,
weight_decay=self.weight_decay,
evaluation_strategy=self.evaluation_strategy,
eval_steps=self.eval_steps,
logging_strategy=self.logging_strategy,
logging_steps=self.logging_steps,
save_strategy=self.save_strategy,
save_steps=self.save_steps,
)
# Only return non-None values so we use the HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
26 changes: 25 additions & 1 deletion src/flamingo/integrations/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from typing import Any

from datasets import DatasetDict, load_dataset
from huggingface_hub.utils import HFValidationError, validate_repo_id


def is_valid_huggingface_model_name(s: str):
def repo_id_validator(x: Any):
if isinstance(x, str) and not is_valid_huggingface_repo_id(x):
raise ValueError(f"{x} is not a valid HuggingFace repo ID.")
return x


def is_valid_huggingface_repo_id(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Sadly, theirs throws an exception and has no return.
Expand All @@ -14,3 +23,18 @@ def is_valid_huggingface_model_name(s: str):
return True
except HFValidationError:
return False


def load_and_split_dataset(
path: str,
*,
split: str | None = None,
test_size: float | None,
seed: int | None = None,
) -> DatasetDict:
dataset = load_dataset(path, split=split)
if test_size is not None:
datasets = dataset.train_test_split(test_size=test_size, seed=seed)
else:
datasets = DatasetDict({"train": dataset})
return datasets
6 changes: 4 additions & 2 deletions src/flamingo/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .wandb_environment import WandbEnvironment # noqa: I001
from .artifact_link import WandbArtifactLink
from .run_link import WandbRunLink
from .utils import get_wandb_summary, update_wandb_summary

__all__ = [
"WandbEnvironment",
"WandbArtifactLink",
"WandbRunLink",
"get_wandb_summary",
"update_wandb_summary",
]
17 changes: 17 additions & 0 deletions src/flamingo/integrations/wandb/artifact_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from flamingo.types import BaseFlamingoConfig


class WandbArtifactLink(BaseFlamingoConfig):
"""Data required to retrieve an artifact from W&B."""

name: str
version: str = "latest"
project: str | None = None
entity: str | None = None

@property
def wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
path = f"{path}:{self.version}"
return path
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
import os
import warnings

from pydantic import Extra, root_validator
from pydantic import root_validator
from wandb.apis.public import Run
from wandb.util import random_string

from flamingo.types import BaseFlamingoConfig


class WandbEnvironment(BaseFlamingoConfig):
class WandbRunLink(BaseFlamingoConfig):
"""Settings required to log to a W&B run.
The fields on this class map to the environment variables
that are used to control the W&B logging locations.
A W&B Run is uniquely identified by the combination of `entity/project/run_id`.
The W&B platform will auto-generate values for these variables if they are not provided.
The `name` and `project` are required as they are the minimum information
required to identify a run. The `name` is the human-readable name that appears in the W&B UI.
`name` is different than the `run_id` which must be unique within a project.
Although the `name` is not mandatorily unique, it is generally best practice to use a
unique and descriptive name to later identify the run.
However, based on how these attributes are passed between jobs it is often necessary
to know the run ID before initializing a run.
For this reason, the run ID field is made non-optional and auto-generated locally
if it is not provided.
"""

class Config:
extra = Extra.forbid # Error on extra kwargs

__match_args__ = ("name", "project", "run_id", "run_group", "entity")
__match_args__ = ("run_id", "name", "project", "run_group", "entity")

run_id: str
name: str | None = None
project: str | None = None
run_id: str | None = None
run_group: str | None = None
entity: str | None = None

Expand All @@ -40,22 +37,15 @@ def warn_missing_api_key(cls, values):
)
return values

@property
def env_vars(self) -> dict[str, str]:
# WandB w/ HuggingFace is weird. You can specify the run name inline,
# but the rest must be injected as environment variables
env_vars = {
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_ID": self.run_id,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}
@root_validator(pre=True)
def ensure_run_id(cls, values):
if values.get("run_id", None) is None:
# Generate an random 8-digit alphanumeric string, analogous to W&B platform
values["run_id"] = random_string(length=8)
return values

@classmethod
def from_run(cls, run: Run) -> "WandbEnvironment":
def from_run(cls, run: Run) -> "WandbRunLink":
"""Extract environment settings from a W&B Run object.
Useful when listing runs from the W&B API and extracting their settings for a job.
Expand All @@ -67,3 +57,20 @@ def from_run(cls, run: Run) -> "WandbEnvironment":
entity=run.entity,
run_id=run.id,
)

@property
def wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None)
return path

def get_env_vars(self) -> dict[str, str]:
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}
38 changes: 11 additions & 27 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,23 @@
import wandb
from wandb.apis.public import Run

from flamingo.integrations.wandb import WandbEnvironment
from flamingo.integrations.wandb import WandbRunLink


def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]:
def get_wandb_run(env: WandbRunLink) -> Run:
"""Retrieve a run from the W&B API."""
api = wandb.Api()
return api.run(env.wandb_path)


def get_wandb_summary(env: WandbRunLink) -> dict[str, Any]:
"""Get the summary dictionary attached to a W&B run."""
run = _resolve_wandb_run(env)
run = get_wandb_run(env)
return dict(run.summary)


def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None:
def update_wandb_summary(env: WandbRunLink, metrics: dict[str, Any]) -> None:
"""Update a run's summary with the provided metrics."""
run = _resolve_wandb_run(env)
run = get_wandb_run(env)
run.summary.update(metrics)
run.update()


def _resolve_wandb_run(env: WandbEnvironment) -> Run:
"""Resolve a WandB run object from the provided environment settings.
An exception is raised if a Run cannot be found,
or if multiple runs exist in scope with the same name.
"""
api = wandb.Api()
base_path = "/".join(x for x in (env.entity, env.project) if x)
if env.run_id is not None:
full_path = f"{base_path}/{env.run_id}"
return api.run(full_path)
else:
match [run for run in api.runs(base_path) if run.name == env.name]:
case []:
raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}")
case [Run(), _]:
raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}")
case [Run()] as mr:
# we have a single one, hurray
return mr[0]
7 changes: 0 additions & 7 deletions src/flamingo/jobs/base_config.py

This file was deleted.

Loading

0 comments on commit d726bfc

Please sign in to comment.