From 25a57855f493735fcd335a00eaa7d22d05efbcc7 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Thu, 11 Jan 2024 13:29:25 -0500 Subject: [PATCH] refactor mnp --- README.md | 2 +- pyproject.toml | 2 +- src/flamingo/cli.py | 8 +-- src/flamingo/integrations/__init__.py | 0 .../integrations/huggingface/__init__.py | 9 +++- .../huggingface/model_name_or_path.py | 49 +++++++++++++++++++ .../huggingface/quantization_config.py | 3 +- .../huggingface/trainer_config.py | 5 ++ .../integrations/huggingface/utils.py | 16 ------ .../integrations/wandb/wandb_environment.py | 11 +++++ .../jobs/configs/lm_harness_config.py | 34 +------------ .../jobs/entrypoints/finetuning_entrypoint.py | 2 +- .../jobs/entrypoints/lm_harness_entrypoint.py | 2 +- .../jobs/entrypoints/ludwig_entrypoint.py | 2 +- .../jobs/entrypoints/simple_entrypoint.py | 2 +- 15 files changed, 85 insertions(+), 62 deletions(-) create mode 100644 src/flamingo/integrations/__init__.py create mode 100644 src/flamingo/integrations/huggingface/model_name_or_path.py create mode 100644 src/flamingo/integrations/huggingface/trainer_config.py delete mode 100644 src/flamingo/integrations/huggingface/utils.py diff --git a/README.md b/README.md index 9eda0044..2ac7139a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # flamingo

- +

