Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
working on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 17, 2024
1 parent 5e54d10 commit 5ae6bac
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 89 deletions.
5 changes: 1 addition & 4 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from .base_config import BaseJobConfig
from .finetuning_config import FinetuningJobConfig
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath
from .lm_harness_config import LMHarnessJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
"ModelNameOrCheckpointPath",
]
26 changes: 20 additions & 6 deletions src/flamingo/jobs/finetuning_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from peft import LoraConfig
from pydantic import Field, validator
from pydantic import Field, root_validator, validator

from flamingo.integrations.huggingface import (
AutoModelConfig,
Expand All @@ -14,15 +14,15 @@
from flamingo.types import BaseFlamingoConfig


class RayTrainConfig(BaseFlamingoConfig):
"""Misc settings passed to Ray train.
class FinetuningRayConfig(BaseFlamingoConfig):
"""Misc settings passed to Ray train for finetuning.
Includes information for scaling, checkpointing, and runtime storage.
"""

use_gpu: bool = True
num_workers: int | None = None
storage_path: str | None = None
storage_path: str | None = None # TODO: This should be set globally somehow

def get_scaling_args(self) -> dict[str, Any]:
args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers)
Expand All @@ -34,12 +34,26 @@ class FinetuningJobConfig(BaseFlamingoConfig):

model: AutoModelConfig
dataset: DatasetConfig
tokenizer: AutoTokenizerConfig | None = None
tokenizer: AutoTokenizerConfig
quantization: QuantizationConfig | None = None
adapter: LoraConfig | None = None # TODO: Create own dataclass here
tracking: WandbRunConfig | None = None
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
ray: RayTrainConfig = Field(default_factory=RayTrainConfig)
ray: FinetuningRayConfig = Field(default_factory=FinetuningRayConfig)

@root_validator(pre=True)
def ensure_tokenizer_config(cls, values):
"""Set the tokenizer to the model path when not explicitly provided."""
if values.get("tokenizer", None) is None:
match values["model"]:
case str() as model_path:
values["tokenizer"] = model_path
case dict() as model_data:
values["tokenizer"] = model_data["path"]
case AutoModelConfig() as model_config:
values["tokenizer"] = model_config.path
# No fallback necessary, downstream validation will flag invalid model types
return values

@validator("model", pre=True, always=True)
def validate_model_arg(cls, x):
Expand Down
17 changes: 12 additions & 5 deletions src/flamingo/jobs/lm_harness_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import datetime

from pydantic import Field
from pydantic import Field, validator

from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig
from flamingo.integrations.wandb import WandbRunConfig
from flamingo.types import BaseFlamingoConfig


class RayComputeSettings(BaseFlamingoConfig):
class LMHarnessRayConfig(BaseFlamingoConfig):
"""Misc settings for Ray compute in the LM harness job."""

use_gpu: bool = True
num_workers: int = 1
timeout: datetime.timedelta | None = None


class LMHarnessEvaluatorSettings(BaseFlamingoConfig):
class LMHarnessEvaluatorConfig(BaseFlamingoConfig):
"""Misc settings provided to an lm-harness evaluation job."""

tasks: list[str]
Expand All @@ -28,7 +28,14 @@ class LMHarnessJobConfig(BaseFlamingoConfig):
"""Configuration to run an lm-evaluation-harness evaluation job."""

model: AutoModelConfig
evaluator: LMHarnessEvaluatorSettings
evaluator: LMHarnessEvaluatorConfig
quantization: QuantizationConfig | None = None
tracking: WandbRunConfig | None = None
ray: RayComputeSettings = Field(default_factory=RayComputeSettings)
ray: LMHarnessRayConfig = Field(default_factory=LMHarnessRayConfig)

@validator("model", pre=True, always=True)
def validate_model_arg(cls, x):
"""Allow for passing just a path string as the model argument."""
if isinstance(x, str):
return AutoModelConfig(path=x)
return x
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pathlib import Path

Check failure on line 1 in tests/__init__.py

View workflow job for this annotation

GitHub Actions / pytest_ruff

Ruff (F401)

tests/__init__.py:1:21: F401 `pathlib.Path` imported but unused
49 changes: 2 additions & 47 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
This file is used to provide fixtures for the test session that are accessible to all submodules.
"""
import os
from pathlib import Path
from unittest import mock

import pytest

from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig
from flamingo.jobs import LMHarnessJobConfig
TEST_RESOURCES = Path(__file__) / "resources"


@pytest.fixture(autouse=True, scope="function")
Expand All @@ -24,48 +24,3 @@ def mock_environment_without_keys():
"""Mocks an environment missing common API keys."""
with mock.patch.dict(os.environ, clear=True):
yield


@pytest.fixture(scope="function")
def default_wandb_run_config():
def generator(**kwargs) -> WandbRunConfig:
mine = {
"name": "my-run",
"project": "my-project",
"entity": "mozilla-ai",
"run_id": "gabbagool-123",
}
return WandbRunConfig(**{**mine, **kwargs})

yield generator


@pytest.fixture(scope="function")
def default_wandb_artifact_config():
def generator(**kwargs) -> WandbArtifactConfig:
mine = {
"name": "my-run",
"version": "latest",
"project": "research-project",
"entity": "mozilla-corporation",
}
return WandbArtifactConfig(**{**mine, **kwargs})

yield generator


@pytest.fixture(scope="function")
def default_lm_harness_config():
def generator(**kwargs) -> LMHarnessJobConfig:
mine = {
"tasks": ["task1", "task2"],
"num_fewshot": 5,
"batch_size": 16,
"torch_dtype": "bfloat16",
"model_name_or_path": None,
"quantization": None,
"timeout": 3600,
}
return LMHarnessJobConfig(**{**mine, **kwargs})

yield generator
11 changes: 7 additions & 4 deletions tests/integrations/wandb/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from flamingo.integrations.wandb import WandbRunConfig


def test_env_vars(default_wandb_run_config):
env_vars = default_wandb_run_config().get_env_vars()
def test_env_vars(wandb_run_config_generator):
env_vars = wandb_run_config_generator().get_env_vars()
expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"]
for key in expected:
assert key in env_vars
assert "WANDB_RUN_GROUP" not in env_vars


def test_serde_round_trip(default_wandb_run_config):
assert WandbRunConfig.parse_raw(default_wandb_run_config().json()) == default_wandb_run_config()
def test_serde_round_trip(wandb_run_config_generator):
assert (
WandbRunConfig.parse_raw(wandb_run_config_generator().json())
== wandb_run_config_generator()
)


def test_disallowed_kwargs():
Expand Down
58 changes: 58 additions & 0 deletions tests/jobs/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from peft import LoraConfig

from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
QuantizationConfig,
)
from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig


