From d726bfcb1507b97075e82a2cfc805fb6cb4f3063 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 18:56:15 -0800 Subject: [PATCH 01/10] copy over new config classes --- .../huggingface/dataset_config.py | 16 ++++ .../integrations/huggingface/model_config.py | 19 +++++ .../huggingface/model_name_or_path.py | 34 --------- .../huggingface/tokenizer_config.py | 25 +++++++ .../huggingface/trainer_config.py | 54 +++++++++---- .../integrations/huggingface/utils.py | 26 ++++++- src/flamingo/integrations/wandb/__init__.py | 6 +- .../integrations/wandb/artifact_link.py | 17 +++++ .../{wandb_environment.py => run_link.py} | 63 +++++++++------- src/flamingo/integrations/wandb/utils.py | 38 +++------- src/flamingo/jobs/base_config.py | 7 -- src/flamingo/jobs/finetuning_config.py | 75 ++++++++++++++----- src/flamingo/jobs/lm_harness_config.py | 49 +++++------- src/flamingo/jobs/simple_config.py | 4 +- tests/conftest.py | 6 +- .../wandb/test_wandb_environment.py | 8 +- tests/{configs => jobs}/__init__.py | 0 .../test_finetuning_config.py | 0 .../test_lm_harness_config.py | 0 19 files changed, 275 insertions(+), 172 deletions(-) create mode 100644 src/flamingo/integrations/huggingface/dataset_config.py create mode 100644 src/flamingo/integrations/huggingface/model_config.py delete mode 100644 src/flamingo/integrations/huggingface/model_name_or_path.py create mode 100644 src/flamingo/integrations/huggingface/tokenizer_config.py create mode 100644 src/flamingo/integrations/wandb/artifact_link.py rename src/flamingo/integrations/wandb/{wandb_environment.py => run_link.py} (54%) delete mode 100644 src/flamingo/jobs/base_config.py rename tests/{configs => jobs}/__init__.py (100%) rename tests/{configs => jobs}/test_finetuning_config.py (100%) rename tests/{configs => jobs}/test_lm_harness_config.py (100%) diff --git a/src/flamingo/integrations/huggingface/dataset_config.py b/src/flamingo/integrations/huggingface/dataset_config.py new file mode 100644 index 00000000..e31014eb --- /dev/null +++ b/src/flamingo/integrations/huggingface/dataset_config.py @@ -0,0 +1,16 @@ +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig + + +class DatasetConfig(BaseFlamingoConfig): + """Settings passed to load a HuggingFace dataset.""" + + path: str | WandbArtifactLink + split: str | None = None + test_size: float | None = None + seed: int | None = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) diff --git a/src/flamingo/integrations/huggingface/model_config.py b/src/flamingo/integrations/huggingface/model_config.py new file mode 100644 index 00000000..60f54077 --- /dev/null +++ b/src/flamingo/integrations/huggingface/model_config.py @@ -0,0 +1,19 @@ +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype + + +class AutoModelConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoModel instantiation. + + The model path can either be a string corresponding to a HuggingFace repo ID, + or an artifact link to a reference artifact on W&B. + """ + + path: str | WandbArtifactLink + trust_remote_code: bool = False + torch_dtype: SerializableTorchDtype = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) diff --git a/src/flamingo/integrations/huggingface/model_name_or_path.py b/src/flamingo/integrations/huggingface/model_name_or_path.py deleted file mode 100644 index f51a69dd..00000000 --- a/src/flamingo/integrations/huggingface/model_name_or_path.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path - -from pydantic.dataclasses import dataclass - -from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name - - -@dataclass -class ModelNameOrCheckpointPath: - """ - This class is explicitly used to validate if a string is - a valid HuggingFace model or can be used as a checkpoint. - - Checkpoint will be automatically assigned if it's a valid checkpoint; - it will be None if it's not valid. - """ - - # explictly needed for matching - __match_args__ = ("name", "checkpoint") - - name: str - checkpoint: str | None = None - - def __post_init__(self): - if isinstance(self.name, Path): - self.name = str(self.name) - - if Path(self.name).is_absolute(): - self.checkpoint = self.name - else: - self.checkpoint = None - - if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): - raise ValueError(f"{self.name} is not a valid checkpoint path or HF model name.") diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py new file mode 100644 index 00000000..3f16b342 --- /dev/null +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -0,0 +1,25 @@ +from typing import Any + +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig + + +class AutoTokenizerConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoTokenizer instantiation.""" + + path: str | WandbArtifactLink + trust_remote_code: bool | None = None + use_fast: bool | None = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) + + def get_tokenizer_args(self) -> dict[str, Any]: + args = dict( + trust_remote_code=self.trust_remote_code, + use_fast=self.use_fast, + ) + # Only return non-None values so we get HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index e78e1391..e3cf4b64 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -1,21 +1,45 @@ -from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype +from typing import Any + +from flamingo.types import BaseFlamingoConfig class TrainerConfig(BaseFlamingoConfig): - """Configuration for a HuggingFace trainer/training arguments.""" + """Configuration for a HuggingFace trainer/training arguments. + + This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class, + but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes. + """ max_seq_length: int | None = None - num_train_epochs: int = 1 - batch_size: int = 16 - learning_rate: float = 1e-5 - weight_decay: float = 1e-3 - gradient_accumulation_steps: int = 1 - gradient_checkpointing: bool = False - trust_remote_code: bool = False - torch_dtype: SerializableTorchDtype = None - evaluation_strategy: str = "epoch" + num_train_epochs: int | None = None + per_device_train_batch_size: int | None = None + per_device_eval_batch_size: int | None = None + learning_rate: float | None = None + weight_decay: float | None = None + gradient_accumulation_steps: int | None = None + gradient_checkpointing: bool | None = None + evaluation_strategy: str | None = None eval_steps: float | None = None - logging_strategy: str = "steps" - logging_steps: float = 100 - save_strategy: str = "steps" - save_steps: int = 500 + logging_strategy: str | None = None + logging_steps: float | None = None + save_strategy: str | None = None + save_steps: int | None = None + + def get_training_args(self) -> dict[str, Any]: + args = dict( + num_train_epochs=self.num_train_epochs, + learning_rate=self.learning_rate, + per_device_train_batch_size=self.per_device_train_batch_size, + per_device_eval_batch_size=self.per_device_eval_batch_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + gradient_checkpointing=self.gradient_checkpointing, + weight_decay=self.weight_decay, + evaluation_strategy=self.evaluation_strategy, + eval_steps=self.eval_steps, + logging_strategy=self.logging_strategy, + logging_steps=self.logging_steps, + save_strategy=self.save_strategy, + save_steps=self.save_steps, + ) + # Only return non-None values so we use the HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/utils.py b/src/flamingo/integrations/huggingface/utils.py index 21d4ce54..7622e3c8 100644 --- a/src/flamingo/integrations/huggingface/utils.py +++ b/src/flamingo/integrations/huggingface/utils.py @@ -1,7 +1,16 @@ +from typing import Any + +from datasets import DatasetDict, load_dataset from huggingface_hub.utils import HFValidationError, validate_repo_id -def is_valid_huggingface_model_name(s: str): +def repo_id_validator(x: Any): + if isinstance(x, str) and not is_valid_huggingface_repo_id(x): + raise ValueError(f"{x} is not a valid HuggingFace repo ID.") + return x + + +def is_valid_huggingface_repo_id(s: str): """ Simple test to check if an HF model is valid using HuggingFace's tools. Sadly, theirs throws an exception and has no return. @@ -14,3 +23,18 @@ def is_valid_huggingface_model_name(s: str): return True except HFValidationError: return False + + +def load_and_split_dataset( + path: str, + *, + split: str | None = None, + test_size: float | None, + seed: int | None = None, +) -> DatasetDict: + dataset = load_dataset(path, split=split) + if test_size is not None: + datasets = dataset.train_test_split(test_size=test_size, seed=seed) + else: + datasets = DatasetDict({"train": dataset}) + return datasets diff --git a/src/flamingo/integrations/wandb/__init__.py b/src/flamingo/integrations/wandb/__init__.py index 8e7c8a52..44b455d0 100644 --- a/src/flamingo/integrations/wandb/__init__.py +++ b/src/flamingo/integrations/wandb/__init__.py @@ -1,8 +1,10 @@ -from .wandb_environment import WandbEnvironment # noqa: I001 +from .artifact_link import WandbArtifactLink +from .run_link import WandbRunLink from .utils import get_wandb_summary, update_wandb_summary __all__ = [ - "WandbEnvironment", + "WandbArtifactLink", + "WandbRunLink", "get_wandb_summary", "update_wandb_summary", ] diff --git a/src/flamingo/integrations/wandb/artifact_link.py b/src/flamingo/integrations/wandb/artifact_link.py new file mode 100644 index 00000000..1f3c9d53 --- /dev/null +++ b/src/flamingo/integrations/wandb/artifact_link.py @@ -0,0 +1,17 @@ +from flamingo.types import BaseFlamingoConfig + + +class WandbArtifactLink(BaseFlamingoConfig): + """Data required to retrieve an artifact from W&B.""" + + name: str + version: str = "latest" + project: str | None = None + entity: str | None = None + + @property + def wandb_path(self) -> str: + """String identifier for the asset on the W&B platform.""" + path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) + path = f"{path}:{self.version}" + return path diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/run_link.py similarity index 54% rename from src/flamingo/integrations/wandb/wandb_environment.py rename to src/flamingo/integrations/wandb/run_link.py index 24d3eedc..c6d168c7 100644 --- a/src/flamingo/integrations/wandb/wandb_environment.py +++ b/src/flamingo/integrations/wandb/run_link.py @@ -1,33 +1,30 @@ import os import warnings -from pydantic import Extra, root_validator +from pydantic import root_validator from wandb.apis.public import Run +from wandb.util import random_string from flamingo.types import BaseFlamingoConfig -class WandbEnvironment(BaseFlamingoConfig): +class WandbRunLink(BaseFlamingoConfig): """Settings required to log to a W&B run. - The fields on this class map to the environment variables - that are used to control the W&B logging locations. + A W&B Run is uniquely identified by the combination of `entity/project/run_id`. + The W&B platform will auto-generate values for these variables if they are not provided. - The `name` and `project` are required as they are the minimum information - required to identify a run. The `name` is the human-readable name that appears in the W&B UI. - `name` is different than the `run_id` which must be unique within a project. - Although the `name` is not mandatorily unique, it is generally best practice to use a - unique and descriptive name to later identify the run. + However, based on how these attributes are passed between jobs it is often necessary + to know the run ID before initializing a run. + For this reason, the run ID field is made non-optional and auto-generated locally + if it is not provided. """ - class Config: - extra = Extra.forbid # Error on extra kwargs - - __match_args__ = ("name", "project", "run_id", "run_group", "entity") + __match_args__ = ("run_id", "name", "project", "run_group", "entity") + run_id: str name: str | None = None project: str | None = None - run_id: str | None = None run_group: str | None = None entity: str | None = None @@ -40,22 +37,15 @@ def warn_missing_api_key(cls, values): ) return values - @property - def env_vars(self) -> dict[str, str]: - # WandB w/ HuggingFace is weird. You can specify the run name inline, - # but the rest must be injected as environment variables - env_vars = { - "WANDB_NAME": self.name, - "WANDB_PROJECT": self.project, - "WANDB_RUN_ID": self.run_id, - "WANDB_RUN_GROUP": self.run_group, - "WANDB_ENTITY": self.entity, - "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), - } - return {k: v for k, v in env_vars.items() if v is not None} + @root_validator(pre=True) + def ensure_run_id(cls, values): + if values.get("run_id", None) is None: + # Generate an random 8-digit alphanumeric string, analogous to W&B platform + values["run_id"] = random_string(length=8) + return values @classmethod - def from_run(cls, run: Run) -> "WandbEnvironment": + def from_run(cls, run: Run) -> "WandbRunLink": """Extract environment settings from a W&B Run object. Useful when listing runs from the W&B API and extracting their settings for a job. @@ -67,3 +57,20 @@ def from_run(cls, run: Run) -> "WandbEnvironment": entity=run.entity, run_id=run.id, ) + + @property + def wandb_path(self) -> str: + """String identifier for the asset on the W&B platform.""" + path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) + return path + + def get_env_vars(self) -> dict[str, str]: + env_vars = { + "WANDB_RUN_ID": self.run_id, + "WANDB_NAME": self.name, + "WANDB_PROJECT": self.project, + "WANDB_RUN_GROUP": self.run_group, + "WANDB_ENTITY": self.entity, + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), + } + return {k: v for k, v in env_vars.items() if v is not None} diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 1b00d9ea..6ea2afc4 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -3,39 +3,23 @@ import wandb from wandb.apis.public import Run -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink -def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: +def get_wandb_run(env: WandbRunLink) -> Run: + """Retrieve a run from the W&B API.""" + api = wandb.Api() + return api.run(env.wandb_path) + + +def get_wandb_summary(env: WandbRunLink) -> dict[str, Any]: """Get the summary dictionary attached to a W&B run.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) return dict(run.summary) -def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: +def update_wandb_summary(env: WandbRunLink, metrics: dict[str, Any]) -> None: """Update a run's summary with the provided metrics.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) run.summary.update(metrics) run.update() - - -def _resolve_wandb_run(env: WandbEnvironment) -> Run: - """Resolve a WandB run object from the provided environment settings. - - An exception is raised if a Run cannot be found, - or if multiple runs exist in scope with the same name. - """ - api = wandb.Api() - base_path = "/".join(x for x in (env.entity, env.project) if x) - if env.run_id is not None: - full_path = f"{base_path}/{env.run_id}" - return api.run(full_path) - else: - match [run for run in api.runs(base_path) if run.name == env.name]: - case []: - raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") - case [Run(), _]: - raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") - case [Run()] as mr: - # we have a single one, hurray - return mr[0] diff --git a/src/flamingo/jobs/base_config.py b/src/flamingo/jobs/base_config.py deleted file mode 100644 index 3fec2edc..00000000 --- a/src/flamingo/jobs/base_config.py +++ /dev/null @@ -1,7 +0,0 @@ -from flamingo.types import BaseFlamingoConfig - - -class BaseJobConfig(BaseFlamingoConfig): - """Configuration defining a job to submit to the Ray cluster.""" - - pass diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 1356d95a..40906020 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,28 +1,63 @@ +from typing import Any + from peft import LoraConfig -from pydantic import validator -from ray.train import ScalingConfig +from pydantic import Field, validator + +from flamingo.integrations.huggingface import ( + AutoModelConfig, + AutoTokenizerConfig, + DatasetConfig, + QuantizationConfig, + TrainerConfig, +) +from flamingo.integrations.wandb import WandbRunLink +from flamingo.types import BaseFlamingoConfig + + +class RayTrainConfig(BaseFlamingoConfig): + """Misc settings passed to Ray train. + + Includes information for scaling, checkpointing, and runtime storage. + """ -from flamingo.integrations.huggingface import QuantizationConfig -from flamingo.integrations.huggingface.trainer_config import TrainerConfig -from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name -from flamingo.jobs import BaseJobConfig + use_gpu: bool = True + num_workers: int | None = None + storage_path: str | None = None + + def get_scaling_args(self) -> dict[str, Any]: + args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers) + return {k: v for k, v in args.items() if v is not None} -class FinetuningJobConfig(BaseJobConfig): +class FinetuningJobConfig(BaseFlamingoConfig): """Configuration to submit an LLM finetuning job.""" - model: str - dataset: str - tokenizer: str | None = None - trainer: TrainerConfig | None = None - lora: LoraConfig | None = None # TODO: Create our own config type + model: AutoModelConfig + dataset: DatasetConfig + tokenizer: AutoTokenizerConfig | None = None quantization: QuantizationConfig | None = None - scaling: ScalingConfig | None = None # TODO: Create our own config type - storage_path: str | None = None + adapter: LoraConfig | None = None # TODO: Create own dataclass here + tracking: WandbRunLink | None = None + trainer: TrainerConfig = Field(default_factory=TrainerConfig) + ray: RayTrainConfig = Field(default_factory=RayTrainConfig) + + @validator("model", pre=True, always=True) + def validate_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(path=x) + return x + + @validator("dataset", pre=True, always=True) + def validate_dataset_arg(cls, x): + """Allow for passing just a path string as the dataset argument.""" + if isinstance(x, str): + return DatasetConfig(path=x) + return x - @validator("model") - def _validate_model_name(cls, v): - if is_valid_huggingface_model_name(v): - return v - else: - raise ValueError(f"`{v}` is not a valid HuggingFace model name.") + @validator("tokenizer", pre=True, always=True) + def validate_tokenizer_arg(cls, x): + """Allow for passing just a path string as the tokenizer argument.""" + if isinstance(x, str): + return AutoTokenizerConfig(name_or_artifact=x) + return x diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index 905c1078..0f5332e5 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -1,43 +1,34 @@ import datetime -from pathlib import Path -from pydantic import validator +from pydantic import Field -from flamingo.integrations.huggingface import ModelNameOrCheckpointPath, QuantizationConfig -from flamingo.jobs import BaseJobConfig -from flamingo.types import SerializableTorchDtype +from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig +from flamingo.integrations.wandb import WandbRunLink +from flamingo.types import BaseFlamingoConfig -class LMHarnessJobConfig(BaseJobConfig): - """Configuration to run an lm-evaluation-harness evaluation job. +class RayComputeSettings(BaseFlamingoConfig): + """Misc settings for Ray compute in the LM harness job.""" - This job loads an existing checkpoint path from Ray storage to run evaluation against, - OR a huggingface Model and logs the evaluation results to W&B. + use_gpu: bool = True + num_workers: int = 1 + timeout: datetime.timedelta | None = None - This can be manually overwritten by specifying the `model_name_or_path` variable - which will take prescedence over the W&B checkpoint path. - """ - class Config: - validate_assignment = True +class LMHarnessEvaluatorSettings(BaseFlamingoConfig): + """Misc settings provided to an lm-harness evaluation job.""" tasks: list[str] batch_size: int | None = None num_fewshot: int | None = None limit: int | float | None = None - trust_remote_code: bool = False - torch_dtype: SerializableTorchDtype = None - model_name_or_path: str | Path | ModelNameOrCheckpointPath | None = None - quantization: QuantizationConfig | None = None - num_cpus: int = 1 - num_gpus: int = 1 - timeout: datetime.timedelta | None = None - @validator("model_name_or_path", pre=True, always=True) - def _validate_model_name_or_path(cls, v): - if isinstance(v, dict): - return ModelNameOrCheckpointPath(**v) - elif v is None: - return None - else: - return ModelNameOrCheckpointPath(name=v) + +class LMHarnessJobConfig(BaseFlamingoConfig): + """Configuration to run an lm-evaluation-harness evaluation job.""" + + model: AutoModelConfig + evaluator: LMHarnessEvaluatorSettings + quantization: QuantizationConfig | None = None + tracking: WandbRunLink | None = None + ray: RayComputeSettings = Field(default_factory=RayComputeSettings) diff --git a/src/flamingo/jobs/simple_config.py b/src/flamingo/jobs/simple_config.py index bc18fa96..24fa60a8 100644 --- a/src/flamingo/jobs/simple_config.py +++ b/src/flamingo/jobs/simple_config.py @@ -1,5 +1,5 @@ -from flamingo.jobs import BaseJobConfig +from flamingo.types import BaseFlamingoConfig -class SimpleJobConfig(BaseJobConfig): +class SimpleJobConfig(BaseFlamingoConfig): magic_number: int diff --git a/tests/conftest.py b/tests/conftest.py index ccd5abfd..c6630295 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import pytest -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink from flamingo.jobs import LMHarnessJobConfig @@ -28,14 +28,14 @@ def mock_environment_without_keys(): @pytest.fixture(scope="function") def default_wandb_env(): - def generator(**kwargs) -> WandbEnvironment: + def generator(**kwargs) -> WandbRunLink: mine = { "name": "my-run", "project": "my-project", "entity": "mozilla-ai", "run_id": "gabbagool-123", } - return WandbEnvironment(**{**mine, **kwargs}) + return WandbRunLink(**{**mine, **kwargs}) yield generator diff --git a/tests/integrations/wandb/test_wandb_environment.py b/tests/integrations/wandb/test_wandb_environment.py index 69e815e8..55cfd397 100644 --- a/tests/integrations/wandb/test_wandb_environment.py +++ b/tests/integrations/wandb/test_wandb_environment.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink def test_env_vars(default_wandb_env): @@ -13,15 +13,15 @@ def test_env_vars(default_wandb_env): def test_serde_round_trip(default_wandb_env): - assert WandbEnvironment.parse_raw(default_wandb_env().json()) == default_wandb_env() + assert WandbRunLink.parse_raw(default_wandb_env().json()) == default_wandb_env() def test_disallowed_kwargs(): with pytest.raises(ValidationError): - WandbEnvironment(name="name", project="project", old_name="I will throw") + WandbRunLink(name="name", project="project", old_name="I will throw") def test_missing_key_warning(mock_environment_without_keys): with pytest.warns(UserWarning): - env = WandbEnvironment(name="I am missing an API key", project="I should warn the user") + env = WandbRunLink(name="I am missing an API key", project="I should warn the user") assert "WANDB_API_KEY" not in env.env_vars diff --git a/tests/configs/__init__.py b/tests/jobs/__init__.py similarity index 100% rename from tests/configs/__init__.py rename to tests/jobs/__init__.py diff --git a/tests/configs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py similarity index 100% rename from tests/configs/test_finetuning_config.py rename to tests/jobs/test_finetuning_config.py diff --git a/tests/configs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py similarity index 100% rename from tests/configs/test_lm_harness_config.py rename to tests/jobs/test_lm_harness_config.py From 4941fc9c25a516b9d06f4ced44b8c23e9eb0d49c Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 18:56:38 -0800 Subject: [PATCH 02/10] update gitignore --- .gitignore | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.gitignore b/.gitignore index f845eb4a..de7d4a82 100644 --- a/.gitignore +++ b/.gitignore @@ -161,8 +161,3 @@ cython_debug/ # Ruff .ruff_cache - - -# ignore local wandb cache files. Not perfect -**/wandb/*.log -**/wandb/*run* From 7c8ff23c64422b32c8df5878211e4aeaa10f3cf5 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 19:30:07 -0800 Subject: [PATCH 03/10] starting to add new config tests --- .../integrations/huggingface/__init__.py | 8 +++- .../huggingface/dataset_config.py | 4 +- .../integrations/huggingface/model_config.py | 4 +- .../huggingface/tokenizer_config.py | 4 +- src/flamingo/integrations/wandb/__init__.py | 8 ++-- .../{artifact_link.py => artifact_config.py} | 4 +- .../wandb/{run_link.py => run_config.py} | 11 ++--- src/flamingo/integrations/wandb/utils.py | 8 ++-- src/flamingo/jobs/finetuning_config.py | 6 +-- src/flamingo/jobs/lm_harness_config.py | 4 +- src/flamingo/types.py | 2 +- tests/conftest.py | 22 +++++++-- tests/integrations/wandb/test_run_config.py | 27 +++++++++++ .../wandb/test_wandb_environment.py | 27 ----------- tests/resources/finetuning_config.yaml | 45 +++++++++++++++++++ tests/resources/lm_harness_config.yaml | 31 +++++++++++++ 16 files changed, 155 insertions(+), 60 deletions(-) rename src/flamingo/integrations/wandb/{artifact_link.py => artifact_config.py} (78%) rename src/flamingo/integrations/wandb/{run_link.py => run_config.py} (92%) create mode 100644 tests/integrations/wandb/test_run_config.py delete mode 100644 tests/integrations/wandb/test_wandb_environment.py create mode 100644 tests/resources/finetuning_config.yaml create mode 100644 tests/resources/lm_harness_config.yaml diff --git a/src/flamingo/integrations/huggingface/__init__.py b/src/flamingo/integrations/huggingface/__init__.py index 7380bd37..2c97958d 100644 --- a/src/flamingo/integrations/huggingface/__init__.py +++ b/src/flamingo/integrations/huggingface/__init__.py @@ -1,9 +1,13 @@ -from .model_name_or_path import ModelNameOrCheckpointPath +from .dataset_config import DatasetConfig +from .model_config import AutoModelConfig from .quantization_config import QuantizationConfig +from .tokenizer_config import AutoTokenizerConfig from .trainer_config import TrainerConfig __all__ = [ - "ModelNameOrCheckpointPath", + "AutoModelConfig", + "AutoTokenizerConfig", + "DatasetConfig", "QuantizationConfig", "TrainerConfig", ] diff --git a/src/flamingo/integrations/huggingface/dataset_config.py b/src/flamingo/integrations/huggingface/dataset_config.py index e31014eb..352a30b6 100644 --- a/src/flamingo/integrations/huggingface/dataset_config.py +++ b/src/flamingo/integrations/huggingface/dataset_config.py @@ -1,14 +1,14 @@ from pydantic import validator from flamingo.integrations.huggingface.utils import repo_id_validator -from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.integrations.wandb import WandbArtifactConfig from flamingo.types import BaseFlamingoConfig class DatasetConfig(BaseFlamingoConfig): """Settings passed to load a HuggingFace dataset.""" - path: str | WandbArtifactLink + path: str | WandbArtifactConfig split: str | None = None test_size: float | None = None seed: int | None = None diff --git a/src/flamingo/integrations/huggingface/model_config.py b/src/flamingo/integrations/huggingface/model_config.py index 60f54077..6fb2f683 100644 --- a/src/flamingo/integrations/huggingface/model_config.py +++ b/src/flamingo/integrations/huggingface/model_config.py @@ -1,7 +1,7 @@ from pydantic import validator from flamingo.integrations.huggingface.utils import repo_id_validator -from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.integrations.wandb import WandbArtifactConfig from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype @@ -12,7 +12,7 @@ class AutoModelConfig(BaseFlamingoConfig): or an artifact link to a reference artifact on W&B. """ - path: str | WandbArtifactLink + path: str | WandbArtifactConfig trust_remote_code: bool = False torch_dtype: SerializableTorchDtype = None diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py index 3f16b342..750a4d94 100644 --- a/src/flamingo/integrations/huggingface/tokenizer_config.py +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -3,14 +3,14 @@ from pydantic import validator from flamingo.integrations.huggingface.utils import repo_id_validator -from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.integrations.wandb import WandbArtifactConfig from flamingo.types import BaseFlamingoConfig class AutoTokenizerConfig(BaseFlamingoConfig): """Settings passed to a HuggingFace AutoTokenizer instantiation.""" - path: str | WandbArtifactLink + path: str | WandbArtifactConfig trust_remote_code: bool | None = None use_fast: bool | None = None diff --git a/src/flamingo/integrations/wandb/__init__.py b/src/flamingo/integrations/wandb/__init__.py index 44b455d0..e41077df 100644 --- a/src/flamingo/integrations/wandb/__init__.py +++ b/src/flamingo/integrations/wandb/__init__.py @@ -1,10 +1,10 @@ -from .artifact_link import WandbArtifactLink -from .run_link import WandbRunLink +from .artifact_config import WandbArtifactConfig +from .run_config import WandbRunConfig from .utils import get_wandb_summary, update_wandb_summary __all__ = [ - "WandbArtifactLink", - "WandbRunLink", + "WandbArtifactConfig", + "WandbRunConfig", "get_wandb_summary", "update_wandb_summary", ] diff --git a/src/flamingo/integrations/wandb/artifact_link.py b/src/flamingo/integrations/wandb/artifact_config.py similarity index 78% rename from src/flamingo/integrations/wandb/artifact_link.py rename to src/flamingo/integrations/wandb/artifact_config.py index 1f3c9d53..4f445330 100644 --- a/src/flamingo/integrations/wandb/artifact_link.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -1,8 +1,8 @@ from flamingo.types import BaseFlamingoConfig -class WandbArtifactLink(BaseFlamingoConfig): - """Data required to retrieve an artifact from W&B.""" +class WandbArtifactConfig(BaseFlamingoConfig): + """Configuration required to retrieve an artifact from W&B.""" name: str version: str = "latest" diff --git a/src/flamingo/integrations/wandb/run_link.py b/src/flamingo/integrations/wandb/run_config.py similarity index 92% rename from src/flamingo/integrations/wandb/run_link.py rename to src/flamingo/integrations/wandb/run_config.py index c6d168c7..bf62cf4d 100644 --- a/src/flamingo/integrations/wandb/run_link.py +++ b/src/flamingo/integrations/wandb/run_config.py @@ -8,11 +8,12 @@ from flamingo.types import BaseFlamingoConfig -class WandbRunLink(BaseFlamingoConfig): - """Settings required to log to a W&B run. +class WandbRunConfig(BaseFlamingoConfig): + """Configuration required to log to a W&B run. A W&B Run is uniquely identified by the combination of `entity/project/run_id`. - The W&B platform will auto-generate values for these variables if they are not provided. + The W&B platform will auto-generate values for these variables if they are not provided + when you initialize a run. However, based on how these attributes are passed between jobs it is often necessary to know the run ID before initializing a run. @@ -45,7 +46,7 @@ def ensure_run_id(cls, values): return values @classmethod - def from_run(cls, run: Run) -> "WandbRunLink": + def from_run(cls, run: Run) -> "WandbRunConfig": """Extract environment settings from a W&B Run object. Useful when listing runs from the W&B API and extracting their settings for a job. @@ -63,7 +64,7 @@ def wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) return path - + def get_env_vars(self) -> dict[str, str]: env_vars = { "WANDB_RUN_ID": self.run_id, diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 6ea2afc4..72a3b40c 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -3,22 +3,22 @@ import wandb from wandb.apis.public import Run -from flamingo.integrations.wandb import WandbRunLink +from flamingo.integrations.wandb import WandbRunConfig -def get_wandb_run(env: WandbRunLink) -> Run: +def get_wandb_run(env: WandbRunConfig) -> Run: """Retrieve a run from the W&B API.""" api = wandb.Api() return api.run(env.wandb_path) -def get_wandb_summary(env: WandbRunLink) -> dict[str, Any]: +def get_wandb_summary(env: WandbRunConfig) -> dict[str, Any]: """Get the summary dictionary attached to a W&B run.""" run = get_wandb_run(env) return dict(run.summary) -def update_wandb_summary(env: WandbRunLink, metrics: dict[str, Any]) -> None: +def update_wandb_summary(env: WandbRunConfig, metrics: dict[str, Any]) -> None: """Update a run's summary with the provided metrics.""" run = get_wandb_run(env) run.summary.update(metrics) diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 40906020..63ae62b1 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -10,7 +10,7 @@ QuantizationConfig, TrainerConfig, ) -from flamingo.integrations.wandb import WandbRunLink +from flamingo.integrations.wandb import WandbRunConfig from flamingo.types import BaseFlamingoConfig @@ -37,7 +37,7 @@ class FinetuningJobConfig(BaseFlamingoConfig): tokenizer: AutoTokenizerConfig | None = None quantization: QuantizationConfig | None = None adapter: LoraConfig | None = None # TODO: Create own dataclass here - tracking: WandbRunLink | None = None + tracking: WandbRunConfig | None = None trainer: TrainerConfig = Field(default_factory=TrainerConfig) ray: RayTrainConfig = Field(default_factory=RayTrainConfig) @@ -59,5 +59,5 @@ def validate_dataset_arg(cls, x): def validate_tokenizer_arg(cls, x): """Allow for passing just a path string as the tokenizer argument.""" if isinstance(x, str): - return AutoTokenizerConfig(name_or_artifact=x) + return AutoTokenizerConfig(path=x) return x diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index 0f5332e5..d638fd1e 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -3,7 +3,7 @@ from pydantic import Field from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig -from flamingo.integrations.wandb import WandbRunLink +from flamingo.integrations.wandb import WandbRunConfig from flamingo.types import BaseFlamingoConfig @@ -30,5 +30,5 @@ class LMHarnessJobConfig(BaseFlamingoConfig): model: AutoModelConfig evaluator: LMHarnessEvaluatorSettings quantization: QuantizationConfig | None = None - tracking: WandbRunLink | None = None + tracking: WandbRunConfig | None = None ray: RayComputeSettings = Field(default_factory=RayComputeSettings) diff --git a/src/flamingo/types.py b/src/flamingo/types.py index ca877751..055fb983 100644 --- a/src/flamingo/types.py +++ b/src/flamingo/types.py @@ -26,7 +26,7 @@ class Config: } @validator("*", pre=True) - def validate_serializable_dtype(cls, x: Any, field: ModelField) -> Any: # noqa: N805 + def validate_serializable_dtype(cls, x: Any, field: ModelField) -> Any: """Extract the torch.dtype corresponding to a string value, else return the value. Inspired by the HuggingFace `BitsAndBytesConfig` logic. diff --git a/tests/conftest.py b/tests/conftest.py index c6630295..0f98deae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import pytest -from flamingo.integrations.wandb import WandbRunLink +from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig from flamingo.jobs import LMHarnessJobConfig @@ -27,15 +27,29 @@ def mock_environment_without_keys(): @pytest.fixture(scope="function") -def default_wandb_env(): - def generator(**kwargs) -> WandbRunLink: +def default_wandb_run_config(): + def generator(**kwargs) -> WandbRunConfig: mine = { "name": "my-run", "project": "my-project", "entity": "mozilla-ai", "run_id": "gabbagool-123", } - return WandbRunLink(**{**mine, **kwargs}) + return WandbRunConfig(**{**mine, **kwargs}) + + yield generator + + +@pytest.fixture(scope="function") +def default_wandb_artifact_config(): + def generator(**kwargs) -> WandbArtifactConfig: + mine = { + "name": "my-run", + "version": "latest", + "project": "research-project", + "entity": "mozilla-corporation", + } + return WandbArtifactConfig(**{**mine, **kwargs}) yield generator diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py new file mode 100644 index 00000000..385e40a3 --- /dev/null +++ b/tests/integrations/wandb/test_run_config.py @@ -0,0 +1,27 @@ +import pytest +from pydantic import ValidationError + +from flamingo.integrations.wandb import WandbRunConfig + + +def test_env_vars(default_wandb_run_config): + env_vars = default_wandb_run_config().get_env_vars() + expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] + for key in expected: + assert key in env_vars + assert "WANDB_RUN_GROUP" not in env_vars + + +def test_serde_round_trip(default_wandb_run_config): + assert WandbRunConfig.parse_raw(default_wandb_run_config().json()) == default_wandb_run_config() + + +def test_disallowed_kwargs(): + with pytest.raises(ValidationError): + WandbRunConfig(name="name", project="project", old_name="I will throw") + + +def test_missing_key_warning(mock_environment_without_keys): + with pytest.warns(UserWarning): + env = WandbRunConfig(name="I am missing an API key", project="I should warn the user") + assert "WANDB_API_KEY" not in env.env_vars diff --git a/tests/integrations/wandb/test_wandb_environment.py b/tests/integrations/wandb/test_wandb_environment.py deleted file mode 100644 index 55cfd397..00000000 --- a/tests/integrations/wandb/test_wandb_environment.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -from pydantic import ValidationError - -from flamingo.integrations.wandb import WandbRunLink - - -def test_env_vars(default_wandb_env): - env_vars = default_wandb_env().env_vars - expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] - for key in expected: - assert key in env_vars - assert "WANDB_RUN_GROUP" not in env_vars - - -def test_serde_round_trip(default_wandb_env): - assert WandbRunLink.parse_raw(default_wandb_env().json()) == default_wandb_env() - - -def test_disallowed_kwargs(): - with pytest.raises(ValidationError): - WandbRunLink(name="name", project="project", old_name="I will throw") - - -def test_missing_key_warning(mock_environment_without_keys): - with pytest.warns(UserWarning): - env = WandbRunLink(name="I am missing an API key", project="I should warn the user") - assert "WANDB_API_KEY" not in env.env_vars diff --git a/tests/resources/finetuning_config.yaml b/tests/resources/finetuning_config.yaml new file mode 100644 index 00000000..8ce53870 --- /dev/null +++ b/tests/resources/finetuning_config.yaml @@ -0,0 +1,45 @@ +# Tokenizer defined by only a string repo ID +tokenizer: "mistral-ai/other-repo-with-special-tokenizer" + +# Model defined as an object with additional settings beyond the repo ID +model: + path: "mistral-ai/mistral-7b" + trust_remote_code: True + torch_dtype: "bfloat16" + +# Dataset defined as an object with a path linking to a W&B artifact +dataset: + path: + name: "dataset-artifact" + version: "latest" + project: "research-project" + split: "train" + test_size: 0.2 + +# HuggingFace Trainer/TrainingArguments +trainer: + max_seq_length: 512 + learning_rate: 0.1 + num_train_epochs: 2 + +# HuggingFace quantization settings +quantization: + load_in_4bit: True + bnb_4bit_quant_type: "fp4" + +# LORA adapter settings +adapter: + r: 16 + lora_alpha: 32 + lora_dropout: 0.2 + +# W&B run for logging results +tracking: + name: "location-to-log-results" + project: "another-project" + entity: "another-entity" + +# Ray compute settings +ray: + use_gpu: True + num_workers: 4 diff --git a/tests/resources/lm_harness_config.yaml b/tests/resources/lm_harness_config.yaml new file mode 100644 index 00000000..ac9bad14 --- /dev/null +++ b/tests/resources/lm_harness_config.yaml @@ -0,0 +1,31 @@ +# Model to evaluate, specified as a W&B artifact +model: + path: + name: "training-run-model-artifact" + version: "v4" + project: "research-project" + entity: "twitter.com" + trust_remote_code: True + torch_dtype: "float16" + +# Settings specific to lm_harness.evaluate +evaluator: + tasks: ["task1", "task2", "...", "taskN"] + num_fewshot: 5 + +# HuggingFace quantization settings +quantization: + load_in_4bit: True + bnb_4bit_quant_type: "fp4" + +# W&B run for logging results +tracking: + name: "location-to-log-results" + project: "another-project" + entity: "another-entity" + +# Ray compute settings +ray: + use_gpu: True + num_workers: 4 + timeout: 3600 From 5e54d107b496757f369a957461e9896561074eed Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 19:41:57 -0800 Subject: [PATCH 04/10] backup --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 88198fc9..5647b16e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,9 @@ ludwig = ["ludwig==0.9.1"] evaluate = ["lm-eval==0.4.0", "einops"] -test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"] +test = ["ruff==0.1.7", "pytest==7.4.3", "pytest-cov==4.1.0"] -all = ["flamingo[finetune,ludwig,evaluate,test]"] +all = ["flamingo[finetune,evaluate,test]"] [project.scripts] flamingo = "flamingo.cli:main" From 5ae6bacdd7bdca314d7eebd33f283e77b62ea014 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 11:37:41 -0800 Subject: [PATCH 05/10] working on tests --- src/flamingo/jobs/__init__.py | 5 +- src/flamingo/jobs/finetuning_config.py | 26 +++++++-- src/flamingo/jobs/lm_harness_config.py | 17 ++++-- tests/__init__.py | 1 + tests/conftest.py | 49 +---------------- tests/integrations/wandb/test_run_config.py | 11 ++-- tests/jobs/conftest.py | 58 ++++++++++++++++++++ tests/jobs/test_lm_harness_config.py | 61 ++++++++++++++++----- tests/resources/finetuning_config.yaml | 5 -- tests/resources/lm_harness_config.yaml | 3 - 10 files changed, 147 insertions(+), 89 deletions(-) create mode 100644 tests/jobs/conftest.py diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py index 4a616d75..5c003cd0 100644 --- a/src/flamingo/jobs/__init__.py +++ b/src/flamingo/jobs/__init__.py @@ -1,12 +1,9 @@ -from .base_config import BaseJobConfig from .finetuning_config import FinetuningJobConfig -from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath +from .lm_harness_config import LMHarnessJobConfig from .simple_config import SimpleJobConfig __all__ = [ - "BaseJobConfig", "SimpleJobConfig", "FinetuningJobConfig", "LMHarnessJobConfig", - "ModelNameOrCheckpointPath", ] diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 63ae62b1..902c82db 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,7 +1,7 @@ from typing import Any from peft import LoraConfig -from pydantic import Field, validator +from pydantic import Field, root_validator, validator from flamingo.integrations.huggingface import ( AutoModelConfig, @@ -14,15 +14,15 @@ from flamingo.types import BaseFlamingoConfig -class RayTrainConfig(BaseFlamingoConfig): - """Misc settings passed to Ray train. +class FinetuningRayConfig(BaseFlamingoConfig): + """Misc settings passed to Ray train for finetuning. Includes information for scaling, checkpointing, and runtime storage. """ use_gpu: bool = True num_workers: int | None = None - storage_path: str | None = None + storage_path: str | None = None # TODO: This should be set globally somehow def get_scaling_args(self) -> dict[str, Any]: args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers) @@ -34,12 +34,26 @@ class FinetuningJobConfig(BaseFlamingoConfig): model: AutoModelConfig dataset: DatasetConfig - tokenizer: AutoTokenizerConfig | None = None + tokenizer: AutoTokenizerConfig quantization: QuantizationConfig | None = None adapter: LoraConfig | None = None # TODO: Create own dataclass here tracking: WandbRunConfig | None = None trainer: TrainerConfig = Field(default_factory=TrainerConfig) - ray: RayTrainConfig = Field(default_factory=RayTrainConfig) + ray: FinetuningRayConfig = Field(default_factory=FinetuningRayConfig) + + @root_validator(pre=True) + def ensure_tokenizer_config(cls, values): + """Set the tokenizer to the model path when not explicitly provided.""" + if values.get("tokenizer", None) is None: + match values["model"]: + case str() as model_path: + values["tokenizer"] = model_path + case dict() as model_data: + values["tokenizer"] = model_data["path"] + case AutoModelConfig() as model_config: + values["tokenizer"] = model_config.path + # No fallback necessary, downstream validation will flag invalid model types + return values @validator("model", pre=True, always=True) def validate_model_arg(cls, x): diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index d638fd1e..cce3459b 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -1,13 +1,13 @@ import datetime -from pydantic import Field +from pydantic import Field, validator from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig from flamingo.integrations.wandb import WandbRunConfig from flamingo.types import BaseFlamingoConfig -class RayComputeSettings(BaseFlamingoConfig): +class LMHarnessRayConfig(BaseFlamingoConfig): """Misc settings for Ray compute in the LM harness job.""" use_gpu: bool = True @@ -15,7 +15,7 @@ class RayComputeSettings(BaseFlamingoConfig): timeout: datetime.timedelta | None = None -class LMHarnessEvaluatorSettings(BaseFlamingoConfig): +class LMHarnessEvaluatorConfig(BaseFlamingoConfig): """Misc settings provided to an lm-harness evaluation job.""" tasks: list[str] @@ -28,7 +28,14 @@ class LMHarnessJobConfig(BaseFlamingoConfig): """Configuration to run an lm-evaluation-harness evaluation job.""" model: AutoModelConfig - evaluator: LMHarnessEvaluatorSettings + evaluator: LMHarnessEvaluatorConfig quantization: QuantizationConfig | None = None tracking: WandbRunConfig | None = None - ray: RayComputeSettings = Field(default_factory=RayComputeSettings) + ray: LMHarnessRayConfig = Field(default_factory=LMHarnessRayConfig) + + @validator("model", pre=True, always=True) + def validate_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(path=x) + return x diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..2bb88df4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 0f98deae..3ecdaf8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,12 +4,12 @@ This file is used to provide fixtures for the test session that are accessible to all submodules. """ import os +from pathlib import Path from unittest import mock import pytest -from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig -from flamingo.jobs import LMHarnessJobConfig +TEST_RESOURCES = Path(__file__) / "resources" @pytest.fixture(autouse=True, scope="function") @@ -24,48 +24,3 @@ def mock_environment_without_keys(): """Mocks an environment missing common API keys.""" with mock.patch.dict(os.environ, clear=True): yield - - -@pytest.fixture(scope="function") -def default_wandb_run_config(): - def generator(**kwargs) -> WandbRunConfig: - mine = { - "name": "my-run", - "project": "my-project", - "entity": "mozilla-ai", - "run_id": "gabbagool-123", - } - return WandbRunConfig(**{**mine, **kwargs}) - - yield generator - - -@pytest.fixture(scope="function") -def default_wandb_artifact_config(): - def generator(**kwargs) -> WandbArtifactConfig: - mine = { - "name": "my-run", - "version": "latest", - "project": "research-project", - "entity": "mozilla-corporation", - } - return WandbArtifactConfig(**{**mine, **kwargs}) - - yield generator - - -@pytest.fixture(scope="function") -def default_lm_harness_config(): - def generator(**kwargs) -> LMHarnessJobConfig: - mine = { - "tasks": ["task1", "task2"], - "num_fewshot": 5, - "batch_size": 16, - "torch_dtype": "bfloat16", - "model_name_or_path": None, - "quantization": None, - "timeout": 3600, - } - return LMHarnessJobConfig(**{**mine, **kwargs}) - - yield generator diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 385e40a3..3f6d1c28 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -4,16 +4,19 @@ from flamingo.integrations.wandb import WandbRunConfig -def test_env_vars(default_wandb_run_config): - env_vars = default_wandb_run_config().get_env_vars() +def test_env_vars(wandb_run_config_generator): + env_vars = wandb_run_config_generator().get_env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] for key in expected: assert key in env_vars assert "WANDB_RUN_GROUP" not in env_vars -def test_serde_round_trip(default_wandb_run_config): - assert WandbRunConfig.parse_raw(default_wandb_run_config().json()) == default_wandb_run_config() +def test_serde_round_trip(wandb_run_config_generator): + assert ( + WandbRunConfig.parse_raw(wandb_run_config_generator().json()) + == wandb_run_config_generator() + ) def test_disallowed_kwargs(): diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py new file mode 100644 index 00000000..6987a4b3 --- /dev/null +++ b/tests/jobs/conftest.py @@ -0,0 +1,58 @@ +import pytest +from peft import LoraConfig + +from flamingo.integrations.huggingface import ( + AutoModelConfig, + AutoTokenizerConfig, + DatasetConfig, + QuantizationConfig, +) +from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig + + +@pytest.fixture +def model_config_with_path(): + return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True) + + +@pytest.fixture +def model_config_with_artifact(): + artifact = WandbArtifactConfig(name="model") + return AutoModelConfig(artifact, trust_remote_code=True) + + +@pytest.fixture +def tokenizer_config_with_path(): + return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True) + + +@pytest.fixture +def tokenizer_config_with_artifact(): + artifact = WandbArtifactConfig(name="tokenizer") + return AutoTokenizerConfig(artifact, trust_remote_code=True) + + +@pytest.fixture +def dataset_config_with_path(): + return DatasetConfig("databricks/dolly7b", split="train") + + +@pytest.fixture +def dataset_config_with_artifact(): + artifact = WandbArtifactConfig(name="dataset") + return DatasetConfig(artifact, split="train") + + +@pytest.fixture +def quantization_config(): + return QuantizationConfig(load_in_8bit=True) + + +@pytest.fixture +def lora_config(): + return LoraConfig(r=8, lora_alpha=32, lora_dropout=0.2) + + +@pytest.fixture +def wandb_run_config(): + return WandbRunConfig(name="run", run_id="12345", project="research", entity="mozilla-ai") diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 32b32098..805e761f 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -1,28 +1,59 @@ -from pathlib import Path - import pytest from pydantic import ValidationError +from flamingo.integrations.huggingface import AutoModelConfig from flamingo.jobs import LMHarnessJobConfig +from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig +from tests.conftest import TEST_RESOURCES -def test_bad_hf_name(default_lm_harness_config): - with pytest.raises(ValidationError): - default_lm_harness_config(model_name_or_path="dfa../invalid") +@pytest.fixture +def lm_harness_evaluator_config(): + return LMHarnessEvaluatorConfig( + tasks=["task1", "task2", "task3"], + num_fewshot=5, + ) -def test_serde_round_trip_default_config(default_lm_harness_config): - config = default_lm_harness_config() - assert LMHarnessJobConfig.parse_raw(config.json()) == config +@pytest.fixture +def lm_harness_ray_config(): + return LMHarnessRayConfig( + num_workers=4, + use_gpu=True, + ) + +def test_model_validation(lm_harness_evaluator_config): + allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) + assert allowed_config.model == AutoModelConfig(path="hf_repo_id") -def test_serde_round_trip_with_path(default_lm_harness_config): - config = default_lm_harness_config(model_name_or_path=Path("fake/path")) + with pytest.raises(ValidationError): + LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) + + +def test_serde_round_trip( + model_config_with_artifact, + quantization_config, + wandb_run_config, + lm_harness_evaluator_config, + lm_harness_ray_config, +): + config = LMHarnessJobConfig( + model=model_config_with_artifact, + evaluator=lm_harness_evaluator_config, + ray=lm_harness_ray_config, + tracking=wandb_run_config, + quantization=quantization_config, + ) assert LMHarnessJobConfig.parse_raw(config.json()) == config -def test_parse_from_yaml(default_lm_harness_config, tmp_path_factory): - config = default_lm_harness_config(model_name_or_path="not_a_real_model") - p = tmp_path_factory.mktemp("test_yaml") / "eval.yaml" - config.to_yaml_file(p) - assert config == LMHarnessJobConfig.from_yaml_file(p) +def test_parse_yaml_file(tmp_path_factory): + load_path = TEST_RESOURCES / "lm_harness_config.yaml" + config = LMHarnessJobConfig.from_yaml_file(load_path) + write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml" + config.to_yaml_file(write_path) + assert config == LMHarnessJobConfig.from_yaml_file(write_path) diff --git a/tests/resources/finetuning_config.yaml b/tests/resources/finetuning_config.yaml index 8ce53870..e43baa22 100644 --- a/tests/resources/finetuning_config.yaml +++ b/tests/resources/finetuning_config.yaml @@ -16,30 +16,25 @@ dataset: split: "train" test_size: 0.2 -# HuggingFace Trainer/TrainingArguments trainer: max_seq_length: 512 learning_rate: 0.1 num_train_epochs: 2 -# HuggingFace quantization settings quantization: load_in_4bit: True bnb_4bit_quant_type: "fp4" -# LORA adapter settings adapter: r: 16 lora_alpha: 32 lora_dropout: 0.2 -# W&B run for logging results tracking: name: "location-to-log-results" project: "another-project" entity: "another-entity" -# Ray compute settings ray: use_gpu: True num_workers: 4 diff --git a/tests/resources/lm_harness_config.yaml b/tests/resources/lm_harness_config.yaml index ac9bad14..64592ba4 100644 --- a/tests/resources/lm_harness_config.yaml +++ b/tests/resources/lm_harness_config.yaml @@ -13,18 +13,15 @@ evaluator: tasks: ["task1", "task2", "...", "taskN"] num_fewshot: 5 -# HuggingFace quantization settings quantization: load_in_4bit: True bnb_4bit_quant_type: "fp4" -# W&B run for logging results tracking: name: "location-to-log-results" project: "another-project" entity: "another-entity" -# Ray compute settings ray: use_gpu: True num_workers: 4 From 81440e107c882fade4c966e83179b4b7ee94182e Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 12:20:36 -0800 Subject: [PATCH 06/10] add more test coverage --- .../configs}/finetuning_config.yaml | 0 .../configs}/lm_harness_config.yaml | 0 tests/__init__.py | 1 - tests/conftest.py | 3 - .../wandb/test_artifact_config.py | 21 +++++ tests/integrations/wandb/test_run_config.py | 33 +++++--- tests/jobs/conftest.py | 12 +-- tests/jobs/test_finetuning_config.py | 82 +++++++++++++++---- tests/jobs/test_lm_harness_config.py | 43 +++++----- 9 files changed, 137 insertions(+), 58 deletions(-) rename {tests/resources => examples/configs}/finetuning_config.yaml (100%) rename {tests/resources => examples/configs}/lm_harness_config.yaml (100%) create mode 100644 tests/integrations/wandb/test_artifact_config.py diff --git a/tests/resources/finetuning_config.yaml b/examples/configs/finetuning_config.yaml similarity index 100% rename from tests/resources/finetuning_config.yaml rename to examples/configs/finetuning_config.yaml diff --git a/tests/resources/lm_harness_config.yaml b/examples/configs/lm_harness_config.yaml similarity index 100% rename from tests/resources/lm_harness_config.yaml rename to examples/configs/lm_harness_config.yaml diff --git a/tests/__init__.py b/tests/__init__.py index 2bb88df4..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 3ecdaf8d..e9fd2d21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,10 @@ This file is used to provide fixtures for the test session that are accessible to all submodules. """ import os -from pathlib import Path from unittest import mock import pytest -TEST_RESOURCES = Path(__file__) / "resources" - @pytest.fixture(autouse=True, scope="function") def mock_environment_with_keys(): diff --git a/tests/integrations/wandb/test_artifact_config.py b/tests/integrations/wandb/test_artifact_config.py new file mode 100644 index 00000000..6fdb98ec --- /dev/null +++ b/tests/integrations/wandb/test_artifact_config.py @@ -0,0 +1,21 @@ +import pytest + +from flamingo.integrations.wandb import WandbArtifactConfig + + +@pytest.fixture +def wandb_artifact_config(): + return WandbArtifactConfig( + name="artifact-name", + version="latest", + project="cortex-research", + entity="twitter.com", + ) + + +def test_serde_round_trip(wandb_artifact_config): + assert WandbArtifactConfig.parse_raw(wandb_artifact_config.json()) == wandb_artifact_config + + +def test_wandb_path(wandb_artifact_config): + assert wandb_artifact_config.wandb_path == "twitter.com/cortex-research/artifact-name:latest" diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 3f6d1c28..0cf64bb0 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -4,21 +4,32 @@ from flamingo.integrations.wandb import WandbRunConfig -def test_env_vars(wandb_run_config_generator): - env_vars = wandb_run_config_generator().get_env_vars() +@pytest.fixture +def wandb_run_config(): + return WandbRunConfig( + name="run-name", + run_id="run-id", + project="cortex-research", + entity="twitter.com", + ) + + +def test_serde_round_trip(wandb_run_config): + assert WandbRunConfig.parse_raw(wandb_run_config.json()) == wandb_run_config + + +def test_wandb_path(wandb_run_config): + assert wandb_run_config.wandb_path == "twitter.com/cortex-research/run-id" + + +def test_env_vars(wandb_run_config): + env_vars = wandb_run_config.get_env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] for key in expected: assert key in env_vars assert "WANDB_RUN_GROUP" not in env_vars -def test_serde_round_trip(wandb_run_config_generator): - assert ( - WandbRunConfig.parse_raw(wandb_run_config_generator().json()) - == wandb_run_config_generator() - ) - - def test_disallowed_kwargs(): with pytest.raises(ValidationError): WandbRunConfig(name="name", project="project", old_name="I will throw") @@ -26,5 +37,5 @@ def test_disallowed_kwargs(): def test_missing_key_warning(mock_environment_without_keys): with pytest.warns(UserWarning): - env = WandbRunConfig(name="I am missing an API key", project="I should warn the user") - assert "WANDB_API_KEY" not in env.env_vars + config = WandbRunConfig(name="I am missing an API key", project="I should warn the user") + assert "WANDB_API_KEY" not in config.get_env_vars() diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py index 6987a4b3..5dd3c812 100644 --- a/tests/jobs/conftest.py +++ b/tests/jobs/conftest.py @@ -12,35 +12,35 @@ @pytest.fixture def model_config_with_path(): - return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True) + return AutoModelConfig(path="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def model_config_with_artifact(): artifact = WandbArtifactConfig(name="model") - return AutoModelConfig(artifact, trust_remote_code=True) + return AutoModelConfig(path=artifact, trust_remote_code=True) @pytest.fixture def tokenizer_config_with_path(): - return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True) + return AutoTokenizerConfig(path="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def tokenizer_config_with_artifact(): artifact = WandbArtifactConfig(name="tokenizer") - return AutoTokenizerConfig(artifact, trust_remote_code=True) + return AutoTokenizerConfig(path=artifact, trust_remote_code=True) @pytest.fixture def dataset_config_with_path(): - return DatasetConfig("databricks/dolly7b", split="train") + return DatasetConfig(path="databricks/dolly15k", split="train") @pytest.fixture def dataset_config_with_artifact(): artifact = WandbArtifactConfig(name="dataset") - return DatasetConfig(artifact, split="train") + return DatasetConfig(path=artifact, split="train") @pytest.fixture diff --git a/tests/jobs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py index ca5377a9..099a1b81 100644 --- a/tests/jobs/test_finetuning_config.py +++ b/tests/jobs/test_finetuning_config.py @@ -1,22 +1,72 @@ -from peft import LoraConfig -from ray.train import ScalingConfig +import pytest +from pydantic import ValidationError -from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig +from flamingo.integrations.huggingface import AutoModelConfig, AutoTokenizerConfig, DatasetConfig from flamingo.jobs import FinetuningJobConfig +from flamingo.jobs.finetuning_config import FinetuningRayConfig -def test_serde_round_trip(): - trainer_config = TrainerConfig(torch_dtype="bfloat16") - lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM") - quantization_config = QuantizationConfig(load_in_8bit=True) - scaling_config = ScalingConfig(num_workers=2, use_gpu=True) - config = FinetuningJobConfig( - model="test-model", - dataset="test-dataset", - trainer=trainer_config, - lora=lora_config, +@pytest.fixture +def finetuning_ray_config(): + return FinetuningRayConfig( + num_workers=4, + use_gpu=True, + ) + + +@pytest.fixture +def finetuning_job_config( + model_config_with_artifact, + dataset_config_with_artifact, + tokenizer_config_with_artifact, + quantization_config, + lora_config, + wandb_run_config, + finetuning_ray_config, +): + return FinetuningJobConfig( + model=model_config_with_artifact, + dataset=dataset_config_with_artifact, + tokenizer=tokenizer_config_with_artifact, quantization=quantization_config, - scaling=scaling_config, - storage_path="/mnt/data/ray_results", + adapter=lora_config, + tracking=wandb_run_config, + ray=finetuning_ray_config, + ) + + +def test_serde_round_trip(finetuning_job_config): + assert FinetuningJobConfig.parse_raw(finetuning_job_config.json()) == finetuning_job_config + + +def test_parse_yaml_file(finetuning_job_config, tmp_path_factory): + config_path = tmp_path_factory.mktemp("flamingo_tests") / "finetuning_config.yaml" + finetuning_job_config.to_yaml_file(config_path) + assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) + + +def test_argument_validation(): + # Strings should be upcast to configs as the path argument + allowed_config = FinetuningJobConfig( + model="model_path", + tokenizer="tokenizer_path", + dataset="dataset_path", + ) + assert allowed_config.model == AutoModelConfig(path="model_path") + assert allowed_config.tokenizer == AutoTokenizerConfig(path="tokenizer_path") + assert allowed_config.dataset == DatasetConfig(path="dataset_path") + + # Check passing invalid arguments is validated for each asset type + with pytest.raises(ValidationError): + FinetuningJobConfig(model=12345, tokenizer="tokenizer_path", dataset="dataset_path") + with pytest.raises(ValidationError): + FinetuningJobConfig(model="model_path", tokenizer=12345, dataset="dataset_path") + with pytest.raises(ValidationError): + FinetuningJobConfig(model="model_path", tokenizer="tokenizer_path", dataset=12345) + + # Check that tokenizer is set to model path when absent + missing_tokenizer_config = FinetuningJobConfig( + model="model_path", + dataset="dataset_path", ) - assert FinetuningJobConfig.parse_raw(config.json()) == config + assert missing_tokenizer_config.tokenizer.path == "model_path" diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 805e761f..0ef5fcea 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -4,7 +4,6 @@ from flamingo.integrations.huggingface import AutoModelConfig from flamingo.jobs import LMHarnessJobConfig from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig -from tests.conftest import TEST_RESOURCES @pytest.fixture @@ -23,37 +22,39 @@ def lm_harness_ray_config(): ) -def test_model_validation(lm_harness_evaluator_config): - allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) - assert allowed_config.model == AutoModelConfig(path="hf_repo_id") - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) - - -def test_serde_round_trip( +@pytest.fixture +def lm_harness_job_config( model_config_with_artifact, quantization_config, wandb_run_config, lm_harness_evaluator_config, lm_harness_ray_config, ): - config = LMHarnessJobConfig( + return LMHarnessJobConfig( model=model_config_with_artifact, evaluator=lm_harness_evaluator_config, ray=lm_harness_ray_config, tracking=wandb_run_config, quantization=quantization_config, ) - assert LMHarnessJobConfig.parse_raw(config.json()) == config -def test_parse_yaml_file(tmp_path_factory): - load_path = TEST_RESOURCES / "lm_harness_config.yaml" - config = LMHarnessJobConfig.from_yaml_file(load_path) - write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml" - config.to_yaml_file(write_path) - assert config == LMHarnessJobConfig.from_yaml_file(write_path) +def test_serde_round_trip(lm_harness_job_config): + assert LMHarnessJobConfig.parse_raw(lm_harness_job_config.json()) == lm_harness_job_config + + +def test_parse_yaml_file(lm_harness_job_config, tmp_path_factory): + config_path = tmp_path_factory.mktemp("flamingo_tests") / "lm_harness_config.yaml" + lm_harness_job_config.to_yaml_file(config_path) + assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) + + +def test_model_validation(lm_harness_evaluator_config): + allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) + assert allowed_config.model == AutoModelConfig(path="hf_repo_id") + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) From 3a40abd387a0719bbb0369eb02af4ab6f20d1664 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 12:27:18 -0800 Subject: [PATCH 07/10] add a test to load example configs --- tests/conftest.py | 6 ++++++ tests/jobs/test_finetuning_config.py | 7 +++++++ tests/jobs/test_lm_harness_config.py | 7 +++++++ 3 files changed, 20 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index e9fd2d21..b523d7db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,17 @@ This file is used to provide fixtures for the test session that are accessible to all submodules. """ import os +from pathlib import Path from unittest import mock import pytest +@pytest.fixture +def examples_folder(): + return Path(__file__).parents[1] / "examples" + + @pytest.fixture(autouse=True, scope="function") def mock_environment_with_keys(): """Mocks an API key-like mechanism for the environment.""" diff --git a/tests/jobs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py index 099a1b81..d1180f86 100644 --- a/tests/jobs/test_finetuning_config.py +++ b/tests/jobs/test_finetuning_config.py @@ -45,6 +45,13 @@ def test_parse_yaml_file(finetuning_job_config, tmp_path_factory): assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) +def test_load_example_config(examples_folder): + """Load the example configs to make sure they stay up to date.""" + config_file = examples_folder / "configs" / "finetuning_config.yaml" + config = FinetuningJobConfig.from_yaml_file(config_file) + assert FinetuningJobConfig.parse_raw(config.json()) == config + + def test_argument_validation(): # Strings should be upcast to configs as the path argument allowed_config = FinetuningJobConfig( diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 0ef5fcea..3e2938ae 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -49,6 +49,13 @@ def test_parse_yaml_file(lm_harness_job_config, tmp_path_factory): assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) +def test_load_example_config(examples_folder): + """Load the example configs to make sure they stay up to date.""" + config_file = examples_folder / "configs" / "lm_harness_config.yaml" + config = LMHarnessJobConfig.from_yaml_file(config_file) + assert LMHarnessJobConfig.parse_raw(config.json()) == config + + def test_model_validation(lm_harness_evaluator_config): allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) assert allowed_config.model == AutoModelConfig(path="hf_repo_id") From e086c417ef22cf8b8474c3e1e9a28a22559d7ea2 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 12:31:38 -0800 Subject: [PATCH 08/10] add run id test --- tests/integrations/wandb/test_run_config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 0cf64bb0..e71387d5 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -22,6 +22,11 @@ def test_wandb_path(wandb_run_config): assert wandb_run_config.wandb_path == "twitter.com/cortex-research/run-id" +def test_ensure_run_id(): + env = WandbRunConfig(name="defined", project="defined", entity="defined") + assert env.run_id is not None # Pydantic validator fills this in + + def test_env_vars(wandb_run_config): env_vars = wandb_run_config.get_env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] From 31a905d2ec9f7183df7775cf0c9541396355369b Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 12:33:45 -0800 Subject: [PATCH 09/10] rename some api stuff --- src/flamingo/integrations/wandb/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 72a3b40c..f4303486 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -1,25 +1,25 @@ from typing import Any import wandb -from wandb.apis.public import Run +from wandb.apis.public import Run as ApiRun from flamingo.integrations.wandb import WandbRunConfig -def get_wandb_run(env: WandbRunConfig) -> Run: +def get_wandb_api_run(run_config: WandbRunConfig) -> ApiRun: """Retrieve a run from the W&B API.""" api = wandb.Api() - return api.run(env.wandb_path) + return api.run(run_config.wandb_path) -def get_wandb_summary(env: WandbRunConfig) -> dict[str, Any]: +def get_wandb_summary(run_config: WandbRunConfig) -> dict[str, Any]: """Get the summary dictionary attached to a W&B run.""" - run = get_wandb_run(env) + run = get_wandb_api_run(run_config) return dict(run.summary) -def update_wandb_summary(env: WandbRunConfig, metrics: dict[str, Any]) -> None: +def update_wandb_summary(run_config: WandbRunConfig, metrics: dict[str, Any]) -> None: """Update a run's summary with the provided metrics.""" - run = get_wandb_run(env) + run = get_wandb_api_run(run_config) run.summary.update(metrics) run.update() From 5368178663bda9fae6890142fa7347deb0f1c38a Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 15:03:17 -0800 Subject: [PATCH 10/10] respond to reviewer comments --- .../huggingface/tokenizer_config.py | 10 --------- .../huggingface/trainer_config.py | 21 ------------------- .../integrations/huggingface/utils.py | 16 -------------- .../integrations/wandb/artifact_config.py | 3 +-- src/flamingo/integrations/wandb/run_config.py | 3 +-- src/flamingo/integrations/wandb/utils.py | 2 +- .../wandb/test_artifact_config.py | 6 +++--- tests/integrations/wandb/test_run_config.py | 6 +++--- 8 files changed, 9 insertions(+), 58 deletions(-) diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py index 750a4d94..2f8fe70c 100644 --- a/src/flamingo/integrations/huggingface/tokenizer_config.py +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -1,5 +1,3 @@ -from typing import Any - from pydantic import validator from flamingo.integrations.huggingface.utils import repo_id_validator @@ -15,11 +13,3 @@ class AutoTokenizerConfig(BaseFlamingoConfig): use_fast: bool | None = None _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) - - def get_tokenizer_args(self) -> dict[str, Any]: - args = dict( - trust_remote_code=self.trust_remote_code, - use_fast=self.use_fast, - ) - # Only return non-None values so we get HuggingFace defaults when not specified - return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index e3cf4b64..9447f4bf 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -1,5 +1,3 @@ -from typing import Any - from flamingo.types import BaseFlamingoConfig @@ -24,22 +22,3 @@ class TrainerConfig(BaseFlamingoConfig): logging_steps: float | None = None save_strategy: str | None = None save_steps: int | None = None - - def get_training_args(self) -> dict[str, Any]: - args = dict( - num_train_epochs=self.num_train_epochs, - learning_rate=self.learning_rate, - per_device_train_batch_size=self.per_device_train_batch_size, - per_device_eval_batch_size=self.per_device_eval_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - gradient_checkpointing=self.gradient_checkpointing, - weight_decay=self.weight_decay, - evaluation_strategy=self.evaluation_strategy, - eval_steps=self.eval_steps, - logging_strategy=self.logging_strategy, - logging_steps=self.logging_steps, - save_strategy=self.save_strategy, - save_steps=self.save_steps, - ) - # Only return non-None values so we use the HuggingFace defaults when not specified - return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/utils.py b/src/flamingo/integrations/huggingface/utils.py index 7622e3c8..af8d5e50 100644 --- a/src/flamingo/integrations/huggingface/utils.py +++ b/src/flamingo/integrations/huggingface/utils.py @@ -1,6 +1,5 @@ from typing import Any -from datasets import DatasetDict, load_dataset from huggingface_hub.utils import HFValidationError, validate_repo_id @@ -23,18 +22,3 @@ def is_valid_huggingface_repo_id(s: str): return True except HFValidationError: return False - - -def load_and_split_dataset( - path: str, - *, - split: str | None = None, - test_size: float | None, - seed: int | None = None, -) -> DatasetDict: - dataset = load_dataset(path, split=split) - if test_size is not None: - datasets = dataset.train_test_split(test_size=test_size, seed=seed) - else: - datasets = DatasetDict({"train": dataset}) - return datasets diff --git a/src/flamingo/integrations/wandb/artifact_config.py b/src/flamingo/integrations/wandb/artifact_config.py index 4f445330..7dfff2ba 100644 --- a/src/flamingo/integrations/wandb/artifact_config.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -9,8 +9,7 @@ class WandbArtifactConfig(BaseFlamingoConfig): project: str | None = None entity: str | None = None - @property - def wandb_path(self) -> str: + def get_wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) path = f"{path}:{self.version}" diff --git a/src/flamingo/integrations/wandb/run_config.py b/src/flamingo/integrations/wandb/run_config.py index bf62cf4d..bacb99cc 100644 --- a/src/flamingo/integrations/wandb/run_config.py +++ b/src/flamingo/integrations/wandb/run_config.py @@ -59,8 +59,7 @@ def from_run(cls, run: Run) -> "WandbRunConfig": run_id=run.id, ) - @property - def wandb_path(self) -> str: + def get_wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) return path diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index f4303486..8c0862be 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -9,7 +9,7 @@ def get_wandb_api_run(run_config: WandbRunConfig) -> ApiRun: """Retrieve a run from the W&B API.""" api = wandb.Api() - return api.run(run_config.wandb_path) + return api.run(run_config.get_wandb_path()) def get_wandb_summary(run_config: WandbRunConfig) -> dict[str, Any]: diff --git a/tests/integrations/wandb/test_artifact_config.py b/tests/integrations/wandb/test_artifact_config.py index 6fdb98ec..af4b9bb0 100644 --- a/tests/integrations/wandb/test_artifact_config.py +++ b/tests/integrations/wandb/test_artifact_config.py @@ -8,8 +8,8 @@ def wandb_artifact_config(): return WandbArtifactConfig( name="artifact-name", version="latest", - project="cortex-research", - entity="twitter.com", + project="cortex", + entity="twitter", ) @@ -18,4 +18,4 @@ def test_serde_round_trip(wandb_artifact_config): def test_wandb_path(wandb_artifact_config): - assert wandb_artifact_config.wandb_path == "twitter.com/cortex-research/artifact-name:latest" + assert wandb_artifact_config.get_wandb_path() == "twitter/cortex/artifact-name:latest" diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index e71387d5..7e0bec62 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -9,8 +9,8 @@ def wandb_run_config(): return WandbRunConfig( name="run-name", run_id="run-id", - project="cortex-research", - entity="twitter.com", + project="cortex", + entity="twitter", ) @@ -19,7 +19,7 @@ def test_serde_round_trip(wandb_run_config): def test_wandb_path(wandb_run_config): - assert wandb_run_config.wandb_path == "twitter.com/cortex-research/run-id" + assert wandb_run_config.get_wandb_path() == "twitter/cortex/run-id" def test_ensure_run_id():