## Installation diff --git a/pyproject.toml b/pyproject.toml index 9ef61b7b..4a46bc67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ finetune = [ "accelerate==0.25.0", "peft==0.7.1", "trl==0.7.4", - "bitsandbytes==0.41.3", + "bitsandbytes==0.40.2", ] ludwig = ["ludwig==0.9.1"] diff --git a/src/flamingo/cli.py b/src/flamingo/cli.py index cdcb0c82..f2b23682 100644 --- a/src/flamingo/cli.py +++ b/src/flamingo/cli.py @@ -18,7 +18,7 @@ def run_simple(config: str) -> None: from flamingo.jobs.entrypoints import simple_entrypoint config = SimpleJobConfig.from_yaml_file(config) - simple_entrypoint.main(config) + simple_entrypoint.run(config) @run.command("finetuning") @@ -28,7 +28,7 @@ def run_finetuning(config: str) -> None: from flamingo.jobs.entrypoints import finetuning_entrypoint config = FinetuningJobConfig.from_yaml_file(config) - finetuning_entrypoint.main(config) + finetuning_entrypoint.run(config) @run.command("ludwig") @@ -37,7 +37,7 @@ def run_finetuning(config: str) -> None: def run_ludwig(config: str, dataset: str) -> None: from flamingo.jobs.entrypoints import ludwig_entrypoint - ludwig_entrypoint.main(config, dataset) + ludwig_entrypoint.run(config, dataset) @run.command("lm-harness") @@ -47,7 +47,7 @@ def run_lm_harness(config: str) -> None: from flamingo.jobs.entrypoints import lm_harness_entrypoint config = LMHarnessJobConfig.from_yaml_file(config) - lm_harness_entrypoint.main(config) + lm_harness_entrypoint.run(config) if __name__ == "__main__": diff --git a/src/flamingo/integrations/__init__.py b/src/flamingo/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/flamingo/integrations/huggingface/__init__.py b/src/flamingo/integrations/huggingface/__init__.py index f704e2c5..7380bd37 100644 --- a/src/flamingo/integrations/huggingface/__init__.py +++ b/src/flamingo/integrations/huggingface/__init__.py @@ -1,4 +1,9 @@ +from .model_name_or_path import ModelNameOrCheckpointPath from .quantization_config import QuantizationConfig -from .utils import is_valid_huggingface_model_name +from .trainer_config import TrainerConfig -__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"] +__all__ = [ + "ModelNameOrCheckpointPath", + "QuantizationConfig", + "TrainerConfig", +] diff --git a/src/flamingo/integrations/huggingface/model_name_or_path.py b/src/flamingo/integrations/huggingface/model_name_or_path.py new file mode 100644 index 00000000..47a4aa9f --- /dev/null +++ b/src/flamingo/integrations/huggingface/model_name_or_path.py @@ -0,0 +1,49 @@ +from dataclasses import InitVar +from pathlib import Path + +from huggingface_hub.utils import HFValidationError, validate_repo_id +from pydantic.dataclasses import dataclass + + +def is_valid_huggingface_model_name(s: str): + """ + Simple test to check if an HF model is valid using HuggingFace's tools. + Sadly, theirs throws an exception and has no return. + + Args: + s: string to test. + """ + try: + validate_repo_id(s) + return True + except HFValidationError: + return False + + +@dataclass +class ModelNameOrCheckpointPath: + """ + This class is explicitly used to validate if a string is + a valid HuggingFace model or can be used as a checkpoint. + + Checkpoint will be automatically assigned if it's a valid checkpoint; + it will be None if it's not valid. + """ + + # explictly needed for matching + __match_args__ = ("name", "checkpoint") + + name: str + checkpoint: InitVar[str | None] = None + + def __post_init__(self, checkpoint): + if isinstance(self.name, Path): + self.name = str(self.name) + + if Path(self.name).is_absolute(): + self.checkpoint = self.name + else: + self.checkpoint = None + + if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): + raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name")) diff --git a/src/flamingo/integrations/huggingface/quantization_config.py b/src/flamingo/integrations/huggingface/quantization_config.py index 8443c397..bc842743 100644 --- a/src/flamingo/integrations/huggingface/quantization_config.py +++ b/src/flamingo/integrations/huggingface/quantization_config.py @@ -1,6 +1,7 @@ -from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype from transformers import BitsAndBytesConfig +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype + class QuantizationConfig(BaseFlamingoConfig): """Basic quantization settings to pass to training and evaluation jobs. diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py new file mode 100644 index 00000000..b9ef120b --- /dev/null +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -0,0 +1,5 @@ +from flamingo.types import BaseFlamingoConfig + + +class TrainerConfig(BaseFlamingoConfig): + pass diff --git a/src/flamingo/integrations/huggingface/utils.py b/src/flamingo/integrations/huggingface/utils.py deleted file mode 100644 index 21d4ce54..00000000 --- a/src/flamingo/integrations/huggingface/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -from huggingface_hub.utils import HFValidationError, validate_repo_id - - -def is_valid_huggingface_model_name(s: str): - """ - Simple test to check if an HF model is valid using HuggingFace's tools. - Sadly, theirs throws an exception and has no return. - - Args: - s: string to test. - """ - try: - validate_repo_id(s) - return True - except HFValidationError: - return False diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/wandb_environment.py index 3fcf2281..3a1dd90a 100644 --- a/src/flamingo/integrations/wandb/wandb_environment.py +++ b/src/flamingo/integrations/wandb/wandb_environment.py @@ -54,6 +54,17 @@ def env_vars(self) -> dict[str, str]: } return {k: v for k, v in env_vars.items() if v is not None} + @classmethod + def from_env(cls) -> "WandbEnvironment": + """Extract W&B settings from the runtime environment.""" + return cls( + name=os.environ.get("WANDB_NAME"), + project=os.environ.get("WANDB_PROJECT"), + entity=os.environ.get("WANDB_ENTITY"), + run_id=os.environ.get("WANDB_RUN_ID"), + run_group=os.environ.get("WANDB_RUN_GROUP"), + ) + @classmethod def from_run(cls, run: Run) -> "WandbEnvironment": """Extract environment settings from a W&B Run object. diff --git a/src/flamingo/jobs/configs/lm_harness_config.py b/src/flamingo/jobs/configs/lm_harness_config.py index fce53cd6..edebf1c7 100644 --- a/src/flamingo/jobs/configs/lm_harness_config.py +++ b/src/flamingo/jobs/configs/lm_harness_config.py @@ -1,47 +1,15 @@ import datetime -from dataclasses import InitVar from pathlib import Path from typing import Any from pydantic import root_validator, validator -from pydantic.dataclasses import dataclass -from flamingo.integrations.huggingface import QuantizationConfig -from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name +from flamingo.integrations.huggingface import ModelNameOrCheckpointPath, QuantizationConfig from flamingo.integrations.wandb import WandbEnvironment from flamingo.jobs.configs import BaseJobConfig from flamingo.types import SerializableTorchDtype -@dataclass -class ModelNameOrCheckpointPath: - """ - This class is explicitly used to validate if a string is - a valid HuggingFace model or can be used as a checkpoint. - - Checkpoint will be automatically assigned if it's a valid checkpoint; - it will be None if it's not valid. - """ - - # explictly needed for matching - __match_args__ = ("name", "checkpoint") - - name: str - checkpoint: InitVar[str | None] = None - - def __post_init__(self, checkpoint): - if isinstance(self.name, Path): - self.name = str(self.name) - - if Path(self.name).is_absolute(): - self.checkpoint = self.name - else: - self.checkpoint = None - - if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): - raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name")) - - class LMHarnessJobConfig(BaseJobConfig): """Configuration to run an lm-evaluation-harness evaluation job. diff --git a/src/flamingo/jobs/entrypoints/finetuning_entrypoint.py b/src/flamingo/jobs/entrypoints/finetuning_entrypoint.py index 838eb172..1d8c5137 100644 --- a/src/flamingo/jobs/entrypoints/finetuning_entrypoint.py +++ b/src/flamingo/jobs/entrypoints/finetuning_entrypoint.py @@ -111,7 +111,7 @@ def train_func(config_data: dict): wandb.finish() -def main(config: FinetuningJobConfig): +def run(config: FinetuningJobConfig): print(f"Received job configuration: {config}") run_config = RunConfig( diff --git a/src/flamingo/jobs/entrypoints/lm_harness_entrypoint.py b/src/flamingo/jobs/entrypoints/lm_harness_entrypoint.py index d9603acb..1d5dd947 100644 --- a/src/flamingo/jobs/entrypoints/lm_harness_entrypoint.py +++ b/src/flamingo/jobs/entrypoints/lm_harness_entrypoint.py @@ -94,7 +94,7 @@ def run_evaluation(config: LMHarnessJobConfig, model_to_load: str) -> None: update_wandb_summary(config.wandb_env, formatted_results) -def main(config: LMHarnessJobConfig): +def run(config: LMHarnessJobConfig): print(f"Received job configuration: {config}") # Resolve path and ensure exists diff --git a/src/flamingo/jobs/entrypoints/ludwig_entrypoint.py b/src/flamingo/jobs/entrypoints/ludwig_entrypoint.py index 4b7a557a..d6ee69aa 100644 --- a/src/flamingo/jobs/entrypoints/ludwig_entrypoint.py +++ b/src/flamingo/jobs/entrypoints/ludwig_entrypoint.py @@ -3,6 +3,6 @@ from ludwig.api import LudwigModel -def main(config_path: str | Path, dataset_path: str | Path): +def run(config_path: str | Path, dataset_path: str | Path): model = LudwigModel(str(config_path)) model.train(dataset=str(dataset_path)) diff --git a/src/flamingo/jobs/entrypoints/simple_entrypoint.py b/src/flamingo/jobs/entrypoints/simple_entrypoint.py index 897002ca..8541adb9 100644 --- a/src/flamingo/jobs/entrypoints/simple_entrypoint.py +++ b/src/flamingo/jobs/entrypoints/simple_entrypoint.py @@ -1,5 +1,5 @@ from flamingo.jobs.configs import SimpleJobConfig -def main(config: SimpleJobConfig): +def run(config: SimpleJobConfig): print(f"The magic number is {config.magic_number}")