Skip to content

Commit

Permalink
support grain checkpoint (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei authored Nov 13, 2024
1 parent 269b621 commit f4b9042
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 8 deletions.
9 changes: 8 additions & 1 deletion docs/data_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `
| -------- | ---------------- | --------------- | ----------------------- |
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) |

## Usage examples

Expand Down Expand Up @@ -45,6 +46,12 @@ dataset_type: tfrecord
train_data_dir: gs://<bucket>/<folder> # will use all TFRecord files under the directory
```

### Grain (dataset_type=grain)
```
dataset_type: grain
grain_train_files: gs://<bucket>/<folder>/*.arrayrecord # match the file pattern
```

## Best Practice
### Multihost Dataloading
In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
from jax.sharding import Mesh
import orbax.checkpoint as ocp
import grain.python as grain
from maxdiffusion import (
max_utils,
FlaxStableDiffusionPipeline,
Expand Down Expand Up @@ -57,7 +58,11 @@ def __init__(self, config, checkpoint_type):
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)

def _create_optimizer(self, config, learning_rate):
Expand Down Expand Up @@ -157,6 +162,22 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
training=is_training,
)

def restore_data_iterator_state(self, data_iterator):
if (
self.config.dataset_type == "grain"
and data_iterator is not None
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
):
max_logging.log("Restoring data iterator from checkpoint")
restored = self.checkpoint_manager.restore(
self.checkpoint_manager.latest_step(),
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
)
data_iterator.local_iterator = restored["iter"]
else:
max_logging.log("data iterator checkpoint not found")
return data_iterator

def _get_pipeline_class(self):
if self.checkpoint_type == STABLE_DIFFUSION_CHECKPOINT:
pipeline_class = FlaxStableDiffusionPipeline
Expand Down Expand Up @@ -212,7 +233,7 @@ def load_diffusers_checkpoint(self):
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
return pipeline, params

def save_checkpoint(self, train_step, pipeline, params, train_states):
def save_checkpoint(self, train_step, pipeline, params, train_states, data_iterator=None):
def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

Expand All @@ -233,7 +254,8 @@ def config_to_json(model_or_config):

tokenizer_config = {"path": self.config.tokenizer_model_name_or_path}
items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config)

if self.config.dataset_type == "grain":
items["iter"] = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step=None):
Expand Down
3 changes: 3 additions & 0 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def create_orbax_checkpoint_manager(
enable_checkpointing: bool,
save_interval_steps,
checkpoint_type: str,
dataset_type: str = "tf",
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
Expand Down Expand Up @@ -70,6 +71,8 @@ def create_orbax_checkpoint_manager(
"text_encoder_2_state",
"text_encoder_2_config",
)
if dataset_type == "grain":
item_names += ("iter",)

print("item_names: ", item_names)

Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/trainers/base_stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def start_training(self):

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

data_shardings = self.get_data_shardings()
# Compile train_step
Expand Down
12 changes: 8 additions & 4 deletions src/maxdiffusion/trainers/stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import os
import sys
from functools import partial
import datetime
import time
Expand Down Expand Up @@ -211,7 +212,6 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
unet_state, vae_state, text_encoder_state, example_batch, train_rngs
)
samples_count = self.total_train_batch_size * (step + 1)
new_time = datetime.datetime.now()

train_utils.record_scalar_metrics(
Expand All @@ -221,11 +221,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
last_step_completion = new_time

if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0:
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
train_states["unet_state"] = unet_state
train_states["vae_state"] = vae_state
train_states["text_encoder"] = text_encoder_state
self.save_checkpoint(step, pipeline, params, train_states)
self.save_checkpoint(step, pipeline, params, train_states, data_iterator)

if self.checkpoint_manager.reached_preemption(step):
self.checkpoint_manager.wait_until_finished()
sys.exit()

if self.config.enable_profiler and step == last_profiling_step:
max_utils.deactivate_profiler(self.config)
Expand All @@ -239,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
train_states["vae_state"] = vae_state
train_states["text_encoder"] = text_encoder_state
# save the inference states of the last checkpoint so they can be easily loaded during gen.
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states)
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states, data_iterator)
self.checkpoint_manager.wait_until_finished()


Expand Down

0 comments on commit f4b9042

Please sign in to comment.