diff --git a/examples/configs/finetuning.yaml b/examples/configs/finetuning.yaml new file mode 100644 index 00000000..08e94414 --- /dev/null +++ b/examples/configs/finetuning.yaml @@ -0,0 +1,50 @@ +# Base model to load for finetuning +model: + load_from: + repo_id: "distilgpt2" + # Can also specify the asset to load as a W&B artifact + # load_from: + # name: "artifact-name" + # project: "artifact-project" + # version: "v0" + torch_dtype: "bfloat16" + +# Tokenizer section (when not defined, will default to the model value) +# tokenizer: "distilgpt2" + +# Text dataset to use for training +dataset: + load_from: + repo_id: "imdb" + split: "train[:100]" + test_size: 0.2 + text_field: "text" + +trainer: + max_seq_length: 512 + learning_rate: 0.001 + num_train_epochs: 2 + save_steps: 1 + save_strategy: "epochs" + logging_steps: 1 + logging_strategy: "steps" + +# Quantization section (not necessary when using LORA w/ built in LOFT-Q) +# quantization: + +adapter: + peft_type: "LORA" + task_type: "CAUSAL_LM" + r: 16 + lora_alpha: 32 + lora_dropout: 0.2 + +# Tracking info for where to log the run results +tracking: + name: "flamingo-example-finetuning" + project: "flamingo-examples" + entity: "mozilla-ai" + +ray: + use_gpu: True + num_workers: 2 diff --git a/examples/configs/finetuning_config.yaml b/examples/configs/finetuning_config.yaml deleted file mode 100644 index 99a581eb..00000000 --- a/examples/configs/finetuning_config.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# Tokenizer defined by only a string repo ID -tokenizer: "mistral-ai/other-repo-with-special-tokenizer" - -# Model defined as an object with additional settings beyond the repo ID -model: - load_from: "mistral-ai/mistral-7b" - trust_remote_code: True - torch_dtype: "bfloat16" - -# Dataset defined as an object with a path linking to a W&B artifact -dataset: - load_from: - name: "dataset-artifact" - version: "latest" - project: "research-project" - split: "train" - test_size: 0.2 - -trainer: - max_seq_length: 512 - learning_rate: 0.1 - num_train_epochs: 2 - -quantization: - load_in_4bit: True - bnb_4bit_quant_type: "fp4" - -adapter: - peft_type: "LORA" - task_type: "CAUSAL_LM" - r: 16 - lora_alpha: 32 - lora_dropout: 0.2 - -tracking: - name: "location-to-log-results" - project: "another-project" - entity: "another-entity" - -ray: - use_gpu: True - num_workers: 4 diff --git a/examples/configs/lm_harness.yaml b/examples/configs/lm_harness.yaml new file mode 100644 index 00000000..1e8380ca --- /dev/null +++ b/examples/configs/lm_harness.yaml @@ -0,0 +1,25 @@ +# Model to evaluate +model: + load_from: "distilgpt2" + torch_dtype: "bfloat16" + +# Settings specific to lm_harness.evaluate +evaluator: + tasks: ["hellaswag"] + num_fewshot: 5 + limit: 10 + +quantization: + load_in_4bit: True + bnb_4bit_quant_type: "fp4" + +# Tracking info for where to log the run results +tracking: + name: "flamingo-example-lm-harness" + project: "flamingo-examples" + entity: "mozilla-ai" + +ray: + num_cpus: 1 + num_gpus: 1 + timeout: 3600 diff --git a/examples/configs/lm_harness_config.yaml b/examples/configs/lm_harness_config.yaml deleted file mode 100644 index 46df88c5..00000000 --- a/examples/configs/lm_harness_config.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# Model to evaluate, specified as a W&B artifact -model: - load_from: - name: "training-run-model-artifact" - version: "v4" - project: "research-project" - entity: "mozilla.ai" - trust_remote_code: True - torch_dtype: "float16" - -# Settings specific to lm_harness.evaluate -evaluator: - tasks: ["task1", "task2", "...", "taskN"] - num_fewshot: 5 - -quantization: - load_in_4bit: True - bnb_4bit_quant_type: "fp4" - -tracking: - name: "location-to-log-results" - project: "another-project" - entity: "another-entity" - -ray: - num_cpus: 1 - num_gpus: 4 - timeout: 3600 diff --git a/examples/configs/simple_config.yaml b/examples/configs/simple.yaml similarity index 100% rename from examples/configs/simple_config.yaml rename to examples/configs/simple.yaml diff --git a/examples/dev_workflow.ipynb b/examples/dev_workflow.ipynb index d7931b0e..5ee3bdf4 100644 --- a/examples/dev_workflow.ipynb +++ b/examples/dev_workflow.ipynb @@ -8,10 +8,30 @@ "# Development Workflow" ] }, + { + "cell_type": "markdown", + "id": "9366fd9e", + "metadata": {}, + "source": [ + "## File-based Submission" + ] + }, + { + "cell_type": "markdown", + "id": "fcd5240e", + "metadata": {}, + "source": [ + "This demonstrates the basic workflow for submitting a Flamingo job to Ray\n", + "from a configuration stored as a local file.\n", + "\n", + "The job configuration is stored as a YAML file in a the local `configs` directory,\n", + "and that directory is specified as the working directory of the Ray runtime environment upon submission." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "8c0f15ed-77dc-44ce-adb6-d1b59368f03c", + "id": "9b26d777", "metadata": {}, "outputs": [], "source": [ @@ -19,10 +39,7 @@ "import os\n", "from pathlib import Path\n", "\n", - "from ray.job_submission import JobSubmissionClient\n", - "\n", - "# flamingo should be installed in your development environment\n", - "import flamingo" + "from ray.job_submission import JobSubmissionClient" ] }, { @@ -45,25 +62,11 @@ "outputs": [], "source": [ "# Determine local module path for the flamingo repo\n", - "flamingo_module = Path(flamingo.__file__).parent" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1db3b9aa-99a4-49d9-8773-7b91ccf89c85", - "metadata": {}, - "outputs": [], - "source": [ - "# Load and inspect the config file\n", - "# Not mandatory for job submission, but helpful when debugging\n", - "from flamingo.jobs.simple import SimpleJobConfig\n", - "\n", - "CONFIG_DIR = Path(\"configs\")\n", - "CONFIG_FILE = \"simple_config.yaml\"\n", + "# In theory this workflow is possible without having the flamingo package installed locally,\n", + "# but this is a convenient means to access the local module path\n", + "import flamingo\n", "\n", - "config = SimpleJobConfig.from_yaml_file(CONFIG_DIR / CONFIG_FILE)\n", - "config" + "flamingo_module = Path(flamingo.__file__).parent" ] }, { @@ -77,10 +80,10 @@ "# py_modules contains the path to the local flamingo module directory\n", "# pip contains an export of the dependencies for the flamingo package (see CONTRIBUTING.md for how to generate)\n", "runtime_env = {\n", - " \"working_dir\": str(CONFIG_DIR),\n", + " \"working_dir\": \"configs\",\n", " \"env_vars\": {\"WANDB_API_KEY\": os.environ[\"WANDB_API_KEY\"]}, # If running a job that uses W&B\n", " \"py_modules\": [str(flamingo_module)],\n", - " \"pip\": \"/path/to/flamingo/requirements.txt\",\n", + " \"pip\": \"requirements.txt\", # See CONTRIBUTING.md for how to generate this\n", "}" ] }, @@ -94,9 +97,111 @@ "# Submit the job to the Ray cluster\n", "# Note: flamingo is invoked by 'python -m flamingo' since the CLI is not installed in the environment\n", "client.submit_job(\n", - " entrypoint=f\"python -m flamingo run simple --config {CONFIG_FILE}\", runtime_env=runtime_env\n", + " entrypoint=f\"python -m flamingo run simple --config simple.yaml\",\n", + " runtime_env=runtime_env,\n", ")" ] + }, + { + "cell_type": "markdown", + "id": "425be140", + "metadata": {}, + "source": [ + "## Iterative Submission" + ] + }, + { + "cell_type": "markdown", + "id": "e99ce273", + "metadata": {}, + "source": [ + "It is also possible to submit Flamingo jobs using a fully Python/Jupyter driven workflow.\n", + "\n", + "In this case, the Flamingo job configuration is instantiated in your Python script\n", + "and written to a temporary directory for submission. \n", + "\n", + "The Ray working directory is based off this temporary YAML file location.\n", + "\n", + "This approach is convenient if you want to run sweeps over parameter ranges\n", + "and use a Python script/Jupyter notebook as your local \"driver\" for the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cfccaa9", + "metadata": {}, + "outputs": [], + "source": [ + "# Required imports\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "from ray.job_submission import JobSubmissionClient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a51235ed", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a submission client bound to a Ray cluster\n", + "# Note: You will likely have to update the cluster address shown below\n", + "client = JobSubmissionClient(\"http://10.147.154.77:8265\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1216c43", + "metadata": {}, + "outputs": [], + "source": [ + "# Determine local module path for the flamingo repo\n", + "# In theory this workflow is possible without having the flamingo package installed locally,\n", + "# but this is a convenient means to access the local module path\n", + "import flamingo\n", + "\n", + "flamingo_module = Path(flamingo.__file__).parent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5715d09", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from flamingo.jobs.simple import SimpleJobConfig\n", + "\n", + "# Generate job configs programatically for sweeps over parameter ranges\n", + "magic_numbers = [0, 10, 20, 40]\n", + "\n", + "for number in magic_numbers:\n", + " # Instantitate config in your workflow script\n", + " # You may also want to read a \"base\" config from file with some suitable defaults\n", + " config = SimpleJobConfig(magic_number=number)\n", + "\n", + " # `config_path` is the fully qualified path to the config file on your local filesystem\n", + " with config.to_tempfile(name=\"config.yaml\") as config_path:\n", + " # `config_path.parent` is the working directory\n", + " runtime_env = {\n", + " \"working_dir\": str(config_path.parent),\n", + " \"env_vars\": {\"WANDB_API_KEY\": os.environ[\"WANDB_API_KEY\"]},\n", + " \"py_modules\": [str(flamingo_module)],\n", + " \"pip\": \"requirements.txt\", # See CONTRIBUTING.md for how to generate this\n", + " }\n", + "\n", + " # `config_path.name` is the file name within the working directory, i.e., \"config.yaml\"\n", + " client.submit_job(\n", + " entrypoint=f\"python -m flamingo run simple --config {config_path.name}\",\n", + " runtime_env=runtime_env,\n", + " )" + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index d729da83..1daea82d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ protobuf = "3.20.0" urllib3 = ">=1.26.18,<2" pydantic = "1.10.14" pydantic-yaml = "1.2.0" -ray = { version = "2.8.0", extras = ["default"] } +ray = { version = "2.9.1", extras = ["default"] } [tool.poetry.dev-dependencies] ruff = "0.1.7" diff --git a/src/flamingo/integrations/wandb/artifact_config.py b/src/flamingo/integrations/wandb/artifact_config.py index 797cacde..c38f115b 100644 --- a/src/flamingo/integrations/wandb/artifact_config.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -1,3 +1,5 @@ +import re + from flamingo.types import BaseFlamingoConfig @@ -5,12 +7,26 @@ class WandbArtifactConfig(BaseFlamingoConfig): """Configuration required to retrieve an artifact from W&B.""" name: str + project: str version: str = "latest" - project: str | None = None entity: str | None = None + @classmethod + def from_wandb_path(cls, path: str) -> "WandbArtifactConfig": + """Construct an artifact configuration from the W&B name. + + The name should be of the form "//:" + with the "entity" field optional. + """ + match = re.search(r"((.*)\/)?(.*)\/(.*)\:(.*)", path) + if match is not None: + entity, project, name, version = match.groups()[1:] + return cls(name=name, project=project, version=version, entity=entity) + raise ValueError(f"Invalid artifact path: {path}") + def wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" - path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) - path = f"{path}:{self.version}" + path = f"{self.project}/{self.name}:{self.version}" + if self.entity is not None: + path = f"{self.entity}/{path}" return path diff --git a/src/flamingo/integrations/wandb/run_utils.py b/src/flamingo/integrations/wandb/run_utils.py index afd55630..cd56ef17 100644 --- a/src/flamingo/integrations/wandb/run_utils.py +++ b/src/flamingo/integrations/wandb/run_utils.py @@ -1,4 +1,5 @@ import contextlib +from enum import Enum from typing import Any import wandb @@ -8,20 +9,32 @@ from flamingo.types import BaseFlamingoConfig +class WandbResumeMode(str, Enum): + """Enumeration of modes for resuming a W&B run. + + This is not an exahustive list of the values that can be passed to the W&B SDK + (Docs: https://docs.wandb.ai/ref/python/init), but just those commonly used within the package. + """ + + ALLOW = "allow" + MUST = "must" + NEVER = "never" + + @contextlib.contextmanager def wandb_init_from_config( config: WandbRunConfig, *, parameters: BaseFlamingoConfig | None = None, + resume: WandbResumeMode | None = None, job_type: str | None = None, - resume: str | None = None, ): """Initialize a W&B run from the internal run configuration. This method can be entered as a context manager similar to `wandb.init` as follows: ``` - with wandb_init_from_config(run_config, resume="must") as run: + with wandb_init_from_config(run_config, resume=WandbResumeMode.MUST) as run: # Use the initialized run here ... ``` diff --git a/src/flamingo/jobs/finetuning/entrypoint.py b/src/flamingo/jobs/finetuning/entrypoint.py index c25949b3..52c7f792 100644 --- a/src/flamingo/jobs/finetuning/entrypoint.py +++ b/src/flamingo/jobs/finetuning/entrypoint.py @@ -13,6 +13,7 @@ from flamingo.integrations.wandb import ( ArtifactType, ArtifactURIScheme, + WandbResumeMode, default_artifact_name, log_directory_reference, wandb_init_from_config, @@ -64,9 +65,7 @@ def training_function(config_data: dict): config = FinetuningJobConfig(**config_data) if is_tracking_enabled(config): with wandb_init_from_config( - config.tracking, - job_type=FlamingoJobType.FINETUNING, - resume="never", + config.tracking, resume=WandbResumeMode.NEVER, job_type=FlamingoJobType.FINETUNING ): load_and_train(config) else: @@ -96,7 +95,7 @@ def run_finetuning(config: FinetuningJobConfig): # Register a model artifact if tracking is enabled and Ray saved a checkpoint if config.tracking and result.checkpoint: # Must resume from the just-completed training run - with wandb_init_from_config(config.tracking, resume="must") as run: + with wandb_init_from_config(config.tracking, resume=WandbResumeMode.MUST) as run: print("Logging artifact for model checkpoint...") log_directory_reference( dir_path=f"{result.checkpoint.path}/checkpoint", diff --git a/src/flamingo/jobs/lm_harness/entrypoint.py b/src/flamingo/jobs/lm_harness/entrypoint.py index 22389c33..8429ddf6 100644 --- a/src/flamingo/jobs/lm_harness/entrypoint.py +++ b/src/flamingo/jobs/lm_harness/entrypoint.py @@ -8,13 +8,18 @@ from peft import PeftConfig from flamingo.integrations.huggingface import resolve_loadable_path -from flamingo.integrations.wandb import ArtifactType, default_artifact_name, wandb_init_from_config +from flamingo.integrations.wandb import ( + ArtifactType, + WandbResumeMode, + default_artifact_name, + wandb_init_from_config, +) from flamingo.jobs.lm_harness import LMHarnessJobConfig from flamingo.jobs.utils import FlamingoJobType # TODO: Should this also be abstracted to a helper method like log_artifact_from_path? -def build_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) -> wandb.Artifact: +def log_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) -> wandb.Artifact: print("Building artifact for evaluation results...") artifact_name = default_artifact_name(run_name, ArtifactType.EVALUATION) artifact = wandb.Artifact(artifact_name, type=ArtifactType.EVALUATION) @@ -23,12 +28,12 @@ def build_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) task_data = [(k, v) for k, v in task_results.items() if isinstance(v, int | float)] task_table = wandb.Table(data=task_data, columns=["metric", "value"]) artifact.add(task_table, name=f"task-{task_name}") - return artifact + return wandb.log_artifact(artifact) def load_harness_model(config: LMHarnessJobConfig) -> HFLM | OpenaiCompletionsLM: # Helper method to return lm-harness model wrapper - def _loader(model: str | None , tokenizer: str, peft: str | None): + def loader(model: str | None , tokenizer: str, peft: str | None): """Load model directly from HF if HF path, otherwise from an inference server URL""" @@ -52,12 +57,13 @@ def _loader(model: str | None , tokenizer: str, peft: str | None): ) + # We don't know if the checkpoint is adapter weights or merged model weights # Try to load as an adapter and fall back to the checkpoint containing the full model load_path, revision = resolve_loadable_path(config.model.load_from) try: peft_config = PeftConfig.from_pretrained(load_path, revision=revision) - return _loader( + return loader( pretrained=peft_config.base_model_name_or_path, tokenizer=peft_config.base_model_name_or_path, peft=load_path, @@ -67,7 +73,7 @@ def _loader(model: str | None , tokenizer: str, peft: str | None): f"Unable to load model as adapter: {e}. " "This is expected if the checkpoint does not contain adapter weights." ) - return _loader(pretrained=load_path, tokenizer=load_path, peft=None) + return loader(pretrained=load_path, tokenizer=load_path, peft=None) def load_and_evaluate(config: LMHarnessJobConfig) -> dict[str, Any]: @@ -94,12 +100,11 @@ def evaluation_task(config: LMHarnessJobConfig) -> None: with wandb_init_from_config( config.tracking, parameters=config.evaluator, # Log eval settings in W&B run + resume=WandbResumeMode.ALLOW, job_type=FlamingoJobType.EVALUATION, - resume="allow", ) as run: eval_results = load_and_evaluate(config) - artifact = build_evaluation_artifact(run.name, eval_results) - run.log_artifact(artifact) + log_evaluation_artifact(run.name, eval_results) else: load_and_evaluate(config) diff --git a/src/flamingo/types.py b/src/flamingo/types.py index be8aa294..fee4ef82 100644 --- a/src/flamingo/types.py +++ b/src/flamingo/types.py @@ -1,3 +1,5 @@ +import contextlib +import tempfile from pathlib import Path from typing import Any @@ -56,3 +58,19 @@ def from_yaml_file(cls, path: Path | str): def to_yaml_file(self, path: Path | str): to_yaml_file(path, self, exclude_none=True) + + @contextlib.contextmanager + def to_tempfile(self, *, name: str = "config.yaml", dir: str | Path | None = None): + """Enter a context manager with the config written to a temporary YAML file. + + Args: + name (str): Name of the config file in the tmp directory. Defaults to "config.yaml" + dir (str | Path, optional): Root path of the temporary directory + + Returns: + Path to the temporary config file + """ + with tempfile.TemporaryDirectory(dir=dir) as tmpdir: + config_path = Path(tmpdir) / name + self.to_yaml_file(config_path) + yield config_path diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c341d223..934e5f1c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,39 +12,35 @@ @pytest.fixture -def model_config_with_path(): +def model_config_with_repo_id(): return AutoModelConfig(load_from="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def model_config_with_artifact(): - artifact = WandbArtifactConfig(name="model") + artifact = WandbArtifactConfig(name="model", project="project") return AutoModelConfig(load_from=artifact, trust_remote_code=True) @pytest.fixture -def tokenizer_config_with_path(): +def tokenizer_config_with_repo_id(): return AutoTokenizerConfig(load_from="mistral-ai/mistral-7", trust_remote_code=True) @pytest.fixture def tokenizer_config_with_artifact(): - artifact = WandbArtifactConfig(name="tokenizer") + artifact = WandbArtifactConfig(name="tokenizer", project="project") return AutoTokenizerConfig(load_from=artifact, trust_remote_code=True) @pytest.fixture -def dataset_config_with_path(): - return TextDatasetConfig( - load_from="databricks/dolly15k", - text_field="text", - split="train", - ) +def dataset_config_with_repo_id(): + return TextDatasetConfig(load_from="databricks/dolly15k", text_field="text", split="train") @pytest.fixture def dataset_config_with_artifact(): - artifact = WandbArtifactConfig(name="dataset") + artifact = WandbArtifactConfig(name="dataset", project="project") return TextDatasetConfig(load_from=artifact, split="train") @@ -66,4 +62,4 @@ def adapter_config(): @pytest.fixture def wandb_run_config(): - return WandbRunConfig(name="run", run_id="12345", project="research", entity="mozilla-ai") + return WandbRunConfig(name="run", run_id="12345", project="research", entity="mzai") diff --git a/tests/unit/integrations/huggingface/test_loading_utils.py b/tests/unit/integrations/huggingface/test_loading_utils.py index e47e87b6..8a3e3be9 100644 --- a/tests/unit/integrations/huggingface/test_loading_utils.py +++ b/tests/unit/integrations/huggingface/test_loading_utils.py @@ -19,7 +19,7 @@ def test_dataset_loading(resources_dir): "flamingo.integrations.huggingface.loading_utils.get_artifact_filesystem_path", return_value=xyz_dataset_path, ): - artifact = WandbArtifactConfig(name="xyz-dataset") + artifact = WandbArtifactConfig(name="xyz-dataset", project="project") dataset_config = DatasetConfig(load_from=artifact, test_size=0.2, seed=0) dataset = load_dataset_from_config(dataset_config) diff --git a/tests/unit/integrations/wandb/test_artifact_config.py b/tests/unit/integrations/wandb/test_artifact_config.py index 0cdfdd73..adc5f923 100644 --- a/tests/unit/integrations/wandb/test_artifact_config.py +++ b/tests/unit/integrations/wandb/test_artifact_config.py @@ -19,3 +19,24 @@ def test_serde_round_trip(wandb_artifact_config): def test_wandb_path(wandb_artifact_config): assert wandb_artifact_config.wandb_path() == "team/research/artifact-name:latest" + + +def test_from_wandb_path(): + valid_path_with_entity = "entity/project/name:latest" + config_with_entity = WandbArtifactConfig.from_wandb_path(valid_path_with_entity) + assert config_with_entity.name == "name" + assert config_with_entity.project == "project" + assert config_with_entity.version == "latest" + assert config_with_entity.entity == "entity" + + valid_path_without_entity = "project/name:latest" + config_without_entity = WandbArtifactConfig.from_wandb_path(valid_path_without_entity) + assert config_without_entity.name == "name" + assert config_without_entity.project == "project" + assert config_without_entity.version == "latest" + assert config_without_entity.entity is None + + with pytest.raises(ValueError): + WandbArtifactConfig.from_wandb_path("entity/project/name") # No version + with pytest.raises(ValueError): + WandbArtifactConfig.from_wandb_path("entity/project/name/version") # Bad delimiter diff --git a/tests/unit/jobs/test_finetuning_config.py b/tests/unit/jobs/test_finetuning_config.py index 54b9c515..0e776cf6 100644 --- a/tests/unit/jobs/test_finetuning_config.py +++ b/tests/unit/jobs/test_finetuning_config.py @@ -38,15 +38,14 @@ def test_serde_round_trip(finetuning_job_config): assert FinetuningJobConfig.parse_raw(finetuning_job_config.json()) == finetuning_job_config -def test_parse_yaml_file(finetuning_job_config, tmp_path_factory): - config_path = tmp_path_factory.mktemp("flamingo_tests") / "finetuning_config.yaml" - finetuning_job_config.to_yaml_file(config_path) - assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) +def test_parse_yaml_file(finetuning_job_config): + with finetuning_job_config.to_tempfile() as config_path: + assert finetuning_job_config == FinetuningJobConfig.from_yaml_file(config_path) def test_load_example_config(examples_dir): """Load the example configs to make sure they stay up to date.""" - config_file = examples_dir / "configs" / "finetuning_config.yaml" + config_file = examples_dir / "configs" / "finetuning.yaml" config = FinetuningJobConfig.from_yaml_file(config_file) assert FinetuningJobConfig.parse_raw(config.json()) == config diff --git a/tests/unit/jobs/test_lm_harness_config.py b/tests/unit/jobs/test_lm_harness_config.py index fdbb2fa0..a2a07bdc 100644 --- a/tests/unit/jobs/test_lm_harness_config.py +++ b/tests/unit/jobs/test_lm_harness_config.py @@ -47,15 +47,14 @@ def test_serde_round_trip(lm_harness_job_config): assert LMHarnessJobConfig.parse_raw(lm_harness_job_config.json()) == lm_harness_job_config -def test_parse_yaml_file(lm_harness_job_config, tmp_path_factory): - config_path = tmp_path_factory.mktemp("flamingo_tests") / "lm_harness_config.yaml" - lm_harness_job_config.to_yaml_file(config_path) - assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) +def test_parse_yaml_file(lm_harness_job_config): + with lm_harness_job_config.to_tempfile() as config_path: + assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) def test_load_example_config(examples_dir): """Load the example configs to make sure they stay up to date.""" - config_file = examples_dir / "configs" / "lm_harness_config.yaml" + config_file = examples_dir / "configs" / "lm_harness.yaml" config = LMHarnessJobConfig.from_yaml_file(config_file) assert LMHarnessJobConfig.parse_raw(config.json()) == config diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 1f03ad74..c4022384 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1,6 +1,7 @@ import pytest import torch +from flamingo.jobs.simple import SimpleJobConfig from flamingo.types import TorchDtypeString @@ -17,3 +18,11 @@ def test_torch_dtype_validation(): TorchDtypeString.validate(5) with pytest.raises(ValueError): TorchDtypeString.validate("dogs") + + +def test_config_as_tempfile(): + config = SimpleJobConfig(magic_number=42) + config_name = "my-special-config.yaml" + with config.to_tempfile(name=config_name) as path: + assert path.name == config_name + assert SimpleJobConfig.from_yaml_file(path) == config