Skip to content

Commit

Permalink
Fsdp & multi-host fixes (#99)
Browse files Browse the repository at this point in the history
* Makes fixes for fsdp multi-host.
* Use train_states to create generation states.
* Remove sharding from VAE in order to preproces dataset correctly during fsdp.
* Option to load pokemon dataset from gcs.

---------

Co-authored-by: Juan Acevedo <[email protected]>
  • Loading branch information
entrpn and jfacevedo-google authored Sep 11, 2024
1 parent 57310aa commit 6234383
Show file tree
Hide file tree
Showing 37 changed files with 357 additions and 369 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ jobs:
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
steps:
- uses: actions/checkout@v3
- name: Install Ubuntu dependencies
run: |
sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
- name: Install dependencies
run: |
pip install -e .
Expand All @@ -54,7 +51,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ python3 -m pytest
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ After installation completes, run the training script.

```bash
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
```

- **Stable Diffusion 1.5**
- **Stable Diffusion 1.4**

```bash
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
```

To generate images with a trained checkpoint, run:
Expand All @@ -109,7 +109,7 @@ After installation completes, run the training script.
**Stable Diffusion 1.x,2.x**

```bash
python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base15.yml class_data_dir=<your-class-dir> instance_data_dir=<your-instance-dir> instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 jax_cache_dir=<your-cache-dir> class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name=<your-run-name> output_dir=gs://<your-bucket-name>
python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base14.yml class_data_dir=<your-class-dir> instance_data_dir=<your-instance-dir> instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 jax_cache_dir=<your-cache-dir> class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name=<your-run-name> output_dir=gs://<your-bucket-name>
```

## Inference
Expand Down Expand Up @@ -143,7 +143,7 @@ To generate images, run the following command:
Single and Multi host inference is supported with sharding annotations:

```bash
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors"
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors"
```

## ControlNet
Expand Down Expand Up @@ -176,7 +176,7 @@ cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run output_dir=gs://your-bucket/"
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run output_dir=gs://your-bucket/"
```

# Comparison to Alternatives
Expand Down
12 changes: 6 additions & 6 deletions docs/train_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ In this session, we'll explain some of the core config parameters and how they a

| config | model | supports |
| ------ | ----- | -------- |
| [base15.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base15.yml) | [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) | training / inference
| [base14.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base14.yml) | [stable-diffusion-v1-4](CompVis/stable-diffusion-v1-4) | training / inference
| [base_2_base.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base_2_base.yml) | [stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) | training / inference
| [base21.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base21.yml) | [stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | training / inference
| [base_xl.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base_xl.yml) | [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | training / inference
Expand All @@ -40,10 +40,10 @@ Let's start with a simple example. After setting up your environment, create a t

```bash
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
```

The job will use the predefined parameters in base15.yml and will overwrite any parameters that as passed into the cli.
The job will use the predefined parameters in base14.yml and will overwrite any parameters that as passed into the cli.

### Changing The Base Model

Expand All @@ -53,7 +53,7 @@ To load Pytorch weights, set `from_pt=True` and set `revision=main`. Let's look

```bash
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=main
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=main
```

After training, a new folder structure with weights and metrics has been created under the `output_dir` folder:
Expand All @@ -71,15 +71,15 @@ It is recommended to use a Google Cloud Storage bucket as the `output_dir`. This
To use the trained checkpoint, then run:

```bash
python src/maxdiffusion/generate.py src/maxdiffusion/configs/base15.yml output_dir="gs://your-bucket/" run_name="my_run"
python src/maxdiffusion/generate.py src/maxdiffusion/configs/base14.yml output_dir="gs://your-bucket/" run_name="my_run"
```


### Changing The Sharding Strategy

MaxDiffusion models use logical axis annotations, which allows users to explore different sharding layouts without making changes to the model code. To learn more about distributed arrays and Flax partitioning, checkout JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and then FLAX's [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html#flax-and-jax-jit-scaled-up)

The main [config values](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base15.yml#L74) for these are:
The main [config values](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base14.yml#L74) for these are:

- mesh_axes
- logical_axis_rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def load_diffusers_checkpoint(self):

return pipeline, params

def save_checkpoint(self, train_step, pipeline, params, train_states, save_inference_states=False):
def save_checkpoint(self, train_step, pipeline, params, train_states):
def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

Expand All @@ -248,35 +248,6 @@ def config_to_json(model_or_config):
tokenizer_config = {"path" : self.config.tokenizer_model_name_or_path}
items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config)

if save_inference_states:
inference_unet_state, _, _ = self.create_unet_state(
pipeline,
{"unet" : train_states["unet_state"].params},
checkpoint_item_name="inference_unet_state",
is_training=False,
)
items["inference_unet_state"] = ocp.args.StandardSave(inference_unet_state)

if self.config.train_text_encoder:
inference_text_encoder_state, _ = self.create_text_encoder_state(
pipeline,
{"text_encoder" : train_states["text_encoder_state"].params},
checkpoint_item_name="inference_text_encoder_state",
is_training=False)

items["inference_text_encoder_state"] = ocp.args.StandardSave(inference_text_encoder_state)

# TODO - this is broken since create_text_encoder_state will create a text_encoder init weights
# and not a text_encoder_2 init weights fn.
if hasattr(pipeline, "text_encoder_2"):
inference_text_encoder_2_state, _ = self.create_text_encoder_2_state(
pipeline,
{"text_encoder" : train_states["text_encoder_2_state"].params},
checkpoint_item_name="inference_text_encoder_2_state",
is_training=False
)
items["inference_text_encoder_2_state"] = ocp.args.StandardSave(inference_text_encoder_2_state)

self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step = None):
Expand Down
32 changes: 29 additions & 3 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Optional, Any
import os
from maxdiffusion import max_logging
from etils import epath
from flax.training import train_state
Expand Down Expand Up @@ -55,17 +56,14 @@ def create_orbax_checkpoint_manager(
"text_encoder_config",
"scheduler_config",
"unet_state",
"inference_unet_state",
"vae_state",
"text_encoder_state",
"inference_text_encoder_state",
"tokenizer_config"
)
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
item_names+= (
"text_encoder_2_state",
"text_encoder_2_config",
"inference_text_encoder_2_state"
)

print("item_names: ", item_names)
Expand Down Expand Up @@ -120,6 +118,34 @@ def load_stable_diffusion_configs(
args=orbax.checkpoint.args.Composite(**restore_args)
),None)

def load_params_from_path(
config,
checkpoint_manager: CheckpointManager,
unboxed_abstract_params,
checkpoint_item : str,
step: Optional[int] = None,
):
ckptr = ocp.PyTreeCheckpointer()

if step is None:
step = checkpoint_manager.latest_step()
if step is None:
return None

ckpt_path = os.path.join(config.checkpoint_dir, str(step),checkpoint_item)
ckpt_path = epath.Path(ckpt_path)

restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
restored = ckptr.restore(
ckpt_path,
item={"params": unboxed_abstract_params},
transforms={},
restore_args={"params" : restore_args}
)
return restored["params"]



def load_state_if_possible(
checkpoint_manager: CheckpointManager,
abstract_unboxed_pre_state: train_state.TrainState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ gcs_metrics: True
save_config_to_gcs: False
log_period: 10000000000 # Flushes Tensorboard

pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5'
pretrained_model_name_or_path: 'CompVis/stable-diffusion-v1-4'
unet_checkpoint: ''
revision: 'flax'

Expand Down Expand Up @@ -105,6 +105,7 @@ logical_axis_rules: [
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
Expand Down Expand Up @@ -143,7 +144,6 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
upload_ckpts_to_gcs: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand Down
12 changes: 7 additions & 5 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],
['out_channels', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['length', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down Expand Up @@ -143,7 +146,6 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
upload_ckpts_to_gcs: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand Down
12 changes: 7 additions & 5 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],
['out_channels', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['length', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down Expand Up @@ -156,7 +159,6 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
upload_ckpts_to_gcs: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand Down
5 changes: 3 additions & 2 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ logical_axis_rules: [
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['out_channels', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
Expand Down Expand Up @@ -146,7 +147,6 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
upload_ckpts_to_gcs: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand Down Expand Up @@ -185,6 +185,7 @@ profiler_steps: 10
# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 9
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
Expand Down
12 changes: 8 additions & 4 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],
['out_channels', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['length', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down Expand Up @@ -146,6 +149,7 @@ profiler_steps: 5
# Generation parameters
prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors"
negative_prompt: "purple, red"
do_classifier_free_guidance: False
guidance_scale: 2
guidance_rescale: -1
num_inference_steps: 4
Expand Down
Loading

0 comments on commit 6234383

Please sign in to comment.