Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
reorganize repo layout
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 15, 2024
1 parent 47d8d0b commit 96aa29b
Show file tree
Hide file tree
Showing 26 changed files with 77 additions and 121 deletions.
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,3 @@ cython_debug/

# Ruff
.ruff_cache


# ignore local wandb cache files. Not perfect
**/wandb/*.log
**/wandb/*run*
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "flamingo"
version = "0.1.0"
description = "Ray-centric job library for training and evaluation"
description = "Ray-centric job library for ML training and evaluation"
readme = "README.md"
requires-python = ">=3.10,<3.11"

Expand Down Expand Up @@ -40,7 +40,6 @@ all = ["flamingo[finetune,ludwig,evaluate,test]"]
[project.scripts]
flamingo = "flamingo.cli:main"


[tool.pytest.ini_options]
addopts = "-v --cov src --no-cov-on-fail --disable-warnings"
testpaths = ["tests"]
Expand Down Expand Up @@ -99,6 +98,9 @@ ignore = [
# Avoid trying to fix some violations
unfixable = ["B", "SIM", "TRY", "RUF"]

[tool.ruff.lint.isort]
known-first-party = ["flamingo"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
Expand Down
19 changes: 8 additions & 11 deletions src/flamingo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,37 @@ def run():
@run.command("simple")
@click.option("--config", type=str)
def run_simple(config: str) -> None:
from flamingo.jobs.configs import SimpleJobConfig
from flamingo.jobs.entrypoints import simple_entrypoint
from flamingo.jobs import SimpleJobConfig, simple_job

config = SimpleJobConfig.from_yaml_file(config)
simple_entrypoint.run(config)
simple_job.run(config)


@run.command("finetuning")
@click.option("--config", type=str)
def run_finetuning(config: str) -> None:
from flamingo.jobs.configs import FinetuningJobConfig
from flamingo.jobs.entrypoints import finetuning_entrypoint
from flamingo.jobs import FinetuningJobConfig, finetuning_job

config = FinetuningJobConfig.from_yaml_file(config)
finetuning_entrypoint.run(config)
finetuning_job.run(config)


@run.command("ludwig")
@click.option("--config", type=str)
@click.option("--dataset", type=str)
def run_ludwig(config: str, dataset: str) -> None:
from flamingo.jobs.entrypoints import ludwig_entrypoint
from flamingo.jobs import ludwig_job

ludwig_entrypoint.run(config, dataset)
ludwig_job.run(config, dataset)


@run.command("lm-harness")
@click.option("--config", type=str)
def run_lm_harness(config: str) -> None:
from flamingo.jobs.configs import LMHarnessJobConfig
from flamingo.jobs.entrypoints import lm_harness_entrypoint
from flamingo.jobs import LMHarnessJobConfig, lm_harness_job

config = LMHarnessJobConfig.from_yaml_file(config)
lm_harness_entrypoint.run(config)
lm_harness_job.run(config)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .model_name_or_path import ModelNameOrCheckpointPath
from .quantization_config import QuantizationConfig
from .trainer_config import TrainerConfig
from .utils import is_valid_huggingface_repo

__all__ = [
"ModelNameOrCheckpointPath",
"QuantizationConfig",
"TrainerConfig",
"is_valid_huggingface_repo",
]
34 changes: 0 additions & 34 deletions src/flamingo/integrations/huggingface/model_name_or_path.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class QuantizationConfig(BaseFlamingoConfig):
Note that in order to use BitsAndBytes quantization on Ray,
you must ensure that the runtime environment is installed with GPU support.
This can be configured by setting the `entrypoint_num_gpus > 0` when submitting a job
to the cluster, e.g.,
to the cluster.
"""

load_in_8bit: bool | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/integrations/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from huggingface_hub.utils import HFValidationError, validate_repo_id


def is_valid_huggingface_model_name(s: str):
def is_valid_huggingface_repo(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Sadly, theirs throws an exception and has no return.
Expand Down
5 changes: 4 additions & 1 deletion src/flamingo/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .wandb_environment import WandbEnvironment # noqa: I001
# ruff: noqa
from .wandb_artifact_link import WandbArtifactLink
from .wandb_environment import WandbEnvironment
from .utils import get_wandb_summary, update_wandb_summary

__all__ = [
"WandbEnvironment",
"WandbArtifactLink",
"get_wandb_summary",
"update_wandb_summary",
]
16 changes: 16 additions & 0 deletions src/flamingo/integrations/wandb/wandb_artifact_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from flamingo.types import BaseFlamingoConfig


class WandbArtifactLink(BaseFlamingoConfig):
"""Data required to retrieve an artifact from W&B."""

name: str
alias: str | None = None
project: str | None = None
entity: str | None = None

def artifact_path(self) -> str:
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
if self.alias:
path = f"{path}:{self.alias}"
return path
7 changes: 2 additions & 5 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import warnings

from pydantic import Extra, root_validator
from pydantic import root_validator
from wandb.apis.public import Run

from flamingo.types import BaseFlamingoConfig


class WandbEnvironment(BaseFlamingoConfig):
"""Settings required to log to a W&B run.
"""Settings to associate with a W&B run.
The fields on this class map to the environment variables
that are used to control the W&B logging locations.
Expand All @@ -20,9 +20,6 @@ class WandbEnvironment(BaseFlamingoConfig):
unique and descriptive name to later identify the run.
"""

class Config:
extra = Extra.forbid # Error on extra kwargs

__match_args__ = ("name", "project", "run_id", "run_group", "entity")

name: str | None = None
Expand Down
1 change: 1 addition & 0 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .configs import * # noqa: F403
10 changes: 2 additions & 8 deletions src/flamingo/jobs/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from .base_config import BaseJobConfig
from .finetuning_config import FinetuningJobConfig
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath
from .lm_harness_config import LMHarnessJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
"ModelNameOrCheckpointPath",
]
__all__ = ["BaseJobConfig", "SimpleJobConfig", "FinetuningJobConfig", "LMHarnessJobConfig"]
11 changes: 6 additions & 5 deletions src/flamingo/jobs/configs/finetuning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,28 @@
from pydantic import validator
from ray.train import ScalingConfig

from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.trainer_config import TrainerConfig
from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name
from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig
from flamingo.integrations.huggingface.utils import is_valid_huggingface_repo
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.jobs.configs import BaseJobConfig


class FinetuningJobConfig(BaseJobConfig):
"""Configuration to submit an LLM finetuning job."""

model: str
dataset: str
# TODO: Add dataset config back when dataset loading is implemented
tokenizer: str | None = None
trainer: TrainerConfig | None = None
lora: LoraConfig | None = None # TODO: Create our own config type
quantization: QuantizationConfig | None = None
tracking: WandbEnvironment | None = None
scaling: ScalingConfig | None = None # TODO: Create our own config type
storage_path: str | None = None

@validator("model")
def _validate_model_name(cls, v):
if is_valid_huggingface_model_name(v):
if is_valid_huggingface_repo(v):
return v
else:
raise ValueError(f"`{v}` is not a valid HuggingFace model name.")
23 changes: 9 additions & 14 deletions src/flamingo/jobs/configs/lm_harness_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime
from pathlib import Path

from pydantic import validator

from flamingo.integrations.huggingface import ModelNameOrCheckpointPath, QuantizationConfig
from flamingo.integrations.huggingface import QuantizationConfig, is_valid_huggingface_repo
from flamingo.integrations.wandb import WandbArtifactLink, WandbEnvironment
from flamingo.jobs.configs import BaseJobConfig
from flamingo.types import SerializableTorchDtype

Expand All @@ -18,26 +18,21 @@ class LMHarnessJobConfig(BaseJobConfig):
which will take prescedence over the W&B checkpoint path.
"""

class Config:
validate_assignment = True

model: str | WandbArtifactLink
tasks: list[str]
batch_size: int | None = None
num_fewshot: int | None = None
limit: int | float | None = None
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
model_name_or_path: str | Path | ModelNameOrCheckpointPath | None = None
quantization: QuantizationConfig | None = None
tracking: WandbEnvironment | None = None
num_cpus: int = 1
num_gpus: int = 1
timeout: datetime.timedelta | None = None

@validator("model_name_or_path", pre=True, always=True)
def _validate_model_name_or_path(cls, v):
if isinstance(v, dict):
return ModelNameOrCheckpointPath(**v)
elif v is None:
return None
else:
return ModelNameOrCheckpointPath(name=v)
@validator("model", pre=True, always=True)
def _validate_model_artifact(cls, v):
if isinstance(v, str) and not is_valid_huggingface_repo(v):
raise ValueError(f"{v} is not a valid HuggingFace model name.")
return v
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def get_datasets(config: FinetuningJobConfig) -> DatasetDict:

def get_model(config: FinetuningJobConfig) -> PreTrainedModel:
device_map, bnb_config = None, None
if config.quantization_config:
bnb_config = config.quantization_config.as_huggingface()
if config.quantization is not None:
bnb_config = config.quantization.as_huggingface()
# When quantization is enabled, model must all be on same GPU to work with DDP
# If a device_map is not specified we will get accelerate errors downstream
# Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from lm_eval.models.huggingface import HFLM
from peft import PeftConfig

from flamingo.integrations.huggingface import ModelNameOrCheckpointPath
from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary
from flamingo.jobs.configs import LMHarnessJobConfig, ModelNameOrCheckpointPath
from flamingo.jobs.configs import LMHarnessJobConfig


def resolve_model_or_path(config: LMHarnessJobConfig) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/flamingo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class BaseFlamingoConfig(BaseModel):
"""Base class for all Pydnatic configs in the library.
"""Base class for all Pydantic configs in the library.
Defines some common settings used by all subclasses.
"""
Expand All @@ -26,7 +26,7 @@ class Config:
}

@validator("*", pre=True)
def validate_serializable_dtype(cls, x: Any, field: ModelField) -> Any: # noqa: N805
def validate_serializable_dtype(cls, x: Any, field: ModelField) -> Any:
"""Extract the torch.dtype corresponding to a string value, else return the value.
This is a Pydantic-specific construct that is run on all fields
Expand Down
11 changes: 0 additions & 11 deletions src/flamingo/utils.py

This file was deleted.

7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from unittest import mock

import pytest
from flamingo.integrations.wandb.wandb_environment import WandbEnvironment
from flamingo.jobs.configs import LMHarnessJobConfig

from flamingo.integrations.wandb import WandbEnvironment
from flamingo.jobs import LMHarnessJobConfig


@pytest.fixture(autouse=True, scope="function")
Expand Down Expand Up @@ -43,11 +44,11 @@ def generator(**kwargs) -> WandbEnvironment:
def default_lm_harness_config():
def generator(**kwargs) -> LMHarnessJobConfig:
mine = {
"model": "mistral",
"tasks": ["task1", "task2"],
"num_fewshot": 5,
"batch_size": 16,
"torch_dtype": "bfloat16",
"model_name_or_path": None,
"quantization": None,
"timeout": 3600,
}
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/huggingface/test_quantization_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch

from flamingo.integrations.huggingface import QuantizationConfig


Expand Down
3 changes: 2 additions & 1 deletion tests/integrations/wandb/test_wandb_environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from flamingo.integrations.wandb import WandbEnvironment
from pydantic import ValidationError

from flamingo.integrations.wandb import WandbEnvironment


def test_env_vars(default_wandb_env):
env_vars = default_wandb_env().env_vars
Expand Down
Loading

0 comments on commit 96aa29b

Please sign in to comment.