From 5ae6bacdd7bdca314d7eebd33f283e77b62ea014 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 11:37:41 -0800 Subject: [PATCH] working on tests --- src/flamingo/jobs/__init__.py | 5 +- src/flamingo/jobs/finetuning_config.py | 26 +++++++-- src/flamingo/jobs/lm_harness_config.py | 17 ++++-- tests/__init__.py | 1 + tests/conftest.py | 49 +---------------- tests/integrations/wandb/test_run_config.py | 11 ++-- tests/jobs/conftest.py | 58 ++++++++++++++++++++ tests/jobs/test_lm_harness_config.py | 61 ++++++++++++++++----- tests/resources/finetuning_config.yaml | 5 -- tests/resources/lm_harness_config.yaml | 3 - 10 files changed, 147 insertions(+), 89 deletions(-) create mode 100644 tests/jobs/conftest.py diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py index 4a616d75..5c003cd0 100644 --- a/src/flamingo/jobs/__init__.py +++ b/src/flamingo/jobs/__init__.py @@ -1,12 +1,9 @@ -from .base_config import BaseJobConfig from .finetuning_config import FinetuningJobConfig -from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath +from .lm_harness_config import LMHarnessJobConfig from .simple_config import SimpleJobConfig __all__ = [ - "BaseJobConfig", "SimpleJobConfig", "FinetuningJobConfig", "LMHarnessJobConfig", - "ModelNameOrCheckpointPath", ] diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 63ae62b1..902c82db 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,7 +1,7 @@ from typing import Any from peft import LoraConfig -from pydantic import Field, validator +from pydantic import Field, root_validator, validator from flamingo.integrations.huggingface import ( AutoModelConfig, @@ -14,15 +14,15 @@ from flamingo.types import BaseFlamingoConfig -class RayTrainConfig(BaseFlamingoConfig): - """Misc settings passed to Ray train. +class FinetuningRayConfig(BaseFlamingoConfig): + """Misc settings passed to Ray train for finetuning. Includes information for scaling, checkpointing, and runtime storage. """ use_gpu: bool = True num_workers: int | None = None - storage_path: str | None = None + storage_path: str | None = None # TODO: This should be set globally somehow def get_scaling_args(self) -> dict[str, Any]: args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers) @@ -34,12 +34,26 @@ class FinetuningJobConfig(BaseFlamingoConfig): model: AutoModelConfig dataset: DatasetConfig - tokenizer: AutoTokenizerConfig | None = None + tokenizer: AutoTokenizerConfig quantization: QuantizationConfig | None = None adapter: LoraConfig | None = None # TODO: Create own dataclass here tracking: WandbRunConfig | None = None trainer: TrainerConfig = Field(default_factory=TrainerConfig) - ray: RayTrainConfig = Field(default_factory=RayTrainConfig) + ray: FinetuningRayConfig = Field(default_factory=FinetuningRayConfig) + + @root_validator(pre=True) + def ensure_tokenizer_config(cls, values): + """Set the tokenizer to the model path when not explicitly provided.""" + if values.get("tokenizer", None) is None: + match values["model"]: + case str() as model_path: + values["tokenizer"] = model_path + case dict() as model_data: + values["tokenizer"] = model_data["path"] + case AutoModelConfig() as model_config: + values["tokenizer"] = model_config.path + # No fallback necessary, downstream validation will flag invalid model types + return values @validator("model", pre=True, always=True) def validate_model_arg(cls, x): diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index d638fd1e..cce3459b 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -1,13 +1,13 @@ import datetime -from pydantic import Field +from pydantic import Field, validator from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig from flamingo.integrations.wandb import WandbRunConfig from flamingo.types import BaseFlamingoConfig -class RayComputeSettings(BaseFlamingoConfig): +class LMHarnessRayConfig(BaseFlamingoConfig): """Misc settings for Ray compute in the LM harness job.""" use_gpu: bool = True @@ -15,7 +15,7 @@ class RayComputeSettings(BaseFlamingoConfig): timeout: datetime.timedelta | None = None -class LMHarnessEvaluatorSettings(BaseFlamingoConfig): +class LMHarnessEvaluatorConfig(BaseFlamingoConfig): """Misc settings provided to an lm-harness evaluation job.""" tasks: list[str] @@ -28,7 +28,14 @@ class LMHarnessJobConfig(BaseFlamingoConfig): """Configuration to run an lm-evaluation-harness evaluation job.""" model: AutoModelConfig - evaluator: LMHarnessEvaluatorSettings + evaluator: LMHarnessEvaluatorConfig quantization: QuantizationConfig | None = None tracking: WandbRunConfig | None = None - ray: RayComputeSettings = Field(default_factory=RayComputeSettings) + ray: LMHarnessRayConfig = Field(default_factory=LMHarnessRayConfig) + + @validator("model", pre=True, always=True) + def validate_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(path=x) + return x diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..2bb88df4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 0f98deae..3ecdaf8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,12 +4,12 @@ This file is used to provide fixtures for the test session that are accessible to all submodules. """ import os +from pathlib import Path from unittest import mock import pytest -from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig -from flamingo.jobs import LMHarnessJobConfig +TEST_RESOURCES = Path(__file__) / "resources" @pytest.fixture(autouse=True, scope="function") @@ -24,48 +24,3 @@ def mock_environment_without_keys(): """Mocks an environment missing common API keys.""" with mock.patch.dict(os.environ, clear=True): yield - - -@pytest.fixture(scope="function") -def default_wandb_run_config(): - def generator(**kwargs) -> WandbRunConfig: - mine = { - "name": "my-run", - "project": "my-project", - "entity": "mozilla-ai", - "run_id": "gabbagool-123", - } - return WandbRunConfig(**{**mine, **kwargs}) - - yield generator - - -@pytest.fixture(scope="function") -def default_wandb_artifact_config(): - def generator(**kwargs) -> WandbArtifactConfig: - mine = { - "name": "my-run", - "version": "latest", - "project": "research-project", - "entity": "mozilla-corporation", - } - return WandbArtifactConfig(**{**mine, **kwargs}) - - yield generator - - -@pytest.fixture(scope="function") -def default_lm_harness_config(): - def generator(**kwargs) -> LMHarnessJobConfig: - mine = { - "tasks": ["task1", "task2"], - "num_fewshot": 5, - "batch_size": 16, - "torch_dtype": "bfloat16", - "model_name_or_path": None, - "quantization": None, - "timeout": 3600, - } - return LMHarnessJobConfig(**{**mine, **kwargs}) - - yield generator diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 385e40a3..3f6d1c28 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -4,16 +4,19 @@ from flamingo.integrations.wandb import WandbRunConfig -def test_env_vars(default_wandb_run_config): - env_vars = default_wandb_run_config().get_env_vars() +def test_env_vars(wandb_run_config_generator): + env_vars = wandb_run_config_generator().get_env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] for key in expected: assert key in env_vars assert "WANDB_RUN_GROUP" not in env_vars -def test_serde_round_trip(default_wandb_run_config): - assert WandbRunConfig.parse_raw(default_wandb_run_config().json()) == default_wandb_run_config() +def test_serde_round_trip(wandb_run_config_generator): + assert ( + WandbRunConfig.parse_raw(wandb_run_config_generator().json()) + == wandb_run_config_generator() + ) def test_disallowed_kwargs(): diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py new file mode 100644 index 00000000..6987a4b3 --- /dev/null +++ b/tests/jobs/conftest.py @@ -0,0 +1,58 @@ +import pytest +from peft import LoraConfig + +from flamingo.integrations.huggingface import ( + AutoModelConfig, + AutoTokenizerConfig, + DatasetConfig, + QuantizationConfig, +) +from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig + + +@pytest.fixture +def model_config_with_path(): + return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True) + + +@pytest.fixture +def model_config_with_artifact(): + artifact = WandbArtifactConfig(name="model") + return AutoModelConfig(artifact, trust_remote_code=True) + + +@pytest.fixture +def tokenizer_config_with_path(): + return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True) + + +@pytest.fixture +def tokenizer_config_with_artifact(): + artifact = WandbArtifactConfig(name="tokenizer") + return AutoTokenizerConfig(artifact, trust_remote_code=True) + + +@pytest.fixture +def dataset_config_with_path(): + return DatasetConfig("databricks/dolly7b", split="train") + + +@pytest.fixture +def dataset_config_with_artifact(): + artifact = WandbArtifactConfig(name="dataset") + return DatasetConfig(artifact, split="train") + + +@pytest.fixture +def quantization_config(): + return QuantizationConfig(load_in_8bit=True) + + +@pytest.fixture +def lora_config(): + return LoraConfig(r=8, lora_alpha=32, lora_dropout=0.2) + + +@pytest.fixture +def wandb_run_config(): + return WandbRunConfig(name="run", run_id="12345", project="research", entity="mozilla-ai") diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 32b32098..805e761f 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -1,28 +1,59 @@ -from pathlib import Path - import pytest from pydantic import ValidationError +from flamingo.integrations.huggingface import AutoModelConfig from flamingo.jobs import LMHarnessJobConfig +from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig +from tests.conftest import TEST_RESOURCES -def test_bad_hf_name(default_lm_harness_config): - with pytest.raises(ValidationError): - default_lm_harness_config(model_name_or_path="dfa../invalid") +@pytest.fixture +def lm_harness_evaluator_config(): + return LMHarnessEvaluatorConfig( + tasks=["task1", "task2", "task3"], + num_fewshot=5, + ) -def test_serde_round_trip_default_config(default_lm_harness_config): - config = default_lm_harness_config() - assert LMHarnessJobConfig.parse_raw(config.json()) == config +@pytest.fixture +def lm_harness_ray_config(): + return LMHarnessRayConfig( + num_workers=4, + use_gpu=True, + ) + +def test_model_validation(lm_harness_evaluator_config): + allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) + assert allowed_config.model == AutoModelConfig(path="hf_repo_id") -def test_serde_round_trip_with_path(default_lm_harness_config): - config = default_lm_harness_config(model_name_or_path=Path("fake/path")) + with pytest.raises(ValidationError): + LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) + + +def test_serde_round_trip( + model_config_with_artifact, + quantization_config, + wandb_run_config, + lm_harness_evaluator_config, + lm_harness_ray_config, +): + config = LMHarnessJobConfig( + model=model_config_with_artifact, + evaluator=lm_harness_evaluator_config, + ray=lm_harness_ray_config, + tracking=wandb_run_config, + quantization=quantization_config, + ) assert LMHarnessJobConfig.parse_raw(config.json()) == config -def test_parse_from_yaml(default_lm_harness_config, tmp_path_factory): - config = default_lm_harness_config(model_name_or_path="not_a_real_model") - p = tmp_path_factory.mktemp("test_yaml") / "eval.yaml" - config.to_yaml_file(p) - assert config == LMHarnessJobConfig.from_yaml_file(p) +def test_parse_yaml_file(tmp_path_factory): + load_path = TEST_RESOURCES / "lm_harness_config.yaml" + config = LMHarnessJobConfig.from_yaml_file(load_path) + write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml" + config.to_yaml_file(write_path) + assert config == LMHarnessJobConfig.from_yaml_file(write_path) diff --git a/tests/resources/finetuning_config.yaml b/tests/resources/finetuning_config.yaml index 8ce53870..e43baa22 100644 --- a/tests/resources/finetuning_config.yaml +++ b/tests/resources/finetuning_config.yaml @@ -16,30 +16,25 @@ dataset: split: "train" test_size: 0.2 -# HuggingFace Trainer/TrainingArguments trainer: max_seq_length: 512 learning_rate: 0.1 num_train_epochs: 2 -# HuggingFace quantization settings quantization: load_in_4bit: True bnb_4bit_quant_type: "fp4" -# LORA adapter settings adapter: r: 16 lora_alpha: 32 lora_dropout: 0.2 -# W&B run for logging results tracking: name: "location-to-log-results" project: "another-project" entity: "another-entity" -# Ray compute settings ray: use_gpu: True num_workers: 4 diff --git a/tests/resources/lm_harness_config.yaml b/tests/resources/lm_harness_config.yaml index ac9bad14..64592ba4 100644 --- a/tests/resources/lm_harness_config.yaml +++ b/tests/resources/lm_harness_config.yaml @@ -13,18 +13,15 @@ evaluator: tasks: ["task1", "task2", "...", "taskN"] num_fewshot: 5 -# HuggingFace quantization settings quantization: load_in_4bit: True bnb_4bit_quant_type: "fp4" -# W&B run for logging results tracking: name: "location-to-log-results" project: "another-project" entity: "another-entity" -# Ray compute settings ray: use_gpu: True num_workers: 4