Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support grain checkpoint #133

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/data_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `
| -------- | ---------------- | --------------- | ----------------------- |
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) |

## Usage examples

Expand Down Expand Up @@ -45,6 +46,12 @@ dataset_type: tfrecord
train_data_dir: gs://<bucket>/<folder> # will use all TFRecord files under the directory
```

### Grain (dataset_type=grain)
```
dataset_type: grain
grain_train_files: gs://<bucket>/<folder>/*.arrayrecord # match the file pattern
```

## Best Practice
### Multihost Dataloading
In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
from jax.sharding import Mesh
import orbax.checkpoint as ocp
import grain.python as grain
from maxdiffusion import (
max_utils,
FlaxStableDiffusionPipeline,
Expand Down Expand Up @@ -57,7 +58,11 @@ def __init__(self, config, checkpoint_type):
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)

def _create_optimizer(self, config, learning_rate):
Expand Down Expand Up @@ -157,6 +162,22 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
training=is_training,
)

def restore_data_iterator_state(self, data_iterator):
if (
self.config.dataset_type == "grain"
and data_iterator is not None
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
):
max_logging.log("Restoring data iterator from checkpoint")
restored = self.checkpoint_manager.restore(
self.checkpoint_manager.latest_step(),
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
)
data_iterator.local_iterator = restored["iter"]
else:
max_logging.log("data iterator checkpoint not found")
return data_iterator

def _get_pipeline_class(self):
if self.checkpoint_type == STABLE_DIFFUSION_CHECKPOINT:
pipeline_class = FlaxStableDiffusionPipeline
Expand Down Expand Up @@ -212,7 +233,7 @@ def load_diffusers_checkpoint(self):
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
return pipeline, params

def save_checkpoint(self, train_step, pipeline, params, train_states):
def save_checkpoint(self, train_step, pipeline, params, train_states, data_iterator=None):
def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

Expand All @@ -233,7 +254,8 @@ def config_to_json(model_or_config):

tokenizer_config = {"path": self.config.tokenizer_model_name_or_path}
items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config)

if self.config.dataset_type == "grain":
items["iter"] = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step=None):
Expand Down
3 changes: 3 additions & 0 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def create_orbax_checkpoint_manager(
enable_checkpointing: bool,
save_interval_steps,
checkpoint_type: str,
dataset_type: str = "tf",
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
Expand Down Expand Up @@ -70,6 +71,8 @@ def create_orbax_checkpoint_manager(
"text_encoder_2_state",
"text_encoder_2_config",
)
if dataset_type == "grain":
item_names += ("iter",)

print("item_names: ", item_names)

Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/trainers/base_stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def start_training(self):

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

data_shardings = self.get_data_shardings()
# Compile train_step
Expand Down
12 changes: 8 additions & 4 deletions src/maxdiffusion/trainers/stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import os
import sys
from functools import partial
import datetime
import time
Expand Down Expand Up @@ -211,7 +212,6 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
unet_state, vae_state, text_encoder_state, example_batch, train_rngs
)
samples_count = self.total_train_batch_size * (step + 1)
new_time = datetime.datetime.now()

train_utils.record_scalar_metrics(
Expand All @@ -221,11 +221,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
last_step_completion = new_time

if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0:
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
train_states["unet_state"] = unet_state
train_states["vae_state"] = vae_state
train_states["text_encoder"] = text_encoder_state
self.save_checkpoint(step, pipeline, params, train_states)
self.save_checkpoint(step, pipeline, params, train_states, data_iterator)

if self.checkpoint_manager.reached_preemption(step):
self.checkpoint_manager.wait_until_finished()
sys.exit()

if self.config.enable_profiler and step == last_profiling_step:
max_utils.deactivate_profiler(self.config)
Expand All @@ -239,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
train_states["vae_state"] = vae_state
train_states["text_encoder"] = text_encoder_state
# save the inference states of the last checkpoint so they can be easily loaded during gen.
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states)
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states, data_iterator)
self.checkpoint_manager.wait_until_finished()


Expand Down
Loading