Skip to content

Commit

Permalink
add multislice support (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssusie authored Sep 25, 2024
1 parent 5de63af commit 0015075
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_0.4.33 MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_0.4.33 BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1
- name: build maxdiffusion jax nightly image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
4 changes: 2 additions & 2 deletions .github/workflows/build_and_upload_images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# If you want to create and push your own images you should instead use docker_build_dependency_image and
# docker_upload_runner in the maxdiffusion root directory.

# Example command:
# Example command:
# bash build_and_upload_images.sh PROJECT=<project> MODE=stable CLOUD_IMAGE_NAME=${USER}_runner


Expand Down Expand Up @@ -55,4 +55,4 @@ docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxdiffusion_runner.
docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest
docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest
docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:${image_date}
docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:${image_date}
docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:${image_date}
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.

# Overview
# Overview

MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.

The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.

MaxDiffusion supports
MaxDiffusion supports
* Stable Diffusion 2 base (training and inference)
* Stable Diffusion 2.1 (training and inference)
* Stable Diffusion 2.1 (training and inference)
* Stable Diffusion XL (training and inference).
* Stable Diffusion Lightning (inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
Expand All @@ -54,7 +54,7 @@ We recommend starting with a single TPU host and then moving to multihost.

Minimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0.

## Getting Started:
## Getting Started:

For your first time running Maxdiffusion, we provide specific [instructions](docs/getting_started/first_run.md).

Expand Down Expand Up @@ -115,7 +115,7 @@ To generate images, run the following command:
```

Single host pmap version:

```bash
python -m src.maxdiffusion.generate_sdxl_replicated
```
Expand Down Expand Up @@ -152,7 +152,7 @@ To generate images, run the following command:
```bash
python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py
```


## Getting Started: Multihost development
Multihost training for Stable Diffusion 2 base can be run using the following command:
Expand All @@ -172,7 +172,7 @@ python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_na

# Comparison to Alternatives

MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/diffusers), a Hugging Face diffusion library written in Python, Pytorch and Jax. MaxDiffusion is compatible with Hugging Face Jax models. MaxDiffusion is more complex and was designed to run distributed across TPU Pods.
MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/diffusers), a Hugging Face diffusion library written in Python, Pytorch and Jax. MaxDiffusion is compatible with Hugging Face Jax models. MaxDiffusion is more complex and was designed to run distributed across TPU Pods.

# Development

Expand Down
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ This folder contains documentation for getting started with and using MaxDiffusi

## Training

* **[Common Training Guide](train_README.md)** - Provides a comprehensive guide to training MaxDiffusion models, including script usage, configuration options, and sharding strategies.
* **[Common Training Guide](train_README.md)** - Provides a comprehensive guide to training MaxDiffusion models, including script usage, configuration options, and sharding strategies.
17 changes: 12 additions & 5 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,18 @@ def create_device_mesh(config, devices=None, logging=True):
num_devices_per_slice = num_devices//num_slices
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")

multi_slice_env = num_slices > 1

dcn_parallelism = [config.dcn_data_parallelism, config.dcn_fsdp_parallelism, config.dcn_tensor_parallelism]
ici_parallelism = [config.ici_data_parallelism, config.ici_fsdp_parallelism, config.ici_tensor_parallelism]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, 'ICI')
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, 'DCN')
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
else:
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)

if logging:
max_logging.log(f"Decided on mesh: {mesh}")
Expand Down Expand Up @@ -313,9 +320,9 @@ def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True):
)
with nn_partitioning.axis_rules(config.logical_axis_rules):
abstract_state = jax.eval_shape(init_state_partial)

state_logical_annotations = nn.get_partition_spec(abstract_state)

state_mesh_shardings = nn.logical_to_mesh_sharding(
state_logical_annotations, mesh, config.logical_axis_rules
)
Expand Down Expand Up @@ -476,7 +483,7 @@ def get_memory_allocations():


# Taking inspiration from flax's https://flax.readthedocs.io/en/v0.5.3/_modules/flax/linen/summary.html#tabulate
# to retrieve layer parameters and calculate
# to retrieve layer parameters and calculate
def calculate_model_tflops(
module: module_lib.Module,
rngs: Union[PRNGKey, RNGSequences],
Expand Down Expand Up @@ -541,4 +548,4 @@ def get_global_batch_size(config):
return config.per_device_batch_size * jax.device_count()

def maybe_initialize_jax_distributed_system(raw_keys):
jax.distributed.initialize()
jax.distributed.initialize()

0 comments on commit 0015075

Please sign in to comment.