Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
add more test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 17, 2024
1 parent 5ae6bac commit 81440e1
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 58 deletions.
1 change: 0 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from pathlib import Path
3 changes: 0 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
This file is used to provide fixtures for the test session that are accessible to all submodules.
"""
import os
from pathlib import Path
from unittest import mock

import pytest

TEST_RESOURCES = Path(__file__) / "resources"


@pytest.fixture(autouse=True, scope="function")
def mock_environment_with_keys():
Expand Down
21 changes: 21 additions & 0 deletions tests/integrations/wandb/test_artifact_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from flamingo.integrations.wandb import WandbArtifactConfig


@pytest.fixture
def wandb_artifact_config():
return WandbArtifactConfig(
name="artifact-name",
version="latest",
project="cortex-research",
entity="twitter.com",
)


def test_serde_round_trip(wandb_artifact_config):
assert WandbArtifactConfig.parse_raw(wandb_artifact_config.json()) == wandb_artifact_config


def test_wandb_path(wandb_artifact_config):
assert wandb_artifact_config.wandb_path == "twitter.com/cortex-research/artifact-name:latest"
33 changes: 22 additions & 11 deletions tests/integrations/wandb/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@
from flamingo.integrations.wandb import WandbRunConfig


def test_env_vars(wandb_run_config_generator):
env_vars = wandb_run_config_generator().get_env_vars()
@pytest.fixture
def wandb_run_config():
return WandbRunConfig(
name="run-name",
run_id="run-id",
project="cortex-research",
entity="twitter.com",
)


def test_serde_round_trip(wandb_run_config):
assert WandbRunConfig.parse_raw(wandb_run_config.json()) == wandb_run_config


def test_wandb_path(wandb_run_config):
assert wandb_run_config.wandb_path == "twitter.com/cortex-research/run-id"


def test_env_vars(wandb_run_config):
env_vars = wandb_run_config.get_env_vars()
expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"]
for key in expected:
assert key in env_vars
assert "WANDB_RUN_GROUP" not in env_vars


def test_serde_round_trip(wandb_run_config_generator):
assert (
WandbRunConfig.parse_raw(wandb_run_config_generator().json())
== wandb_run_config_generator()
)


def test_disallowed_kwargs():
with pytest.raises(ValidationError):
WandbRunConfig(name="name", project="project", old_name="I will throw")


def test_missing_key_warning(mock_environment_without_keys):
with pytest.warns(UserWarning):
env = WandbRunConfig(name="I am missing an API key", project="I should warn the user")
assert "WANDB_API_KEY" not in env.env_vars
config = WandbRunConfig(name="I am missing an API key", project="I should warn the user")
assert "WANDB_API_KEY" not in config.get_env_vars()
12 changes: 6 additions & 6 deletions tests/jobs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,35 @@

@pytest.fixture
def model_config_with_path():
return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True)
return AutoModelConfig(path="mistral-ai/mistral-7", trust_remote_code=True)


@pytest.fixture
def model_config_with_artifact():
artifact = WandbArtifactConfig(name="model")
return AutoModelConfig(artifact, trust_remote_code=True)
return AutoModelConfig(path=artifact, trust_remote_code=True)


@pytest.fixture
def tokenizer_config_with_path():
return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True)
return AutoTokenizerConfig(path="mistral-ai/mistral-7", trust_remote_code=True)


@pytest.fixture
def tokenizer_config_with_artifact():
artifact = WandbArtifactConfig(name="tokenizer")
return AutoTokenizerConfig(artifact, trust_remote_code=True)
return AutoTokenizerConfig(path=artifact, trust_remote_code=True)


@pytest.fixture
def dataset_config_with_path():
return DatasetConfig("databricks/dolly7b", split="train")
return DatasetConfig(path="databricks/dolly15k", split="train")


@pytest.fixture
def dataset_config_with_artifact():
artifact = WandbArtifactConfig(name="dataset")
return DatasetConfig(artifact, split="train")
return DatasetConfig(path=artifact, split="train")


@pytest.fixture
Expand Down
82 changes: 66 additions & 16 deletions tests/jobs/test_finetuning_config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,72 @@
from peft import LoraConfig
from ray.train import ScalingConfig
import pytest
from pydantic import ValidationError

from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig
from flamingo.integrations.huggingface import AutoModelConfig, AutoTokenizerConfig, DatasetConfig
from flamingo.jobs import FinetuningJobConfig
from flamingo.jobs.finetuning_config import FinetuningRayConfig


def test_serde_round_trip():
trainer_config = TrainerConfig(torch_dtype="bfloat16")
lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM")
quantization_config = QuantizationConfig(load_in_8bit=True)
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
config = FinetuningJobConfig(
model="test-model",
dataset="test-dataset",
trainer=trainer_config,
lora=lora_config,
@pytest.fixture
def finetuning_ray_config():
return FinetuningRayConfig(
num_workers=4,
use_gpu=True,
)


@pytest.fixture
def finetuning_job_config(
model_config_with_artifact,
dataset_config_with_artifact,
tokenizer_config_with_artifact,
quantization_config,
lora_config,
wandb_run_config,
finetuning_ray_config,
):
return FinetuningJobConfig(
model=model_config_with_artifact,
dataset=dataset_config_with_artifact,
tokenizer=tokenizer_config_with_artifact,
quantization=quantization_config,
scaling=scaling_config,
storage_path="/mnt/data/ray_results",
adapter=lora_config,
tracking=wandb_run_config,
ray=finetuning_ray_config,
)


def test_serde_round_trip(finetuning_job_config):
assert FinetuningJobConfig.parse_raw(finetuning_job_config.json()) == finetuning_job_config


def test_parse_yaml_file(finetuning_job_config, tmp_path_factory):
config_path = tmp_path_factory.mktemp("flamingo_tests") / "finetuning_config.yaml"
finetuning_job_config.to_yaml_file(config_path)
assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path)


def test_argument_validation():
# Strings should be upcast to configs as the path argument
allowed_config = FinetuningJobConfig(
model="model_path",
tokenizer="tokenizer_path",
dataset="dataset_path",
)
assert allowed_config.model == AutoModelConfig(path="model_path")
assert allowed_config.tokenizer == AutoTokenizerConfig(path="tokenizer_path")
assert allowed_config.dataset == DatasetConfig(path="dataset_path")

# Check passing invalid arguments is validated for each asset type
with pytest.raises(ValidationError):
FinetuningJobConfig(model=12345, tokenizer="tokenizer_path", dataset="dataset_path")
with pytest.raises(ValidationError):
FinetuningJobConfig(model="model_path", tokenizer=12345, dataset="dataset_path")
with pytest.raises(ValidationError):
FinetuningJobConfig(model="model_path", tokenizer="tokenizer_path", dataset=12345)

# Check that tokenizer is set to model path when absent
missing_tokenizer_config = FinetuningJobConfig(
model="model_path",
dataset="dataset_path",
)
assert FinetuningJobConfig.parse_raw(config.json()) == config
assert missing_tokenizer_config.tokenizer.path == "model_path"
43 changes: 22 additions & 21 deletions tests/jobs/test_lm_harness_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from flamingo.integrations.huggingface import AutoModelConfig
from flamingo.jobs import LMHarnessJobConfig
from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig
from tests.conftest import TEST_RESOURCES


@pytest.fixture
Expand All @@ -23,37 +22,39 @@ def lm_harness_ray_config():
)


def test_model_validation(lm_harness_evaluator_config):
allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config)
assert allowed_config.model == AutoModelConfig(path="hf_repo_id")

with pytest.raises(ValidationError):
LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config)

with pytest.raises(ValidationError):
LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config)


def test_serde_round_trip(
@pytest.fixture
def lm_harness_job_config(
model_config_with_artifact,
quantization_config,
wandb_run_config,
lm_harness_evaluator_config,
lm_harness_ray_config,
):
config = LMHarnessJobConfig(
return LMHarnessJobConfig(
model=model_config_with_artifact,
evaluator=lm_harness_evaluator_config,
ray=lm_harness_ray_config,
tracking=wandb_run_config,
quantization=quantization_config,
)
assert LMHarnessJobConfig.parse_raw(config.json()) == config


def test_parse_yaml_file(tmp_path_factory):
load_path = TEST_RESOURCES / "lm_harness_config.yaml"
config = LMHarnessJobConfig.from_yaml_file(load_path)
write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml"
config.to_yaml_file(write_path)
assert config == LMHarnessJobConfig.from_yaml_file(write_path)
def test_serde_round_trip(lm_harness_job_config):
assert LMHarnessJobConfig.parse_raw(lm_harness_job_config.json()) == lm_harness_job_config


def test_parse_yaml_file(lm_harness_job_config, tmp_path_factory):
config_path = tmp_path_factory.mktemp("flamingo_tests") / "lm_harness_config.yaml"
lm_harness_job_config.to_yaml_file(config_path)
assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path)


def test_model_validation(lm_harness_evaluator_config):
allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config)
assert allowed_config.model == AutoModelConfig(path="hf_repo_id")

with pytest.raises(ValidationError):
LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config)

with pytest.raises(ValidationError):
LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config)

0 comments on commit 81440e1

Please sign in to comment.