From f4b904248a49cc8b5e114cf996be2e817bc862f8 Mon Sep 17 00:00:00 2001 From: aireenmei <12836798+aireenmei@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:01:34 -0800 Subject: [PATCH] support grain checkpoint (#133) --- docs/data_README.md | 9 +++++- .../base_stable_diffusion_checkpointer.py | 28 +++++++++++++++++-- .../checkpointing/checkpointing_utils.py | 3 ++ .../trainers/base_stable_diffusion_trainer.py | 2 ++ .../trainers/stable_diffusion_trainer.py | 12 +++++--- 5 files changed, 46 insertions(+), 8 deletions(-) diff --git a/docs/data_README.md b/docs/data_README.md index ffa5751..2b726dc 100644 --- a/docs/data_README.md +++ b/docs/data_README.md @@ -6,7 +6,8 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag ` | -------- | ---------------- | --------------- | ----------------------- | | HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset | | tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset | -| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset | +| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset | +| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) | ## Usage examples @@ -45,6 +46,12 @@ dataset_type: tfrecord train_data_dir: gs:/// # will use all TFRecord files under the directory ``` +### Grain (dataset_type=grain) +``` +dataset_type: grain +grain_train_files: gs:////*.arrayrecord # match the file pattern +``` + ## Best Practice ### Multihost Dataloading In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met. diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 0ff1045..7221554 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -22,6 +22,7 @@ import jax from jax.sharding import Mesh import orbax.checkpoint as ocp +import grain.python as grain from maxdiffusion import ( max_utils, FlaxStableDiffusionPipeline, @@ -57,7 +58,11 @@ def __init__(self, config, checkpoint_type): self.total_train_batch_size = self.config.total_train_batch_size self.checkpoint_manager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, ) def _create_optimizer(self, config, learning_rate): @@ -157,6 +162,22 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is training=is_training, ) + def restore_data_iterator_state(self, data_iterator): + if ( + self.config.dataset_type == "grain" + and data_iterator is not None + and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists() + ): + max_logging.log("Restoring data iterator from checkpoint") + restored = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)), + ) + data_iterator.local_iterator = restored["iter"] + else: + max_logging.log("data iterator checkpoint not found") + return data_iterator + def _get_pipeline_class(self): if self.checkpoint_type == STABLE_DIFFUSION_CHECKPOINT: pipeline_class = FlaxStableDiffusionPipeline @@ -212,7 +233,7 @@ def load_diffusers_checkpoint(self): params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params) return pipeline, params - def save_checkpoint(self, train_step, pipeline, params, train_states): + def save_checkpoint(self, train_step, pipeline, params, train_states, data_iterator=None): def config_to_json(model_or_config): return json.loads(model_or_config.to_json_string()) @@ -233,7 +254,8 @@ def config_to_json(model_or_config): tokenizer_config = {"path": self.config.tokenizer_model_name_or_path} items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config) - + if self.config.dataset_type == "grain": + items["iter"] = grain.PyGrainCheckpointSave(data_iterator.local_iterator) self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 6768177..b8710e1 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -40,6 +40,7 @@ def create_orbax_checkpoint_manager( enable_checkpointing: bool, save_interval_steps, checkpoint_type: str, + dataset_type: str = "tf", use_async: bool = True, orbax_logger: Optional[abstract_logger.AbstractLogger] = None, ): @@ -70,6 +71,8 @@ def create_orbax_checkpoint_manager( "text_encoder_2_state", "text_encoder_2_config", ) + if dataset_type == "grain": + item_names += ("iter",) print("item_names: ", item_names) diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index 3091161..f6cecb5 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -128,6 +128,8 @@ def start_training(self): # Load dataset data_iterator = self.load_dataset(pipeline, params, train_states) + if self.config.dataset_type == "grain": + data_iterator = self.restore_data_iterator_state(data_iterator) data_shardings = self.get_data_shardings() # Compile train_step diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index f8d56e1..8315864 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -15,6 +15,7 @@ """ import os +import sys from functools import partial import datetime import time @@ -211,7 +212,6 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera unet_state, text_encoder_state, train_metric, train_rngs = p_train_step( unet_state, vae_state, text_encoder_state, example_batch, train_rngs ) - samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now() train_utils.record_scalar_metrics( @@ -221,11 +221,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) last_step_completion = new_time - if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: + if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: train_states["unet_state"] = unet_state train_states["vae_state"] = vae_state train_states["text_encoder"] = text_encoder_state - self.save_checkpoint(step, pipeline, params, train_states) + self.save_checkpoint(step, pipeline, params, train_states, data_iterator) + + if self.checkpoint_manager.reached_preemption(step): + self.checkpoint_manager.wait_until_finished() + sys.exit() if self.config.enable_profiler and step == last_profiling_step: max_utils.deactivate_profiler(self.config) @@ -239,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["vae_state"] = vae_state train_states["text_encoder"] = text_encoder_state # save the inference states of the last checkpoint so they can be easily loaded during gen. - self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states) + self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states, data_iterator) self.checkpoint_manager.wait_until_finished()