Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
copy over files
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 10, 2024
1 parent 6c2a5a7 commit 661ede5
Show file tree
Hide file tree
Showing 31 changed files with 931 additions and 0 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/dependent_prs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Dependent/Blocking PRs

on:
pull_request_target:
types: [opened, edited, closed, reopened]

jobs:
check_dependencies:
runs-on: ubuntu-latest
name: Check Dependencies
steps:
- uses: gregsdennis/dependencies-action@main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
34 changes: 34 additions & 0 deletions .github/workflows/pr_checks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: PR Checks

on: [push]

jobs:
pytest_ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install test dependencies
run: |
pip install -r requirements/test.txt
continue-on-error: true

- name: Lint with Ruff
run: |
ruff --output-format=github .
continue-on-error: false

- name: Install full dependencies
run: |
pip install ".[test]"
continue-on-error: true

- name: Run unit tests
run: |
pytest
continue-on-error: false
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,11 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Ruff
.ruff_cache


# ignore local wandb cache files. Not perfect
**/wandb/*.log
**/wandb/*run*
12 changes: 12 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"python.analysis.importFormat": "absolute",
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "never",
"source.organizeImports.ruff": "explicit"
}
},
"python.testing.pytestEnabled": true
}
101 changes: 101 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[project]
name = "flamingo"
version = "0.1.0"
description = "Ray-centric job library for training and evaluation"
readme = "README.md"
requires-python = ">=3.10,<3.11"

dependencies = [
"click==8.1.7",
"ray[default]==2.7.0",
"torch==2.1.0",
"scipy==1.10.1",
"wandb==0.16.1",
"pydantic-yaml==1.2.0",
"pydantic==1.10.8",
]

[project.optional-dependencies]
finetune = [
"datasets==2.15.0",
"transformers==4.36.2",
"accelerate==0.25.0",
"peft==0.7.1",
"trl==0.7.4",
"bitsandbytes==0.41.3",
]

evaluate = ["lm-eval==0.4.0", "einops"]

test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"]

all = ["flamingo[finetune,evaluate,test]"]


[tool.pytest.ini_options]
addopts = "-v --cov src --no-cov-on-fail --disable-warnings"
testpaths = ["tests"]

[tool.ruff]
target-version = "py310"
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
line-length = 100


[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"UP", # pyupgrade
"I", # import sorting
"N", # pep8 naming
"ISC", # flake8 implicit string concat
"PTH", # flake8-use-pathlib use Path library
"PD", # pandas-vet
]

ignore = [
"D417", # documentation for every function parameter.
"N806", # ignore uppercased variables
"N812", # import as uppercased
"N803", # lowercased args
"N817", # imported as acryonym
"B023", # doesn't bind loop var, we do this a lot in torch
"D100", # module-level docstrings
"N805", # first param needs to be self; pydantic breaks this sometimes
]

# Avoid trying to fix some violations
unfixable = ["B", "SIM", "TRY", "RUF"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
line-ending = "auto"
skip-magic-trailing-comma = false
Empty file added src/flamingo/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions src/flamingo/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import click


@click.group()
def main():
pass


# need to add the group / command function itself, not the module
main.add_command(simple.driver)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .quantization_config import QuantizationConfig
from .utils import is_valid_huggingface_model_name

__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"]
25 changes: 25 additions & 0 deletions src/flamingo/integrations/huggingface/quantization_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype
from transformers import BitsAndBytesConfig


class QuantizationConfig(BaseFlamingoConfig):
"""Basic quantization settings to pass to training and evaluation jobs.
Note that in order to use BitsAndBytes quantization on Ray,
you must ensure that the runtime environment is installed with GPU support.
This can be configured by setting the `entrypoint_num_gpus > 0` when submitting a job
to the cluster, e.g.,
"""

load_in_8bit: bool | None = None
load_in_4bit: bool | None = None
bnb_4bit_quant_type: str = "fp4"
bnb_4bit_compute_dtype: SerializableTorchDtype = None

def as_huggingface(self) -> BitsAndBytesConfig:
return BitsAndBytesConfig(
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype,
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
)
16 changes: 16 additions & 0 deletions src/flamingo/integrations/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from huggingface_hub.utils import HFValidationError, validate_repo_id


def is_valid_huggingface_model_name(s: str):
"""
Simple test to check if an HF model is valid using HuggingFace's tools.
Sadly, theirs throws an exception and has no return.
Args:
s: string to test.
"""
try:
validate_repo_id(s)
return True
except HFValidationError:
return False
10 changes: 10 additions & 0 deletions src/flamingo/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .wandb_environment import WandbEnvironment # noqa: I001
from .wandb_mixin import WandbEnvironmentMixin
from .utils import get_wandb_summary, update_wandb_summary

__all__ = [
"WandbEnvironment",
"WandbEnvironmentMixin",
"get_wandb_summary",
"update_wandb_summary",
]
41 changes: 41 additions & 0 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

from flamingo.integrations.wandb import WandbEnvironment

import wandb
from wandb.apis.public import Run


def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]:
"""Get the summary dictionary attached to a W&B run."""
run = _resolve_wandb_run(env)
return dict(run.summary)


def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None:
"""Update a run's summary with the provided metrics."""
run = _resolve_wandb_run(env)
run.summary.update(metrics)
run.update()


def _resolve_wandb_run(env: WandbEnvironment) -> Run:
"""Resolve a WandB run object from the provided environment settings.
An exception is raised if a Run cannot be found,
or if multiple runs exist in scope with the same name.
"""
api = wandb.Api()
base_path = "/".join(x for x in (env.entity, env.project) if x)
if env.run_id is not None:
full_path = f"{base_path}/{env.run_id}"
return api.run(full_path)
else:
match [run for run in api.runs(base_path) if run.name == env.name]:
case []:
raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}")
case [Run(), _]:
raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}")
case [Run()] as mr:
# we have a single one, hurray
return mr[0]
69 changes: 69 additions & 0 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import warnings

from flamingo.types import BaseFlamingoConfig
from pydantic import Extra, root_validator

from wandb.apis.public import Run


class WandbEnvironment(BaseFlamingoConfig):
"""Settings required to log to a W&B run.
The fields on this class map to the environment variables
that are used to control the W&B logging locations.
The `name` and `project` are required as they are the minimum information
required to identify a run. The `name` is the human-readable name that appears in the W&B UI.
`name` is different than the `run_id` which must be unique within a project.
Although the `name` is not mandatorily unique, it is generally best practice to use a
unique and descriptive name to later identify the run.
"""

class Config:
extra = Extra.forbid # Error on extra kwargs

__match_args__ = ("name", "project", "run_id", "run_group", "entity")

name: str
project: str
run_id: str | None = None
run_group: str | None = None
entity: str | None = None

@root_validator(pre=True)
def warn_missing_api_key(cls, values):
if not os.environ.get("WANDB_API_KEY", None):
warnings.warn(
"Cannot find `WANDB_API_KEY` in your environment. "
"Tracking will fail if a default key does not exist on the Ray cluster."
)
return values

@property
def env_vars(self) -> dict[str, str]:
# WandB w/ HuggingFace is weird. You can specify the run name inline,
# but the rest must be injected as environment variables
env_vars = {
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_ID": self.run_id,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}

@classmethod
def from_run(cls, run: Run) -> "WandbEnvironment":
"""Extract environment settings from a W&B Run object.
Useful when listing runs from the W&B API and extracting their settings for a job.
"""
# TODO: Can we get the run group from this when it exists?
return cls(
name=run.name,
project=run.project,
entity=run.entity,
run_id=run.id,
)
22 changes: 22 additions & 0 deletions src/flamingo/integrations/wandb/wandb_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.types import BaseFlamingoConfig


class WandbEnvironmentMixin(BaseFlamingoConfig):
"""Mixin for a config that contains W&B environment settings."""

wandb_env: WandbEnvironment | None = None

@property
def env_vars(self) -> dict[str, str]:
return self.wandb_env.env_vars if self.wandb_env else {}

@property
def wandb_name(self) -> str | None:
"""Return the W&B run name, if it exists."""
return self.wandb_env.name if self.wandb_env else None

@property
def wandb_project(self) -> str | None:
"""Return the W&B project name, if it exists."""
return self.wandb_env.project if self.wandb_env else None
11 changes: 11 additions & 0 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .base_config import BaseJobConfig
from .evaluation_config import EvaluationJobConfig
from .finetuning_config import FinetuningJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"EvaluationJobConfig",
]
Loading

0 comments on commit 661ede5

Please sign in to comment.