@pytest.fixture
def model_config_with_path():
return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True)


@pytest.fixture
def model_config_with_artifact():
artifact = WandbArtifactConfig(name="model")
return AutoModelConfig(artifact, trust_remote_code=True)


@pytest.fixture
def tokenizer_config_with_path():
return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True)


@pytest.fixture
def tokenizer_config_with_artifact():
artifact = WandbArtifactConfig(name="tokenizer")
return AutoTokenizerConfig(artifact, trust_remote_code=True)


@pytest.fixture
def dataset_config_with_path():
return DatasetConfig("databricks/dolly7b", split="train")


@pytest.fixture
def dataset_config_with_artifact():
artifact = WandbArtifactConfig(name="dataset")
return DatasetConfig(artifact, split="train")


@pytest.fixture
def quantization_config():
return QuantizationConfig(load_in_8bit=True)


@pytest.fixture
def lora_config():
return LoraConfig(r=8, lora_alpha=32, lora_dropout=0.2)


@pytest.fixture
def wandb_run_config():
return WandbRunConfig(name="run", run_id="12345", project="research", entity="mozilla-ai")
61 changes: 46 additions & 15 deletions tests/jobs/test_lm_harness_config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,59 @@
from pathlib import Path

import pytest
from pydantic import ValidationError

