Skip to content

Commit

Permalink
Support multi lora loading (#137)
Browse files Browse the repository at this point in the history
* adds multi lora support
* update readme
  • Loading branch information
entrpn authored Dec 13, 2024
1 parent b1c63d7 commit efa11e4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 40 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2024/12/12`**: Load multiple LoRAs for inference.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.
Expand All @@ -33,6 +34,7 @@ MaxDiffusion supports
* Stable Diffusion XL (training and inference).
* Stable Diffusion Lightning (inference).
* Hyper-SD XL LoRA loading (inference).
* Load Multiple LoRA (SDXL inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.

Expand All @@ -45,6 +47,7 @@ MaxDiffusion supports
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [Load Multiple LoRA](#load-multiple-lora)
* [SDXL Lightning](#sdxl-lightning)
* [ControlNet](#controlnet)
* [Comparison To Alternatives](#comparison-to-alternatives)
Expand Down Expand Up @@ -139,6 +142,14 @@ To generate images, run the following command:
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'
```

## Load Multiple LoRA

Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub.

```bash
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" per_device_batch_size=1 diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}'
```

## SDXL Lightning

Single and Multi host inference is supported with sharding annotations:
Expand Down
15 changes: 10 additions & 5 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functools
from absl import app
from contextlib import ExitStack
from typing import Sequence
import time

Expand Down Expand Up @@ -233,14 +234,15 @@ def run(config):
params["unet"] = unet_params

# maybe load lora and create interceptor
params, lora_interceptor = maybe_load_lora(config, pipeline, params)
params, lora_interceptors = maybe_load_lora(config, pipeline, params)

if config.lightning_repo:
pipeline, params = load_sdxllightning_unet(config, pipeline, params)

# Don't restore the full train state, instead, just restore params
# and create an inference state.
with nn.intercept_methods(lora_interceptor):
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
unet_state, unet_state_shardings = max_utils.setup_initial_state(
model=pipeline.unet,
tx=None,
Expand All @@ -254,7 +256,8 @@ def run(config):
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
pipeline, params, checkpoint_item_name="vae_state", is_training=False
)
with nn.intercept_methods(lora_interceptor):
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
)
Expand Down Expand Up @@ -293,11 +296,13 @@ def run(config):
)

s = time.time()
with nn.intercept_methods(lora_interceptor):
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
print("compile time: ", (time.time() - s))
s = time.time()
with nn.intercept_methods(lora_interceptor):
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
images = p_run_inference(states).block_until_ready()
print("inference time: ", (time.time() - s))
images = jax.experimental.multihost_utils.process_allgather(images)
Expand Down
28 changes: 15 additions & 13 deletions src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load_lora_weights(
return params, rank, network_alphas

@classmethod
def _get_lora_layer(cls, module_path, module, rank, network_alphas):
def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name):
is_conv = any("conv" in str_ for str_ in module_path)
network_alpha = network_alphas.get(module_path, None)
if is_conv:
Expand All @@ -105,7 +105,7 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
dtype=module.dtype,
weights_dtype=module.param_dtype,
precision=module.precision,
name="lora",
name=f"lora-{adapter_name}",
)
else:
lora_module = LoRALinearLayer(
Expand All @@ -115,39 +115,41 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
dtype=module.dtype,
weights_dtype=module.param_dtype,
precision=module.precision,
name="lora",
name=f"lora-{adapter_name}",
)
return lora_module

def rename_for_interceptor(params_keys, network_alphas):
def rename_for_interceptor(params_keys, network_alphas, adapter_name):
new_params_keys = []
new_network_alphas = {}
lora_name = f"lora-{adapter_name}"
for layer_lora in params_keys:
if "lora" in layer_lora:
new_layer_lora = layer_lora[: layer_lora.index("lora")]
if lora_name in layer_lora:
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
if new_layer_lora not in new_params_keys:
new_params_keys.append(new_layer_lora)
network_alpha = network_alphas[layer_lora]
new_network_alphas[new_layer_lora] = network_alpha
return new_params_keys, new_network_alphas

@classmethod
def make_lora_interceptor(cls, params, rank, network_alphas):
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
# Only unet interceptor supported for now.
network_alphas_for_interceptor = {}

unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas, adapter_name)
network_alphas_for_interceptor.update(unet_alphas)

text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys()
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas)
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas, adapter_name)
lora_keys.extend(text_encoder_keys)
network_alphas_for_interceptor.update(text_encoder_alphas)

