-
Notifications
You must be signed in to change notification settings - Fork 3
RD2024-89: Use artifact loader in finetuning #34
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from typing import Any | ||
|
||
import ray | ||
from ray import train | ||
from ray.train import CheckpointConfig, RunConfig, ScalingConfig | ||
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer | ||
|
@@ -7,8 +10,8 @@ | |
|
||
from flamingo.integrations.huggingface import HuggingFaceAssetLoader | ||
from flamingo.integrations.wandb import ( | ||
ArtifactLoader, | ||
ArtifactType, | ||
WandbArtifactLoader, | ||
WandbResumeMode, | ||
build_directory_artifact, | ||
default_artifact_name, | ||
|
@@ -24,10 +27,10 @@ def is_tracking_enabled(config: FinetuningJobConfig): | |
return config.tracking is not None and train.get_context().get_world_rank() == 0 | ||
|
||
|
||
def load_and_train(config: FinetuningJobConfig): | ||
# Load the input artifacts, potentially linking them to the active W&B run | ||
# TODO(RD2024-89): Inject this into Ray workers somehow | ||
hf_loader = HuggingFaceAssetLoader(WandbArtifactLoader()) | ||
def load_and_train(config: FinetuningJobConfig, artifact_loader: ArtifactLoader): | ||
# Load the HF assets from configurations | ||
# Internally, artifact lineages are declared for the active training run | ||
hf_loader = HuggingFaceAssetLoader(artifact_loader) | ||
model = hf_loader.load_pretrained_model(config.model, config.quantization) | ||
tokenizer = hf_loader.load_pretrained_tokenizer(config.tokenizer) | ||
datasets = hf_loader.load_and_split_dataset(config.dataset) | ||
|
@@ -59,18 +62,25 @@ def load_and_train(config: FinetuningJobConfig): | |
trainer.train() | ||
|
||
|
||
def training_function(config_data: dict): | ||
config = FinetuningJobConfig(**config_data) | ||
if is_tracking_enabled(config): | ||
with wandb_init_from_config( | ||
config.tracking, resume=WandbResumeMode.NEVER, job_type=FlamingoJobType.FINETUNING | ||
): | ||
load_and_train(config) | ||
else: | ||
load_and_train(config) | ||
def run_finetuning(config: FinetuningJobConfig, artifact_loader: ArtifactLoader): | ||
# Place the artifact loader in Ray object store | ||
artifact_loader_ref = ray.put(artifact_loader) | ||
|
||
# Define training function internally to capture the artifact loader ref as a closure | ||
# Reference: https://docs.ray.io/en/latest/ray-core/objects.html#closure-capture-of-objects | ||
def training_function(config_data: dict[str, Any]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a way we could unnest these nested functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nesting the functions allows the training function to act as a closure and capture the |
||
artifact_loader = ray.get(artifact_loader_ref) | ||
config = FinetuningJobConfig(**config_data) | ||
if is_tracking_enabled(config): | ||
with wandb_init_from_config( | ||
config.tracking, | ||
resume=WandbResumeMode.NEVER, | ||
job_type=FlamingoJobType.FINETUNING, | ||
): | ||
load_and_train(config, artifact_loader) | ||
else: | ||
load_and_train(config, artifact_loader) | ||
|
||
def run_finetuning(config: FinetuningJobConfig): | ||
# Construct Ray train configurations from input config | ||
scaling_config = ScalingConfig( | ||
use_gpu=config.ray.use_gpu, | ||
|
@@ -101,5 +111,4 @@ def run_finetuning(config: FinetuningJobConfig): | |
reference=True, | ||
) | ||
print("Logging artifact for model checkpoint...") | ||
artifact_loader = WandbArtifactLoader() | ||
artifact_loader.log_artifact(model_artifact) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import pytest | ||
|
||
from flamingo.integrations.huggingface import AutoModelConfig, TextDatasetConfig, TrainerConfig | ||
from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig | ||
from flamingo.jobs.finetuning import FinetuningJobConfig, FinetuningRayConfig, run_finetuning | ||
from tests.test_utils import FakeArtifactLoader | ||
|
||
|
||
@pytest.fixture | ||
def job_config(llm_model_artifact, text_dataset_artifact): | ||
model_config = AutoModelConfig( | ||
load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test") | ||
) | ||
dataset_config = TextDatasetConfig( | ||
load_from=WandbArtifactConfig(name=text_dataset_artifact.name, project="test"), | ||
text_field="text", | ||
split="train", | ||
) | ||
trainer_config = TrainerConfig( | ||
max_seq_length=8, | ||
num_train_epochs=1, | ||
save_steps=1, | ||
save_strategy="epoch", | ||
) | ||
tracking_config = WandbRunConfig(name="test-finetuning-job") | ||
ray_config = FinetuningRayConfig(num_workers=1, use_gpu=False) | ||
return FinetuningJobConfig( | ||
model=model_config, | ||
dataset=dataset_config, | ||
trainer=trainer_config, | ||
tracking=tracking_config, | ||
ray=ray_config, | ||
) | ||
|
||
|
||
def test_finetuning_job(llm_model_artifact, text_dataset_artifact, job_config): | ||
# Preload input artifact in loader | ||
artifact_loader = FakeArtifactLoader() | ||
sfriedowitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
artifact_loader.log_artifact(llm_model_artifact) | ||
artifact_loader.log_artifact(text_dataset_artifact) | ||
|
||
# Run test job | ||
run_finetuning(job_config, artifact_loader) | ||
|
||
# Two input artifacts, and one output model artifact produced | ||
assert artifact_loader.num_artifacts() == 3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,10 @@ | |
|
||
|
||
@pytest.fixture | ||
def job_config(gpt2_model_artifact): | ||
artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="test") | ||
model_config = AutoModelConfig(load_from=artifact_config) | ||
def job_config(llm_model_artifact): | ||
model_config = AutoModelConfig( | ||
load_from=WandbArtifactConfig(name=llm_model_artifact.name, project="test") | ||
) | ||
|
||
tracking_config = WandbRunConfig(name="test-lm-harness-job") | ||
evaluator_config = LMHarnessEvaluatorConfig(tasks=["hellaswag"], limit=5) | ||
|
@@ -20,10 +21,10 @@ def job_config(gpt2_model_artifact): | |
) | ||
|
||
|
||
def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config): | ||
def test_lm_harness_job_with_tracking(llm_model_artifact, job_config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be helpful to type annotate |
||
# Preload input artifact in loader | ||
artifact_loader = FakeArtifactLoader() | ||
artifact_loader.log_artifact(gpt2_model_artifact) | ||
artifact_loader.log_artifact(llm_model_artifact) | ||
|
||
# Run test job | ||
run_lm_harness(job_config, artifact_loader) | ||
|
@@ -32,13 +33,13 @@ def test_lm_harness_job_with_tracking(gpt2_model_artifact, job_config): | |
assert artifact_loader.num_artifacts() == 2 | ||
|
||
|
||
def test_lm_harness_job_no_tracking(gpt2_model_artifact, job_config): | ||
def test_lm_harness_job_no_tracking(llm_model_artifact, job_config): | ||
# Disable tracking on job config | ||
job_config.tracking = None | ||
|
||
# Preload input artifact in loader | ||
artifact_loader = FakeArtifactLoader() | ||
artifact_loader.log_artifact(gpt2_model_artifact) | ||
artifact_loader.log_artifact(llm_model_artifact) | ||
|
||
# Run test job | ||
run_lm_harness(job_config, artifact_loader) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Script to generate the `tiny_shakespeare` dataset files.""" | ||
sfriedowitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pathlib import Path | ||
|
||
from datasets import load_dataset | ||
|
||
if __name__ == "__main__": | ||
repo_id = "Trelis/tiny-shakespeare" | ||
dataset = load_dataset(repo_id, split="train[:10]") | ||
dataset = dataset.rename_column("Text", "text") | ||
dataset.save_to_disk(dataset_path=Path(__file__).parent) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
{ | ||
"builder_name": "csv", | ||
"citation": "", | ||
"config_name": "default", | ||
"dataset_name": "tiny-shakespeare", | ||
"dataset_size": 1343458, | ||
"description": "", | ||
"download_checksums": { | ||
"hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/train.csv": { | ||
"num_bytes": 1224246, | ||
"checksum": null | ||
}, | ||
"hf://datasets/Trelis/tiny-shakespeare@2225a5674b252d623cfb151ef099d179629c2f49/test.csv": { | ||
"num_bytes": 119222, | ||
"checksum": null | ||
} | ||
}, | ||
"download_size": 1343468, | ||
"features": { | ||
"text": { | ||
"dtype": "string", | ||
"_type": "Value" | ||
} | ||
}, | ||
"homepage": "", | ||
"license": "", | ||
"size_in_bytes": 2686926, | ||
"splits": { | ||
"train": { | ||
"name": "train", | ||
"num_bytes": 1224242, | ||
"num_examples": 472, | ||
"dataset_name": "tiny-shakespeare" | ||
}, | ||
"test": { | ||
"name": "test", | ||
"num_bytes": 119216, | ||
"num_examples": 49, | ||
"dataset_name": "tiny-shakespeare" | ||
} | ||
}, | ||
"version": { | ||
"version_str": "0.0.0", | ||
"major": 0, | ||
"minor": 0, | ||
"patch": 0 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
{ | ||
"_data_files": [ | ||
{ | ||
"filename": "data-00000-of-00001.arrow" | ||
} | ||
], | ||
"_fingerprint": "72f06f97cc145672", | ||
"_format_columns": null, | ||
"_format_kwargs": {}, | ||
"_format_type": null, | ||
"_output_all_columns": false, | ||
"_split": "train[:10]" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,13 +23,13 @@ def test_dataset_loading(xyz_dataset_artifact): | |
assert "train" in datasets and "test" in datasets | ||
|
||
|
||
def test_model_loading(gpt2_model_artifact): | ||
def test_model_loading(llm_model_artifact): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be helpful to type annotate this |
||
# Preload fake artifact for testing | ||
artifact_loader = FakeArtifactLoader() | ||
artifact_loader.log_artifact(gpt2_model_artifact) | ||
artifact_loader.log_artifact(llm_model_artifact) | ||
hf_loader = HuggingFaceAssetLoader(artifact_loader) | ||
|
||
artifact_config = WandbArtifactConfig(name=gpt2_model_artifact.name, project="project") | ||
artifact_config = WandbArtifactConfig(name=llm_model_artifact.name, project="project") | ||
model_config = AutoModelConfig(load_from=artifact_config, torch_dtype=torch.bfloat16) | ||
|
||
hf_config = hf_loader.load_pretrained_config(model_config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ended up working out pretty easily. I can place the
ArtifactLoader
class in the Ray object store, and then retrieve it from the object ref within each worker.The caveat is that the
ArtifactLoader
instance needs to be serializable via cloudpickle, but I don't think it's an antipattern to be doing this according to the Ray docs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on a live cluster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the job logs give any info about the object ref retrieval, but here is an example run off this branch: http://10.145.55.132:8265/#/jobs/04000000
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the ref something we might need? if so, would be helpful to log
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The object ref is essentially a UID string. It's not really useful for the logs, so I don't know if its helpful to add an additional statement there