Skip to content

Commit

Permalink
Hyper SDXL Lora support (#127)
Browse files Browse the repository at this point in the history
* Adds Hyper SDXL LoRA loading for inference using Flax interceptor.

---------

Co-authored-by: Juan Acevedo <[email protected]>
  • Loading branch information
entrpn and jfacevedo-google authored Oct 31, 2024
1 parent 9238bc9 commit 1deeca5
Show file tree
Hide file tree
Showing 26 changed files with 1,441 additions and 58 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?

- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.

# Overview
Expand All @@ -32,6 +32,7 @@ MaxDiffusion supports
* Stable Diffusion 2.1 (training and inference)
* Stable Diffusion XL (training and inference).
* Stable Diffusion Lightning (inference).
* Hyper-SD XL LoRA loading (inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.

Expand All @@ -43,6 +44,7 @@ MaxDiffusion supports
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [SDXL Lightning](#sdxl-lightning)
* [ControlNet](#controlnet)
* [Comparison To Alternatives](#comparison-to-alternatives)
Expand Down Expand Up @@ -129,6 +131,14 @@ To generate images, run the following command:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
```

## Hyper SDXL LoRA

Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)

```bash
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'
```

## SDXL Lightning

Single and Multi host inference is supported with sharding annotations:
Expand Down
Binary file removed generated_image.png
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
huggingface_hub==0.24.7

huggingface_hub==0.24.7
27 changes: 22 additions & 5 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ timestep_bias: {

# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: '',
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
_class_name: 'FlaxEulerDiscreteScheduler',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: ''
timestep_spacing: 'trailing'
}

# Output directory
Expand Down Expand Up @@ -197,7 +196,7 @@ profiler_steps: 10
prompt: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 9
guidance_scale: 9.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 20
Expand All @@ -209,6 +208,24 @@ lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""

# LoRA parameters
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }

enable_mllog: False

#controlnet
Expand Down
22 changes: 20 additions & 2 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ text_encoder_learning_rate: 4.25e-6
diffusion_scheduler_config: {
_class_name: 'DDIMScheduler',
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: 'trailing'
}
Expand Down Expand Up @@ -156,7 +156,7 @@ profiler_steps: 5
prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors"
negative_prompt: "purple, red"
do_classifier_free_guidance: False
guidance_scale: 2
guidance_scale: 2.0
guidance_rescale: -1
num_inference_steps: 4

Expand All @@ -165,4 +165,22 @@ lightning_from_pt: True
lightning_repo: "ByteDance/SDXL-Lightning"
lightning_ckpt: "sdxl_lightning_4step_unet.safetensors"

# LoRA parameters
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }

enable_mllog: False
55 changes: 28 additions & 27 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import flax.linen as nn
from flax.linen import partitioning as nn_partitioning

from maxdiffusion import (
FlaxEulerDiscreteScheduler,
)


from maxdiffusion import pyconfig, max_utils
from maxdiffusion.image_processor import VaeImageProcessor
from maxdiffusion.maxdiffusion_utils import (get_add_time_ids, rescale_noise_cfg, load_sdxllightning_unet)
from maxdiffusion.maxdiffusion_utils import (
get_add_time_ids,
rescale_noise_cfg,
load_sdxllightning_unet,
maybe_load_lora,
create_scheduler,
)

from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer)

Expand Down Expand Up @@ -82,7 +84,6 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale):
lambda _: noise_pred,
operand=None,
)

latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

return latents, scheduler_state, state
Expand Down Expand Up @@ -217,6 +218,8 @@ def run(config):
checkpoint_loader = GenerateSDXL(config)
pipeline, params = checkpoint_loader.load_checkpoint()

noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)

weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
Expand All @@ -228,20 +231,24 @@ def run(config):
if unet_params:
params["unet"] = unet_params

# maybe load lora and create interceptor
params, lora_interceptor = maybe_load_lora(config, pipeline, params)

if config.lightning_repo:
pipeline, params = load_sdxllightning_unet(config, pipeline, params)

# Don't restore the train state to save memory, just restore params
# Don't restore the full train state, instead, just restore params
# and create an inference state.
unet_state, unet_state_shardings = max_utils.setup_initial_state(
model=pipeline.unet,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get("unet", None),
training=False,
)
with nn.intercept_methods(lora_interceptor):
unet_state, unet_state_shardings = max_utils.setup_initial_state(
model=pipeline.unet,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get("unet", None),
training=False,
)

vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
pipeline, params, checkpoint_item_name="vae_state", is_training=False
Expand All @@ -267,14 +274,6 @@ def run(config):
states["text_encoder_state"] = text_encoder_state
states["text_encoder_2_state"] = text_encoder_2_state

noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained(
config.pretrained_model_name_or_path,
revision=config.revision,
subfolder="scheduler",
dtype=jnp.float32,
timestep_spacing="trailing",
)

pipeline.scheduler = noise_scheduler
params["scheduler"] = noise_scheduler_state

Expand All @@ -293,10 +292,12 @@ def run(config):
)

s = time.time()
p_run_inference(states).block_until_ready()
with nn.intercept_methods(lora_interceptor):
p_run_inference(states).block_until_ready()
print("compile time: ", (time.time() - s))
s = time.time()
images = p_run_inference(states).block_until_ready()
with nn.intercept_methods(lora_interceptor):
images = p_run_inference(states).block_until_ready()
print("inference time: ", (time.time() - s))
images = jax.experimental.multihost_utils.process_allgather(images)
numpy_images = np.array(images)
Expand Down
15 changes: 15 additions & 0 deletions src/maxdiffusion/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .lora_pipeline import StableDiffusionLoraLoaderMixin
106 changes: 106 additions & 0 deletions src/maxdiffusion/loaders/lora_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..models.modeling_utils import load_state_dict
from ..utils import _get_model_file

import safetensors


class LoRABaseMixin:
"""Utility class for handing LoRAs"""

_lora_lodable_modules = []
num_fused_loras = 0

def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")

@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
resume_download,
proxies,
use_auth_token,
revision,
subfolder,
user_agent,
allow_pickle,
):
from .lora_pipeline import LORA_WEIGHT_NAME_SAFE

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (weight_name is not None and weight_name.endswith(".safetensors")):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass

if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

return state_dict
Loading

0 comments on commit 1deeca5

Please sign in to comment.