Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
refactor mnp
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 11, 2024
1 parent 9db31da commit 25a5785
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flamingo

<p align="center">
<img src="https://github.com/mozilla-ai/flamingo/blob/main/assets/flamingo.png" width="450">
<img src="https://github.com/mozilla-ai/flamingo/blob/main/assets/flamingo.png" width="300">
</p>

## Installation
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ finetune = [
"accelerate==0.25.0",
"peft==0.7.1",
"trl==0.7.4",
"bitsandbytes==0.41.3",
"bitsandbytes==0.40.2",
]

ludwig = ["ludwig==0.9.1"]
Expand Down
8 changes: 4 additions & 4 deletions src/flamingo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run_simple(config: str) -> None:
from flamingo.jobs.entrypoints import simple_entrypoint

config = SimpleJobConfig.from_yaml_file(config)
simple_entrypoint.main(config)
simple_entrypoint.run(config)


@run.command("finetuning")
Expand All @@ -28,7 +28,7 @@ def run_finetuning(config: str) -> None:
from flamingo.jobs.entrypoints import finetuning_entrypoint

config = FinetuningJobConfig.from_yaml_file(config)
finetuning_entrypoint.main(config)
finetuning_entrypoint.run(config)


@run.command("ludwig")
Expand All @@ -37,7 +37,7 @@ def run_finetuning(config: str) -> None:
def run_ludwig(config: str, dataset: str) -> None:
from flamingo.jobs.entrypoints import ludwig_entrypoint

ludwig_entrypoint.main(config, dataset)
ludwig_entrypoint.run(config, dataset)


@run.command("lm-harness")
Expand All @@ -47,7 +47,7 @@ def run_lm_harness(config: str) -> None:
from flamingo.jobs.entrypoints import lm_harness_entrypoint

config = LMHarnessJobConfig.from_yaml_file(config)
lm_harness_entrypoint.main(config)
lm_harness_entrypoint.run(config)


if __name__ == "__main__":
Expand Down
Empty file.
9 changes: 7 additions & 2 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .model_name_or_path import ModelNameOrCheckpointPath
from .quantization_config import QuantizationConfig
from .utils import is_valid_huggingface_model_name
from .trainer_config import TrainerConfig

__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"]
__all__ = [
"ModelNameOrCheckpointPath",
"QuantizationConfig",
"TrainerConfig",
]
49 changes: 49 additions & 0 deletions src/flamingo/integrations/huggingface/model_name_or_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import InitVar
from pathlib import Path

from huggingface_hub.utils import HFValidationError, validate_repo_id
from pydantic.dataclasses import dataclass


def is_valid_huggingface_model_name(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Sadly, theirs throws an exception and has no return.
Args:
s: string to test.
"""
try:
validate_repo_id(s)
return True
except HFValidationError:
return False


@dataclass
class ModelNameOrCheckpointPath:
"""
This class is explicitly used to validate if a string is
a valid HuggingFace model or can be used as a checkpoint.
Checkpoint will be automatically assigned if it's a valid checkpoint;
it will be None if it's not valid.
"""

# explictly needed for matching
__match_args__ = ("name", "checkpoint")

name: str
checkpoint: InitVar[str | None] = None

def __post_init__(self, checkpoint):
if isinstance(self.name, Path):
self.name = str(self.name)

if Path(self.name).is_absolute():
self.checkpoint = self.name
else:
self.checkpoint = None

if self.checkpoint is None and not is_valid_huggingface_model_name(self.name):
raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name"))
3 changes: 2 additions & 1 deletion src/flamingo/integrations/huggingface/quantization_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from transformers import BitsAndBytesConfig

from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class QuantizationConfig(BaseFlamingoConfig):
"""Basic quantization settings to pass to training and evaluation jobs.
Expand Down
5 changes: 5 additions & 0 deletions src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flamingo.types import BaseFlamingoConfig


class TrainerConfig(BaseFlamingoConfig):
pass
16 changes: 0 additions & 16 deletions src/flamingo/integrations/huggingface/utils.py

This file was deleted.

11 changes: 11 additions & 0 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ def env_vars(self) -> dict[str, str]:
}
return {k: v for k, v in env_vars.items() if v is not None}

@classmethod
def from_env(cls) -> "WandbEnvironment":
"""Extract W&B settings from the runtime environment."""
return cls(
name=os.environ.get("WANDB_NAME"),
project=os.environ.get("WANDB_PROJECT"),
entity=os.environ.get("WANDB_ENTITY"),
run_id=os.environ.get("WANDB_RUN_ID"),
run_group=os.environ.get("WANDB_RUN_GROUP"),
)

@classmethod
def from_run(cls, run: Run) -> "WandbEnvironment":
"""Extract environment settings from a W&B Run object.
Expand Down
34 changes: 1 addition & 33 deletions src/flamingo/jobs/configs/lm_harness_config.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,15 @@
import datetime
from dataclasses import InitVar
from pathlib import Path
from typing import Any

from pydantic import root_validator, validator
from pydantic.dataclasses import dataclass

from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name
from flamingo.integrations.huggingface import ModelNameOrCheckpointPath, QuantizationConfig
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.jobs.configs import BaseJobConfig
from flamingo.types import SerializableTorchDtype


@dataclass
class ModelNameOrCheckpointPath:
"""
This class is explicitly used to validate if a string is
a valid HuggingFace model or can be used as a checkpoint.
Checkpoint will be automatically assigned if it's a valid checkpoint;
it will be None if it's not valid.
"""

# explictly needed for matching
__match_args__ = ("name", "checkpoint")

name: str
checkpoint: InitVar[str | None] = None

def __post_init__(self, checkpoint):
if isinstance(self.name, Path):
self.name = str(self.name)

if Path(self.name).is_absolute():
self.checkpoint = self.name
else:
self.checkpoint = None

if self.checkpoint is None and not is_valid_huggingface_model_name(self.name):
raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name"))


class LMHarnessJobConfig(BaseJobConfig):
"""Configuration to run an lm-evaluation-harness evaluation job.
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/jobs/entrypoints/finetuning_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train_func(config_data: dict):
wandb.finish()


def main(config: FinetuningJobConfig):
def run(config: FinetuningJobConfig):
print(f"Received job configuration: {config}")

run_config = RunConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/jobs/entrypoints/lm_harness_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run_evaluation(config: LMHarnessJobConfig, model_to_load: str) -> None:
update_wandb_summary(config.wandb_env, formatted_results)


def main(config: LMHarnessJobConfig):
def run(config: LMHarnessJobConfig):
print(f"Received job configuration: {config}")

# Resolve path and ensure exists
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/jobs/entrypoints/ludwig_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from ludwig.api import LudwigModel


def main(config_path: str | Path, dataset_path: str | Path):
def run(config_path: str | Path, dataset_path: str | Path):
model = LudwigModel(str(config_path))
model.train(dataset=str(dataset_path))
2 changes: 1 addition & 1 deletion src/flamingo/jobs/entrypoints/simple_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flamingo.jobs.configs import SimpleJobConfig


def main(config: SimpleJobConfig):
def run(config: SimpleJobConfig):
print(f"The magic number is {config.magic_number}")

0 comments on commit 25a5785

Please sign in to comment.