from flamingo.integrations.huggingface import AutoModelConfig
from flamingo.jobs import LMHarnessJobConfig
from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig
from tests.conftest import TEST_RESOURCES


def test_bad_hf_name(default_lm_harness_config):
with pytest.raises(ValidationError):
default_lm_harness_config(model_name_or_path="dfa../invalid")
@pytest.fixture
def lm_harness_evaluator_config():
return LMHarnessEvaluatorConfig(
tasks=["task1", "task2", "task3"],
num_fewshot=5,
)


def test_serde_round_trip_default_config(default_lm_harness_config):
config = default_lm_harness_config()
assert LMHarnessJobConfig.parse_raw(config.json()) == config
@pytest.fixture
def lm_harness_ray_config():
return LMHarnessRayConfig(
num_workers=4,
use_gpu=True,
)


def test_model_validation(lm_harness_evaluator_config):
allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config)
assert allowed_config.model == AutoModelConfig(path="hf_repo_id")

def test_serde_round_trip_with_path(default_lm_harness_config):
config = default_lm_harness_config(model_name_or_path=Path("fake/path"))
with pytest.raises(ValidationError):
LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config)

with pytest.raises(ValidationError):
LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config)


def test_serde_round_trip(
model_config_with_artifact,
quantization_config,
wandb_run_config,
lm_harness_evaluator_config,
lm_harness_ray_config,
):
config = LMHarnessJobConfig(
model=model_config_with_artifact,
evaluator=lm_harness_evaluator_config,
ray=lm_harness_ray_config,
tracking=wandb_run_config,
quantization=quantization_config,
)
assert LMHarnessJobConfig.parse_raw(config.json()) == config


def test_parse_from_yaml(default_lm_harness_config, tmp_path_factory):
config = default_lm_harness_config(model_name_or_path="not_a_real_model")
p = tmp_path_factory.mktemp("test_yaml") / "eval.yaml"
config.to_yaml_file(p)
assert config == LMHarnessJobConfig.from_yaml_file(p)
def test_parse_yaml_file(tmp_path_factory):
load_path = TEST_RESOURCES / "lm_harness_config.yaml"
config = LMHarnessJobConfig.from_yaml_file(load_path)
write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml"
config.to_yaml_file(write_path)
assert config == LMHarnessJobConfig.from_yaml_file(write_path)
5 changes: 0 additions & 5 deletions tests/resources/finetuning_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,25 @@ dataset:
split: "train"
test_size: 0.2

# HuggingFace Trainer/TrainingArguments
trainer:
max_seq_length: 512
learning_rate: 0.1
num_train_epochs: 2

# HuggingFace quantization settings
quantization:
load_in_4bit: True
bnb_4bit_quant_type: "fp4"

# LORA adapter settings
adapter:
r: 16
lora_alpha: 32
lora_dropout: 0.2

# W&B run for logging results
tracking:
name: "location-to-log-results"
project: "another-project"
entity: "another-entity"

# Ray compute settings
ray:
use_gpu: True
num_workers: 4
3 changes: 0 additions & 3 deletions tests/resources/lm_harness_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@ evaluator:
tasks: ["task1", "task2", "...", "taskN"]
num_fewshot: 5

# HuggingFace quantization settings
quantization:
load_in_4bit: True
bnb_4bit_quant_type: "fp4"

# W&B run for logging results
tracking:
name: "location-to-log-results"
project: "another-project"
entity: "another-entity"

# Ray compute settings
ray:
use_gpu: True
num_workers: 4
Expand Down

0 comments on commit 5ae6bac

Please sign in to comment.