Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

RD2024-89: Use artifact loader in finetuning #34

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/flamingo/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def run_finetuning(config: str) -> None:
from flamingo.jobs.finetuning import FinetuningJobConfig, run_finetuning

config = FinetuningJobConfig.from_yaml_file(config)
run_finetuning(config)
artifact_loader = WandbArtifactLoader()
run_finetuning(config, artifact_loader)


@group.command("lm-harness", help="Run the lm-harness evaluation job.")
@click.option("--config", type=str)
def run_lm_harness(config: str) -> None:
from flamingo.jobs.lm_harness import LMHarnessJobConfig, run_lm_harness

artifact_loader = WandbArtifactLoader()
config = LMHarnessJobConfig.from_yaml_file(config)
artifact_loader = WandbArtifactLoader()
run_lm_harness(config, artifact_loader)
2 changes: 1 addition & 1 deletion src/flamingo/integrations/huggingface/asset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_pretrained_model(

An exception is raised if the HuggingFace repo does not contain a `config.json` file.

TODO(RD2024-87): Handle PEFT adapter loading directly in this method
TODO(RD2024-87): This fails if the checkpoint only contains a PEFT adapter config
"""
device_map, bnb_config = None, None
if quantization is not None:
Expand Down
41 changes: 25 additions & 16 deletions src/flamingo/jobs/finetuning/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any

import ray
from ray import train
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
Expand All @@ -7,8 +10,8 @@

from flamingo.integrations.huggingface import HuggingFaceAssetLoader
from flamingo.integrations.wandb import (
ArtifactLoader,
ArtifactType,
WandbArtifactLoader,
WandbResumeMode,
build_directory_artifact,
default_artifact_name,
Expand All @@ -24,10 +27,10 @@ def is_tracking_enabled(config: FinetuningJobConfig):
return config.tracking is not None and train.get_context().get_world_rank() == 0


def load_and_train(config: FinetuningJobConfig):
# Load the input artifacts, potentially linking them to the active W&B run
# TODO(RD2024-89): Inject this into Ray workers somehow
hf_loader = HuggingFaceAssetLoader(WandbArtifactLoader())
def load_and_train(config: FinetuningJobConfig, artifact_loader: ArtifactLoader):
# Load the HF assets from configurations
# Internally, artifact lineages are declared for the active training run
hf_loader = HuggingFaceAssetLoader(artifact_loader)
model = hf_loader.load_pretrained_model(config.model, config.quantization)
tokenizer = hf_loader.load_pretrained_tokenizer(config.tokenizer)
datasets = hf_loader.load_and_split_dataset(config.dataset)
Expand Down Expand Up @@ -59,18 +62,25 @@ def load_and_train(config: FinetuningJobConfig):
trainer.train()


def training_function(config_data: dict):
config = FinetuningJobConfig(**config_data)
if is_tracking_enabled(config):
with wandb_init_from_config(
config.tracking, resume=WandbResumeMode.NEVER, job_type=FlamingoJobType.FINETUNING
):
load_and_train(config)
else:
load_and_train(config)
def run_finetuning(config: FinetuningJobConfig, artifact_loader: ArtifactLoader):
# Place the artifact loader in Ray object store
artifact_loader_ref = ray.put(artifact_loader)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up working out pretty easily. I can place the ArtifactLoader class in the Ray object store, and then retrieve it from the object ref within each worker.

The caveat is that the ArtifactLoader instance needs to be serializable via cloudpickle, but I don't think it's an antipattern to be doing this according to the Ray docs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on a live cluster?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does

Copy link
Contributor Author

@sfriedowitz sfriedowitz Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the job logs give any info about the object ref retrieval, but here is an example run off this branch: http://10.145.55.132:8265/#/jobs/04000000

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the ref something we might need? if so, would be helpful to log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The object ref is essentially a UID string. It's not really useful for the logs, so I don't know if its helpful to add an additional statement there


# Define training function internally to capture the artifact loader ref as a closure
# Reference: https://docs.ray.io/en/latest/ray-core/objects.html#closure-capture-of-objects
def training_function(config_data: dict[str, Any]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way we could unnest these nested functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nesting the functions allows the training function to act as a closure and capture the artifact_loader_ref from the outside scope. That is how it is passed into the training function.

artifact_loader = ray.get(artifact_loader_ref)
config = FinetuningJobConfig(**config_data)
if is_tracking_enabled(config):
with wandb_init_from_config(
config.tracking,
resume=WandbResumeMode.NEVER,
job_type=FlamingoJobType.FINETUNING,
):
load_and_train(config, artifact_loader)
else:
load_and_train(config, artifact_loader)

def run_finetuning(config: FinetuningJobConfig):
# Construct Ray train configurations from input config
scaling_config = ScalingConfig(
use_gpu=config.ray.use_gpu,
Expand Down Expand Up @@ -101,5 +111,4 @@ def run_finetuning(config: FinetuningJobConfig):
reference=True,
)
print("Logging artifact for model checkpoint...")
artifact_loader = WandbArtifactLoader()
artifact_loader.log_artifact(model_artifact)
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ def xyz_dataset_artifact(resources_dir):


@pytest.fixture
def gpt2_model_artifact(resources_dir):
def text_dataset_artifact(resources_dir):
dataset_path = resources_dir / "datasets" / "tiny_shakespeare"
return build_directory_artifact(
artifact_name="tiny-shakespeare-dataset",
artifact_type=ArtifactType.DATASET,
dir_path=dataset_path,
reference=True,
)


@pytest.fixture
def llm_model_artifact(resources_dir):
model_path = resources_dir / "models" / "fake_gpt2"
return build_directory_artifact(
artifact_name="fake-gpt2-model",
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

from flamingo.integrations.huggingface import AutoModelConfig, TextDatasetConfig, TrainerConfig
from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig
from flamingo.jobs.finetuning import FinetuningJobConfig, FinetuningRayConfig, run_finetuning
from tests.test_utils import FakeArtifactLoader


@pytest.fixture
def job_config(llm_model_artifact, text_dataset_artifact):
model_config = AutoModelConfig(
load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test")
)
dataset_config = TextDatasetConfig(
load_from=WandbArtifactConfig(name=text_dataset_artifact.name, project="test"),
text_field="text",
split="train",
)
trainer_config = TrainerConfig(
max_seq_length=8,
num_train_epochs=1,
save_steps=1,
save_strategy="epoch",
)
tracking_config = WandbRunConfig(name="test-finetuning-job")
ray_config = FinetuningRayConfig(num_workers=1, use_gpu=False)
return FinetuningJobConfig(
model=model_config,
dataset=dataset_config,
trainer=trainer_config,
tracking=tracking_config,
ray=ray_config,
)


def test_finetuning_job(llm_model_artifact, text_dataset_artifact, job_config):
# Preload input artifact in loader
artifact_loader = FakeArtifactLoader()
sfriedowitz marked this conversation as resolved.
Show resolved Hide resolved
artifact_loader.log_artifact(llm_model_artifact)
artifact_loader.log_artifact(text_dataset_artifact)

# Run test job
run_finetuning(job_config, artifact_loader)

# Two input artifacts, and one output model artifact produced
assert artifact_loader.num_artifacts() == 3
15 changes: 8 additions & 7 deletions tests/integration/test_lm_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@


@pytest.fixture
def job_config(gpt2_model_artifact):
artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="test")
model_config = AutoModelConfig(load_from=artifact_config)
def job_config(llm_model_artifact):
model_config = AutoModelConfig(
load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test")
)

tracking_config = WandbRunConfig(name="test-lm-harness-job")
evaluator_config = LMHarnessEvaluatorConfig(tasks=["hellaswag"], limit=5)
Expand All @@ -20,10 +21,10 @@ def job_config(gpt2_model_artifact):
)


def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config):
def test_lm_harness_job_with_tracking(llm_model_artifact, job_config):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be helpful to type annotate

# Preload input artifact in loader
artifact_loader = FakeArtifactLoader()
artifact_loader.log_artifact(gpt2_model_artifact)
artifact_loader.log_artifact(llm_model_artifact)

# Run test job
run_lm_harness(job_config, artifact_loader)
Expand All @@ -32,13 +33,13 @@ def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config):
assert artifact_loader.num_artifacts() == 2


def test_lm_harness_job_no_tracking(gpt2_model_artifact, job_config):
def test_lm_harness_job_no_tracking(llm_model_artifact, job_config):
# Disable tracking on job config
job_config.tracking = None

# Preload input artifact in loader
artifact_loader = FakeArtifactLoader()
artifact_loader.log_artifact(gpt2_model_artifact)
artifact_loader.log_artifact(llm_model_artifact)

# Run test job
run_lm_harness(job_config, artifact_loader)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Script to generate the `tiny_shakespeare` dataset files."""
sfriedowitz marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path

from datasets import load_dataset

if __name__ == "__main__":
repo_id = "Trelis/tiny-shakespeare"
dataset = load_dataset(repo_id, split="train[:10]")
dataset = dataset.rename_column("Text", "text")
dataset.save_to_disk(dataset_path=Path(__file__).parent)
Binary file not shown.
48 changes: 48 additions & 0 deletions tests/resources/datasets/tiny_shakespeare/dataset_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"builder_name": "csv",
"citation": "",
"config_name": "default",
"dataset_name": "tiny-shakespeare",
"dataset_size": 1343458,
"description": "",
"download_checksums": {
"hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/train.csv": {
"num_bytes": 1224246,
"checksum": null
},
"hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/test.csv": {
"num_bytes": 119222,
"checksum": null
}
},
"download_size": 1343468,
"features": {
"text": {
"dtype": "string",
"_type": "Value"
}
},
"homepage": "",
"license": "",
"size_in_bytes": 2686926,
"splits": {
"train": {
"name": "train",
"num_bytes": 1224242,
"num_examples": 472,
"dataset_name": "tiny-shakespeare"
},
"test": {
"name": "test",
"num_bytes": 119216,
"num_examples": 49,
"dataset_name": "tiny-shakespeare"
}
},
"version": {
"version_str": "0.0.0",
"major": 0,
"minor": 0,
"patch": 0
}
}
13 changes: 13 additions & 0 deletions tests/resources/datasets/tiny_shakespeare/state.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "72f06f97cc145672",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": "train[:10]"
}
6 changes: 3 additions & 3 deletions tests/unit/integrations/huggingface/test_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def test_dataset_loading(xyz_dataset_artifact):
assert "train" in datasets and "test" in datasets


def test_model_loading(gpt2_model_artifact):
def test_model_loading(llm_model_artifact):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be helpful to type annotate this

# Preload fake artifact for testing
artifact_loader = FakeArtifactLoader()
artifact_loader.log_artifact(gpt2_model_artifact)
artifact_loader.log_artifact(llm_model_artifact)
hf_loader = HuggingFaceAssetLoader(artifact_loader)

artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="project")
artifact_config = WandbArtifactConfig(name=llm_model_artifact.name, project="project")
model_config = AutoModelConfig(load_from=artifact_config, torch_dtype=torch.bfloat16)

hf_config = hf_loader.load_pretrained_config(model_config)
Expand Down