diff --git a/src/flamingo/cli/run.py b/src/flamingo/cli/run.py index 175f7cc7..59e59910 100644 --- a/src/flamingo/cli/run.py +++ b/src/flamingo/cli/run.py @@ -23,7 +23,8 @@ def run_finetuning(config: str) -> None: from flamingo.jobs.finetuning import FinetuningJobConfig, run_finetuning config = FinetuningJobConfig.from_yaml_file(config) - run_finetuning(config) + artifact_loader = WandbArtifactLoader() + run_finetuning(config, artifact_loader) @group.command("lm-harness", help="Run the lm-harness evaluation job.") @@ -31,6 +32,6 @@ def run_finetuning(config: str) -> None: def run_lm_harness(config: str) -> None: from flamingo.jobs.lm_harness import LMHarnessJobConfig, run_lm_harness - artifact_loader = WandbArtifactLoader() config = LMHarnessJobConfig.from_yaml_file(config) + artifact_loader = WandbArtifactLoader() run_lm_harness(config, artifact_loader) diff --git a/src/flamingo/integrations/huggingface/asset_loader.py b/src/flamingo/integrations/huggingface/asset_loader.py index 2cdf6dce..5128a44d 100644 --- a/src/flamingo/integrations/huggingface/asset_loader.py +++ b/src/flamingo/integrations/huggingface/asset_loader.py @@ -106,7 +106,7 @@ def load_pretrained_model( An exception is raised if the HuggingFace repo does not contain a `config.json` file. - TODO(RD2024-87): Handle PEFT adapter loading directly in this method + TODO(RD2024-87): This fails if the checkpoint only contains a PEFT adapter config """ device_map, bnb_config = None, None if quantization is not None: diff --git a/src/flamingo/jobs/finetuning/entrypoint.py b/src/flamingo/jobs/finetuning/entrypoint.py index debed613..56cedb37 100644 --- a/src/flamingo/jobs/finetuning/entrypoint.py +++ b/src/flamingo/jobs/finetuning/entrypoint.py @@ -1,3 +1,6 @@ +from typing import Any + +import ray from ray import train from ray.train import CheckpointConfig, RunConfig, ScalingConfig from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer @@ -7,8 +10,8 @@ from flamingo.integrations.huggingface import HuggingFaceAssetLoader from flamingo.integrations.wandb import ( + ArtifactLoader, ArtifactType, - WandbArtifactLoader, WandbResumeMode, build_directory_artifact, default_artifact_name, @@ -24,10 +27,10 @@ def is_tracking_enabled(config: FinetuningJobConfig): return config.tracking is not None and train.get_context().get_world_rank() == 0 -def load_and_train(config: FinetuningJobConfig): - # Load the input artifacts, potentially linking them to the active W&B run - # TODO(RD2024-89): Inject this into Ray workers somehow - hf_loader = HuggingFaceAssetLoader(WandbArtifactLoader()) +def load_and_train(config: FinetuningJobConfig, artifact_loader: ArtifactLoader): + # Load the HF assets from configurations + # Internally, artifact lineages are declared for the active training run + hf_loader = HuggingFaceAssetLoader(artifact_loader) model = hf_loader.load_pretrained_model(config.model, config.quantization) tokenizer = hf_loader.load_pretrained_tokenizer(config.tokenizer) datasets = hf_loader.load_and_split_dataset(config.dataset) @@ -59,18 +62,25 @@ def load_and_train(config: FinetuningJobConfig): trainer.train() -def training_function(config_data: dict): - config = FinetuningJobConfig(**config_data) - if is_tracking_enabled(config): - with wandb_init_from_config( - config.tracking, resume=WandbResumeMode.NEVER, job_type=FlamingoJobType.FINETUNING - ): - load_and_train(config) - else: - load_and_train(config) +def run_finetuning(config: FinetuningJobConfig, artifact_loader: ArtifactLoader): + # Place the artifact loader in Ray object store + artifact_loader_ref = ray.put(artifact_loader) + # Define training function internally to capture the artifact loader ref as a closure + # Reference: https://docs.ray.io/en/latest/ray-core/objects.html#closure-capture-of-objects + def training_function(config_data: dict[str, Any]): + artifact_loader = ray.get(artifact_loader_ref) + config = FinetuningJobConfig(**config_data) + if is_tracking_enabled(config): + with wandb_init_from_config( + config.tracking, + resume=WandbResumeMode.NEVER, + job_type=FlamingoJobType.FINETUNING, + ): + load_and_train(config, artifact_loader) + else: + load_and_train(config, artifact_loader) -def run_finetuning(config: FinetuningJobConfig): # Construct Ray train configurations from input config scaling_config = ScalingConfig( use_gpu=config.ray.use_gpu, @@ -101,5 +111,4 @@ def run_finetuning(config: FinetuningJobConfig): reference=True, ) print("Logging artifact for model checkpoint...") - artifact_loader = WandbArtifactLoader() artifact_loader.log_artifact(model_artifact) diff --git a/tests/conftest.py b/tests/conftest.py index 41e8f770..84102698 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,18 @@ def xyz_dataset_artifact(resources_dir): @pytest.fixture -def gpt2_model_artifact(resources_dir): +def text_dataset_artifact(resources_dir): + dataset_path = resources_dir / "datasets" / "tiny_shakespeare" + return build_directory_artifact( + artifact_name="tiny-shakespeare-dataset", + artifact_type=ArtifactType.DATASET, + dir_path=dataset_path, + reference=True, + ) + + +@pytest.fixture +def llm_model_artifact(resources_dir): model_path = resources_dir / "models" / "fake_gpt2" return build_directory_artifact( artifact_name="fake-gpt2-model", diff --git a/tests/integration/README.md b/tests/integration/README.md new file mode 100644 index 00000000..308b41c6 --- /dev/null +++ b/tests/integration/README.md @@ -0,0 +1,29 @@ +# Integration Tests + +This folder houses tests that bring together the local package code +and external job dependencies. +Currently, the main external dependencies of the package are a Ray cluster +and tracking services (e.g., W&B). + +## Ray compute + +A Ray cluster is provided for testing as a `pytest` fixture (see `conftest.py`). +Currently, this is a tiny cluster with a fixed number of CPUs that runs on +the local test runner machine. +The [Ray documentation](https://docs.ray.io/en/latest/ray-core/examples/testing-tips.html) +provides helpful guides on how to set these clusters up for testing. + +## Tracking services + +Weights & Biases is currently used as the main experiment tracking service. +For testing, W&B can be disabled by setting the environment variable `WANDB_MODE="offline"`, +which is done automatically in a fixture for integration tests. +This causes the [W&B SDK to act like a no-op](https://docs.wandb.ai/guides/technical-faq/general#can-i-disable-wandb-when-testing-my-code) +so the actual service is not contacted during testing. + +However, when W&B is disabled, the loading and logging of artifacts is also disabled +which breaks the input/output data flow for the job entrypoints. +To work around this during testing, we use the `FakeArtifactLoader` class +that stores artifacts in in-memory storage to avoid calls to the W&B SDK. +This allows the full job entrypoints to be executed +and the output artifacts produced by the jobs to be verified as test assertions. diff --git a/tests/integration/test_finetuning.py b/tests/integration/test_finetuning.py new file mode 100644 index 00000000..c71c0cb7 --- /dev/null +++ b/tests/integration/test_finetuning.py @@ -0,0 +1,50 @@ +import pytest + +from flamingo.integrations.huggingface import AutoModelConfig, TextDatasetConfig, TrainerConfig +from flamingo.integrations.wandb import ArtifactType, WandbArtifactConfig, WandbRunConfig +from flamingo.jobs.finetuning import FinetuningJobConfig, FinetuningRayConfig, run_finetuning +from tests.test_utils import FakeArtifactLoader + + +@pytest.fixture +def job_config(llm_model_artifact, text_dataset_artifact): + model_config = AutoModelConfig( + load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test") + ) + dataset_config = TextDatasetConfig( + load_from=WandbArtifactConfig(name=text_dataset_artifact.name, project="test"), + text_field="text", + split="train", + ) + trainer_config = TrainerConfig( + max_seq_length=8, + num_train_epochs=1, + save_steps=1, + save_strategy="epoch", + ) + tracking_config = WandbRunConfig(name="test-finetuning-job") + ray_config = FinetuningRayConfig(num_workers=1, use_gpu=False) + return FinetuningJobConfig( + model=model_config, + dataset=dataset_config, + trainer=trainer_config, + tracking=tracking_config, + ray=ray_config, + ) + + +def test_finetuning_job(llm_model_artifact, text_dataset_artifact, job_config): + # Preload input artifact in loader + artifact_loader = FakeArtifactLoader() + artifact_loader.log_artifact(llm_model_artifact) + artifact_loader.log_artifact(text_dataset_artifact) + + # Run test job + run_finetuning(job_config, artifact_loader) + + # Two input artifacts, and one output model artifact produced + artifacts = artifact_loader.get_artifacts() + num_dataset_artifacts = len([a for a in artifacts if a.type == ArtifactType.DATASET]) + num_model_artifacts = len([a for a in artifacts if a.type == ArtifactType.MODEL]) + assert num_dataset_artifacts == 1 + assert num_model_artifacts == 2 diff --git a/tests/integration/test_lm_harness.py b/tests/integration/test_lm_harness.py index f3d4fab5..b1e4814f 100644 --- a/tests/integration/test_lm_harness.py +++ b/tests/integration/test_lm_harness.py @@ -7,9 +7,10 @@ @pytest.fixture -def job_config(gpt2_model_artifact): - artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="test") - model_config = AutoModelConfig(load_from=artifact_config) +def job_config(llm_model_artifact): + model_config = AutoModelConfig( + load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test") + ) tracking_config = WandbRunConfig(name="test-lm-harness-job") evaluator_config = LMHarnessEvaluatorConfig(tasks=["hellaswag"], limit=5) @@ -20,10 +21,10 @@ def job_config(gpt2_model_artifact): ) -def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config): +def test_lm_harness_job_with_tracking(llm_model_artifact, job_config): # Preload input artifact in loader artifact_loader = FakeArtifactLoader() - artifact_loader.log_artifact(gpt2_model_artifact) + artifact_loader.log_artifact(llm_model_artifact) # Run test job run_lm_harness(job_config, artifact_loader) @@ -32,13 +33,13 @@ def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config): assert artifact_loader.num_artifacts() == 2 -def test_lm_harness_job_no_tracking(gpt2_model_artifact, job_config): +def test_lm_harness_job_no_tracking(llm_model_artifact, job_config): # Disable tracking on job config job_config.tracking = None # Preload input artifact in loader artifact_loader = FakeArtifactLoader() - artifact_loader.log_artifact(gpt2_model_artifact) + artifact_loader.log_artifact(llm_model_artifact) # Run test job run_lm_harness(job_config, artifact_loader) diff --git a/tests/resources/datasets/tiny_shakespeare/cache-1f365c79744bfcef.arrow b/tests/resources/datasets/tiny_shakespeare/cache-1f365c79744bfcef.arrow new file mode 100644 index 00000000..a238f246 Binary files /dev/null and b/tests/resources/datasets/tiny_shakespeare/cache-1f365c79744bfcef.arrow differ diff --git a/tests/resources/datasets/tiny_shakespeare/cache-295b735eadbd0e64.arrow b/tests/resources/datasets/tiny_shakespeare/cache-295b735eadbd0e64.arrow new file mode 100644 index 00000000..c8420995 Binary files /dev/null and b/tests/resources/datasets/tiny_shakespeare/cache-295b735eadbd0e64.arrow differ diff --git a/tests/resources/datasets/tiny_shakespeare/cache-be708fc06f7fa557.arrow b/tests/resources/datasets/tiny_shakespeare/cache-be708fc06f7fa557.arrow new file mode 100644 index 00000000..4790149c Binary files /dev/null and b/tests/resources/datasets/tiny_shakespeare/cache-be708fc06f7fa557.arrow differ diff --git a/tests/resources/datasets/tiny_shakespeare/create_tiny_shakespeare.py b/tests/resources/datasets/tiny_shakespeare/create_tiny_shakespeare.py new file mode 100644 index 00000000..8cd127b9 --- /dev/null +++ b/tests/resources/datasets/tiny_shakespeare/create_tiny_shakespeare.py @@ -0,0 +1,10 @@ +"""Script to generate the `tiny_shakespeare` dataset files.""" +from pathlib import Path + +from datasets import load_dataset + +if __name__ == "__main__": + repo_id = "Trelis/tiny-shakespeare" + dataset = load_dataset(repo_id, split="train[:10]") + dataset = dataset.rename_column("Text", "text") + dataset.save_to_disk(dataset_path=Path(__file__).parent) diff --git a/tests/resources/datasets/tiny_shakespeare/data-00000-of-00001.arrow b/tests/resources/datasets/tiny_shakespeare/data-00000-of-00001.arrow new file mode 100644 index 00000000..2fe65a25 Binary files /dev/null and b/tests/resources/datasets/tiny_shakespeare/data-00000-of-00001.arrow differ diff --git a/tests/resources/datasets/tiny_shakespeare/dataset_info.json b/tests/resources/datasets/tiny_shakespeare/dataset_info.json new file mode 100644 index 00000000..10c3f93d --- /dev/null +++ b/tests/resources/datasets/tiny_shakespeare/dataset_info.json @@ -0,0 +1,48 @@ +{ + "builder_name": "csv", + "citation": "", + "config_name": "default", + "dataset_name": "tiny-shakespeare", + "dataset_size": 1343458, + "description": "", + "download_checksums": { + "hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/train.csv": { + "num_bytes": 1224246, + "checksum": null + }, + "hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/test.csv": { + "num_bytes": 119222, + "checksum": null + } + }, + "download_size": 1343468, + "features": { + "text": { + "dtype": "string", + "_type": "Value" + } + }, + "homepage": "", + "license": "", + "size_in_bytes": 2686926, + "splits": { + "train": { + "name": "train", + "num_bytes": 1224242, + "num_examples": 472, + "dataset_name": "tiny-shakespeare" + }, + "test": { + "name": "test", + "num_bytes": 119216, + "num_examples": 49, + "dataset_name": "tiny-shakespeare" + } + }, + "version": { + "version_str": "0.0.0", + "major": 0, + "minor": 0, + "patch": 0 + } +} \ No newline at end of file diff --git a/tests/resources/datasets/tiny_shakespeare/state.json b/tests/resources/datasets/tiny_shakespeare/state.json new file mode 100644 index 00000000..ad98d9db --- /dev/null +++ b/tests/resources/datasets/tiny_shakespeare/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "72f06f97cc145672", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "train[:10]" +} \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index f0f08f1a..7cc7a5b5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,11 +20,14 @@ class FakeArtifactLoader: """ def __init__(self): - self._storage = dict() + self._storage: dict[str, wandb.Artifact] = dict() def num_artifacts(self) -> int: return len(self._storage) + def get_artifacts(self) -> list[wandb.Artifact]: + return list(self._storage.values()) + def use_artifact(self, config: WandbArtifactConfig) -> wandb.Artifact: return self._storage[config.name] diff --git a/tests/unit/integrations/huggingface/test_loading_utils.py b/tests/unit/integrations/huggingface/test_loading_utils.py index 4ebb73fe..184ef40b 100644 --- a/tests/unit/integrations/huggingface/test_loading_utils.py +++ b/tests/unit/integrations/huggingface/test_loading_utils.py @@ -23,13 +23,13 @@ def test_dataset_loading(xyz_dataset_artifact): assert "train" in datasets and "test" in datasets -def test_model_loading(gpt2_model_artifact): +def test_model_loading(llm_model_artifact): # Preload fake artifact for testing artifact_loader = FakeArtifactLoader() - artifact_loader.log_artifact(gpt2_model_artifact) + artifact_loader.log_artifact(llm_model_artifact) hf_loader = HuggingFaceAssetLoader(artifact_loader) - artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="project") + artifact_config = WandbArtifactConfig(name=llm_model_artifact.name, project="project") model_config = AutoModelConfig(load_from=artifact_config, torch_dtype=torch.bfloat16) hf_config = hf_loader.load_pretrained_config(model_config)