This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Friedowitz
committed
Jan 17, 2024
1 parent
5ae6bac
commit 81440e1
Showing
9 changed files
with
137 additions
and
58 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +0,0 @@ | ||
from pathlib import Path | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
|
||
from flamingo.integrations.wandb import WandbArtifactConfig | ||
|
||
|
||
@pytest.fixture | ||
def wandb_artifact_config(): | ||
return WandbArtifactConfig( | ||
name="artifact-name", | ||
version="latest", | ||
project="cortex-research", | ||
entity="twitter.com", | ||
) | ||
|
||
|
||
def test_serde_round_trip(wandb_artifact_config): | ||
assert WandbArtifactConfig.parse_raw(wandb_artifact_config.json()) == wandb_artifact_config | ||
|
||
|
||
def test_wandb_path(wandb_artifact_config): | ||
assert wandb_artifact_config.wandb_path == "twitter.com/cortex-research/artifact-name:latest" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,72 @@ | ||
from peft import LoraConfig | ||
from ray.train import ScalingConfig | ||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from flamingo.integrations.huggingface import QuantizationConfig, TrainerConfig | ||
from flamingo.integrations.huggingface import AutoModelConfig, AutoTokenizerConfig, DatasetConfig | ||
from flamingo.jobs import FinetuningJobConfig | ||
from flamingo.jobs.finetuning_config import FinetuningRayConfig | ||
|
||
|
||
def test_serde_round_trip(): | ||
trainer_config = TrainerConfig(torch_dtype="bfloat16") | ||
lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM") | ||
quantization_config = QuantizationConfig(load_in_8bit=True) | ||
scaling_config = ScalingConfig(num_workers=2, use_gpu=True) | ||
config = FinetuningJobConfig( | ||
model="test-model", | ||
dataset="test-dataset", | ||
trainer=trainer_config, | ||
lora=lora_config, | ||
@pytest.fixture | ||
def finetuning_ray_config(): | ||
return FinetuningRayConfig( | ||
num_workers=4, | ||
use_gpu=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def finetuning_job_config( | ||
model_config_with_artifact, | ||
dataset_config_with_artifact, | ||
tokenizer_config_with_artifact, | ||
quantization_config, | ||
lora_config, | ||
wandb_run_config, | ||
finetuning_ray_config, | ||
): | ||
return FinetuningJobConfig( | ||
model=model_config_with_artifact, | ||
dataset=dataset_config_with_artifact, | ||
tokenizer=tokenizer_config_with_artifact, | ||
quantization=quantization_config, | ||
scaling=scaling_config, | ||
storage_path="/mnt/data/ray_results", | ||
adapter=lora_config, | ||
tracking=wandb_run_config, | ||
ray=finetuning_ray_config, | ||
) | ||
|
||
|
||
def test_serde_round_trip(finetuning_job_config): | ||
assert FinetuningJobConfig.parse_raw(finetuning_job_config.json()) == finetuning_job_config | ||
|
||
|
||
def test_parse_yaml_file(finetuning_job_config, tmp_path_factory): | ||
config_path = tmp_path_factory.mktemp("flamingo_tests") / "finetuning_config.yaml" | ||
finetuning_job_config.to_yaml_file(config_path) | ||
assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) | ||
|
||
|
||
def test_argument_validation(): | ||
# Strings should be upcast to configs as the path argument | ||
allowed_config = FinetuningJobConfig( | ||
model="model_path", | ||
tokenizer="tokenizer_path", | ||
dataset="dataset_path", | ||
) | ||
assert allowed_config.model == AutoModelConfig(path="model_path") | ||
assert allowed_config.tokenizer == AutoTokenizerConfig(path="tokenizer_path") | ||
assert allowed_config.dataset == DatasetConfig(path="dataset_path") | ||
|
||
# Check passing invalid arguments is validated for each asset type | ||
with pytest.raises(ValidationError): | ||
FinetuningJobConfig(model=12345, tokenizer="tokenizer_path", dataset="dataset_path") | ||
with pytest.raises(ValidationError): | ||
FinetuningJobConfig(model="model_path", tokenizer=12345, dataset="dataset_path") | ||
with pytest.raises(ValidationError): | ||
FinetuningJobConfig(model="model_path", tokenizer="tokenizer_path", dataset=12345) | ||
|
||
# Check that tokenizer is set to model path when absent | ||
missing_tokenizer_config = FinetuningJobConfig( | ||
model="model_path", | ||
dataset="dataset_path", | ||
) | ||
assert FinetuningJobConfig.parse_raw(config.json()) == config | ||
assert missing_tokenizer_config.tokenizer.path == "model_path" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters