From 81440e107c882fade4c966e83179b4b7ee94182e Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Wed, 17 Jan 2024 12:20:36 -0800 Subject: [PATCH] add more test coverage --- .../configs}/finetuning_config.yaml | 0 .../configs}/lm_harness_config.yaml | 0 tests/__init__.py | 1 - tests/conftest.py | 3 - .../wandb/test_artifact_config.py | 21 +++++ tests/integrations/wandb/test_run_config.py | 33 +++++--- tests/jobs/conftest.py | 12 +-- tests/jobs/test_finetuning_config.py | 82 +++++++++++++++---- tests/jobs/test_lm_harness_config.py | 43 +++++----- 9 files changed, 137 insertions(+), 58 deletions(-) rename {tests/resources => examples/configs}/finetuning_config.yaml (100%) rename {tests/resources => examples/configs}/lm_harness_config.yaml (100%) create mode 100644 tests/integrations/wandb/test_artifact_config.py diff --git a/tests/resources/finetuning_config.yaml b/examples/configs/finetuning_config.yaml similarity index 100% rename from tests/resources/finetuning_config.yaml rename to examples/configs/finetuning_config.yaml diff --git a/tests/resources/lm_harness_config.yaml b/examples/configs/lm_harness_config.yaml similarity index 100% rename from tests/resources/lm_harness_config.yaml rename to examples/configs/lm_harness_config.yaml diff --git a/tests/__init__.py b/tests/__init__.py index 2bb88df4..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 3ecdaf8d..e9fd2d21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,10 @@ This file is used to provide fixtures for the test session that are accessible to all submodules. """ import os -from pathlib import Path from unittest import mock import pytest -TEST_RESOURCES = Path(__file__) / "resources" - @pytest.fixture(autouse=True, scope="function") def mock_environment_with_keys(): diff --git a/tests/integrations/wandb/test_artifact_config.py b/tests/integrations/wandb/test_artifact_config.py new file mode 100644 index 00000000..6fdb98ec --- /dev/null +++ b/tests/integrations/wandb/test_artifact_config.py @@ -0,0 +1,21 @@ +import pytest + +from flamingo.integrations.wandb import WandbArtifactConfig + + +@pytest.fixture +def wandb_artifact_config(): + return WandbArtifactConfig( + name="artifact-name", + version="latest", + project="cortex-research", + entity="twitter.com", + ) + + +def test_serde_round_trip(wandb_artifact_config): + assert WandbArtifactConfig.parse_raw(wandb_artifact_config.json()) == wandb_artifact_config + + +def test_wandb_path(wandb_artifact_config): + assert wandb_artifact_config.wandb_path == "twitter.com/cortex-research/artifact-name:latest" diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 3f6d1c28..0cf64bb0 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -4,21 +4,32 @@ from flamingo.integrations.wandb import WandbRunConfig -def test_env_vars(wandb_run_config_generator): - env_vars = wandb_run_config_generator().get_env_vars() +@pytest.fixture +def wandb_run_config(): + return WandbRunConfig( + name="run-name", + run_id="run-id", + project="cortex-research", + entity="twitter.com", + ) + + +def test_serde_round_trip(wandb_run_config): + assert WandbRunConfig.parse_raw(wandb_run_config.json()) == wandb_run_config + + +def test_wandb_path(wandb_run_config): + assert wandb_run_config.wandb_path == "twitter.com/cortex-research/run-id" + + +def test_env_vars(wandb_run_config): + env_vars = wandb_run_config.get_env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] for key in expected: assert key in env_vars assert "WANDB_RUN_GROUP" not in env_vars -def test_serde_round_trip(wandb_run_config_generator): - assert ( - WandbRunConfig.parse_raw(wandb_run_config_generator().json()) - == wandb_run_config_generator() - ) - - def test_disallowed_kwargs(): with pytest.raises(ValidationError): WandbRunConfig(name="name", project="project", old_name="I will throw") @@ -26,5 +37,5 @@ def test_disallowed_kwargs(): def test_missing_key_warning(mock_environment_without_keys): with pytest.warns(UserWarning): - env = WandbRunConfig(name="I am missing an API key", project="I should warn the user") - assert "WANDB_API_KEY" not in env.env_vars + config = WandbRunConfig(name="I am missing an API key", project="I should warn the user") + assert "WANDB_API_KEY" not in config.get_env_vars() diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py index 6987a4b3..5dd3c812 100644 --- a/tests/jobs/conftest.py +++ b/tests/jobs/conftest.py @@ -12,35 +12,35 @@ @pytest.fixture def model_config_with_path(): - return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True) + return AutoModelConfig(path="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def model_config_with_artifact(): artifact = WandbArtifactConfig(name="model") - return AutoModelConfig(artifact, trust_remote_code=True) + return AutoModelConfig(path=artifact, trust_remote_code=True) @pytest.fixture def tokenizer_config_with_path(): - return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True) + return AutoTokenizerConfig(path="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def tokenizer_config_with_artifact(): artifact = WandbArtifactConfig(name="tokenizer") - return AutoTokenizerConfig(artifact, trust_remote_code=True) + return AutoTokenizerConfig(path=artifact, trust_remote_code=True) @pytest.fixture def dataset_config_with_path(): - return DatasetConfig("databricks/dolly7b", split="train") + return DatasetConfig(path="databricks/dolly15k", split="train") @pytest.fixture def dataset_config_with_artifact(): artifact = WandbArtifactConfig(name="dataset") - return DatasetConfig(artifact, split="train") + return DatasetConfig(path=artifact, split="train") @pytest.fixture diff --git a/tests/jobs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py index ca5377a9..099a1b81 100644 --- a/tests/jobs/test_finetuning_config.py +++ b/tests/jobs/test_finetuning_config.py @@ -1,22 +1,72 @@ -from peft import LoraConfig -from ray.train import ScalingConfig +import pytest +from pydantic import ValidationError -from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig +from flamingo.integrations.huggingface import AutoModelConfig, AutoTokenizerConfig, DatasetConfig from flamingo.jobs import FinetuningJobConfig +from flamingo.jobs.finetuning_config import FinetuningRayConfig -def test_serde_round_trip(): - trainer_config = TrainerConfig(torch_dtype="bfloat16") - lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM") - quantization_config = QuantizationConfig(load_in_8bit=True) - scaling_config = ScalingConfig(num_workers=2, use_gpu=True) - config = FinetuningJobConfig( - model="test-model", - dataset="test-dataset", - trainer=trainer_config, - lora=lora_config, +@pytest.fixture +def finetuning_ray_config(): + return FinetuningRayConfig( + num_workers=4, + use_gpu=True, + ) + + +@pytest.fixture +def finetuning_job_config( + model_config_with_artifact, + dataset_config_with_artifact, + tokenizer_config_with_artifact, + quantization_config, + lora_config, + wandb_run_config, + finetuning_ray_config, +): + return FinetuningJobConfig( + model=model_config_with_artifact, + dataset=dataset_config_with_artifact, + tokenizer=tokenizer_config_with_artifact, quantization=quantization_config, - scaling=scaling_config, - storage_path="/mnt/data/ray_results", + adapter=lora_config, + tracking=wandb_run_config, + ray=finetuning_ray_config, + ) + + +def test_serde_round_trip(finetuning_job_config): + assert FinetuningJobConfig.parse_raw(finetuning_job_config.json()) == finetuning_job_config + + +def test_parse_yaml_file(finetuning_job_config, tmp_path_factory): + config_path = tmp_path_factory.mktemp("flamingo_tests") / "finetuning_config.yaml" + finetuning_job_config.to_yaml_file(config_path) + assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) + + +def test_argument_validation(): + # Strings should be upcast to configs as the path argument + allowed_config = FinetuningJobConfig( + model="model_path", + tokenizer="tokenizer_path", + dataset="dataset_path", + ) + assert allowed_config.model == AutoModelConfig(path="model_path") + assert allowed_config.tokenizer == AutoTokenizerConfig(path="tokenizer_path") + assert allowed_config.dataset == DatasetConfig(path="dataset_path") + + # Check passing invalid arguments is validated for each asset type + with pytest.raises(ValidationError): + FinetuningJobConfig(model=12345, tokenizer="tokenizer_path", dataset="dataset_path") + with pytest.raises(ValidationError): + FinetuningJobConfig(model="model_path", tokenizer=12345, dataset="dataset_path") + with pytest.raises(ValidationError): + FinetuningJobConfig(model="model_path", tokenizer="tokenizer_path", dataset=12345) + + # Check that tokenizer is set to model path when absent + missing_tokenizer_config = FinetuningJobConfig( + model="model_path", + dataset="dataset_path", ) - assert FinetuningJobConfig.parse_raw(config.json()) == config + assert missing_tokenizer_config.tokenizer.path == "model_path" diff --git a/tests/jobs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py index 805e761f..0ef5fcea 100644 --- a/tests/jobs/test_lm_harness_config.py +++ b/tests/jobs/test_lm_harness_config.py @@ -4,7 +4,6 @@ from flamingo.integrations.huggingface import AutoModelConfig from flamingo.jobs import LMHarnessJobConfig from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig -from tests.conftest import TEST_RESOURCES @pytest.fixture @@ -23,37 +22,39 @@ def lm_harness_ray_config(): ) -def test_model_validation(lm_harness_evaluator_config): - allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) - assert allowed_config.model == AutoModelConfig(path="hf_repo_id") - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) - - -def test_serde_round_trip( +@pytest.fixture +def lm_harness_job_config( model_config_with_artifact, quantization_config, wandb_run_config, lm_harness_evaluator_config, lm_harness_ray_config, ): - config = LMHarnessJobConfig( + return LMHarnessJobConfig( model=model_config_with_artifact, evaluator=lm_harness_evaluator_config, ray=lm_harness_ray_config, tracking=wandb_run_config, quantization=quantization_config, ) - assert LMHarnessJobConfig.parse_raw(config.json()) == config -def test_parse_yaml_file(tmp_path_factory): - load_path = TEST_RESOURCES / "lm_harness_config.yaml" - config = LMHarnessJobConfig.from_yaml_file(load_path) - write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml" - config.to_yaml_file(write_path) - assert config == LMHarnessJobConfig.from_yaml_file(write_path) +def test_serde_round_trip(lm_harness_job_config): + assert LMHarnessJobConfig.parse_raw(lm_harness_job_config.json()) == lm_harness_job_config + + +def test_parse_yaml_file(lm_harness_job_config, tmp_path_factory): + config_path = tmp_path_factory.mktemp("flamingo_tests") / "lm_harness_config.yaml" + lm_harness_job_config.to_yaml_file(config_path) + assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) + + +def test_model_validation(lm_harness_evaluator_config): + allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) + assert allowed_config.model == AutoModelConfig(path="hf_repo_id") + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) + + with pytest.raises(ValidationError): + LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config)