if "text_encoder_2" in params.keys():
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(
text_encoder_2_keys, network_alphas, adapter_name
)
lora_keys.extend(text_encoder_2_keys)
network_alphas_for_interceptor.update(text_encoder_2_alphas)

Expand All @@ -161,7 +163,7 @@ def _intercept(next_fn, args, kwargs, context):
if context.method_name == "__call__":
module_path = context.module.path
if module_path in lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor)
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name)
return lora_layer(h, *args, **kwargs)
return h

Expand Down Expand Up @@ -290,5 +292,5 @@ def load_lora(cls, state_dict, network_alphas, params, adapter_name=None, _pipel
`default_{i}` where i is the total number of adapters being loaded.
"""
# Load the layers corresponding to Unet.
params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas)
params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas, adapter_name)
return params, rank, network_alphas
23 changes: 13 additions & 10 deletions src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,24 @@ def _noop_interceptor(next_fn, args, kwargs, context):
return next_fn(*args, **kwargs)

lora_config = config.lora_config
interceptor = _noop_interceptor
interceptors = [_noop_interceptor]
if len(lora_config["lora_model_name_or_path"]) > 0:
# For now only first lora supported. In the future, they will be merged
# before being loaded.
# TODO - merge LoRAs here.
params, rank, network_alphas = pipeline.load_lora_weights(
lora_config["lora_model_name_or_path"][0],
weight_name=lora_config["weight_name"][0],
params=params,
adapter_name=lora_config["adapter_name"][0],
unet_config=pipeline.unet.config,
)
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas)
interceptors = []
for i in range(len(lora_config["lora_model_name_or_path"])):
params, rank, network_alphas = pipeline.load_lora_weights(
lora_config["lora_model_name_or_path"][i],
weight_name=lora_config["weight_name"][i],
params=params,
adapter_name=lora_config["adapter_name"][i],
unet_config=pipeline.unet.config,
)
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
interceptors.append(interceptor)

return params, interceptor
return params, interceptors


def vae_apply(images, sample_rng, vae, vae_params):
Expand Down
31 changes: 19 additions & 12 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,13 @@ def get_network_alpha_value(pt_key, network_alphas):


def create_flax_params_from_pytorch_state(
pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, network_alphas, is_lora=False
pt_state_dict,
unet_state_dict,
text_encoder_state_dict,
text_encoder_2_state_dict,
network_alphas,
adapter_name,
is_lora=False,
):
rank = None
renamed_network_alphas = {}
Expand All @@ -157,19 +163,21 @@ def create_flax_params_from_pytorch_state(
flax_key_list = [*pt_tuple_key]
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
rename_from_to = (
("to_k_lora", ("k_proj", "lora")),
("to_q_lora", ("q_proj", "lora")),
("to_v_lora", ("v_proj", "lora")),
("to_out_lora", ("out_proj", "lora")),
("to_k_lora", ("k_proj", f"lora-{adapter_name}")),
("to_q_lora", ("q_proj", f"lora-{adapter_name}")),
("to_v_lora", ("v_proj", f"lora-{adapter_name}")),
("to_out_lora", ("out_proj", f"lora-{adapter_name}")),
("lora", f"lora-{adapter_name}"),
("weight", "kernel"),
)
# the unet
else:
rename_from_to = (
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("to_k_lora", ("to_k", f"lora-{adapter_name}")),
("to_q_lora", ("to_q", f"lora-{adapter_name}")),
("to_v_lora", ("to_v", f"lora-{adapter_name}")),
("to_out_lora", ("to_out_0", f"lora-{adapter_name}")),
("lora", f"lora-{adapter_name}"),
("weight", "kernel"),
)
for rename_from, rename_to in rename_from_to:
Expand Down Expand Up @@ -206,11 +214,10 @@ def create_flax_params_from_pytorch_state(

if network_alpha_value >= 0:
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value

return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas


def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas):
def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name):
# Step 1: Convert pytorch tensor to numpy
# sometimes we load weights in bf16 and numpy doesn't support it
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
Expand All @@ -223,7 +230,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha
text_encoder_2_params = None
(unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = (
create_flax_params_from_pytorch_state(
pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, is_lora=True
pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True
)
)
params["unet"] = unflatten_dict(unet_state_dict)
Expand Down

0 comments on commit efa11e4

Please sign in to comment.