Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
moving some configs around
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 16, 2024
1 parent 169907a commit 0306f3b
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 67 deletions.
7 changes: 6 additions & 1 deletion src/flamingo/integrations/huggingface/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_name_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig


class DatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace dataset."""

artifact: str | WandbArtifactLink
path: str | WandbArtifactLink
split_size: float | None = None
seed: int | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator)
14 changes: 7 additions & 7 deletions src/flamingo/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from peft import LoraConfig
from pydantic import validator

from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id
from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.utils import repo_name_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class AutoModelConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoModel instantiation."""

artifact: str | WandbArtifactLink
path: str | WandbArtifactLink
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
quantization: QuantizationConfig | None = None
lora: LoraConfig | None = None # TODO: Create own dataclass here

@validator("artifact", pre=True, always=True)
def _validate_model_name(cls, v):
if isinstance(v, str) and not is_valid_huggingface_repo_id(v):
raise ValueError(f"{v} is not a valid HuggingFace model name.")
return v
_path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator)
7 changes: 6 additions & 1 deletion src/flamingo/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from typing import Any

from pydantic import validator

from flamingo.integrations.huggingface.utils import repo_name_validator
from flamingo.types import BaseFlamingoConfig


class AutoTokenizerConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

name: str
path: str
trust_remote_code: bool | None = None
use_fast: bool | None = None

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator)

def get_tokenizer_args(self) -> dict[str, Any]:
args = dict(
trust_remote_code=self.trust_remote_code,
Expand Down
8 changes: 8 additions & 0 deletions src/flamingo/integrations/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Any

from huggingface_hub.utils import HFValidationError, validate_repo_id


def repo_name_validator(x: Any):
if isinstance(x, str) and not is_valid_huggingface_repo_id(x):
raise ValueError(f"{x} is not a valid HuggingFace repo ID.")
return x


def is_valid_huggingface_repo_id(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Expand Down
41 changes: 17 additions & 24 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import secrets
import warnings

from pydantic import Extra, root_validator
from pydantic import root_validator, validator
from wandb.apis.public import Run

from flamingo.types import BaseFlamingoConfig
Expand All @@ -11,24 +11,20 @@
class WandbEnvironment(BaseFlamingoConfig):
"""Settings required to log to a W&B run.
The fields on this class map to the environment variables
that are used to control the W&B logging locations.
A W&B Run is uniquely identified by the combination of `entity/project/run_id`.
The W&B platform will auto-generate values for these variables if they are not provided.
The `name` and `project` are required as they are the minimum information
required to identify a run. The `name` is the human-readable name that appears in the W&B UI.
`name` is different than the `run_id` which must be unique within a project.
Although the `name` is not mandatorily unique, it is generally best practice to use a
unique and descriptive name to later identify the run.
However, based on how these attributes are passed between jobs it is often necessary
to know the run ID before initializing a run.
For this reason, the run ID field is made non-optional and auto-generated locally
if it is not provided.
"""

class Config:
extra = Extra.forbid # Error on extra kwargs

__match_args__ = ("name", "project", "run_id", "run_group", "entity")
__match_args__ = ("run_id", "name", "project", "run_group", "entity")

run_id: str
name: str | None = None
project: str | None = None
run_id: str | None = None
run_group: str | None = None
entity: str | None = None

Expand All @@ -41,14 +37,19 @@ def warn_missing_api_key(cls, values):
)
return values

@validator("run_id", post=True, always=True)
def ensure_run_id(cls, run_id):
if run_id is None:
# Generates an 8-digit random hexadecimal string, analogous to W&B platform
run_id = secrets.token_hex(nbytes=4)
return run_id

@property
def env_vars(self) -> dict[str, str]:
# WandB w/ HuggingFace is weird. You can specify the run name inline,
# but the rest must be injected as environment variables
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_ID": self.run_id,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
Expand All @@ -68,11 +69,3 @@ def from_run(cls, run: Run) -> "WandbEnvironment":
entity=run.entity,
run_id=run.id,
)

def ensure_run_id(self, provided_run_id: str | None = None) -> None:
"""Ensure that the run_id is set in the configuration.
If None, the run_id is set to the passed value or a random 8-digit hexadecimal string.
"""
if self.run_id is None:
self.run_id = provided_run_id or secrets.token_hex(nbytes=4)
8 changes: 4 additions & 4 deletions src/flamingo/jobs/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .finetuning_job import run_finetuning
from .lm_harness_job import run_lm_harness
from .ludwig_job import run_ludwig
from .simple_job import run_simple
from .finetuning import run_finetuning
from .lm_harness import run_lm_harness
from .ludwig import run_ludwig
from .simple import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_ludwig", "run_simple"]
21 changes: 7 additions & 14 deletions src/flamingo/jobs/finetuning_config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from peft import LoraConfig
from pydantic import Field, validator
from pydantic import Field
from ray.train import ScalingConfig

from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
QuantizationConfig,
TrainerConfig,
)
from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo_id
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.types import BaseFlamingoConfig


class RayTrainConfig(BaseFlamingoConfig):
"""Misc settings passed to Ray train.
Includes information for scaling, checkpointing, and runtime storage.
"""

use_gpu: bool = True
num_workers: int | None = None
storage_path: str | None = None
Expand All @@ -29,15 +31,6 @@ class FinetuningJobConfig(BaseFlamingoConfig):
model: AutoModelConfig
dataset: DatasetConfig
tokenizer: AutoTokenizerConfig | None = None
trainer: TrainerConfig | None = None
lora: LoraConfig | None = None # TODO: Create our own config type
quantization: QuantizationConfig | None = None
tracking: WandbEnvironment | None = None
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
ray: RayTrainConfig = Field(default_factory=RayTrainConfig)

@validator("model")
def _validate_model_name(cls, v):
if is_valid_huggingface_repo_id(v):
return v
else:
raise ValueError(f"`{v}` is not a valid HuggingFace model name.")
36 changes: 20 additions & 16 deletions src/flamingo/jobs/lm_harness_config.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
import datetime

from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig
from pydantic import Field

from flamingo.integrations.huggingface import AutoModelConfig
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from flamingo.types import BaseFlamingoConfig


class LMHarnessJobConfig(BaseFlamingoConfig):
"""Configuration to run an lm-evaluation-harness evaluation job.
class RayComputeSettings(BaseFlamingoConfig):
"""Misc settings for Ray compute in the LM harness job."""

use_gpu: bool = True
num_workers: int = 1
timeout: datetime.timedelta | None = None

This job loads an existing checkpoint path from Ray storage to run evaluation against,
OR a huggingface Model and logs the evaluation results to W&B.

This can be manually overwritten by specifying the `model_name_or_path` variable
which will take prescedence over the W&B checkpoint path.
"""
class LMHarnessEvaluatorSettings(BaseFlamingoConfig):
"""Misc settings provided to an lm-harness evaluation job."""

model: AutoModelConfig
tasks: list[str]
batch_size: int | None = None
num_fewshot: int | None = None
limit: int | float | None = None
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
quantization: QuantizationConfig | None = None


class LMHarnessJobConfig(BaseFlamingoConfig):
"""Configuration to run an lm-evaluation-harness evaluation job."""

model: AutoModelConfig
evaluator: LMHarnessEvaluatorSettings
tracking: WandbEnvironment | None = None
num_cpus: int = 1
num_gpus: int = 1
timeout: datetime.timedelta | None = None
ray: RayComputeSettings = Field(default_factory=RayComputeSettings)

0 comments on commit 0306f3b

Please sign in to comment.