This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Friedowitz
committed
Jan 10, 2024
1 parent
6c2a5a7
commit 661ede5
Showing
31 changed files
with
931 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: Dependent/Blocking PRs | ||
|
||
on: | ||
pull_request_target: | ||
types: [opened, edited, closed, reopened] | ||
|
||
jobs: | ||
check_dependencies: | ||
runs-on: ubuntu-latest | ||
name: Check Dependencies | ||
steps: | ||
- uses: gregsdennis/dependencies-action@main | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: PR Checks | ||
|
||
on: [push] | ||
|
||
jobs: | ||
pytest_ruff: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
|
||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Install test dependencies | ||
run: | | ||
pip install -r requirements/test.txt | ||
continue-on-error: true | ||
|
||
- name: Lint with Ruff | ||
run: | | ||
ruff --output-format=github . | ||
continue-on-error: false | ||
|
||
- name: Install full dependencies | ||
run: | | ||
pip install ".[test]" | ||
continue-on-error: true | ||
|
||
- name: Run unit tests | ||
run: | | ||
pytest | ||
continue-on-error: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"python.analysis.importFormat": "absolute", | ||
"[python]": { | ||
"editor.defaultFormatter": "charliermarsh.ruff", | ||
"editor.formatOnSave": true, | ||
"editor.codeActionsOnSave": { | ||
"source.fixAll": "never", | ||
"source.organizeImports.ruff": "explicit" | ||
} | ||
}, | ||
"python.testing.pytestEnabled": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
[build-system] | ||
requires = ["setuptools", "setuptools-scm"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "flamingo" | ||
version = "0.1.0" | ||
description = "Ray-centric job library for training and evaluation" | ||
readme = "README.md" | ||
requires-python = ">=3.10,<3.11" | ||
|
||
dependencies = [ | ||
"click==8.1.7", | ||
"ray[default]==2.7.0", | ||
"torch==2.1.0", | ||
"scipy==1.10.1", | ||
"wandb==0.16.1", | ||
"pydantic-yaml==1.2.0", | ||
"pydantic==1.10.8", | ||
] | ||
|
||
[project.optional-dependencies] | ||
finetune = [ | ||
"datasets==2.15.0", | ||
"transformers==4.36.2", | ||
"accelerate==0.25.0", | ||
"peft==0.7.1", | ||
"trl==0.7.4", | ||
"bitsandbytes==0.41.3", | ||
] | ||
|
||
evaluate = ["lm-eval==0.4.0", "einops"] | ||
|
||
test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"] | ||
|
||
all = ["flamingo[finetune,evaluate,test]"] | ||
|
||
|
||
[tool.pytest.ini_options] | ||
addopts = "-v --cov src --no-cov-on-fail --disable-warnings" | ||
testpaths = ["tests"] | ||
|
||
[tool.ruff] | ||
target-version = "py310" | ||
exclude = [ | ||
".bzr", | ||
".direnv", | ||
".eggs", | ||
".git", | ||
".git-rewrite", | ||
".hg", | ||
".mypy_cache", | ||
".nox", | ||
".pants.d", | ||
".pytype", | ||
".ruff_cache", | ||
".svn", | ||
".tox", | ||
".venv", | ||
"__pypackages__", | ||
"_build", | ||
"buck-out", | ||
"build", | ||
"dist", | ||
"node_modules", | ||
"venv", | ||
] | ||
line-length = 100 | ||
|
||
|
||
[tool.ruff.lint] | ||
select = [ | ||
"E", # pycodestyle | ||
"F", # pyflakes | ||
"UP", # pyupgrade | ||
"I", # import sorting | ||
"N", # pep8 naming | ||
"ISC", # flake8 implicit string concat | ||
"PTH", # flake8-use-pathlib use Path library | ||
"PD", # pandas-vet | ||
] | ||
|
||
ignore = [ | ||
"D417", # documentation for every function parameter. | ||
"N806", # ignore uppercased variables | ||
"N812", # import as uppercased | ||
"N803", # lowercased args | ||
"N817", # imported as acryonym | ||
"B023", # doesn't bind loop var, we do this a lot in torch | ||
"D100", # module-level docstrings | ||
"N805", # first param needs to be self; pydantic breaks this sometimes | ||
] | ||
|
||
# Avoid trying to fix some violations | ||
unfixable = ["B", "SIM", "TRY", "RUF"] | ||
|
||
[tool.ruff.format] | ||
quote-style = "double" | ||
indent-style = "space" | ||
line-ending = "auto" | ||
skip-magic-trailing-comma = false |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import click | ||
|
||
|
||
@click.group() | ||
def main(): | ||
pass | ||
|
||
|
||
# need to add the group / command function itself, not the module | ||
main.add_command(simple.driver) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .quantization_config import QuantizationConfig | ||
from .utils import is_valid_huggingface_model_name | ||
|
||
__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"] |
25 changes: 25 additions & 0 deletions
25
src/flamingo/integrations/huggingface/quantization_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype | ||
from transformers import BitsAndBytesConfig | ||
|
||
|
||
class QuantizationConfig(BaseFlamingoConfig): | ||
"""Basic quantization settings to pass to training and evaluation jobs. | ||
Note that in order to use BitsAndBytes quantization on Ray, | ||
you must ensure that the runtime environment is installed with GPU support. | ||
This can be configured by setting the `entrypoint_num_gpus > 0` when submitting a job | ||
to the cluster, e.g., | ||
""" | ||
|
||
load_in_8bit: bool | None = None | ||
load_in_4bit: bool | None = None | ||
bnb_4bit_quant_type: str = "fp4" | ||
bnb_4bit_compute_dtype: SerializableTorchDtype = None | ||
|
||
def as_huggingface(self) -> BitsAndBytesConfig: | ||
return BitsAndBytesConfig( | ||
load_in_4bit=self.load_in_4bit, | ||
load_in_8bit=self.load_in_8bit, | ||
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype, | ||
bnb_4bit_quant_type=self.bnb_4bit_quant_type, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from huggingface_hub.utils import HFValidationError, validate_repo_id | ||
|
||
|
||
def is_valid_huggingface_model_name(s: str): | ||
""" | ||
Simple test to check if an HF model is valid using HuggingFace's tools. | ||
Sadly, theirs throws an exception and has no return. | ||
Args: | ||
s: string to test. | ||
""" | ||
try: | ||
validate_repo_id(s) | ||
return True | ||
except HFValidationError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .wandb_environment import WandbEnvironment # noqa: I001 | ||
from .wandb_mixin import WandbEnvironmentMixin | ||
from .utils import get_wandb_summary, update_wandb_summary | ||
|
||
__all__ = [ | ||
"WandbEnvironment", | ||
"WandbEnvironmentMixin", | ||
"get_wandb_summary", | ||
"update_wandb_summary", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Any | ||
|
||
from flamingo.integrations.wandb import WandbEnvironment | ||
|
||
import wandb | ||
from wandb.apis.public import Run | ||
|
||
|
||
def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: | ||
"""Get the summary dictionary attached to a W&B run.""" | ||
run = _resolve_wandb_run(env) | ||
return dict(run.summary) | ||
|
||
|
||
def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: | ||
"""Update a run's summary with the provided metrics.""" | ||
run = _resolve_wandb_run(env) | ||
run.summary.update(metrics) | ||
run.update() | ||
|
||
|
||
def _resolve_wandb_run(env: WandbEnvironment) -> Run: | ||
"""Resolve a WandB run object from the provided environment settings. | ||
An exception is raised if a Run cannot be found, | ||
or if multiple runs exist in scope with the same name. | ||
""" | ||
api = wandb.Api() | ||
base_path = "/".join(x for x in (env.entity, env.project) if x) | ||
if env.run_id is not None: | ||
full_path = f"{base_path}/{env.run_id}" | ||
return api.run(full_path) | ||
else: | ||
match [run for run in api.runs(base_path) if run.name == env.name]: | ||
case []: | ||
raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") | ||
case [Run(), _]: | ||
raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") | ||
case [Run()] as mr: | ||
# we have a single one, hurray | ||
return mr[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
import warnings | ||
|
||
from flamingo.types import BaseFlamingoConfig | ||
from pydantic import Extra, root_validator | ||
|
||
from wandb.apis.public import Run | ||
|
||
|
||
class WandbEnvironment(BaseFlamingoConfig): | ||
"""Settings required to log to a W&B run. | ||
The fields on this class map to the environment variables | ||
that are used to control the W&B logging locations. | ||
The `name` and `project` are required as they are the minimum information | ||
required to identify a run. The `name` is the human-readable name that appears in the W&B UI. | ||
`name` is different than the `run_id` which must be unique within a project. | ||
Although the `name` is not mandatorily unique, it is generally best practice to use a | ||
unique and descriptive name to later identify the run. | ||
""" | ||
|
||
class Config: | ||
extra = Extra.forbid # Error on extra kwargs | ||
|
||
__match_args__ = ("name", "project", "run_id", "run_group", "entity") | ||
|
||
name: str | ||
project: str | ||
run_id: str | None = None | ||
run_group: str | None = None | ||
entity: str | None = None | ||
|
||
@root_validator(pre=True) | ||
def warn_missing_api_key(cls, values): | ||
if not os.environ.get("WANDB_API_KEY", None): | ||
warnings.warn( | ||
"Cannot find `WANDB_API_KEY` in your environment. " | ||
"Tracking will fail if a default key does not exist on the Ray cluster." | ||
) | ||
return values | ||
|
||
@property | ||
def env_vars(self) -> dict[str, str]: | ||
# WandB w/ HuggingFace is weird. You can specify the run name inline, | ||
# but the rest must be injected as environment variables | ||
env_vars = { | ||
"WANDB_NAME": self.name, | ||
"WANDB_PROJECT": self.project, | ||
"WANDB_RUN_ID": self.run_id, | ||
"WANDB_RUN_GROUP": self.run_group, | ||
"WANDB_ENTITY": self.entity, | ||
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), | ||
} | ||
return {k: v for k, v in env_vars.items() if v is not None} | ||
|
||
@classmethod | ||
def from_run(cls, run: Run) -> "WandbEnvironment": | ||
"""Extract environment settings from a W&B Run object. | ||
Useful when listing runs from the W&B API and extracting their settings for a job. | ||
""" | ||
# TODO: Can we get the run group from this when it exists? | ||
return cls( | ||
name=run.name, | ||
project=run.project, | ||
entity=run.entity, | ||
run_id=run.id, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from flamingo.integrations.wandb import WandbEnvironment | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class WandbEnvironmentMixin(BaseFlamingoConfig): | ||
"""Mixin for a config that contains W&B environment settings.""" | ||
|
||
wandb_env: WandbEnvironment | None = None | ||
|
||
@property | ||
def env_vars(self) -> dict[str, str]: | ||
return self.wandb_env.env_vars if self.wandb_env else {} | ||
|
||
@property | ||
def wandb_name(self) -> str | None: | ||
"""Return the W&B run name, if it exists.""" | ||
return self.wandb_env.name if self.wandb_env else None | ||
|
||
@property | ||
def wandb_project(self) -> str | None: | ||
"""Return the W&B project name, if it exists.""" | ||
return self.wandb_env.project if self.wandb_env else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .base_config import BaseJobConfig | ||
from .evaluation_config import EvaluationJobConfig | ||
from .finetuning_config import FinetuningJobConfig | ||
from .simple_config import SimpleJobConfig | ||
|
||
__all__ = [ | ||
"BaseJobConfig", | ||
"SimpleJobConfig", | ||
"FinetuningJobConfig", | ||
"EvaluationJobConfig", | ||
] |
Oops, something went wrong.