This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from mozilla-ai/sfriedowitz/config-layout
Add more structured configurations for jobs
- Loading branch information
Showing
31 changed files
with
598 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,8 +161,3 @@ cython_debug/ | |
|
||
# Ruff | ||
.ruff_cache | ||
|
||
|
||
# ignore local wandb cache files. Not perfect | ||
**/wandb/*.log | ||
**/wandb/*run* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Tokenizer defined by only a string repo ID | ||
tokenizer: "mistral-ai/other-repo-with-special-tokenizer" | ||
|
||
# Model defined as an object with additional settings beyond the repo ID | ||
model: | ||
path: "mistral-ai/mistral-7b" | ||
trust_remote_code: True | ||
torch_dtype: "bfloat16" | ||
|
||
# Dataset defined as an object with a path linking to a W&B artifact | ||
dataset: | ||
path: | ||
name: "dataset-artifact" | ||
version: "latest" | ||
project: "research-project" | ||
split: "train" | ||
test_size: 0.2 | ||
|
||
trainer: | ||
max_seq_length: 512 | ||
learning_rate: 0.1 | ||
num_train_epochs: 2 | ||
|
||
quantization: | ||
load_in_4bit: True | ||
bnb_4bit_quant_type: "fp4" | ||
|
||
adapter: | ||
r: 16 | ||
lora_alpha: 32 | ||
lora_dropout: 0.2 | ||
|
||
tracking: | ||
name: "location-to-log-results" | ||
project: "another-project" | ||
entity: "another-entity" | ||
|
||
ray: | ||
use_gpu: True | ||
num_workers: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Model to evaluate, specified as a W&B artifact | ||
model: | ||
path: | ||
name: "training-run-model-artifact" | ||
version: "v4" | ||
project: "research-project" | ||
entity: "twitter.com" | ||
trust_remote_code: True | ||
torch_dtype: "float16" | ||
|
||
# Settings specific to lm_harness.evaluate | ||
evaluator: | ||
tasks: ["task1", "task2", "...", "taskN"] | ||
num_fewshot: 5 | ||
|
||
quantization: | ||
load_in_4bit: True | ||
bnb_4bit_quant_type: "fp4" | ||
|
||
tracking: | ||
name: "location-to-log-results" | ||
project: "another-project" | ||
entity: "another-entity" | ||
|
||
ray: | ||
use_gpu: True | ||
num_workers: 4 | ||
timeout: 3600 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
from .model_name_or_path import ModelNameOrCheckpointPath | ||
from .dataset_config import DatasetConfig | ||
from .model_config import AutoModelConfig | ||
from .quantization_config import QuantizationConfig | ||
from .tokenizer_config import AutoTokenizerConfig | ||
from .trainer_config import TrainerConfig | ||
|
||
__all__ = [ | ||
"ModelNameOrCheckpointPath", | ||
"AutoModelConfig", | ||
"AutoTokenizerConfig", | ||
"DatasetConfig", | ||
"QuantizationConfig", | ||
"TrainerConfig", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from pydantic import validator | ||
|
||
from flamingo.integrations.huggingface.utils import repo_id_validator | ||
from flamingo.integrations.wandb import WandbArtifactConfig | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class DatasetConfig(BaseFlamingoConfig): | ||
"""Settings passed to load a HuggingFace dataset.""" | ||
|
||
path: str | WandbArtifactConfig | ||
split: str | None = None | ||
test_size: float | None = None | ||
seed: int | None = None | ||
|
||
_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from pydantic import validator | ||
|
||
from flamingo.integrations.huggingface.utils import repo_id_validator | ||
from flamingo.integrations.wandb import WandbArtifactConfig | ||
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype | ||
|
||
|
||
class AutoModelConfig(BaseFlamingoConfig): | ||
"""Settings passed to a HuggingFace AutoModel instantiation. | ||
The model path can either be a string corresponding to a HuggingFace repo ID, | ||
or an artifact link to a reference artifact on W&B. | ||
""" | ||
|
||
path: str | WandbArtifactConfig | ||
trust_remote_code: bool = False | ||
torch_dtype: SerializableTorchDtype = None | ||
|
||
_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) |
34 changes: 0 additions & 34 deletions
34
src/flamingo/integrations/huggingface/model_name_or_path.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pydantic import validator | ||
|
||
from flamingo.integrations.huggingface.utils import repo_id_validator | ||
from flamingo.integrations.wandb import WandbArtifactConfig | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class AutoTokenizerConfig(BaseFlamingoConfig): | ||
"""Settings passed to a HuggingFace AutoTokenizer instantiation.""" | ||
|
||
path: str | WandbArtifactConfig | ||
trust_remote_code: bool | None = None | ||
use_fast: bool | None = None | ||
|
||
_path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,24 @@ | ||
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class TrainerConfig(BaseFlamingoConfig): | ||
"""Configuration for a HuggingFace trainer/training arguments.""" | ||
"""Configuration for a HuggingFace trainer/training arguments. | ||
This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class, | ||
but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes. | ||
""" | ||
|
||
max_seq_length: int | None = None | ||
num_train_epochs: int = 1 | ||
batch_size: int = 16 | ||
learning_rate: float = 1e-5 | ||
weight_decay: float = 1e-3 | ||
gradient_accumulation_steps: int = 1 | ||
gradient_checkpointing: bool = False | ||
trust_remote_code: bool = False | ||
torch_dtype: SerializableTorchDtype = None | ||
evaluation_strategy: str = "epoch" | ||
num_train_epochs: int | None = None | ||
per_device_train_batch_size: int | None = None | ||
per_device_eval_batch_size: int | None = None | ||
learning_rate: float | None = None | ||
weight_decay: float | None = None | ||
gradient_accumulation_steps: int | None = None | ||
gradient_checkpointing: bool | None = None | ||
evaluation_strategy: str | None = None | ||
eval_steps: float | None = None | ||
logging_strategy: str = "steps" | ||
logging_steps: float = 100 | ||
save_strategy: str = "steps" | ||
save_steps: int = 500 | ||
logging_strategy: str | None = None | ||
logging_steps: float | None = None | ||
save_strategy: str | None = None | ||
save_steps: int | None = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from .wandb_environment import WandbEnvironment # noqa: I001 | ||
from .artifact_config import WandbArtifactConfig | ||
from .run_config import WandbRunConfig | ||
from .utils import get_wandb_summary, update_wandb_summary | ||
|
||
__all__ = [ | ||
"WandbEnvironment", | ||
"WandbArtifactConfig", | ||
"WandbRunConfig", | ||
"get_wandb_summary", | ||
"update_wandb_summary", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class WandbArtifactConfig(BaseFlamingoConfig): | ||
"""Configuration required to retrieve an artifact from W&B.""" | ||
|
||
name: str | ||
version: str = "latest" | ||
project: str | None = None | ||
entity: str | None = None | ||
|
||
def get_wandb_path(self) -> str: | ||
"""String identifier for the asset on the W&B platform.""" | ||
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) | ||
path = f"{path}:{self.version}" | ||
return path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.