From 99f608218caa069a2f16dcf9efab46959b15aec0 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 3 Oct 2024 08:36:26 -1000 Subject: [PATCH 01/20] [sd3] make sure height and size are divisible by `16` (#9573) * check size * up --- ...line_stable_diffusion_3_controlnet_inpainting.py | 13 +++++++++++-- src/diffusers/pipelines/pag/pipeline_pag_sd_3.py | 13 +++++++++++-- .../pipeline_stable_diffusion_3.py | 13 +++++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 47fc6d6daf15..b17d9687952e 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -251,6 +251,9 @@ def __init__( if hasattr(self, "transformer") and self.transformer is not None else 128 ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -577,8 +580,14 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index 3035509843c0..174d58022270 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -212,6 +212,9 @@ def __init__( if hasattr(self, "transformer") and self.transformer is not None else 128 ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) self.set_pag_applied_layers( pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) @@ -542,8 +545,14 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 5a10f329a0af..a1420b8e1e82 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -203,6 +203,9 @@ def __init__( if hasattr(self, "transformer") and self.transformer is not None else 128 ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) def _get_t5_prompt_embeds( self, @@ -525,8 +528,14 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs From 3159e60d59819ae874ea3cdbd28e02d9e6c57321 Mon Sep 17 00:00:00 2001 From: Clem <70368164+Clement-Lelievre@users.noreply.github.com> Date: Mon, 7 Oct 2024 07:17:54 +0200 Subject: [PATCH 02/20] fix xlabs FLUX lora conversion typo (#9581) * fix startswith syntax in xlabs lora conversion * Trigger CI https://github.com/huggingface/diffusers/pull/9581#issuecomment-2395530360 --- src/diffusers/loaders/lora_conversion_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index d829cc3a844b..d0ca40213b14 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -632,7 +632,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): new_key += ".lora_B.weight" # Handle single_blocks - elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")): block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.single_transformer_blocks.{block_num}" From 31010ecc45324c9ee294eb34cf5ad3eaf2cb8c3f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 7 Oct 2024 17:43:48 +0530 Subject: [PATCH 03/20] [Chore] add a note on the versions in Flux LoRA integration tests (#9598) add a note on the versions. --- tests/lora/test_lora_layers_flux.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index a75f9df91047..4629c24c8cd8 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -169,7 +169,11 @@ def test_modify_padding_mode(self): @unittest.skip("We cannot run inference on this model with the current CI hardware") # TODO (DN6, sayakpaul): move these tests to a beefier GPU class FluxLoRAIntegrationTests(unittest.TestCase): - """internal note: The integration slices were obtained on audace.""" + """internal note: The integration slices were obtained on audace. + + torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the + assertions to pass. + """ num_inference_steps = 10 seed = 0 From 2cb383f591a66d5c9b5a5ac2b4b0824269f9db6c Mon Sep 17 00:00:00 2001 From: captainzz <73270275+xduzhangjiayu@users.noreply.github.com> Date: Mon, 7 Oct 2024 23:30:25 +0800 Subject: [PATCH 04/20] fix vae dtype when accelerate config using --mixed_precision="fp16" (#9601) * fix vae dtype when accelerate config using --mixed_precision="fp16" * Add param for upcast vae --- examples/controlnet/train_controlnet_sd3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 9ea78370f5e0..4fae8a072c6f 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -357,6 +357,11 @@ def parse_args(input_args=None): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--upcast_vae", + action="store_true", + help="Whether or not to upcast vae to fp32", + ) parser.add_argument( "--learning_rate", type=float, @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae, transformer and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=torch.float32) + if args.upcast_vae: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From a80f6892003e102f56bc956e9f8707b52c5d4487 Mon Sep 17 00:00:00 2001 From: Yijun Lee <119404328+yijun-lee@users.noreply.github.com> Date: Tue, 8 Oct 2024 05:27:35 +0900 Subject: [PATCH 05/20] refac: docstrings in import_utils.py (#9583) * refac: docstrings in import_utils.py * Update import_utils.py --- src/diffusers/utils/import_utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 34cc5fcc8605..daecec4aa258 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -668,8 +668,9 @@ def __getattr__(cls, key): # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ - Args: Compares a library version to some requirement using a given operation. + + Args: library_or_version (`str` or `packaging.version.Version`): A library name or a version to check. operation (`str`): @@ -688,8 +689,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 def is_torch_version(operation: str, version: str): """ - Args: Compares the current PyTorch version to a given reference with an operation. + + Args: operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): @@ -700,8 +702,9 @@ def is_torch_version(operation: str, version: str): def is_transformers_version(operation: str, version: str): """ - Args: Compares the current Transformers version to a given reference with an operation. + + Args: operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): @@ -714,8 +717,9 @@ def is_transformers_version(operation: str, version: str): def is_accelerate_version(operation: str, version: str): """ - Args: Compares the current Accelerate version to a given reference with an operation. + + Args: operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): @@ -728,8 +732,9 @@ def is_accelerate_version(operation: str, version: str): def is_peft_version(operation: str, version: str): """ - Args: Compares the current PEFT version to a given reference with an operation. + + Args: operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): @@ -742,8 +747,9 @@ def is_peft_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str): """ - Args: Compares the current k-diffusion version to a given reference with an operation. + + Args: operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): @@ -756,8 +762,9 @@ def is_k_diffusion_version(operation: str, version: str): def get_objects_from_module(module): """ - Args: Returns a dict of object names and values in a module, while skipping private/internal objects + + Args: module (ModuleType): Module to extract the objects from. @@ -775,7 +782,9 @@ def get_objects_from_module(module): class OptionalDependencyNotAvailable(BaseException): - """An error indicating that an optional dependency of Diffusers was not found in the environment.""" + """ + An error indicating that an optional dependency of Diffusers was not found in the environment. + """ class _LazyModule(ModuleType): From 1287822973e87fc8739a608b7296fe1797b1b074 Mon Sep 17 00:00:00 2001 From: Eliseu Silva Date: Mon, 7 Oct 2024 17:41:32 -0300 Subject: [PATCH 06/20] Fix for use_safetensors parameters, allow use of parameter on loading submodels (#9576) (#9587) * Fix for use_safetensors parameters, allow use of parameter on loading submodels (#9576) --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 ++ src/diffusers/pipelines/pipeline_utils.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 318599f56063..0a744264b7a6 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -601,6 +601,7 @@ def load_sub_model( variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], + use_safetensors: bool, ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -670,6 +671,7 @@ def load_sub_model( loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) + loading_kwargs["use_safetensors"] = use_safetensors if from_flax: loading_kwargs["from_flax"] = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6721706b5689..857a13147cfe 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -905,6 +905,7 @@ def load_module(name, value): variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, + use_safetensors=use_safetensors, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." From 63a5c8742a51a58083f7d9ba15518cce5bb1688b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Oct 2024 08:03:51 +0530 Subject: [PATCH 07/20] Update distributed_inference.md to include `transformer.device_map` (#9553) * Update distributed_inference.md to include `transformer.device_map` * Update docs/source/en/training/distributed_inference.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/training/distributed_inference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index cd642d6aca07..0e1eb7962bf7 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained( ``` > [!TIP] -> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. +> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices. Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet. From 66eef9a6dc8a97815a69fdf97aa20c8ece63d3f6 Mon Sep 17 00:00:00 2001 From: glide-the Date: Tue, 8 Oct 2024 15:22:52 +0800 Subject: [PATCH 08/20] fix: CogVideox train dataset _preprocess_data crop video (#9574) * Removed int8 to float32 conversion (`* 2.0 - 1.0`) from `train_transforms` as it caused image overexposure. Added `_resize_for_rectangle_crop` function to enable video cropping functionality. The cropping mode can be configured via `video_reshape_mode`, supporting options: ['center', 'random', 'none']. * The number 127.5 may experience precision loss during division operations. * wandb request pil image Type * Resizing bug * del jupyter * make style * Update examples/cogvideo/README.md * make style --------- Co-authored-by: --unset <--unset> Co-authored-by: Aryan --- examples/cogvideo/README.md | 1 + examples/cogvideo/train_cogvideox_lora.py | 91 +++++++++++++++++++---- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md index 398ae9543150..8e57c2b19188 100644 --- a/examples/cogvideo/README.md +++ b/examples/cogvideo/README.md @@ -180,6 +180,7 @@ Note that setting the `` is not necessary. From some limited experimen > [!TIP] > You can pass `--use_8bit_adam` to reduce the memory requirements of training. +> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples. > [!IMPORTANT] > The following settings have been tested at the time of adding CogVideoX LoRA training support: diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 6787c37f93a8..2fc05bf692bb 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -21,7 +21,9 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import numpy as np import torch +import torchvision.transforms as TT import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -29,12 +31,14 @@ from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torch.utils.data import DataLoader, Dataset -from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize from tqdm.auto import tqdm from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer import diffusers from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers.image_processor import VaeImageProcessor from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid @@ -214,6 +218,12 @@ def get_args(): default=720, help="All input videos are resized to this width.", ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default="center", + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") parser.add_argument( "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." @@ -413,6 +423,7 @@ def __init__( video_column: str = "video", height: int = 480, width: int = 720, + video_reshape_mode: str = "center", fps: int = 8, max_num_frames: int = 49, skip_frames_start: int = 0, @@ -429,6 +440,7 @@ def __init__( self.video_column = video_column self.height = height self.width = width + self.video_reshape_mode = video_reshape_mode self.fps = fps self.max_num_frames = max_num_frames self.skip_frames_start = skip_frames_start @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self): return instance_prompts, instance_videos + def _resize_for_rectangle_crop(self, arr): + image_size = self.height, self.width + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + def _preprocess_data(self): try: import decord @@ -542,15 +586,14 @@ def _preprocess_data(self): decord.bridge.set_bridge("torch") - videos = [] - train_transforms = transforms.Compose( - [ - transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), - ] + progress_dataset_bar = tqdm( + range(0, len(self.instance_video_paths)), + desc="Loading progress resize and crop videos", ) + videos = [] for filename in self.instance_video_paths: - video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_reader = decord.VideoReader(uri=filename.as_posix()) video_num_frames = len(video_reader) start_frame = min(self.skip_frames_start, video_num_frames) @@ -576,10 +619,16 @@ def _preprocess_data(self): assert (selected_num_frames - 1) % 4 == 0 # Training transforms - frames = frames.float() - frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) - videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + frames = (frames - 127.5) / 127.5 + frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] + progress_dataset_bar.set_description( + f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}" + ) + frames = self._resize_for_rectangle_crop(frames) + videos.append(frames.contiguous()) # [F, C, H, W] + progress_dataset_bar.update(1) + progress_dataset_bar.close() return videos @@ -694,8 +743,13 @@ def log_validation( videos = [] for _ in range(args.num_validation_videos): - video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] - videos.append(video) + pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0] + pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) + + image_np = VaeImageProcessor.pt_to_numpy(pt_images) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + + videos.append(image_pil) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir): video_column=args.video_column, height=args.height, width=args.width, + video_reshape_mode=args.video_reshape_mode, fps=args.fps, max_num_frames=args.max_num_frames, skip_frames_start=args.skip_frames_start, @@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir): id_token=args.id_token, ) - def encode_video(video): + def encode_video(video, bar): + bar.update(1) video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = vae.encode(video).latent_dist return latent_dist - train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + progress_encode_bar = tqdm( + range(0, len(train_dataset.instance_videos)), + desc="Loading Encode videos", + ) + train_dataset.instance_videos = [ + encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos + ] + progress_encode_bar.close() def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] From 02eeb8e77e5067eaf9fc1953c5e832ec894424aa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Oct 2024 21:47:44 +0530 Subject: [PATCH 09/20] [LoRA] Handle DoRA better (#9547) * handle dora. * print test * debug * fix * fix-copies * update logits * add warning in the test. * make is_dora check consistent. * fix-copies --- src/diffusers/loaders/lora_pipeline.py | 39 ++++++++++++++++++++++---- tests/lora/test_lora_layers_sdxl.py | 20 ++++++++----- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ba1435a8cbdc..8c8f2dfa84f8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -99,7 +99,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -211,6 +211,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -562,7 +567,8 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -684,6 +690,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -1089,6 +1100,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -1125,7 +1142,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1587,9 +1604,13 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) @@ -1659,7 +1680,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -2374,6 +2395,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -2405,7 +2432,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 4ec7ef897485..8deecd770c31 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -33,8 +33,10 @@ StableDiffusionXLPipeline, T2IAdapter, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( + CaptureLogger, load_image, nightly, numpy_cosine_similarity_distance, @@ -620,14 +622,18 @@ def test_integration_logits_for_dora_lora(self): pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") pipeline.enable_model_cpu_offload() - images = pipeline( - "photo of ohwx dog", - num_inference_steps=10, - generator=torch.manual_seed(0), - output_type="np", - ).images + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + images = pipeline( + "photo of ohwx dog", + num_inference_steps=10, + generator=torch.manual_seed(0), + output_type="np", + ).images + assert "It seems like you are using a DoRA checkpoint" in cap_logger.out predicted_slice = images[0, -3:, -3:, -1].flatten() - expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) + expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) assert max_diff < 1e-3 From 86bd991ee5c9c669e22f09693d68b60d0ec59dd1 Mon Sep 17 00:00:00 2001 From: v2ray <60914079+LagPixelLOL@users.noreply.github.com> Date: Wed, 9 Oct 2024 03:27:10 +0800 Subject: [PATCH 10/20] Fixed noise_pred_text referenced before assignment. (#9537) * Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time. * Fixed style. * Made returning text pred noise an argument. --- src/diffusers/pipelines/pag/pag_utils.py | 10 ++++++++-- src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py | 4 ++-- src/diffusers/pipelines/pag/pipeline_pag_sd.py | 4 ++-- src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 4 ++-- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 728f730c9904..7a6e30a3c6be 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -98,7 +98,9 @@ def _get_pag_scale(self, t): else: return self.pag_scale - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + def _apply_perturbed_attention_guidance( + self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False + ): r""" Apply perturbed attention guidance to the noise prediction. @@ -107,9 +109,11 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. guidance_scale (float): The scale factor for the guidance term. t (int): The current time step. + return_pred_text (bool): Whether to return the text noise prediction. Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying + perturbed attention guidance and the text noise prediction. """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: @@ -122,6 +126,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui else: noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + if return_pred_text: + return noise_pred, noise_pred_text return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 63126cc5aae9..4663db3a15a1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -893,8 +893,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index c6a4f7f42c84..e9742b08af50 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -993,8 +993,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 18fc06c1f9b8..8da4349594b4 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -1237,8 +1237,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index dc85aaaca37f..4c2c4e5aa3fa 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -1437,8 +1437,8 @@ def denoising_value_valid(dnv): # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index f5ebf4300934..49e4c5ffd50c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1649,8 +1649,8 @@ def denoising_value_valid(dnv): # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) From acd6d2c42f0fa4fade262e8814279748a544b0ce Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Wed, 9 Oct 2024 05:25:48 +0800 Subject: [PATCH 11/20] Fix the bug that `joint_attention_kwargs` is not passed to the FLUX's transformer attention processors (#9517) * Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e38efe668c6c..6238ab8044bb 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -83,14 +83,16 @@ def forward( hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - + joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) @@ -161,18 +163,20 @@ def forward( encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + joint_attention_kwargs = joint_attention_kwargs or {} # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -497,6 +501,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual @@ -533,6 +538,7 @@ def custom_forward(*inputs): hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From ec9e5264c0865c739602f6cab0bef118892a50f3 Mon Sep 17 00:00:00 2001 From: Yijun Lee <119404328+yijun-lee@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:11:13 +0900 Subject: [PATCH 12/20] refac/pipeline_output (#9582) --- .../pipelines/deepfloyd_if/pipeline_output.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py index 7f39ab5ba70c..b8bae89cec03 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py @@ -9,16 +9,17 @@ @dataclass class IFPipelineOutput(BaseOutput): - """ - Args: + r""" Output class for Stable Diffusion pipelines. - images (`List[PIL.Image.Image]` or `np.ndarray`) + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`): List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_detected (`List[bool]`) + nsfw_detected (`List[bool]`): List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content or a watermark. `None` if safety checking could not be performed. - watermark_detected (`List[bool]`) + watermark_detected (`List[bool]`): List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety checking could not be performed. """ From 31058cdaef63ca660a1a045281d156239fba8192 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 9 Oct 2024 10:57:16 +0530 Subject: [PATCH 13/20] [LoRA] allow loras to be loaded with low_cpu_mem_usage. (#9510) * allow loras to be loaded with low_cpu_mem_usage. * add flux support but note https://github.com/huggingface/diffusers/pull/9510\#issuecomment-2378316687 * low_cpu_mem_usage. * fix-copies * fix-copies again * tests * _LOW_CPU_MEM_USAGE_DEFAULT_LORA * _peft_version default. * version checks. * version check. * version check. * version check. * require peft 0.13.1. * explicitly specify low_cpu_mem_usage=False. * docs. * transformers version 4.45.2. * update * fix * empty * better name initialize_dummy_state_dict. * doc todos. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style * fix-copies --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/tutorials/using_peft_for_inference.md | 6 + src/diffusers/loaders/lora_pipeline.py | 251 ++++++++++++++++-- src/diffusers/loaders/unet.py | 21 +- src/diffusers/utils/testing_utils.py | 18 ++ tests/lora/utils.py | 139 +++++++++- 5 files changed, 411 insertions(+), 24 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 907f93d573a0..615af55ef5b5 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -75,6 +75,12 @@ image ![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png) + + +By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints. + + + ## Merge adapters You can also merge different adapter checkpoints for inference to blend their styles together. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8c8f2dfa84f8..2037bd787433 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -25,8 +25,11 @@ deprecate, get_adapter_name, get_peft_kwargs, + is_peft_available, is_peft_version, + is_torch_version, is_transformers_available, + is_transformers_version, logging, scale_lora_layers, ) @@ -39,6 +42,17 @@ ) +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + + if is_transformers_available(): from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules @@ -83,15 +97,24 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -109,6 +132,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) self.load_lora_into_text_encoder( state_dict, @@ -119,6 +143,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -237,7 +262,9 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -255,10 +282,16 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. @@ -268,7 +301,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -281,6 +318,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -303,10 +341,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -377,6 +430,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -547,12 +601,19 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. @@ -573,7 +634,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -585,6 +651,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -597,6 +664,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -717,7 +785,9 @@ def lora_state_dict( @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -735,10 +805,16 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. @@ -748,7 +824,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -762,6 +842,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -784,10 +865,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -858,6 +954,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -1126,15 +1223,22 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -1151,6 +1255,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1163,6 +1268,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -1175,10 +1281,13 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1192,7 +1301,13 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -1236,8 +1351,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys @@ -1266,6 +1385,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1288,10 +1408,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1362,6 +1497,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -1667,10 +1803,17 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -1690,6 +1833,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1702,10 +1846,13 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod - def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1723,7 +1870,13 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -1772,8 +1925,12 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys @@ -1802,6 +1959,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1824,10 +1982,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1898,6 +2071,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -2132,6 +2306,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2154,10 +2329,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -2228,6 +2418,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -2416,15 +2607,22 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -2441,11 +2639,14 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2459,7 +2660,13 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -2503,8 +2710,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32ace77b6224..eaac52df6202 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -115,6 +115,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. Example: @@ -142,8 +145,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict adapter_name = kwargs.pop("adapter_name", None) _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) allow_pickle = False + if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if use_safetensors is None: use_safetensors = True allow_pickle = True @@ -209,6 +218,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise ValueError( @@ -268,7 +278,9 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + def _process_lora( + self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage + ): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -335,9 +347,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7dc3f414d55c..a2f283d0c4f5 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -388,6 +388,24 @@ def decorator(test_case): return decorator +def require_transformers_version_greater(transformers_version): + """ + Decorator marking a test that requires transformers with a specific version, this would require some specific + versions of PEFT and transformers. + """ + + def decorator(test_case): + correct_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse(transformers_version) + return unittest.skipUnless( + correct_transformers_version, + f"test requires transformers with the version greater than {transformers_version}", + )(test_case) + + return decorator + + def require_accelerate_version_greater(accelerate_version): def decorator(test_case): correct_accelerate_version = is_peft_available() and version.parse( diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5def867324f4..9c982e8de37f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -32,13 +32,14 @@ floats_tensor, require_peft_backend, require_peft_version_greater, + require_transformers_version_greater, skip_mps, torch_device, ) if is_peft_available(): - from peft import LoraConfig + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import get_peft_model_state_dict @@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool: return False +def initialize_dummy_state_dict(state_dict): + if not all(v.device.type == "meta" for _, v in state_dict.items()): + raise ValueError("`state_dict` has non-meta values.") + return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} + + @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None @@ -272,6 +279,136 @@ def test_simple_inference_with_text_lora(self): not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) + @require_peft_version_greater("0.13.1") + def test_low_cpu_mem_usage_with_injection(self): + """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." + ) + self.assertTrue( + "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) + set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, + "No param should be on 'meta' device.", + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + self.assertTrue( + "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." + ) + + denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) + set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + self.assertTrue( + "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) + set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "No param should be on 'meta' device.", + ) + + _, _, inputs = self.get_dummy_inputs() + output_lora = pipe(**inputs)[0] + self.assertTrue(output_lora.shape == self.output_shape) + + @require_peft_version_greater("0.13.1") + @require_transformers_version_greater("4.45.1") + def test_low_cpu_mem_usage_with_loading(self): + """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" + + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + # Now, check for `low_cpu_mem_usage.` + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose( + images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3 + ), + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", + ) + def test_simple_inference_with_text_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + scale argument From af28ae2d5ba0ef80d99fff7859ebea730e1cf3f8 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 10 Oct 2024 02:10:58 +0530 Subject: [PATCH 14/20] add PAG support for SD Img2Img (#9463) * added pag to sd img2img pipeline --------- Co-authored-by: YiYi Xu --- docs/source/en/api/pipelines/pag.md | 5 + src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 2 + src/diffusers/pipelines/pag/__init__.py | 2 + .../pipelines/pag/pipeline_pag_sd_img2img.py | 1091 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/pag/test_pag_sd_img2img.py | 282 +++++ 8 files changed, 1401 insertions(+) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py create mode 100644 tests/pipelines/pag/test_pag_sd_img2img.py diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index 8e3c82ea9e27..cc6d075f457f 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial - all - __call__ +## StableDiffusionPAGImg2ImgPipeline +[[autodoc]] StableDiffusionPAGImg2ImgPipeline + - all + - __call__ + ## StableDiffusionControlNetPAGPipeline [[autodoc]] StableDiffusionControlNetPAGPipeline diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4214a4699ec8..fadb234c6e10 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -344,6 +344,7 @@ "StableDiffusionLatentUpscalePipeline", "StableDiffusionLDM3DPipeline", "StableDiffusionModelEditingPipeline", + "StableDiffusionPAGImg2ImgPipeline", "StableDiffusionPAGPipeline", "StableDiffusionPanoramaPipeline", "StableDiffusionParadigmsPipeline", @@ -795,6 +796,7 @@ StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, + StableDiffusionPAGImg2ImgPipeline, StableDiffusionPAGPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3b6cde17c8a3..45a868eb5810 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -164,6 +164,7 @@ "HunyuanDiTPAGPipeline", "StableDiffusion3PAGPipeline", "StableDiffusionPAGPipeline", + "StableDiffusionPAGImg2ImgPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGInpaintPipeline", @@ -569,6 +570,7 @@ StableDiffusion3PAGPipeline, StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGPipeline, + StableDiffusionPAGImg2ImgPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline, StableDiffusionXLControlNetPAGPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index e3e78d0663fa..ebeaf9f8aeec 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -63,6 +63,7 @@ StableDiffusion3PAGPipeline, StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGPipeline, + StableDiffusionPAGImg2ImgPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline, StableDiffusionXLControlNetPAGPipeline, @@ -131,6 +132,7 @@ ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), ("kandinsky3", Kandinsky3Img2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), + ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index a7ceb7e296d5..6a6723b58ca9 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -32,6 +32,7 @@ _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] + _import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"] _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] @@ -54,6 +55,7 @@ from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline + from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py new file mode 100644 index 000000000000..e4f26494d5c3 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -0,0 +1,1091 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1927fc8cd4d3..fee69d01ebff 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1592,6 +1592,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionPAGImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py new file mode 100644 index 000000000000..ec8cde23c31d --- /dev/null +++ b/tests/pipelines/pag/test_pag_sd_img2img.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + AutoencoderTiny, + AutoPipelineForImage2Image, + EulerDiscreteScheduler, + StableDiffusionImg2ImgPipeline, + StableDiffusionPAGImg2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionPAGImg2ImgPipelineFastTests( + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionPAGImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"} + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_tiny_autoencoder(self): + return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4) + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "pag_scale": 0.9, + "output_type": "np", + } + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + pipe_sd = StableDiffusionImg2ImgPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_pag_inference(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 32, + 32, + 3, + ), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}" + + expected_slice = np.array( + [0.44203848, 0.49598145, 0.42248967, 0.6707724, 0.5683791, 0.43603387, 0.58316565, 0.60077155, 0.5174199] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + +@slow +@require_torch_gpu +class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): + pipeline_class = StableDiffusionPAGImg2ImgPipeline + repo_id = "Jiali/stable-diffusion-1.5" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): + generator = torch.Generator(device=generator_device).manual_seed(seed) + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_img2img/sketch-mountains-input.png" + ) + inputs = { + "prompt": "a fantasy landscape, concept art, high resolution", + "image": init_image, + "generator": generator, + "num_inference_steps": 3, + "strength": 0.75, + "guidance_scale": 7.5, + "pag_scale": 3.0, + "output_type": "np", + } + return inputs + + def test_pag_cfg(self): + pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 512, 512, 3) + print(image_slice.flatten()) + expected_slice = np.array( + [0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" + + def test_pag_uncond(self): + pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device, guidance_scale=0.0) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867] + ) + print(image_slice.flatten()) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" From 07bd2fabb60dbd1c7da6fd176c62e592402b19e3 Mon Sep 17 00:00:00 2001 From: Pakkapon Phongthawee Date: Thu, 10 Oct 2024 05:03:13 +0700 Subject: [PATCH 15/20] make controlnet support interrupt (#9620) * make controlnet support interrupt * remove white space in controlnet interrupt --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_img2img.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 8 ++++++++ .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 ++++++++ .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 8 ++++++++ 6 files changed, 48 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 9b2fefe7b0a4..60ad5eda8e0f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -893,6 +893,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1089,6 +1093,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1235,6 +1240,9 @@ def __call__( is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2a4f46d61990..4cdec5b3cf5f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -891,6 +891,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1081,6 +1085,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1211,6 +1216,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9f7d464f9a91..da5a02d14108 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -976,6 +976,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1191,6 +1195,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1375,6 +1380,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 17fd2cb6c81d..496ad8d73c1d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1427,6 +1431,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index fdebcdf83641..e480a87a70ce 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -990,6 +990,10 @@ def denoising_end(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1245,6 +1249,7 @@ def __call__( self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1442,6 +1447,9 @@ def __call__( is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index af19f3c309f8..21cd87f7570e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1338,6 +1342,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1510,6 +1515,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From e16fd93d0a40156c1f49fde07f6f2eb438983927 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 10 Oct 2024 11:47:49 +0530 Subject: [PATCH 16/20] [LoRA] fix dora test to catch the warning properly. (#9627) fix dora test. --- tests/lora/test_lora_layers_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 8deecd770c31..94a44ed8f9ec 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -619,12 +619,12 @@ def test_integration_logits_multi_adapter(self): @nightly def test_integration_logits_for_dora_lora(self): pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") - pipeline.enable_model_cpu_offload() logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: + pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") + pipeline.enable_model_cpu_offload() images = pipeline( "photo of ohwx dog", num_inference_steps=10, From 38a3e4df926c59bc122191c0fc8066755e98b6d2 Mon Sep 17 00:00:00 2001 From: Subho Ghosh <93722719+ighoshsubho@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:59:02 +0530 Subject: [PATCH 17/20] flux controlnet control_guidance_start and control_guidance_end implement (#9571) * flux controlnet control_guidance_start and control_guidance_end implement * minor fix - added docstrings, consistent controlnet scale flux and SD3 --- .../flux/pipeline_flux_controlnet.py | 42 +++++++++++++++++-- ...pipeline_flux_controlnet_image_to_image.py | 35 +++++++++++++++- .../pipeline_flux_controlnet_inpainting.py | 37 +++++++++++++++- 3 files changed, 108 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6c072c482020..a301f6742c05 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -72,7 +72,9 @@ >>> image = pipe( ... prompt, ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, ... num_inference_steps=28, ... guidance_scale=3.5, ... ).images[0] @@ -572,6 +574,8 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_image: PipelineImageInput = None, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, @@ -614,6 +618,10 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is @@ -674,6 +682,17 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -839,7 +858,16 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 6. Denoising loop + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -856,12 +884,20 @@ def __call__( guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + # controlnet controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index deeb9e3f546a..b5ff2236c4d0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -69,7 +69,9 @@ ... prompt, ... image=init_image, ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, ... strength=0.7, ... num_inference_steps=2, ... guidance_scale=3.5, @@ -631,6 +633,8 @@ def __call__( num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, @@ -710,6 +714,17 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + self.check_inputs( prompt, prompt_2, @@ -862,6 +877,14 @@ def __call__( latents, ) + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -877,11 +900,19 @@ def __call__( ) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e763200155f6..1c1c25302100 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -71,6 +71,8 @@ ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, ... controlnet_conditioning_scale=0.7, ... strength=0.7, ... num_inference_steps=28, @@ -737,6 +739,8 @@ def __call__( timesteps: List[int] = None, num_inference_steps: int = 28, guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, @@ -783,6 +787,10 @@ def __call__( Custom timesteps to use for the denoising process. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. control_mode (`int` or `List[int]`, *optional*): The mode for the ControlNet. If multiple ControlNets are used, this should be a list. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): @@ -826,6 +834,17 @@ def __call__( global_height = height global_width = width + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + # 1. Check inputs self.check_inputs( prompt, @@ -1031,6 +1050,14 @@ def __call__( generator, ) + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -1049,11 +1076,19 @@ def __call__( else: guidance = None + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, + conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, From 164ec9f423f82474c6174c22448f1956823b4bbd Mon Sep 17 00:00:00 2001 From: GSSun <39679899+alaister123@users.noreply.github.com> Date: Fri, 11 Oct 2024 01:03:39 -0700 Subject: [PATCH 18/20] fix IsADirectoryError when running the training code for sd3_dreambooth_lora_16gb.ipynb (#9634) Add files via upload fix IsADirectoryError when running the training code --- .../sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb index 25fcb36a47d5..8e8190a59324 100644 --- a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb +++ b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb @@ -196,7 +196,7 @@ }, "outputs": [], "source": [ - "!rm -rf dog/.huggingface" + "!rm -rf dog/.cache" ] }, { From 3033f08201e4ec46d87c37098b066b8b1924b052 Mon Sep 17 00:00:00 2001 From: M Saqlain <118016760+saqlain2204@users.noreply.github.com> Date: Fri, 11 Oct 2024 19:17:31 +0530 Subject: [PATCH 19/20] Add Differential Diffusion to Kolors (#9423) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added diff diff support for kolors img2img * Fized relative imports * Fized relative imports * Added diff diff support for Kolors * Fized import issues * Added map * Fized import issues * Fixed naming issues * Added diffdiff support for Kolors img2img pipeline * Removed example docstrings * Added map input * Updated latents Co-authored-by: Álvaro Somoza * Updated `original_with_noise` Co-authored-by: Álvaro Somoza * Improved code quality --------- Co-authored-by: Álvaro Somoza --- .../pipeline_kolors_differential_img2img.py | 1280 +++++++++++++++++ 1 file changed, 1280 insertions(+) create mode 100644 examples/community/pipeline_kolors_differential_img2img.py diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py new file mode 100644 index 000000000000..e5570248d22b --- /dev/null +++ b/examples/community/pipeline_kolors_differential_img2img.py @@ -0,0 +1,1280 @@ +# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from diffusers.pipelines.kolors.pipeline_output import KolorsPipelineOutput +from diffusers.pipelines.kolors.text_encoder import ChatGLMModel +from diffusers.pipelines.kolors.tokenizer import ChatGLMTokenizer +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import KolorsDifferentialImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = KolorsDifferentialImg2ImgPipeline.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = ( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png" + ... ) + + + >>> init_image = load_image(url) + >>> prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed." + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsDifferentialImg2ImgPipeline( + DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin +): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True + ) + + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_start(self): + return self._denoising_start + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + map: PipelineImageInput = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + map = self.mask_processor.preprocess( + map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor + ).to(device) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # begin diff diff change + total_time_steps = num_inference_steps + # end diff diff change + + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + init_image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # preparations for diff diff + original_with_noise = self.prepare_latents( + init_image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps + thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device) + masks = map.squeeze() > thresholds + # end diff diff preparations + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # diff diff + if i == 0: + latents = original_with_noise[:1] + else: + mask = masks[i].unsqueeze(0).to(latents.dtype) + mask = mask.unsqueeze(1) # fit shape + latents = original_with_noise[i] * mask + latents * (1 - mask) + # end diff diff + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) From 0f8fb75c7b9c86d4381f51b66f33d0a3f56301e6 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 11 Oct 2024 18:39:19 +0100 Subject: [PATCH 20/20] FluxMultiControlNetModel (#9647) --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 ++ .../pipelines/flux/pipeline_flux_controlnet_image_to_image.py | 2 ++ .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index a301f6742c05..8770c231809f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -202,6 +202,8 @@ def __init__( ], ): super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) self.register_modules( vae=vae, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index b5ff2236c4d0..a7f7c66a2cad 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -214,6 +214,8 @@ def __init__( ], ): super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) self.register_modules( vae=vae, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1c1c25302100..50d2fcaa7fa5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -216,6 +216,8 @@ def __init__( ], ): super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) self.register_modules( scheduler=scheduler,