diff --git a/README.md b/README.md index 84be41d..e4f6a02 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? - -- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format. +- **`2024/10/22`**: LoRA support for Hyper SDXL. +- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. - **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. # Overview @@ -32,6 +32,7 @@ MaxDiffusion supports * Stable Diffusion 2.1 (training and inference) * Stable Diffusion XL (training and inference). * Stable Diffusion Lightning (inference). +* Hyper-SD XL LoRA loading (inference). * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. @@ -43,6 +44,7 @@ MaxDiffusion supports * [Training](#training) * [Dreambooth](#dreambooth) * [Inference](#inference) + * [Hyper-SD XL LoRA](#hyper-sdxl-lora) * [SDXL Lightning](#sdxl-lightning) * [ControlNet](#controlnet) * [Comparison To Alternatives](#comparison-to-alternatives) @@ -129,6 +131,14 @@ To generate images, run the following command: python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + ## Hyper SDXL LoRA + + Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD) + + ```bash + python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}' + ``` + ## SDXL Lightning Single and Multi host inference is supported with sharding annotations: diff --git a/generated_image.png b/generated_image.png deleted file mode 100644 index 99f6928..0000000 Binary files a/generated_image.png and /dev/null differ diff --git a/requirements.txt b/requirements.txt index 0692bb1..5202dc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,6 @@ git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint>=0.5.20 tokenizers==0.20.0 +huggingface_hub==0.24.7 + huggingface_hub==0.24.7 \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 5bcead5..b22af2c 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -75,11 +75,10 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { - _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. - prediction_type: '', + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', rescale_zero_terminal_snr: False, - timestep_spacing: '' + timestep_spacing: 'trailing' } # Output directory @@ -197,7 +196,7 @@ profiler_steps: 10 prompt: "A magical castle in the middle of a forest, artistic drawing" negative_prompt: "purple, red" do_classifier_free_guidance: True -guidance_scale: 9 +guidance_scale: 9.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 20 @@ -209,6 +208,24 @@ lightning_repo: "" # Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. lightning_ckpt: "" +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + enable_mllog: False #controlnet diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 2d9f4a9..6dd6a4e 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -56,7 +56,7 @@ text_encoder_learning_rate: 4.25e-6 diffusion_scheduler_config: { _class_name: 'DDIMScheduler', # values are v_prediction or leave empty to use scheduler's default. - prediction_type: '', + prediction_type: 'epsilon', rescale_zero_terminal_snr: False, timestep_spacing: 'trailing' } @@ -156,7 +156,7 @@ profiler_steps: 5 prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors" negative_prompt: "purple, red" do_classifier_free_guidance: False -guidance_scale: 2 +guidance_scale: 2.0 guidance_rescale: -1 num_inference_steps: 4 @@ -165,4 +165,22 @@ lightning_from_pt: True lightning_repo: "ByteDance/SDXL-Lightning" lightning_ckpt: "sdxl_lightning_4step_unet.safetensors" +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + enable_mllog: False diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 271d02e..be750fb 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -23,16 +23,18 @@ import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P +import flax.linen as nn from flax.linen import partitioning as nn_partitioning -from maxdiffusion import ( - FlaxEulerDiscreteScheduler, -) - - from maxdiffusion import pyconfig, max_utils from maxdiffusion.image_processor import VaeImageProcessor -from maxdiffusion.maxdiffusion_utils import (get_add_time_ids, rescale_noise_cfg, load_sdxllightning_unet) +from maxdiffusion.maxdiffusion_utils import ( + get_add_time_ids, + rescale_noise_cfg, + load_sdxllightning_unet, + maybe_load_lora, + create_scheduler, +) from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer) @@ -82,7 +84,6 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale): lambda _: noise_pred, operand=None, ) - latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state, state @@ -217,6 +218,8 @@ def run(config): checkpoint_loader = GenerateSDXL(config) pipeline, params = checkpoint_loader.load_checkpoint() + noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config) + weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) unboxed_abstract_state, _, _ = max_utils.get_abstract_state( pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False @@ -228,20 +231,24 @@ def run(config): if unet_params: params["unet"] = unet_params + # maybe load lora and create interceptor + params, lora_interceptor = maybe_load_lora(config, pipeline, params) + if config.lightning_repo: pipeline, params = load_sdxllightning_unet(config, pipeline, params) - # Don't restore the train state to save memory, just restore params + # Don't restore the full train state, instead, just restore params # and create an inference state. - unet_state, unet_state_shardings = max_utils.setup_initial_state( - model=pipeline.unet, - tx=None, - config=config, - mesh=checkpoint_loader.mesh, - weights_init_fn=weights_init_fn, - model_params=params.get("unet", None), - training=False, - ) + with nn.intercept_methods(lora_interceptor): + unet_state, unet_state_shardings = max_utils.setup_initial_state( + model=pipeline.unet, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("unet", None), + training=False, + ) vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( pipeline, params, checkpoint_item_name="vae_state", is_training=False @@ -267,14 +274,6 @@ def run(config): states["text_encoder_state"] = text_encoder_state states["text_encoder_2_state"] = text_encoder_2_state - noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="scheduler", - dtype=jnp.float32, - timestep_spacing="trailing", - ) - pipeline.scheduler = noise_scheduler params["scheduler"] = noise_scheduler_state @@ -293,10 +292,12 @@ def run(config): ) s = time.time() - p_run_inference(states).block_until_ready() + with nn.intercept_methods(lora_interceptor): + p_run_inference(states).block_until_ready() print("compile time: ", (time.time() - s)) s = time.time() - images = p_run_inference(states).block_until_ready() + with nn.intercept_methods(lora_interceptor): + images = p_run_inference(states).block_until_ready() print("inference time: ", (time.time() - s)) images = jax.experimental.multihost_utils.process_allgather(images) numpy_images = np.array(images) diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py new file mode 100644 index 0000000..3413318 --- /dev/null +++ b/src/maxdiffusion/loaders/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lora_pipeline import StableDiffusionLoraLoaderMixin diff --git a/src/maxdiffusion/loaders/lora_base.py b/src/maxdiffusion/loaders/lora_base.py new file mode 100644 index 0000000..f22696d --- /dev/null +++ b/src/maxdiffusion/loaders/lora_base.py @@ -0,0 +1,106 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.modeling_utils import load_state_dict +from ..utils import _get_model_file + +import safetensors + + +class LoRABaseMixin: + """Utility class for handing LoRAs""" + + _lora_lodable_modules = [] + num_fused_loras = 0 + + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") + + @classmethod + def _fetch_state_dict( + cls, + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + resume_download, + proxies, + use_auth_token, + revision, + subfolder, + user_agent, + allow_pickle, + ): + from .lora_pipeline import LORA_WEIGHT_NAME_SAFE + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or (weight_name is not None and weight_name.endswith(".safetensors")): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py new file mode 100644 index 0000000..854aeaf --- /dev/null +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -0,0 +1,610 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from .. import max_logging + +import torch + + +def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): + # 1. get all state_dict_keys + all_keys = list(state_dict.keys()) + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if sgm_patterns[0] in layer: + input_block_ids.add(layer_id) + elif sgm_patterns[1] in layer: + middle_block_ids.add(layer_id) + elif sgm_patterns[2] in layer: + output_block_ids.add(layer_id) + else: + raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + + return new_state_dict + + +def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + """ + Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. + + Args: + state_dict (`dict`): The state dict to convert. + unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". + text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to + "text_encoder". + + Returns: + `tuple`: A tuple containing the converted state dict and a dictionary of alphas. + """ + unet_state_dict = {} + te_state_dict = {} + te2_state_dict = {} + network_alphas = {} + + # Check for DoRA-enabled LoRAs. + dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: + raise ValueError("DoRA is not currently supported") + + # Iterate over all LoRA weights. + all_lora_keys = list(state_dict.keys()) + for key in all_lora_keys: + if not key.endswith("lora_down.weight"): + continue + + # Extract LoRA name. + lora_name = key.split(".")[0] + + # Find corresponding up weight and alpha. + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + # Handle U-Net LoRAs. + if lora_name.startswith("lora_unet_"): + diffusers_name = _convert_unet_lora_key(key) + + # Store down and up weights. + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_unet: + dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." + unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = state_dict.pop( + key.replace("lora_down.weight", "dora_scale") + ) + + # Handle text encoder LoRAs. + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + # Store down and up weights for te or te2. + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_te or dora_present_in_te2: + dora_scale_key_to_replace_te = "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = state_dict.pop( + key.replace("lora_down.weight", "dora_scale") + ) + elif lora_name.startswith("lora_te2_"): + te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = state_dict.pop( + key.replace("lora_down.weight", "dora_scale") + ) + + # Store alpha if present. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) + + # Check if any keys remain. + if len(state_dict) > 0: + raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") + + max_logging.log("Non-diffusers checkpoint detected.") + + # Construct final state dict. + unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alphas + + +def _convert_unet_lora_key(key): + """ + Converts a U-Net LoRA key to a Diffusers compatible key. + """ + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + # Replace common U-Net naming patterns. + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specific conversions. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specific conversions. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General conversions. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + elif "ff" in diffusers_name: + pass + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + pass + else: + pass + + return diffusers_name + + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + + +def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): + """ + Gets the correct alpha name for the Diffusers model. + """ + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + return {new_name: alpha} + + +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + max_logging.log(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + max_logging.log(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 +# Some utilities were reused from +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): + new_state_dict = {} + orig_keys = list(old_state_dict.keys()) + + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + down_weight = sds_sd.pop(sds_key) + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + + for old_key in orig_keys: + # Handle double_blocks + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.transformer_blocks.{block_num}" + + if "processor.proj_lora1" in old_key: + new_key += ".attn.to_out.0" + elif "processor.proj_lora2" in old_key: + new_key += ".attn.to_add_out" + # Handle text latents. + elif "processor.qkv_lora2" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", + ], + ) + # continue + # Handle image latents. + elif "processor.qkv_lora1" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.to_q", + f"transformer.transformer_blocks.{block_num}.attn.to_k", + f"transformer.transformer_blocks.{block_num}.attn.to_v", + ], + ) + # continue + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + # Handle single_blocks + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.single_transformer_blocks.{block_num}" + + if "proj_lora1" in old_key or "proj_lora2" in old_key: + new_key += ".proj_out" + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + new_key += ".norm.linear" + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + else: + # Handle other potential key patterns here + new_key = old_key + + # Since we already handle qkv above. + if "qkv" not in old_key: + new_state_dict[new_key] = old_state_dict.pop(old_key) + + if len(old_state_dict) > 0: + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") + + return new_state_dict diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py new file mode 100644 index 0000000..4b50743 --- /dev/null +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -0,0 +1,280 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Dict +import flax +import jax.numpy as jnp +from .lora_base import LoRABaseMixin +from ..models.lora import LoRALinearLayer, LoRAConv2DLayer, BaseLoRALayer +from .lora_conversion_utils import ( + _convert_non_diffusers_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) +from ..models.modeling_flax_pytorch_utils import convert_lora_pytorch_state_dict_to_flax +from huggingface_hub.utils import validate_hf_hub_args + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + + +class StableDiffusionLoraLoaderMixin(LoRABaseMixin): + r""" + Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + + _lora_lodable_modules = ["unet", "text_encoder"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], params, adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + params, rank, network_alphas = self.load_lora( + state_dict, + network_alphas=network_alphas, + params=params, + adapter_name=adapter_name, + _pipeline=self, + ) + return params, rank, network_alphas + + @classmethod + def _get_lora_layer(cls, module_path, module, rank, network_alphas): + is_conv = any("conv" in str_ for str_ in module_path) + network_alpha = network_alphas.get(module_path, None) + if is_conv: + lora_module = LoRAConv2DLayer( + out_features=module.features, + rank=rank, + network_alpha=network_alpha, + kernel_size=module.kernel_size, + strides=module.strides, + padding=module.padding, + input_dilation=module.input_dilation, + kernel_dilation=module.kernel_dilation, + feature_group_count=module.feature_group_count, + dtype=module.dtype, + weights_dtype=module.param_dtype, + precision=module.precision, + name="lora", + ) + else: + lora_module = LoRALinearLayer( + out_features=module.features, + rank=rank, + network_alpha=network_alpha, + dtype=module.dtype, + weights_dtype=module.param_dtype, + precision=module.precision, + name="lora", + ) + return lora_module + + def rename_for_interceptor(params_keys, network_alphas): + new_params_keys = [] + for layer_lora in params_keys: + if "lora" in layer_lora: + new_layer_lora = layer_lora[: layer_lora.index("lora")] + if new_layer_lora not in new_params_keys: + new_params_keys.append(new_layer_lora) + network_alpha = network_alphas[layer_lora] + del network_alphas[layer_lora] + network_alphas[new_layer_lora] = network_alpha + return new_params_keys, network_alphas + + @classmethod + def make_lora_interceptor(cls, params, rank, network_alphas): + # Only unet interceptor supported for now. + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() + unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + + def _intercept(next_fn, args, kwargs, context): + mod = context.module + while mod is not None: + if isinstance(mod, BaseLoRALayer): + return next_fn(*args, **kwargs) + mod = mod.parent + h = next_fn(*args, **kwargs) + if context.method_name == "__call__": + module_path = context.module.path + if module_path in unet_lora_keys: + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas) + return lora_layer(h, *args, **kwargs) + return h + + return _intercept + + @classmethod + @validate_hf_hub_args + def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + resume_download = kwargs.pop("resume_download", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + network_alphas = None + if all( + (k.startswith("lora_te_") or k.startswith("lora_unet_") or k.startswith("lora_te1_") or k.startswith("lora_te2_")) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + def load_lora(cls, state_dict, network_alphas, params, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + # Load the layers corresponding to Unet. + params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas) + return params, rank, network_alphas diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 075569a..b14c5bb 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -401,7 +401,6 @@ def setup_initial_state( state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - init_train_state_partial = functools.partial( init_train_state, model=model, diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 7a0337d..fb1a827 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -39,6 +39,27 @@ def load_sdxllightning_unet(config, pipeline, params): return pipeline, params +def maybe_load_lora(config, pipeline, params): + + def _noop_interceptor(next_fn, args, kwargs, context): + return next_fn(*args, **kwargs) + + lora_config = config.lora_config + interceptor = _noop_interceptor + if len(lora_config["lora_model_name_or_path"]) > 0: + # For now only first lora supported. In the future, they will be merged + # before being loaded. + params, rank, network_alphas = pipeline.load_lora_weights( + lora_config["lora_model_name_or_path"][0], + weight_name=lora_config["weight_name"][0], + params=params, + adapter_name=lora_config["adapter_name"][0], + ) + interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) + + return params, interceptor + + def vae_apply(images, sample_rng, vae, vae_params): """Apply vae encoder to images.""" vae_outputs = vae.apply({"params": vae_params}, images, deterministic=True, method=vae.encode) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index e459be8..ec09a5e 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -30,6 +30,7 @@ from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL + from .lora import * else: import sys diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 1aa800c..81f774d 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -435,7 +435,7 @@ def setup(self): ) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context=None, deterministic=True): + def __call__(self, hidden_states, context=None, deterministic=True, cross_attention_kwargs=None): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) @@ -542,18 +542,25 @@ def setup(self): self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context, deterministic=True): + def __call__(self, hidden_states, context, deterministic=True, cross_attention_kwargs=None): # self attention residual = hidden_states if self.only_cross_attention: - hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn1( + self.norm1(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) else: - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = self.attn1( + self.norm1(hidden_states), deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) + hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn2( + self.norm2(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) hidden_states = hidden_states + residual # feed forward @@ -689,7 +696,7 @@ def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context, deterministic=True): + def __call__(self, hidden_states, context, deterministic=True, cross_attention_kwargs=None): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) @@ -701,7 +708,9 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: - hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + hidden_states = transformer_block( + hidden_states, context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py new file mode 100644 index 0000000..82d32e8 --- /dev/null +++ b/src/maxdiffusion/models/lora.py @@ -0,0 +1,125 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, Tuple, Optional +import jax +import jax.numpy as jnp +import flax.linen as nn + + +class BaseLoRALayer: + """ + Base LoRA layer class for all LoRA layer implementation + """ + + pass + + +class LoRALinearLayer(nn.Module, BaseLoRALayer): + """ + Implements LoRA linear layer + """ + + out_features: int + rank: int = 0 + network_alpha: float = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + lora_scale: float = 1.0 + + @nn.compact + def __call__(self, h, hidden_states): + if self.rank > self.out_features: + raise ValueError(f"LoRA rank {self.rank} must be less or equal to {min(self.in_features, self.out_features)}") + + down_hidden_states = nn.Dense( + features=self.rank, + use_bias=False, + kernel_init=nn.initializers.kaiming_uniform(), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="down", + )(hidden_states) + up_hidden_states = nn.Dense( + features=self.out_features, + use_bias=False, + kernel_init=nn.initializers.zeros_init(), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="up", + )(down_hidden_states) + if self.network_alpha: + up_hidden_states *= self.network_alpha / self.rank + + return h + (up_hidden_states * self.lora_scale) + + +class LoRAConv2DLayer(nn.Module, BaseLoRALayer): + """ + Implements LoRA Conv layer + """ + + out_features: int + rank: int = 4 + kernel_size: Union[int, Tuple[int, int]] = (1, 1) + strides: Union[int, Tuple[int, int]] = (1, 1) + padding: Union[int, Tuple[int, int], str] = 0 + input_dilation: int = 1 + kernel_dilation: int = 1 + feature_group_count: int = 1 + network_alpha: Optional[float] = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + lora_scale: float = 1.0 + + @nn.compact + def __call__(self, h, hidden_states): + down_hidden_states = nn.Conv( + self.rank, + kernel_size=self.kernel_size, + strides=self.strides, + padding=self.padding, + input_dilation=self.input_dilation, + kernel_dilation=self.kernel_dilation, + feature_group_count=self.feature_group_count, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nn.initializers.kaiming_uniform(), + precision=self.precision, + name="down", + )(hidden_states) + + up_hidden_states = nn.Conv( + self.out_features, + use_bias=False, + kernel_size=(1, 1), + strides=(1, 1), + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nn.initializers.zeros_init(), + precision=self.precision, + name="up", + )(down_hidden_states) + + if self.network_alpha: + up_hidden_states *= self.network_alpha / self.rank + + return h + (up_hidden_states * self.lora_scale) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 54da2fb..5f02ec8 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -19,6 +19,7 @@ import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict +from flax.core.frozen_dict import unfreeze from jax.random import PRNGKey from ..utils import logging @@ -108,6 +109,112 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic return pt_tuple_key, pt_tensor +def get_network_alpha_value(pt_key, network_alphas): + network_alpha_value = -1 + network_alpha_key = tuple(pt_key.split(".")) + for item in network_alpha_key: + # alpha names for LoRA follow different convention for qkv values. + # Ex: + # conv layer - unet.down_blocks.0.downsamplers.0.conv.alpha + # to_k_lora - unet.down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor.to_k_lora.down.weight.alpha + if "lora" == item: + network_alpha_key = network_alpha_key[: network_alpha_key.index(item)] + ("alpha",) + break + elif "lora" in item: + network_alpha_key = network_alpha_key + ("alpha",) + break + network_alpha_key = ".".join(network_alpha_key) + if network_alpha_key in network_alphas: + network_alpha_value = network_alphas[network_alpha_key] + return network_alpha_value + + +def create_flax_params_from_pytorch_state( + pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, network_alphas, is_lora=False +): + rank = None + renamed_network_alphas = {} + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + network_alpha_value = get_network_alpha_value(pt_key, network_alphas) + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + # conv + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + flax_key_list = [*pt_tuple_key] + flax_tensor = pt_tensor + else: + flax_key_list = [*pt_tuple_key] + for rename_from, rename_to in ( + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel"), + ): + tmp = [] + for s in flax_key_list: + if s == rename_from: + if type(rename_to) is tuple: + for s_in_tuple in rename_to: + tmp.append(s_in_tuple) + else: + tmp.append(rename_to) + else: + tmp.append(s) + flax_key_list = tmp + + flax_tensor = pt_tensor.T + + if is_lora: + if "lora.up" in renamed_pt_key: + rank = pt_tensor.shape[1] + if "processor" in flax_key_list: + flax_key_list.remove("processor") + if "unet" in flax_key_list: + flax_key_list.remove("unet") + unet_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + + if "text_encoder" in flax_key_list: + flax_key_list.remove("text_encoder") + text_encoder_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + + if "text_encoder_2" in flax_key_list: + flax_key_list.remove("text_encoder_2") + text_encoder_2_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + + if network_alpha_value >= 0: + renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value + + return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas + + +def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas): + # Step 1: Convert pytorch tensor to numpy + # sometimes we load weights in bf16 and numpy doesn't support it + pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} + + unet_params = flatten_dict(unfreeze(params["unet"])) + text_encoder_params = flatten_dict(unfreeze(params["text_encoder"])) + if "text_encoder_2" in params.keys(): + text_encoder_2_params = flatten_dict(unfreeze(params["text_encoder_2"])) + else: + text_encoder_2_params = None + (unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = ( + create_flax_params_from_pytorch_state( + pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, is_lora=True + ) + ) + params["unet"] = unflatten_dict(unet_state_dict) + params["text_encoder"] = unflatten_dict(text_encoder_state_dict) + if text_encoder_2_state_dict is not None: + params["text_encoder_2"] = unflatten_dict(text_encoder_2_state_dict) + + return params, rank, network_alphas + + def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} diff --git a/src/maxdiffusion/models/resnet_flax.py b/src/maxdiffusion/models/resnet_flax.py index b9e9ac0..79ddcb3 100644 --- a/src/maxdiffusion/models/resnet_flax.py +++ b/src/maxdiffusion/models/resnet_flax.py @@ -148,6 +148,7 @@ def setup(self): ) self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), @@ -161,15 +162,17 @@ def setup(self): precision=self.precision, ) - def __call__(self, hidden_states, temb, deterministic=True): + def __call__(self, hidden_states, temb, deterministic=True, cross_attention_kwargs={}): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = nn.with_logical_constraint(hidden_states, ("conv_batch", "height", "keep_2", "out_channels")) + temb = nn.swish(temb) + temb = self.time_emb_proj(temb) - temb = self.time_emb_proj(nn.swish(temb)) temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index 08a6449..8bc47fe 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -120,12 +120,14 @@ def setup(self): if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype, weights_dtype=self.weights_dtype) - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs={}): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn( + hidden_states, encoder_hidden_states, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) output_states += (hidden_states,) if self.add_downsample: @@ -300,7 +302,15 @@ def setup(self): if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype, weights_dtype=self.weights_dtype) - def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + def __call__( + self, + hidden_states, + res_hidden_states_tuple, + temb, + encoder_hidden_states, + deterministic=True, + cross_attention_kwargs=None, + ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -308,7 +318,9 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn( + hidden_states, encoder_hidden_states, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) @@ -483,10 +495,12 @@ def setup(self): self.resnets = resnets self.attentions = attentions - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn( + hidden_states, encoder_hidden_states, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + ) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index a93af0a..d526063 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -392,6 +392,7 @@ def __call__( mid_block_additional_residual=None, return_dict: bool = True, train: bool = False, + cross_attention_kwargs: Optional[Union[Dict, FrozenDict]] = None, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: @@ -410,6 +411,8 @@ def __call__( plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. + cross_attention_kwargs: (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to FlaxAttention. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: @@ -461,7 +464,9 @@ def __call__( down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): - sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + sample, res_samples = down_block( + sample, t_emb, encoder_hidden_states, deterministic=not train, cross_attention_kwargs=cross_attention_kwargs + ) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples @@ -478,7 +483,9 @@ def __call__( down_block_res_samples = new_down_block_res_samples # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + sample = self.mid_block( + sample, t_emb, encoder_hidden_states, deterministic=not train, cross_attention_kwargs=cross_attention_kwargs + ) if mid_block_additional_residual is not None: sample += mid_block_additional_residual @@ -494,6 +501,7 @@ def __call__( encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) diff --git a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c32872c..d438ebe 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -30,6 +30,7 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) +from ...loaders import (StableDiffusionLoraLoaderMixin) from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionPipelineOutput @@ -73,7 +74,7 @@ """ -class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline, StableDiffusionLoraLoaderMixin): r""" Flax-based pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index ca410e9..c3512ec 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -29,6 +29,7 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) +from ...loaders import (StableDiffusionLoraLoaderMixin) from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionXLPipelineOutput @@ -39,7 +40,7 @@ DEBUG = False -class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline, StableDiffusionLoraLoaderMixin): def __init__( self, diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index 5b739a2..bc9013c 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -222,7 +222,6 @@ def step( step_index = step_index[0] sigma = state.sigmas[step_index] - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index 80e3ad2..dfbeb12 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -14,10 +14,38 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +JAX_CACHE_DIR = "gs://maxdiffusion-github-runner-test-assets/cache_dir" + class Generate(unittest.TestCase): """Smoke test.""" + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_hyper_sdxl_lora(self): + img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") + base_image = np.array(Image.open(img_url)).astype(np.uint8) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_xl.yml"), + "pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", + "output_dir=gs://maxdiffusion-github-runner-test-assets", + "run_name=test-hypersdxl-lora", + "num_inference_steps=2", + "per_device_batch_size=1", + "do_classifier_free_guidance=False", + 'diffusion_scheduler_config={"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}', + 'lora_config={"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}', + f"jax_cache_dir={JAX_CACHE_DIR}", + ], + unittest=True, + ) + images = generate_run_xl(pyconfig.config) + test_image = np.array(images[0]).astype(np.uint8) + ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255) + assert base_image.shape == test_image.shape + assert ssim_compare >= 0.80 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sdxl_config(self): img_url = os.path.join(THIS_DIR, "images", "test_sdxl.png") @@ -39,6 +67,7 @@ def test_sdxl_config(self): "per_device_batch_size=1", "run_name=sdxl-inference-test", "split_head_dim=False", + f"jax_cache_dir={JAX_CACHE_DIR}", ], unittest=True, ) @@ -70,6 +99,7 @@ def test_sdxl_from_gcs(self): "per_device_batch_size=1", "run_name=sdxl-inference-test", "split_head_dim=False", + f"jax_cache_dir={JAX_CACHE_DIR}", ], unittest=True, ) @@ -92,6 +122,7 @@ def test_controlnet_sdxl(self): "pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", "activations_dtype=bfloat16", "weights_dtype=bfloat16", + f"jax_cache_dir={JAX_CACHE_DIR}", ], unittest=True, ) @@ -106,7 +137,12 @@ def test_sdxl_lightning(self): img_url = os.path.join(THIS_DIR, "images", "test_lightning.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) pyconfig.initialize( - [None, os.path.join(THIS_DIR, "..", "configs", "base_xl_lightning.yml"), "run_name=sdxl-lightning-test"], + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_xl_lightning.yml"), + "run_name=sdxl-lightning-test", + f"jax_cache_dir={JAX_CACHE_DIR}", + ], unittest=True, ) images = generate_run_xl(pyconfig.config) diff --git a/src/maxdiffusion/tests/images/test_hyper_sdxl.png b/src/maxdiffusion/tests/images/test_hyper_sdxl.png new file mode 100644 index 0000000..cd0cc60 Binary files /dev/null and b/src/maxdiffusion/tests/images/test_hyper_sdxl.png differ diff --git a/src/maxdiffusion/tests/images/test_lightning.png b/src/maxdiffusion/tests/images/test_lightning.png index 0f50b35..36e844c 100644 Binary files a/src/maxdiffusion/tests/images/test_lightning.png and b/src/maxdiffusion/tests/images/test_lightning.png differ diff --git a/src/maxdiffusion/tests/images/test_sd15.png b/src/maxdiffusion/tests/images/test_sd15.png index 2e173cb..7f3fbee 100644 Binary files a/src/maxdiffusion/tests/images/test_sd15.png and b/src/maxdiffusion/tests/images/test_sd15.png differ