This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Friedowitz
committed
Jan 17, 2024
1 parent
5e54d10
commit 5ae6bac
Showing
10 changed files
with
147 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,9 @@ | ||
from .base_config import BaseJobConfig | ||
from .finetuning_config import FinetuningJobConfig | ||
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath | ||
from .lm_harness_config import LMHarnessJobConfig | ||
from .simple_config import SimpleJobConfig | ||
|
||
__all__ = [ | ||
"BaseJobConfig", | ||
"SimpleJobConfig", | ||
"FinetuningJobConfig", | ||
"LMHarnessJobConfig", | ||
"ModelNameOrCheckpointPath", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pathlib import Path | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
from peft import LoraConfig | ||
|
||
from flamingo.integrations.huggingface import ( | ||
AutoModelConfig, | ||
AutoTokenizerConfig, | ||
DatasetConfig, | ||
QuantizationConfig, | ||
) | ||
from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig | ||
|
||
|
||
@pytest.fixture | ||
def model_config_with_path(): | ||
return AutoModelConfig("mistral-ai/mistral-7", trust_remote_code=True) | ||
|
||
|
||
@pytest.fixture | ||
def model_config_with_artifact(): | ||
artifact = WandbArtifactConfig(name="model") | ||
return AutoModelConfig(artifact, trust_remote_code=True) | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer_config_with_path(): | ||
return AutoTokenizerConfig("mistral-ai/mistral-7", trust_remote_code=True) | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer_config_with_artifact(): | ||
artifact = WandbArtifactConfig(name="tokenizer") | ||
return AutoTokenizerConfig(artifact, trust_remote_code=True) | ||
|
||
|
||
@pytest.fixture | ||
def dataset_config_with_path(): | ||
return DatasetConfig("databricks/dolly7b", split="train") | ||
|
||
|
||
@pytest.fixture | ||
def dataset_config_with_artifact(): | ||
artifact = WandbArtifactConfig(name="dataset") | ||
return DatasetConfig(artifact, split="train") | ||
|
||
|
||
@pytest.fixture | ||
def quantization_config(): | ||
return QuantizationConfig(load_in_8bit=True) | ||
|
||
|
||
@pytest.fixture | ||
def lora_config(): | ||
return LoraConfig(r=8, lora_alpha=32, lora_dropout=0.2) | ||
|
||
|
||
@pytest.fixture | ||
def wandb_run_config(): | ||
return WandbRunConfig(name="run", run_id="12345", project="research", entity="mozilla-ai") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,59 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from flamingo.integrations.huggingface import AutoModelConfig | ||
from flamingo.jobs import LMHarnessJobConfig | ||
from flamingo.jobs.lm_harness_config import LMHarnessEvaluatorConfig, LMHarnessRayConfig | ||
from tests.conftest import TEST_RESOURCES | ||
|
||
|
||
def test_bad_hf_name(default_lm_harness_config): | ||
with pytest.raises(ValidationError): | ||
default_lm_harness_config(model_name_or_path="dfa../invalid") | ||
@pytest.fixture | ||
def lm_harness_evaluator_config(): | ||
return LMHarnessEvaluatorConfig( | ||
tasks=["task1", "task2", "task3"], | ||
num_fewshot=5, | ||
) | ||
|
||
|
||
def test_serde_round_trip_default_config(default_lm_harness_config): | ||
config = default_lm_harness_config() | ||
assert LMHarnessJobConfig.parse_raw(config.json()) == config | ||
@pytest.fixture | ||
def lm_harness_ray_config(): | ||
return LMHarnessRayConfig( | ||
num_workers=4, | ||
use_gpu=True, | ||
) | ||
|
||
|
||
def test_model_validation(lm_harness_evaluator_config): | ||
allowed_config = LMHarnessJobConfig(model="hf_repo_id", evaluator=lm_harness_evaluator_config) | ||
assert allowed_config.model == AutoModelConfig(path="hf_repo_id") | ||
|
||
def test_serde_round_trip_with_path(default_lm_harness_config): | ||
config = default_lm_harness_config(model_name_or_path=Path("fake/path")) | ||
with pytest.raises(ValidationError): | ||
LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) | ||
|
||
with pytest.raises(ValidationError): | ||
LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config) | ||
|
||
|
||
def test_serde_round_trip( | ||
model_config_with_artifact, | ||
quantization_config, | ||
wandb_run_config, | ||
lm_harness_evaluator_config, | ||
lm_harness_ray_config, | ||
): | ||
config = LMHarnessJobConfig( | ||
model=model_config_with_artifact, | ||
evaluator=lm_harness_evaluator_config, | ||
ray=lm_harness_ray_config, | ||
tracking=wandb_run_config, | ||
quantization=quantization_config, | ||
) | ||
assert LMHarnessJobConfig.parse_raw(config.json()) == config | ||
|
||
|
||
def test_parse_from_yaml(default_lm_harness_config, tmp_path_factory): | ||
config = default_lm_harness_config(model_name_or_path="not_a_real_model") | ||
p = tmp_path_factory.mktemp("test_yaml") / "eval.yaml" | ||
config.to_yaml_file(p) | ||
assert config == LMHarnessJobConfig.from_yaml_file(p) | ||
def test_parse_yaml_file(tmp_path_factory): | ||
load_path = TEST_RESOURCES / "lm_harness_config.yaml" | ||
config = LMHarnessJobConfig.from_yaml_file(load_path) | ||
write_path = tmp_path_factory.mktemp("flamingo_tests") / "harness_config.yaml" | ||
config.to_yaml_file(write_path) | ||
assert config == LMHarnessJobConfig.from_yaml_file(write_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters