From e2b3c248d85e1d9c16775d0093745e5690c9e50a Mon Sep 17 00:00:00 2001 From: Sookwan Han <80747187+jellyheadandrew@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:05:58 +0900 Subject: [PATCH 01/11] Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] ComA (#9228) * Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models --- examples/community/README.md | 156 ++ .../community/adaptive_mask_inpainting.py | 1465 +++++++++++++++++ 2 files changed, 1621 insertions(+) create mode 100644 examples/community/adaptive_mask_inpainting.py diff --git a/examples/community/README.md b/examples/community/README.md index 743993eb44c3..d2116c6dc4e3 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Example | Description | Code Example | Colab | Author | |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)| |Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)| |Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)| | HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) | @@ -85,6 +86,161 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion ## Example usages +### Adaptive Mask Inpainting + +**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo** + +**Seoul National University, Naver Webtoon** + +Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene. + +Here is the demonstration of Adaptive Mask Inpainting: + + + +![teaser-img](https://snuvclab.github.io/coma/static/images/example_result_adaptive_mask_inpainting.png) + + +You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma). + +#### Usage example +First, clone the diffusers github repository, and run the following command to set environment. +```Shell +git clone https://github.com/huggingface/diffusers.git +cd diffusers + +conda create --name ami python=3.9 -y +conda activate ami + +conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y +python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html +pip install easydict +pip install diffusers==0.20.2 accelerate safetensors transformers +pip install setuptools==59.5.0 +pip install opencv-python +pip install numpy==1.24.1 +``` +Then, run the below code under 'diffusers' directory. +```python +import numpy as np +import torch +from PIL import Image + +from diffusers import DDIMScheduler +from diffusers import DiffusionPipeline +from diffusers.utils import load_image + +from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE + +print(AMI_INSTALL_MESSAGE) + +from easydict import EasyDict + + + +if __name__ == "__main__": + """ + Download Necessary Files + """ + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true", + output_file = "model_final_edd263.pkl", + exist_ok=True, + ) + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true", + output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml", + exist_ok=True, + ) + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true", + output_file = "input_img.png", + exist_ok=True, + ) + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true", + output_file = "input_mask.png", + exist_ok=True, + ) + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true", + output_file = "Base-PointRend-RCNN-FPN.yaml", + exist_ok=True, + ) + download_file( + url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true", + output_file = "Base-RCNN-FPN.yaml", + exist_ok=True, + ) + + """ + Prepare Adaptive Mask Inpainting Pipeline + """ + # device + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + num_steps = 50 + + # Scheduler + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False + ) + scheduler.set_timesteps(num_inference_steps=num_steps) + + ## load models as pipelines + pipeline = AdaptiveMaskInpaintPipeline.from_pretrained( + "Uminosachi/realisticVisionV51_v51VAE-inpainting", + scheduler=scheduler, + torch_dtype=torch.float16, + requires_safety_checker=False + ).to(device) + + ## disable safety checker + enable_safety_checker = False + if not enable_safety_checker: + pipeline.safety_checker = None + + """ + Run Adaptive Mask Inpainting + """ + default_mask_image = Image.open("./input_mask.png").convert("L") + init_image = Image.open("./input_img.png").convert("RGB") + + + seed = 59 + generator = torch.Generator(device=device) + generator.manual_seed(seed) + + image = pipeline( + prompt="a man sitting on a couch", + negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw", + image=init_image, + default_mask_image=default_mask_image, + guidance_scale=11.0, + strength=0.98, + use_adaptive_mask=True, + generator=generator, + enforce_full_mask_ratio=0.0, + visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!! + human_detection_thres=0.015, + ).images[0] + + + image.save(f'final_img.png') +``` +#### [Troubleshooting] + +If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`. + +For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`. + + ### Flux with CFG Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md). diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py new file mode 100644 index 000000000000..a9de26b29a89 --- /dev/null +++ b/examples/community/adaptive_mask_inpainting.py @@ -0,0 +1,1465 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import os +import shutil +from glob import glob +from typing import Any, Callable, Dict, List, Optional, Union + +import cv2 +import numpy as np +import PIL.Image +import requests +import torch +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog +from detectron2.engine import DefaultPredictor +from detectron2.projects import point_rend +from detectron2.structures.instances import Instances +from detectron2.utils.visualizer import ColorMode, Visualizer +from packaging import version +from tqdm import tqdm +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +AMI_INSTALL_MESSAGE = """ + +Example Demo of Adaptive Mask Inpainting + +Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models +Kim et al. +ECCV-2024 (Oral) + + +Please prepare the environment via + +``` +conda create --name ami python=3.9 -y +conda activate ami + +conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y +python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html +pip install easydict +pip install diffusers==0.20.2 accelerate safetensors transformers +pip install setuptools==59.5.0 +pip install opencv-python +pip install numpy==1.24.1 +``` + + +Put the code inside the root of diffusers library (e.g., as '/home/username/diffusers/adaptive_mask_inpainting_example.py') and run the python code. + + + + +""" + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image + + + >>> control_image = make_inpaint_condition(init_image, mask_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +def download_file(url, output_file, exist_ok: bool): + if exist_ok and os.path.exists(output_file): + return + + response = requests.get(url, stream=True) + + with open(output_file, "wb") as file: + for chunk in tqdm(response.iter_content(chunk_size=8192), desc=f"Downloading '{output_file}'..."): + if chunk: + file.write(chunk) + + +def generate_video_from_imgs(images_save_directory, fps=15.0, delete_dir=True): + # delete videos if exists + if os.path.exists(f"{images_save_directory}.mp4"): + os.remove(f"{images_save_directory}.mp4") + if os.path.exists(f"{images_save_directory}_before_process.mp4"): + os.remove(f"{images_save_directory}_before_process.mp4") + + # assume there are "enumerated" images under "images_save_directory" + assert os.path.isdir(images_save_directory) + ImgPaths = sorted(glob(f"{images_save_directory}/*")) + + if len(ImgPaths) == 0: + print("\tSkipping, since there must be at least one image to create mp4\n") + else: + # mp4 configuration + video_path = images_save_directory + "_before_process.mp4" + + # Get height and width config + images = sorted([ImgPath.split("/")[-1] for ImgPath in ImgPaths if ImgPath.endswith(".png")]) + frame = cv2.imread(os.path.join(images_save_directory, images[0])) + height, width, channels = frame.shape + + # create mp4 video writer + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) + for image in images: + video.write(cv2.imread(os.path.join(images_save_directory, image))) + cv2.destroyAllWindows() + video.release() + + # generated video is not compatible with HTML5. Post-process and change codec of video, so that it is applicable to HTML. + os.system( + f'ffmpeg -i "{images_save_directory}_before_process.mp4" -vcodec libx264 -f mp4 "{images_save_directory}.mp4" ' + ) + + # remove group of images, and remove video before post-process. + if delete_dir and os.path.exists(images_save_directory): + shutil.rmtree(images_save_directory) + # remove 'before-process' video + if os.path.exists(f"{images_save_directory}_before_process.mp4"): + os.remove(f"{images_save_directory}_before_process.mp4") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class AdaptiveMaskInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # safety_checker: StableDiffusionSafetyChecker, + safety_checker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_adaptive_mask_model() + self.register_adaptive_mask_settings() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + """ Preparation for Adaptive Mask inpainting """ + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a + time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. + Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the + iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + default_mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + use_adaptive_mask: bool = True, + enforce_full_mask_ratio: float = 0.5, + human_detection_thres: float = 0.008, + visualization_save_dir: str = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked + out with `default_mask_image` and repainted according to `prompt`). + default_mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted + while black pixels are preserved. If `default_mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import AdaptiveMaskInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> default_mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, default_mask_image=default_mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + width, height = image.size + # height = height or self.unet.config.sample_size * self.vae_scale_factor + # width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image (will be used later, once again) + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, default_mask_image, height, width, return_image=True + ) + default_mask_image_np = np.array(default_mask_image).astype(np.uint8) / 255 + mask_condition = mask.clone() + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `default_mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + mask_image_np = default_mask_image_np + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + else: + raise NotImplementedError + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 & predicted original sample x_0 + outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) + latents = outputs["prev_sample"] # x_t-1 + pred_orig_latents = outputs["pred_original_sample"] # x_0 + + # run segmentation + if use_adaptive_mask: + if enforce_full_mask_ratio > 0.0: + use_default_mask = t < self.scheduler.config.num_train_timesteps * enforce_full_mask_ratio + elif enforce_full_mask_ratio == 0.0: + use_default_mask = False + else: + raise NotImplementedError + + pred_orig_image = self.decode_to_npuint8_image(pred_orig_latents) + dilate_num = self.adaptive_mask_settings.dilate_scheduler(i) + do_adapt_mask = self.adaptive_mask_settings.provoke_scheduler(i) + if do_adapt_mask: + mask, masked_image_latents, mask_image_np, vis_np = self.adapt_mask( + init_image, + pred_orig_image, + default_mask_image_np, + dilate_num=dilate_num, + use_default_mask=use_default_mask, + height=height, + width=width, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + device=device, + generator=generator, + do_classifier_free_guidance=do_classifier_free_guidance, + i=i, + human_detection_thres=human_detection_thres, + mask_image_np=mask_image_np, + ) + + if self.adaptive_mask_model.use_visualizer: + import matplotlib.pyplot as plt + + # mask_image_new_colormap = np.clip(0.6 + (1.0 - mask_image_np), a_min=0.0, a_max=1.0) * 255 + + os.makedirs(visualization_save_dir, exist_ok=True) + + # Image.fromarray(mask_image_new_colormap).convert("L").save(f"{visualization_save_dir}/masks/{i:05}.png") + plt.axis("off") + plt.subplot(1, 2, 1) + plt.imshow(mask_image_np) + plt.subplot(1, 2, 2) + plt.imshow(pred_orig_image) + plt.savefig(f"{visualization_save_dir}/{i:05}.png", bbox_inches="tight") + plt.close("all") + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if self.adaptive_mask_model.use_visualizer: + generate_video_from_imgs(images_save_directory=visualization_save_dir, fps=10, delete_dir=True) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def decode_to_npuint8_image(self, latents): + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **{})[ + 0 + ] # torch, float32, -1.~1. + image = self.image_processor.postprocess(image, output_type="pt", do_denormalize=[True] * image.shape[0]) + image = (image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) # np, uint8, 0~255 + return image + + def register_adaptive_mask_settings(self): + from easydict import EasyDict + + num_steps = 50 + + step_num = int(num_steps * 0.1) + final_step_num = num_steps - step_num * 7 + # adaptive mask settings + self.adaptive_mask_settings = EasyDict( + dilate_scheduler=MaskDilateScheduler( + max_dilate_num=20, + num_inference_steps=num_steps, + schedule=[20] * step_num + + [10] * step_num + + [5] * step_num + + [4] * step_num + + [3] * step_num + + [2] * step_num + + [1] * step_num + + [0] * final_step_num, + ), + dilate_kernel=np.ones((3, 3), dtype=np.uint8), + provoke_scheduler=ProvokeScheduler( + num_inference_steps=num_steps, + schedule=list(range(2, 10 + 1, 2)) + list(range(12, 40 + 1, 2)) + [45], + is_zero_indexing=False, + ), + ) + + def register_adaptive_mask_model(self): + # declare segmentation model used for mask adaptation + use_visualizer = True + # assert not use_visualizer, \ + # """ + # If you plan to 'use_visualizer', USE WITH CAUTION. + # It creates a directory of images and masks, which is used for merging into a video. + # The procedure involves deleting the directory of images, which means that + # if you set the directory wrong you can have other important files blown away. + # """ + + self.adaptive_mask_model = PointRendPredictor( + # pointrend_thres=0.2, + pointrend_thres=0.9, + device="cuda" if torch.cuda.is_available() else "cpu", + use_visualizer=use_visualizer, + config_pth="pointrend_rcnn_R_50_FPN_3x_coco.yaml", + weights_pth="model_final_edd263.pkl", + ) + + def adapt_mask(self, init_image, pred_orig_image, default_mask_image, dilate_num, use_default_mask, **kwargs): + ## predict mask to use for adaptation + adapt_output = self.adaptive_mask_model(pred_orig_image) # vis can be None if 'use_visualizer' is False + mask = adapt_output["mask"] + vis = adapt_output["vis"] + + ## if mask is empty or too small, use default_mask_image. else, use dilate and intersect with default_mask_image + if use_default_mask or mask.sum() < 512 * 512 * kwargs["human_detection_thres"]: # 0.005 + # set mask as default mask + mask = default_mask_image # HxW + + else: + ## timestep-adaptive mask + mask = cv2.dilate( + mask, self.adaptive_mask_settings.dilate_kernel, iterations=dilate_num + ) # dilate_kernel: np.ones((3,3), np.uint8) + mask = np.logical_and(mask, default_mask_image) # HxW + + ## prepare mask as pt tensor format + mask = torch.tensor(mask, dtype=torch.float32).to(kwargs["device"])[None, None] # 1 x 1 x H x W + mask, masked_image = prepare_mask_and_masked_image( + init_image.to(kwargs["device"]), mask, kwargs["height"], kwargs["width"], return_image=False + ) + + mask_image_np = mask.clone().squeeze().detach().cpu().numpy() + + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + kwargs["batch_size"] * kwargs["num_images_per_prompt"], + kwargs["height"], + kwargs["width"], + kwargs["prompt_embeds"].dtype, + kwargs["device"], + kwargs["generator"], + kwargs["do_classifier_free_guidance"], + ) + + return mask, masked_image_latents, mask_image_np, vis + + +def seg2bbox(seg_mask: np.ndarray): + nonzero_i, nonzero_j = seg_mask.nonzero() + min_i, max_i = nonzero_i.min(), nonzero_i.max() + min_j, max_j = nonzero_j.min(), nonzero_j.max() + + return np.array([min_j, min_i, max_j + 1, max_i + 1]) + + +def merge_bbox(bboxes: list): + assert len(bboxes) > 0 + + all_bboxes = np.stack(bboxes, axis=0) # shape: N_bbox X 4 + merged_bbox = np.zeros_like(all_bboxes[0]) # shape: 4, + + merged_bbox[0] = all_bboxes[:, 0].min() + merged_bbox[1] = all_bboxes[:, 1].min() + merged_bbox[2] = all_bboxes[:, 2].max() + merged_bbox[3] = all_bboxes[:, 3].max() + + return merged_bbox + + +class PointRendPredictor: + def __init__( + self, + cat_id_to_focus=0, + pointrend_thres=0.9, + device="cuda", + use_visualizer=False, + merge_mode="merge", + config_pth=None, + weights_pth=None, + ): + super().__init__() + + # category id to focus (default: 0, which is human) + self.cat_id_to_focus = cat_id_to_focus + + # setup coco metadata + self.coco_metadata = MetadataCatalog.get("coco_2017_val") + self.cfg = get_cfg() + + # get segmentation model config + point_rend.add_pointrend_config(self.cfg) # --> Add PointRend-specific config + self.cfg.merge_from_file(config_pth) + self.cfg.MODEL.WEIGHTS = weights_pth + self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = pointrend_thres + self.cfg.MODEL.DEVICE = device + + # get segmentation model + self.pointrend_seg_model = DefaultPredictor(self.cfg) + + # settings for visualizer + self.use_visualizer = use_visualizer + + # mask-merge mode + assert merge_mode in ["merge", "max-confidence"], f"'merge_mode': {merge_mode} not implemented." + self.merge_mode = merge_mode + + def merge_mask(self, masks, scores=None): + if self.merge_mode == "merge": + mask = np.any(masks, axis=0) + elif self.merge_mode == "max-confidence": + mask = masks[np.argmax(scores)] + return mask + + def vis_seg_on_img(self, image, mask): + if type(mask) == np.ndarray: + mask = torch.tensor(mask) + v = Visualizer(image, self.coco_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW) + instances = Instances(image_size=image.shape[:2], pred_masks=mask if len(mask.shape) == 3 else mask[None]) + vis = v.draw_instance_predictions(instances.to("cpu")).get_image() + return vis + + def __call__(self, image): + # run segmentation + outputs = self.pointrend_seg_model(image) + instances = outputs["instances"] + + # merge instances for the category-id to focus + is_class = instances.pred_classes == self.cat_id_to_focus + masks = instances.pred_masks[is_class] + masks = masks.detach().cpu().numpy() # [N, img_size, img_size] + mask = self.merge_mask(masks, scores=instances.scores[is_class]) + + return { + "asset_mask": None, + "mask": mask.astype(np.uint8), + "vis": self.vis_seg_on_img(image, mask) if self.use_visualizer else None, + } + + +class MaskDilateScheduler: + def __init__(self, max_dilate_num=15, num_inference_steps=50, schedule=None): + super().__init__() + self.max_dilate_num = max_dilate_num + self.schedule = [num_inference_steps - i for i in range(num_inference_steps)] if schedule is None else schedule + assert len(self.schedule) == num_inference_steps + + def __call__(self, i): + return min(self.max_dilate_num, self.schedule[i]) + + +class ProvokeScheduler: + def __init__(self, num_inference_steps=50, schedule=None, is_zero_indexing=False): + super().__init__() + if len(schedule) > 0: + if is_zero_indexing: + assert max(schedule) <= num_inference_steps - 1 + else: + assert max(schedule) <= num_inference_steps + + # register as self + self.is_zero_indexing = is_zero_indexing + self.schedule = schedule + + def __call__(self, i): + if self.is_zero_indexing: + return i in self.schedule + else: + return i + 1 in self.schedule From 76b7d86a9a5c0c2186efa09c4a67b5f5666ac9e3 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 6 Nov 2024 06:38:50 +0530 Subject: [PATCH 02/11] Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth_sd3 (#9800) * updated encode prompt and clip encod prompt --------- Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_sd3.py | 26 ++++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 525a4cc906e9..865696855940 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -902,20 +902,26 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] @@ -937,6 +943,7 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -945,13 +952,14 @@ def encode_prompt( clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): + for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt, device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) From ded3db164bb3c090871647f30ff9988c9c17fd83 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 7 Nov 2024 03:08:55 +0100 Subject: [PATCH 03/11] [Core] introduce `controlnet` module (#8768) * move vae flax module. * controlnet module. * prepare for PR. * revert a commit * gracefully deprecate controlnet deps. * fix * fix doc path * fix-copies * fix path * style * style * conflicts * fix * fix-copies * sparsectrl. * updates * fix * updates * updates * updates * fix --------- Co-authored-by: Dhruv Nair --- docs/source/en/api/models/controlnet.md | 4 +- docs/source/en/api/models/controlnet_sd3.md | 2 +- .../promptdiffusioncontrolnet.py | 6 +- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 39 +- src/diffusers/models/controlnet.py | 872 +----------------- src/diffusers/models/controlnet_flux.py | 529 +---------- src/diffusers/models/controlnet_sd3.py | 415 +-------- src/diffusers/models/controlnet_sparsectrl.py | 784 +--------------- src/diffusers/models/controlnets/__init__.py | 22 + .../models/controlnets/controlnet.py | 872 ++++++++++++++++++ .../{ => controlnets}/controlnet_flax.py | 10 +- .../models/controlnets/controlnet_flux.py | 536 +++++++++++ .../{ => controlnets}/controlnet_hunyuan.py | 14 +- .../models/controlnets/controlnet_sd3.py | 422 +++++++++ .../controlnets/controlnet_sparsectrl.py | 788 ++++++++++++++++ .../models/{ => controlnets}/controlnet_xs.py | 21 +- .../models/controlnets/multicontrolnet.py | 183 ++++ .../pipeline_animatediff_sparsectrl.py | 2 +- .../pipelines/controlnet/multicontrolnet.py | 185 +--- .../pipeline_stable_diffusion_3_controlnet.py | 2 +- ...table_diffusion_3_controlnet_inpainting.py | 2 +- .../flux/pipeline_flux_controlnet.py | 2 +- ...pipeline_flux_controlnet_image_to_image.py | 2 +- .../pipeline_flux_controlnet_inpainting.py | 2 +- tests/pipelines/test_pipelines_common.py | 2 +- 26 files changed, 2970 insertions(+), 2752 deletions(-) create mode 100644 src/diffusers/models/controlnets/__init__.py create mode 100644 src/diffusers/models/controlnets/controlnet.py rename src/diffusers/models/{ => controlnets}/controlnet_flax.py (98%) create mode 100644 src/diffusers/models/controlnets/controlnet_flux.py rename src/diffusers/models/{ => controlnets}/controlnet_hunyuan.py (98%) create mode 100644 src/diffusers/models/controlnets/controlnet_sd3.py create mode 100644 src/diffusers/models/controlnets/controlnet_sparsectrl.py rename src/diffusers/models/{ => controlnets}/controlnet_xs.py (99%) create mode 100644 src/diffusers/models/controlnets/multicontrolnet.py diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md index 966a0e53b496..5d4cac6658cc 100644 --- a/docs/source/en/api/models/controlnet.md +++ b/docs/source/en/api/models/controlnet.md @@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro ## ControlNetOutput -[[autodoc]] models.controlnet.ControlNetOutput +[[autodoc]] models.controlnets.controlnet.ControlNetOutput ## FlaxControlNetModel @@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro ## FlaxControlNetOutput -[[autodoc]] models.controlnet_flax.FlaxControlNetOutput +[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md index 59db64546fa2..78564d238eea 100644 --- a/docs/source/en/api/models/controlnet_sd3.md +++ b/docs/source/en/api/models/controlnet_sd3.md @@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di ## SD3ControlNetOutput -[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput +[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 46cabd863dfa..6b1826a1c92d 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -229,11 +229,11 @@ def forward( In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fb6d22084bd6..533aa5de1e87 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -487,7 +487,7 @@ else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] @@ -914,7 +914,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: - from .models.controlnet_flax import FlaxControlNetModel + from .models.controlnets.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 518ab6df65c4..65e2418ac794 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,12 +36,16 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] - _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] - _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] - _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] - _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet"] = ["ControlNetModel"] + _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] + _import_structure["controlnets.controlnet_hunyuan"] = [ + "HunyuanDiT2DControlNetModel", + "HunyuanDiT2DMultiControlNetModel", + ] + _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] + _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] + _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -74,7 +78,7 @@ _import_structure["unets.uvit_2d"] = ["UVit2DModel"] if is_flax_available(): - _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] @@ -94,12 +98,19 @@ ConsistencyDecoderVAE, VQModel, ) - from .controlnet import ControlNetModel - from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel - from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel - from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel - from .controlnet_sparsectrl import SparseControlNetModel - from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel + from .controlnets import ( + ControlNetModel, + ControlNetXSAdapter, + FluxControlNetModel, + FluxMultiControlNetModel, + HunyuanDiT2DControlNetModel, + HunyuanDiT2DMultiControlNetModel, + MultiControlNetModel, + SD3ControlNetModel, + SD3MultiControlNetModel, + SparseControlNetModel, + UNetControlNetXSModel, + ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( @@ -137,7 +148,7 @@ ) if is_flax_available(): - from .controlnet_flax import FlaxControlNetModel + from .controlnets import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index d3ae96605077..174f2b9ada96 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -11,860 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn -from torch.nn import functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders.single_file_model import FromOriginalModelMixin -from ..utils import BaseOutput, logging -from .attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, +from ..utils import deprecate +from .controlnets.controlnet import ( # noqa + BaseOutput, + ControlNetConditioningEmbedding, + ControlNetModel, + ControlNetOutput, + zero_module, ) -from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, - UNetMidBlock2D, - UNetMidBlock2DCrossAttn, - get_down_block, -) -from .unets.unet_2d_condition import UNet2DConditionModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetOutput(BaseOutput): - """ - The output of [`ControlNetModel`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the middle block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): - """ - A ControlNet model. - - Args: - in_channels (`int`, defaults to 4): - The number of channels in the input sample. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, defaults to 0): - The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): - block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, defaults to 2): - The number of layers per block. - downsample_padding (`int`, defaults to 1): - The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, defaults to 1): - The scale factor to use for the mid block. - act_fn (`str`, defaults to "silu"): - The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the normalization. If None, normalization and activation layers is skipped - in post-processing. - norm_eps (`float`, defaults to 1e-5): - The epsilon to use for the normalization. - cross_attention_dim (`int`, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): - The dimension of the attention heads. - use_linear_projection (`bool`, defaults to `False`): - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - num_class_embeds (`int`, *optional*, defaults to 0): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - upcast_attention (`bool`, defaults to `False`): - resnet_time_scale_shift (`str`, defaults to `"default"`): - Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. - projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): - The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when - `class_embed_type="projection"`. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `conditioning_embedding` layer. - global_pool_conditions (`bool`, defaults to `False`): - TODO(Patrick) - unused parameter. - addition_embed_type_num_heads (`int`, defaults to 64): - The number of heads to use for the `TextTimeEmbedding` layer. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 3, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - projection_class_embeddings_input_dim: Optional[int] = None, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - addition_embed_type_num_heads: int = 64, - ): - super().__init__() - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - ) - - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - # control net conditioning embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[i], - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channel = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=mid_block_channel, - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2D": - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - num_layers=0, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False, - ) - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - load_weights_from_unet: bool = True, - conditioning_channels: int = 3, - ): - r""" - Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied - where applicable. - """ - transformer_layers_per_block = ( - unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 - ) - encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None - encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None - addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None - addition_time_embed_dim = ( - unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None - ) - - controlnet = cls( - encoder_hid_dim=encoder_hid_dim, - encoder_hid_dim_type=encoder_hid_dim_type, - addition_embed_type=addition_embed_type, - addition_time_embed_dim=addition_time_embed_dim, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=unet.config.in_channels, - flip_sin_to_cos=unet.config.flip_sin_to_cos, - freq_shift=unet.config.freq_shift, - down_block_types=unet.config.down_block_types, - only_cross_attention=unet.config.only_cross_attention, - block_out_channels=unet.config.block_out_channels, - layers_per_block=unet.config.layers_per_block, - downsample_padding=unet.config.downsample_padding, - mid_block_scale_factor=unet.config.mid_block_scale_factor, - act_fn=unet.config.act_fn, - norm_num_groups=unet.config.norm_num_groups, - norm_eps=unet.config.norm_eps, - cross_attention_dim=unet.config.cross_attention_dim, - attention_head_dim=unet.config.attention_head_dim, - num_attention_heads=unet.config.num_attention_heads, - use_linear_projection=unet.config.use_linear_projection, - class_embed_type=unet.config.class_embed_type, - num_class_embeds=unet.config.num_class_embeds, - upcast_attention=unet.config.upcast_attention, - resnet_time_scale_shift=unet.config.resnet_time_scale_shift, - projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, - mid_block_type=unet.config.mid_block_type, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - if load_weights_from_unet: - controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) - controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) - controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - - if controlnet.class_embedding: - controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) - - if hasattr(controlnet, "add_embedding"): - controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) - - controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) - controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) - - return controlnet - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: - """ - The [`ControlNetModel`] forward method. - - Args: - sample (`torch.Tensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - guess_mode (`bool`, defaults to `False`): - In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if - you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - - Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is - returned where the first element is the sample tensor. - """ - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - if self.config.addition_embed_type is not None: - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - - elif self.config.addition_embed_type == "text_time": - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - - emb = emb + aug_emb if aug_emb is not None else emb - - # 2. pre-process - sample = self.conv_in(sample) - - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + controlnet_cond - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample = self.mid_block(sample, emb) - - # 5. Control net blocks - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - # 6. scaling - if guess_mode and not self.config.global_pool_conditions: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample = mid_block_res_sample * scales[-1] # last one - else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] - mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) +class ControlNetOutput(ControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead." + deprecate("ControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - return ControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) +class ControlNetModel(ControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead." + deprecate("ControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module +class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead." + deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 961e30155a3d..9b256239d712 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -12,525 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin -from ..models.attention_processor import AttentionProcessor -from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..utils import deprecate, logging +from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class FluxControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - controlnet_single_block_samples: Tuple[torch.Tensor] - - -class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], - num_mode: int = None, - conditioning_embedding_channels: int = None, - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_single_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.controlnet_single_blocks = nn.ModuleList([]) - for _ in range(len(self.single_transformer_blocks)): - self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.union = num_mode is not None - if self.union: - self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - - if conditioning_embedding_channels is not None: - self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) - ) - self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - else: - self.input_hint_block = None - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self): - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @classmethod - def from_transformer( - cls, - transformer, - num_layers: int = 4, - num_single_layers: int = 10, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - load_weights_from_transformer=True, - ): - config = transformer.config - config["num_layers"] = num_layers - config["num_single_layers"] = num_single_layers - config["attention_head_dim"] = attention_head_dim - config["num_attention_heads"] = num_attention_heads - - controlnet = cls(**config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - controlnet.single_transformer_blocks.load_state_dict( - transformer.single_transformer_blocks.state_dict(), strict=False - ) - - controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) - - return controlnet - - def forward( - self, - hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - controlnet_mode: torch.Tensor = None, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - controlnet_mode (`torch.Tensor`): - The mode tensor of shape `(batch_size, 1)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - if self.input_hint_block is not None: - controlnet_cond = self.input_hint_block(controlnet_cond) - batch_size, channels, height_pw, width_pw = controlnet_cond.shape - height = height_pw // self.config.patch_size - width = width_pw // self.config.patch_size - controlnet_cond = controlnet_cond.reshape( - batch_size, channels, height, self.config.patch_size, width, self.config.patch_size - ) - controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) - controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) - # add - hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if self.union: - # union mode - if controlnet_mode is None: - raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") - # union mode emb - controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) - encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) - txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - block_samples = () - for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - block_samples = block_samples + (hidden_states,) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - single_block_samples = () - for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) - - # controlnet block - controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): - block_sample = controlnet_block(block_sample) - controlnet_block_samples = controlnet_block_samples + (block_sample,) - - controlnet_single_block_samples = () - for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): - single_block_sample = controlnet_block(single_block_sample) - controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) - - # scaling - controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] - controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] - - controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples - controlnet_single_block_samples = ( - None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_samples, controlnet_single_block_samples) - - return FluxControlNetOutput( - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - ) - - -class FluxMultiControlNetModel(ModelMixin): - r""" - `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel - - This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be - compatible with `FluxControlNetModel`. - - Args: - controlnets (`List[FluxControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `FluxControlNetModel` as a list. - """ - - def __init__(self, controlnets): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: List[torch.tensor], - controlnet_mode: List[torch.tensor], - conditioning_scale: List[float], - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[FluxControlNetOutput, Tuple]: - # ControlNet-Union with multiple conditions - # only load one ControlNet for saving memories - if len(self.nets) == 1 and self.nets[0].union: - controlnet = self.nets[0] - - for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): - block_samples, single_block_samples = controlnet( - hidden_states=hidden_states, - controlnet_cond=image, - controlnet_mode=mode[:, None], - conditioning_scale=scale, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_projections, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - img_ids=img_ids, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) - - # merge samples - if i == 0: - control_block_samples = block_samples - control_single_block_samples = single_block_samples - else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] +class FluxControlNetOutput(FluxControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead." + deprecate("FluxControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] - # Regular Multi-ControlNets - # load all ControlNets into memories - else: - for i, (image, mode, scale, controlnet) in enumerate( - zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) - ): - block_samples, single_block_samples = controlnet( - hidden_states=hidden_states, - controlnet_cond=image, - controlnet_mode=mode[:, None], - conditioning_scale=scale, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_projections, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - img_ids=img_ids, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) +class FluxControlNetModel(FluxControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead." + deprecate("FluxControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - # merge samples - if i == 0: - control_block_samples = block_samples - control_single_block_samples = single_block_samples - else: - if block_samples is not None and control_block_samples is not None: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] - if single_block_samples is not None and control_single_block_samples is not None: - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] - return control_block_samples, control_single_block_samples +class FluxMultiControlNetModel(FluxMultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead." + deprecate("FluxMultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 43b52a645a0d..5e70559e9ac4 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -13,410 +13,29 @@ # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalModelMixin, PeftAdapterMixin -from ..models.attention import JointTransformerBlock -from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 -from ..models.modeling_outputs import Transformer2DModelOutput -from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module -from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..utils import deprecate, logging +from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class SD3ControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - - -class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int = 128, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 18, - attention_head_dim: int = 64, - num_attention_heads: int = 18, - joint_attention_dim: int = 4096, - caption_projection_dim: int = 1152, - pooled_projection_dim: int = 2048, - out_channels: int = 16, - pos_embed_max_size: int = 96, - extra_conditioning_channels: int = 0, - ): - super().__init__() - default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=self.inner_dim, - pos_embed_max_size=pos_embed_max_size, - ) - self.time_text_embed = CombinedTimestepTextProjEmbeddings( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) - - # `attention_head_dim` is doubled to account for the mixing. - # It needs to crafted when we get the actual checkpoints. - self.transformer_blocks = nn.ModuleList( - [ - JointTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - context_pre_only=False, - ) - for i in range(num_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) - controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks.append(controlnet_block) - pos_embed_input = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels + extra_conditioning_channels, - embed_dim=self.inner_dim, - pos_embed_type=None, - ) - self.pos_embed_input = zero_module(pos_embed_input) - - self.gradient_checkpointing = False - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedJointAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @classmethod - def from_transformer( - cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True - ): - config = transformer.config - config["num_layers"] = num_layers or config.num_layers - config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls(**config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - - controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) - - return controlnet - - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`SD3Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - # add - hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) - - block_res_samples = () - - for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - - block_res_samples = block_res_samples + (hidden_states,) - - controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): - block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - - # 6. scaling - controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_res_samples,) - - return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) - - -class SD3MultiControlNetModel(ModelMixin): - r""" - `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet - - This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be - compatible with `SD3ControlNetModel`. - - Args: - controlnets (`List[SD3ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `SD3ControlNetModel` as a list. - """ +class SD3ControlNetOutput(SD3ControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead." + deprecate("SD3ControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - def __init__(self, controlnets): - super().__init__() - self.nets = nn.ModuleList(controlnets) - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - pooled_projections: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[SD3ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - block_samples = controlnet( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - pooled_projections=pooled_projections, - controlnet_cond=image, - conditioning_scale=scale, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) +class SD3ControlNetModel(SD3ControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead." + deprecate("SD3ControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - # merge samples - if i == 0: - control_block_samples = block_samples - else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) - ] - control_block_samples = (tuple(control_block_samples),) - return control_block_samples +class SD3MultiControlNetModel(SD3MultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead." + deprecate("SD3MultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index fa37e1f9e393..1ccbd385b9a6 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -12,777 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import torch -from torch import nn -from torch.nn import functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalModelMixin -from ..utils import BaseOutput, logging -from .attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, +from ..utils import deprecate, logging +from .controlnets.controlnet_sparsectrl import ( # noqa + SparseControlNetConditioningEmbedding, + SparseControlNetModel, + SparseControlNetOutput, + zero_module, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn -from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class SparseControlNetOutput(BaseOutput): - """ - The output of [`SparseControlNetModel`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the middle block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class SparseControlNetConditioningEmbedding(nn.Module): - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning: torch.Tensor) -> torch.Tensor: - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - return embedding - - -class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): - """ - A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion - Models](https://arxiv.org/abs/2311.16933). - - Args: - in_channels (`int`, defaults to 4): - The number of channels in the input sample. - conditioning_channels (`int`, defaults to 4): - The number of input channels in the controlnet conditional embedding module. If - `concat_condition_embedding` is True, the value provided here is incremented by 1. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, defaults to 0): - The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): - block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, defaults to 2): - The number of layers per block. - downsample_padding (`int`, defaults to 1): - The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, defaults to 1): - The scale factor to use for the mid block. - act_fn (`str`, defaults to "silu"): - The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the normalization. If None, normalization and activation layers is skipped - in post-processing. - norm_eps (`float`, defaults to 1e-5): - The epsilon to use for the normalization. - cross_attention_dim (`int`, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer layers to use in each layer in the middle block. - attention_head_dim (`int` or `Tuple[int]`, defaults to 8): - The dimension of the attention heads. - num_attention_heads (`int` or `Tuple[int]`, *optional*): - The number of heads to use for multi-head attention. - use_linear_projection (`bool`, defaults to `False`): - upcast_attention (`bool`, defaults to `False`): - resnet_time_scale_shift (`str`, defaults to `"default"`): - Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. - conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `conditioning_embedding` layer. - global_pool_conditions (`bool`, defaults to `False`): - TODO(Patrick) - unused parameter - controlnet_conditioning_channel_order (`str`, defaults to `rgb`): - motion_max_seq_length (`int`, defaults to `32`): - The maximum sequence length to use in the motion module. - motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): - The number of heads to use in each attention layer of the motion module. - concat_conditioning_mask (`bool`, defaults to `True`): - use_simplified_condition_embedding (`bool`, defaults to `True`): - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "DownBlockMotion", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 768, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, - temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - controlnet_conditioning_channel_order: str = "rgb", - motion_max_seq_length: int = 32, - motion_num_attention_heads: int = 8, - concat_conditioning_mask: bool = True, - use_simplified_condition_embedding: bool = True, - ): - super().__init__() - self.use_simplified_condition_embedding = use_simplified_condition_embedding - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) - - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - if concat_conditioning_mask: - conditioning_channels = conditioning_channels + 1 - - self.concat_conditioning_mask = concat_conditioning_mask - - # control net conditioning embedding - if use_simplified_condition_embedding: - self.controlnet_cond_embedding = zero_module( - nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - ) - else: - self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - ) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(motion_num_attention_heads, int): - motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - if down_block_type == "CrossAttnDownBlockMotion": - down_block = CrossAttnDownBlockMotion( - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - dropout=0, - num_layers=layers_per_block, - transformer_layers_per_block=transformer_layers_per_block[i], - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim[i], - add_downsample=not is_final_block, - dual_cross_attention=False, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, - ) - elif down_block_type == "DownBlockMotion": - down_block = DownBlockMotion( - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - dropout=0, - num_layers=layers_per_block, - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - add_downsample=not is_final_block, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, - ) - else: - raise ValueError( - "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" - ) - - self.down_blocks.append(down_block) - - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channels = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - if transformer_layers_per_mid_block is None: - transformer_layers_per_mid_block = ( - transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 - ) - - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=mid_block_channels, - temb_channels=time_embed_dim, - dropout=0, - num_layers=1, - transformer_layers_per_block=transformer_layers_per_mid_block, - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - num_attention_heads=num_attention_heads[-1], - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - dual_cross_attention=False, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type="default", - ) - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - load_weights_from_unet: bool = True, - conditioning_channels: int = 3, - ) -> "SparseControlNetModel": - r""" - Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also - copied where applicable. - """ - transformer_layers_per_block = ( - unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 - ) - down_block_types = unet.config.down_block_types - - for i in range(len(down_block_types)): - if "CrossAttn" in down_block_types[i]: - down_block_types[i] = "CrossAttnDownBlockMotion" - elif "Down" in down_block_types[i]: - down_block_types[i] = "DownBlockMotion" - else: - raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") - - controlnet = cls( - in_channels=unet.config.in_channels, - conditioning_channels=conditioning_channels, - flip_sin_to_cos=unet.config.flip_sin_to_cos, - freq_shift=unet.config.freq_shift, - down_block_types=unet.config.down_block_types, - only_cross_attention=unet.config.only_cross_attention, - block_out_channels=unet.config.block_out_channels, - layers_per_block=unet.config.layers_per_block, - downsample_padding=unet.config.downsample_padding, - mid_block_scale_factor=unet.config.mid_block_scale_factor, - act_fn=unet.config.act_fn, - norm_num_groups=unet.config.norm_num_groups, - norm_eps=unet.config.norm_eps, - cross_attention_dim=unet.config.cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=unet.config.attention_head_dim, - num_attention_heads=unet.config.num_attention_heads, - use_linear_projection=unet.config.use_linear_projection, - upcast_attention=unet.config.upcast_attention, - resnet_time_scale_shift=unet.config.resnet_time_scale_shift, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - ) - - if load_weights_from_unet: - controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) - controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) - controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) - controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) - controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) - - return controlnet - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - conditioning_mask: Optional[torch.Tensor] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: - """ - The [`SparseControlNetModel`] forward method. - - Args: - sample (`torch.Tensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - guess_mode (`bool`, defaults to `False`): - In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if - you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is - returned where the first element is the sample tensor. - """ - sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape - sample = torch.zeros_like(sample) - - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(sample_num_frames, dim=0) - - # 2. pre-process - batch_size, channels, num_frames, height, width = sample.shape - - sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - sample = self.conv_in(sample) - - batch_frames, channels, height, width = sample.shape - sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) - - if self.concat_conditioning_mask: - controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) - - batch_size, channels, num_frames, height, width = controlnet_cond.shape - controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frames, channels, height, width - ) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - batch_frames, channels, height, width = controlnet_cond.shape - controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) - - sample = sample + controlnet_cond - - batch_size, num_frames, channels, height, width = sample.shape - sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample = self.mid_block(sample, emb) - - # 5. Control net blocks - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - if guess_mode and not self.config.global_pool_conditions: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample = mid_block_res_sample * scales[-1] # last one - else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] - mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) +class SparseControlNetOutput(SparseControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead." + deprecate("SparseControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - return SparseControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) +class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead." + deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message) + super().__init__(*args, **kwargs) -# Copied from diffusers.models.controlnet.zero_module -def zero_module(module: nn.Module) -> nn.Module: - for p in module.parameters(): - nn.init.zeros_(p) - return module +class SparseControlNetModel(SparseControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead." + deprecate("SparseControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py new file mode 100644 index 000000000000..3e4b3561e839 --- /dev/null +++ b/src/diffusers/models/controlnets/__init__.py @@ -0,0 +1,22 @@ +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .controlnet import ControlNetModel, ControlNetOutput + from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel + from .controlnet_hunyuan import ( + HunyuanControlNetOutput, + HunyuanDiT2DControlNetModel, + HunyuanDiT2DMultiControlNetModel, + ) + from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel + from .controlnet_sparsectrl import ( + SparseControlNetConditioningEmbedding, + SparseControlNetModel, + SparseControlNetOutput, + ) + from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .multicontrolnet import MultiControlNetModel + +if is_flax_available(): + from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py new file mode 100644 index 000000000000..bd00f6dd1906 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet.py @@ -0,0 +1,872 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import BaseOutput, logging +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from ..unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain + tuple. + + Returns: + [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py similarity index 98% rename from src/diffusers/models/controlnet_flax.py rename to src/diffusers/models/controlnets/controlnet_flax.py index 0540850a9e61..ab8d9b5f8cbb 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnets/controlnet_flax.py @@ -19,11 +19,11 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .modeling_flax_utils import FlaxModelMixin -from .unets.unet_2d_blocks_flax import ( +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin +from ..unets.unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py new file mode 100644 index 000000000000..e6a3eceed9b4 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -0,0 +1,536 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.attention_processor import AttentionProcessor +from ...models.modeling_utils import ModelMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FluxControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + controlnet_single_block_samples: Tuple[torch.Tensor] + + +class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, + conditioning_embedding_channels: int = None, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.union = num_mode is not None + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + num_single_layers: int = 10, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = transformer.config + config["num_layers"] = num_layers + config["num_single_layers"] = num_single_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + controlnet.single_transformer_blocks.load_state_dict( + transformer.single_transformer_blocks.state_dict(), strict=False + ) + + controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) + + return controlnet + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) + # add + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + block_samples = block_samples + (hidden_states,) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + single_block_samples = () + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + + # controlnet block + controlnet_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + + controlnet_single_block_samples = () + for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + single_block_sample = controlnet_block(single_block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) + + # scaling + controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] + controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] + + controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples + controlnet_single_block_samples = ( + None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_samples, controlnet_single_block_samples) + + return FluxControlNetOutput( + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) + + +class FluxMultiControlNetModel(ModelMixin): + r""" + `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel + + This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be + compatible with `FluxControlNetModel`. + + Args: + controlnets (`List[FluxControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `FluxControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + controlnet_mode: List[torch.tensor], + conditioning_scale: List[float], + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[FluxControlNetOutput, Tuple]: + # ControlNet-Union with multiple conditions + # only load one ControlNet for saving memories + if len(self.nets) == 1 and self.nets[0].union: + controlnet = self.nets[0] + + for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + # Regular Multi-ControlNets + # load all ControlNets into memories + else: + for i, (image, mode, scale, controlnet) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) + ): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + return control_block_samples, control_single_block_samples diff --git a/src/diffusers/models/controlnet_hunyuan.py b/src/diffusers/models/controlnets/controlnet_hunyuan.py similarity index 98% rename from src/diffusers/models/controlnet_hunyuan.py rename to src/diffusers/models/controlnets/controlnet_hunyuan.py index 4277d81d1cb9..f2aa34d2d056 100644 --- a/src/diffusers/models/controlnet_hunyuan.py +++ b/src/diffusers/models/controlnets/controlnet_hunyuan.py @@ -17,17 +17,17 @@ import torch from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging -from .attention_processor import AttentionProcessor -from .controlnet import BaseOutput, Tuple, zero_module -from .embeddings import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention_processor import AttentionProcessor +from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection, ) -from .modeling_utils import ModelMixin -from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock +from ..modeling_utils import ModelMixin +from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock +from .controlnet import BaseOutput, Tuple, zero_module logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py new file mode 100644 index 000000000000..911d65e03d88 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -0,0 +1,422 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..attention import JointTransformerBlock +from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from .controlnet import BaseOutput, zero_module + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SD3ControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + + +class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + extra_conditioning_channels: int = 0, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=False, + ) + for i in range(num_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + pos_embed_input = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels + extra_conditioning_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + self.pos_embed_input = zero_module(pos_embed_input) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + ): + config = transformer.config + config["num_layers"] = num_layers or config.num_layers + config["extra_conditioning_channels"] = num_extra_conditioning_channels + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + + controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) + + return controlnet + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # add + hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) + + block_res_samples = () + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + block_res_samples = block_res_samples + (hidden_states,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + # 6. scaling + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_res_samples,) + + return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) + + +class SD3MultiControlNetModel(ModelMixin): + r""" + `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet + + This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be + compatible with `SD3ControlNetModel`. + + Args: + controlnets (`List[SD3ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `SD3ControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + pooled_projections: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[SD3ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + block_samples = controlnet( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_projections, + controlnet_cond=image, + conditioning_scale=scale, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) + ] + control_block_samples = (tuple(control_block_samples),) + + return control_block_samples diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py new file mode 100644 index 000000000000..fd599c10b2d7 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -0,0 +1,788 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import BaseOutput, logging +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import UNetMidBlock2DCrossAttn +from ..unets.unet_2d_condition import UNet2DConditionModel +from ..unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SparseControlNetOutput(BaseOutput): + """ + The output of [`SparseControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class SparseControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning: torch.Tensor) -> torch.Tensor: + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + return embedding + + +class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion + Models](https://arxiv.org/abs/2311.16933). + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + conditioning_channels (`int`, defaults to 4): + The number of input channels in the controlnet conditional embedding module. If + `concat_condition_embedding` is True, the value provided here is incremented by 1. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each layer in the middle block. + attention_head_dim (`int` or `Tuple[int]`, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of heads to use for multi-head attention. + use_linear_projection (`bool`, defaults to `False`): + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter + controlnet_conditioning_channel_order (`str`, defaults to `rgb`): + motion_max_seq_length (`int`, defaults to `32`): + The maximum sequence length to use in the motion module. + motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): + The number of heads to use in each attention layer of the motion module. + concat_conditioning_mask (`bool`, defaults to `True`): + use_simplified_condition_embedding (`bool`, defaults to `True`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 768, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + controlnet_conditioning_channel_order: str = "rgb", + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + concat_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = True, + ): + super().__init__() + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + if concat_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + + self.concat_conditioning_mask = concat_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + ) + else: + self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_downsample=not is_final_block, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + else: + raise ValueError( + "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channels = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=mid_block_channels, + temb_channels=time_embed_dim, + dropout=0, + num_layers=1, + transformer_layers_per_block=transformer_layers_per_mid_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[-1], + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type="default", + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ) -> "SparseControlNetModel": + r""" + Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also + copied where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + down_block_types = unet.config.down_block_types + + for i in range(len(down_block_types)): + if "CrossAttn" in down_block_types[i]: + down_block_types[i] = "CrossAttnDownBlockMotion" + elif "Down" in down_block_types[i]: + down_block_types[i] = "DownBlockMotion" + else: + raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") + + controlnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + conditioning_mask: Optional[torch.Tensor] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`SparseControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape + sample = torch.zeros_like(sample) + + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(sample_num_frames, dim=0) + + # 2. pre-process + batch_size, channels, num_frames, height, width = sample.shape + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + sample = self.conv_in(sample) + + batch_frames, channels, height, width = sample.shape + sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) + + if self.concat_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + + batch_size, channels, num_frames, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + batch_frames, channels, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) + + sample = sample + controlnet_cond + + batch_size, num_frames, channels, height, width = sample.shape + sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return SparseControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +# Copied from diffusers.models.controlnets.controlnet.zero_module +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py similarity index 99% rename from src/diffusers/models/controlnet_xs.py rename to src/diffusers/models/controlnets/controlnet_xs.py index f676a70f060a..06e0eda3c3b0 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -19,10 +19,10 @@ import torch.utils.checkpoint from torch import Tensor, nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, is_torch_version, logging -from ..utils.torch_utils import apply_freeu -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, @@ -31,10 +31,9 @@ AttnProcessor, FusedAttnProcessor2_0, ) -from .controlnet import ControlNetConditioningEmbedding -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import ( +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, Downsample2D, @@ -43,7 +42,8 @@ UNetMidBlock2DCrossAttn, Upsample2D, ) -from .unets.unet_2d_condition import UNet2DConditionModel +from ..unets.unet_2d_condition import UNet2DConditionModel +from .controlnet import ControlNetConditioningEmbedding logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1062,7 +1062,8 @@ def forward( added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain + tuple. apply_control (`bool`, defaults to `True`): If `False`, the input is run only through the base model. diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py new file mode 100644 index 000000000000..46c3d1681cc1 --- /dev/null +++ b/src/diffusers/models/controlnets/multicontrolnet.py @@ -0,0 +1,183 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" + controlnet.save_pretrained( + save_directory + suffix, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 8b037cdc34fb..6dde7d6686ee 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel -from ...models.controlnet_sparsectrl import SparseControlNetModel +from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index e3c5ec6eed03..33790c10e064 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -1,183 +1,12 @@ -import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn - -from ...models.controlnet import ControlNetModel, ControlNetOutput -from ...models.modeling_utils import ModelMixin -from ...utils import logging +from ...models.controlnets.multicontrolnet import MultiControlNetModel +from ...utils import deprecate, logging logger = logging.get_logger(__name__) -class MultiControlNetModel(ModelMixin): - r""" - Multiple `ControlNetModel` wrapper class for Multi-ControlNet - - This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be - compatible with `ControlNetModel`. - - Args: - controlnets (`List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `ControlNetModel` as a list. - """ - - def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - down_samples, mid_sample = controlnet( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=image, - conditioning_scale=scale, - class_labels=class_labels, - timestep_cond=timestep_cond, - attention_mask=attention_mask, - added_cond_kwargs=added_cond_kwargs, - cross_attention_kwargs=cross_attention_kwargs, - guess_mode=guess_mode, - return_dict=return_dict, - ) - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - save_function: Callable = None, - safe_serialization: bool = True, - variant: Optional[str] = None, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - """ - for idx, controlnet in enumerate(self.nets): - suffix = "" if idx == 0 else f"_{idx}" - controlnet.save_pretrained( - save_directory + suffix, - is_main_process=is_main_process, - save_function=save_function, - safe_serialization=safe_serialization, - variant=variant, - ) - - @classmethod - def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_path (`os.PathLike`): - A path to a *directory* containing model weights saved using - [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., - `./my_model_directory/controlnet`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each - GPU and the available CPU RAM if unset. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from - `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. - """ - idx = 0 - controlnets = [] - - # load controlnet and append to list until no controlnet directory exists anymore - # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` - # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... - model_path_to_load = pretrained_model_path - while os.path.isdir(model_path_to_load): - controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) - controlnets.append(controlnet) - - idx += 1 - model_path_to_load = pretrained_model_path + f"_{idx}" - - logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") - - if len(controlnets) == 0: - raise ValueError( - f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." - ) - - return cls(controlnets) +class MultiControlNetModel(MultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead." + deprecate("MultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 9f674d2d7897..a589821c1f98 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index f362c8f3d0c1..437bb9f2f182 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9f33e26013d5..771150b085d5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -27,7 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 810c970ab715..04582b71d780 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -13,7 +13,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1f5f83561f1c..947e97e272f8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -14,7 +14,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 295a94c1d2e4..12f31aec678b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -31,7 +31,7 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor -from diffusers.models.controlnet_xs import UNetControlNetXSModel +from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel From 5588725e8e7be497839432e5328c596169385f16 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 7 Nov 2024 03:33:39 +0100 Subject: [PATCH 04/11] [Flux] reduce explicit device transfers and typecasting in flux. (#9817) reduce explicit device transfers and typecasting in flux. --- src/diffusers/pipelines/flux/pipeline_flux.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++-- .../flux/pipeline_flux_controlnet_image_to_image.py | 6 +++--- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_img2img.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 6 +++--- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 040d935f1b88..ab4e0fc4d255 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -371,7 +371,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -427,7 +427,7 @@ def check_inputs( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 771150b085d5..9965ffe42bea 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -452,7 +452,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -462,7 +462,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 04582b71d780..937422e1b60d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -407,7 +407,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -495,7 +495,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -505,7 +505,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 947e97e272f8..83cc59c0b1f7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -417,7 +417,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -522,7 +522,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -532,7 +532,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 47f9f268ee9d..aa1a3e7fc3a4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -391,7 +391,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -479,7 +479,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -489,7 +489,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 766f9864839e..97824258b28f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -395,7 +395,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -500,7 +500,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -510,7 +510,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents From 1b392544c758e45cc7097cc35309cb8cc11798e4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 8 Nov 2024 17:49:00 +0530 Subject: [PATCH 05/11] Improve downloads of sharded variants (#9869) * update * update * update * update --------- Co-authored-by: Sayak Paul --- .../pipelines/pipeline_loading_utils.py | 29 +++- tests/pipelines/test_pipeline_utils.py | 131 +++++++++++++++++- 2 files changed, 155 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 5eba1952e608..0a7a222ec007 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -198,10 +198,31 @@ def convert_to_variant(filename): variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename - for f in non_variant_filenames: - variant_filename = convert_to_variant(f) - if variant_filename not in usable_filenames: - usable_filenames.add(f) + def find_component(filename): + if not len(filename.split("/")) == 2: + return + component = filename.split("/")[0] + return component + + def has_sharded_variant(component, variant, variant_filenames): + # If component exists check for sharded variant index filename + # If component doesn't exist check main dir for sharded variant index filename + component = component + "/" if component else "" + variant_index_re = re.compile( + rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + return any(f for f in variant_filenames if variant_index_re.match(f) is not None) + + for filename in non_variant_filenames: + if convert_to_variant(filename) in variant_filenames: + continue + + component = find_component(filename) + # If a sharded variant exists skip adding to allowed patterns + if has_sharded_variant(component, variant, variant_filenames): + continue + + usable_filenames.add(filename) return usable_filenames, variant_filenames diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index bb3bdc273cc4..acf7d9d8401b 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings from diffusers.utils.testing_utils import torch_device @@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self): self.assertFalse(is_safetensors_compatible(filenames)) +class VariantCompatibleSiblingsTest(unittest.TestCase): + def test_only_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_only_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_downloaded(self): + variant = "fp16" + non_variant_file = "text_encoder/model.safetensors" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_non_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_in_main_dir_downloaded(self): + variant = "fp16" + non_variant_file = "model.safetensors" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_sharded_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_sharded_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_sharded_mixed_variants_downloaded(self): + variant = "fp16" + allowed_non_variant = "unet" + filenames = [ + f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): cross_attention_dim = 8 From 0be52c07d6b9b49245b616f9738e52bcf58cd9fe Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Sat, 9 Nov 2024 00:02:32 +0530 Subject: [PATCH 06/11] [fix] Replaced shutil.copy with shutil.copyfile (#9885) fix shutil.copy --- src/diffusers/utils/dynamic_modules_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index f0cf953924ad..50d9bbaac57c 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -325,7 +325,7 @@ def get_cached_module_file( # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path. - shutil.copy(resolved_module_file, submodule_path / module_file) + shutil.copyfile(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: if len(module_needed.split(".")) == 2: module_needed = "/".join(module_needed.split(".")) @@ -333,7 +333,7 @@ def get_cached_module_file( if not os.path.exists(submodule_path / module_folder): os.makedirs(submodule_path / module_folder) module_needed = f"{module_needed}.py" - shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. @@ -350,7 +350,7 @@ def get_cached_module_file( module_folder = module_file.split("/")[0] if not os.path.exists(submodule_path / module_folder): os.makedirs(submodule_path / module_folder) - shutil.copy(resolved_module_file, submodule_path / module_file) + shutil.copyfile(resolved_module_file, submodule_path / module_file) # Make sure we also have every file with relative for module_needed in modules_needed: From 5b972fbd6a6c50cf1afdf1ba34c34d84fc67861c Mon Sep 17 00:00:00 2001 From: Michael Tkachuk <61463055+MikeTkachuk@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:03:26 -0500 Subject: [PATCH 07/11] Enabling gradient checkpointing in eval() mode (#9878) * refactored --- examples/community/matryoshka.py | 8 +++--- .../pixart/controlnet_pixart_alpha.py | 2 +- .../autoencoders/autoencoder_kl_allegro.py | 4 +-- .../autoencoders/autoencoder_kl_cogvideox.py | 10 +++---- .../autoencoders/autoencoder_kl_mochi.py | 10 +++---- .../autoencoder_kl_temporal_decoder.py | 2 +- src/diffusers/models/autoencoders/vae.py | 10 +++---- .../models/controlnets/controlnet_flux.py | 4 +-- .../models/controlnets/controlnet_sd3.py | 2 +- .../models/controlnets/controlnet_xs.py | 6 ++--- .../transformers/auraflow_transformer_2d.py | 4 +-- .../transformers/cogvideox_transformer_3d.py | 2 +- .../models/transformers/dit_transformer_2d.py | 2 +- .../transformers/latte_transformer_3d.py | 4 +-- .../transformers/pixart_transformer_2d.py | 2 +- .../transformers/stable_audio_transformer.py | 2 +- .../models/transformers/transformer_2d.py | 2 +- .../transformers/transformer_allegro.py | 2 +- .../transformers/transformer_cogview3plus.py | 2 +- .../models/transformers/transformer_flux.py | 4 +-- .../models/transformers/transformer_mochi.py | 2 +- .../models/transformers/transformer_sd3.py | 2 +- .../transformers/transformer_temporal.py | 2 +- src/diffusers/models/unets/unet_2d_blocks.py | 26 +++++++++---------- src/diffusers/models/unets/unet_3d_blocks.py | 10 +++---- .../models/unets/unet_motion_model.py | 10 +++---- .../models/unets/unet_stable_cascade.py | 4 +-- src/diffusers/models/unets/uvit_2d.py | 2 +- .../pipelines/audioldm2/modeling_audioldm2.py | 6 ++--- .../blip_diffusion/modeling_blip2.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 10 +++---- .../pipelines/kolors/text_encoder.py | 4 +-- .../pipeline_latent_diffusion.py | 2 +- .../wuerstchen/modeling_wuerstchen_prior.py | 2 +- 34 files changed, 84 insertions(+), 84 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7ac0ab542910..0c85ad118752 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -868,7 +868,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1029,7 +1029,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1191,7 +1191,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1364,7 +1364,7 @@ def forward( # Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index b7f5a427e52e..f825719a1364 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -215,7 +215,7 @@ def forward( # 2. Blocks for block_index, block in enumerate(self.transformer.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: # rc todo: for training and gradient checkpointing print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") exit(1) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 922fd15c08fb..b62ed67ade29 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.temp_conv_in(sample) sample = sample + residual - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 8575c7658605..d9ee15062daf 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -420,7 +420,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -522,7 +522,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -636,7 +636,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -773,7 +773,7 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -939,7 +939,7 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 57e8b8f647ba..0eabf3a26d7c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -206,7 +206,7 @@ def forward( for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -311,7 +311,7 @@ def forward( for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -392,7 +392,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -529,7 +529,7 @@ def forward( hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.permute(0, 4, 1, 2, 3) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -646,7 +646,7 @@ def forward( hidden_states = self.conv_in(hidden_states) # 1. Mid - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 55449644ed03..4e3902ae6dbe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -95,7 +95,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index bb80ce8605ba..2f3f4f2fc35c 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_in(sample) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -291,7 +291,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -544,7 +544,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -876,7 +876,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `EncoderTiny` class.""" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Clamp. x = torch.tanh(x / 3) * 3 - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index e6a3eceed9b4..76a97847ef9a 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -329,7 +329,7 @@ def forward( block_samples = () for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -363,7 +363,7 @@ def custom_forward(*inputs): single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 911d65e03d88..209aad93244e 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -324,7 +324,7 @@ def forward( block_res_samples = () for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 06e0eda3c3b0..11ad676ec92b 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1466,7 +1466,7 @@ def custom_forward(*inputs): h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_base = torch.utils.checkpoint.checkpoint( create_custom_forward(b_res), @@ -1489,7 +1489,7 @@ def custom_forward(*inputs): # apply ctrl subblock if apply_control: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_ctrl = torch.utils.checkpoint.checkpoint( create_custom_forward(c_res), @@ -1898,7 +1898,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index ad64df0c0790..b3f29e6b6224 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -466,7 +466,7 @@ def forward( # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -497,7 +497,7 @@ def custom_forward(*inputs): combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 821da6d032d5..01c54ef090bd 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -452,7 +452,7 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 9f8957737dbc..f787c5279499 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -184,7 +184,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 71d19216e5ff..7e2b1273687d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -238,7 +238,7 @@ def forward( for i, (spatial_block, temp_block) in enumerate( zip(self.transformer_blocks, self.temporal_transformer_blocks) ): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, @@ -271,7 +271,7 @@ def forward( if i == 0 and num_frame > 1: hidden_states = hidden_states + self.temp_pos_embed - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 1e5cd5794517..7f145edf16fb 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -386,7 +386,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index e3462b51a412..d687dbabf317 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -414,7 +414,7 @@ def forward( attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index c7c19e4582c6..e208a1c10ed4 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -415,7 +415,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index f756399a378a..fe9c7290b063 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -371,7 +371,7 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 962cbbff7c1b..94d852f6df4b 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -341,7 +341,7 @@ def forward( hidden_states = hidden_states[:, text_seq_length:] for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f078cace0f3e..0ad3be866019 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -480,7 +480,7 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -525,7 +525,7 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7f4ad2b328fa..8ac8b5dababa 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -350,7 +350,7 @@ def forward( ) for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b28350b8ed9c..f39a102c7256 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -317,7 +317,7 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index c0c5467050dd..6ca42b9745fd 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -340,7 +340,7 @@ def forward( # 2. Blocks for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( block, hidden_states, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 93a0a82cdcff..b9d186ac1aa6 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -859,7 +859,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1257,7 +1257,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1371,7 +1371,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1859,7 +1859,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2011,7 +2011,7 @@ def forward( mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2106,7 +2106,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2215,7 +2215,7 @@ def forward( output_states = () for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2520,7 +2520,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2653,7 +2653,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3183,7 +3183,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3341,7 +3341,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -3444,7 +3444,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3572,7 +3572,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 8b472a89e13d..9c9fd7555899 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1078,7 +1078,7 @@ def forward( ) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1168,7 +1168,7 @@ def forward( ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1281,7 +1281,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1383,7 +1383,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1493,7 +1493,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 6125feba5899..ddc3e41c340d 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -323,7 +323,7 @@ def forward( blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -513,7 +513,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -732,7 +732,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -895,7 +895,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1079,7 +1079,7 @@ def forward( return_dict=False, )[0] - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 7deea9a714d4..238e6b411356 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -455,7 +455,7 @@ def _down_encode(self, x, r_embed, clip): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -504,7 +504,7 @@ def _up_decode(self, level_outputs, r_embed, clip): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 8a379bf5f9c3..2f0b3eb19508 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -181,7 +181,7 @@ def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds hidden_states = self.project_to_hidden(hidden_states) for layer in self.transformer_layers: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def layer_(*args): return checkpoint(layer, *args) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 2af3078f7412..63d3957ae17d 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -1112,7 +1112,7 @@ def forward( ) for i in range(num_layers): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1290,7 +1290,7 @@ def forward( ) for i in range(len(self.resnets[1:])): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1464,7 +1464,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py index 1be4761a9987..0d78b987ce77 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -167,7 +167,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled(): if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 3937e87f63c9..107a5a45bfa2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1595,7 +1595,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1732,7 +1732,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1874,7 +1874,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2033,7 +2033,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2352,7 +2352,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py index 6fb6f18a907a..5eb8d4c43d02 100644 --- a/src/diffusers/pipelines/kolors/text_encoder.py +++ b/src/diffusers/pipelines/kolors/text_encoder.py @@ -590,7 +590,7 @@ def forward( if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -604,7 +604,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache ) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f6f3531a8835..cd63637b6c2f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -675,7 +675,7 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index edb0c1ec45de..f90fc82a98ad 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -158,7 +158,7 @@ def forward(self, x, r, c): c_embed = self.cond_mapper(c) r_embed = self.gen_r_embedding(r) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): From 9cc96a64f11303fc3174929d1cd4ad78609418b1 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Sat, 9 Nov 2024 04:39:24 +0530 Subject: [PATCH 08/11] [FIX] Fix TypeError in DreamBooth SDXL when use_dora is False (#9879) * fix use_dora * fix style and quality * fix use_dora with peft version --------- Co-authored-by: Sayak Paul --- .../dreambooth/train_dreambooth_lora_sdxl.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e621b3caee3..9cd321f6d055 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -67,6 +67,7 @@ convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_unet_state_dict_to_peft, + is_peft_version, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -1183,26 +1184,33 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + def get_lora_config(rank, use_dora, target_modules): + base_config = { + "r": rank, + "lora_alpha": rank, + "init_lora_weights": "gaussian", + "target_modules": target_modules, + } + if use_dora: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + base_config["use_dora"] = True + + return LoraConfig(**base_config) + # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) + unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) + text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) From d720b2132e74a16cd44f98947e667e4a4442adc5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 8 Nov 2024 19:31:43 -0400 Subject: [PATCH 09/11] [Advanced LoRA v1.5] fix: gradient unscaling problem (#7018) fix: gradient unscaling problem Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- .../train_dreambooth_lora_sd15_advanced.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index afe30680567d..5b78501f9b49 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -39,7 +39,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -59,12 +59,13 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, + convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") + lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) From 8d6dc2be5dfe9f54e455d9ca7a6acbd9181fba7b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 8 Nov 2024 19:35:38 -0400 Subject: [PATCH 10/11] Revert "[Flux] reduce explicit device transfers and typecasting in flux." (#9896) Revert "[Flux] reduce explicit device transfers and typecasting in flux. (#9817)" This reverts commit 5588725e8e7be497839432e5328c596169385f16. --- src/diffusers/pipelines/flux/pipeline_flux.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++-- .../flux/pipeline_flux_controlnet_image_to_image.py | 6 +++--- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_img2img.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 6 +++--- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ab4e0fc4d255..040d935f1b88 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -371,7 +371,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -427,7 +427,7 @@ def check_inputs( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9965ffe42bea..771150b085d5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -452,7 +452,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -462,7 +462,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 937422e1b60d..04582b71d780 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -407,7 +407,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -495,7 +495,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -505,7 +505,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 83cc59c0b1f7..947e97e272f8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -417,7 +417,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -522,7 +522,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -532,7 +532,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index aa1a3e7fc3a4..47f9f268ee9d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -391,7 +391,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -479,7 +479,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -489,7 +489,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 97824258b28f..766f9864839e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -395,7 +395,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -500,7 +500,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -510,7 +510,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids + return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents From dac623b59f52c58383a39207d5147aa34e0047cd Mon Sep 17 00:00:00 2001 From: Eliseu Silva Date: Fri, 8 Nov 2024 22:40:51 -0300 Subject: [PATCH 11/11] Feature IP Adapter Xformers Attention Processor (#9881) * Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872 --- src/diffusers/loaders/ip_adapter.py | 14 +- src/diffusers/loaders/unet.py | 13 +- src/diffusers/models/attention_processor.py | 262 +++++++++++++++++++- 3 files changed, 278 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 1006dab9e4b9..49b46c4fc615 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,16 +33,14 @@ if is_transformers_available(): - from transformers import ( - CLIPImageProcessor, - CLIPVisionModelWithProjection, - ) + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, ) logger = logging.get_logger(__name__) @@ -284,7 +282,9 @@ def set_ip_adapter_scale(self, scale): scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) for attn_name, attn_processor in unet.attn_processors.items(): - if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + if isinstance( + attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ): if len(scale_configs) != len(attn_processor.scale): raise ValueError( f"Cannot assign {len(scale_configs)} scale_configs to " @@ -342,7 +342,9 @@ def unload_ip_adapter(self): ) attn_procs[name] = ( attn_processor_class - if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + if isinstance( + value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ) else value.__class__() ) self.unet.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 2fa7732a6a3b..b37b681ae8fe 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, ) if low_cpu_mem_usage: @@ -804,11 +805,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = self.attn_processors[name].__class__ attn_procs[name] = attn_processor_class() - else: - attn_processor_class = ( - IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor - ) + if "XFormers" in str(self.attn_processors[name].__class__): + attn_processor_class = IPAdapterXFormersAttnProcessor + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index da01b7a1edcd..772aae7fcd2f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers( XFormersAttnAddedKVProcessor, ), ) - + is_ip_adapter = hasattr(self, "processor") and isinstance( + self.processor, + (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), + ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and is_custom_diffusion: raise NotImplementedError( @@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + elif is_ip_adapter: + processor = IPAdapterXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_ip_adapter: + processor = IPAdapterAttnProcessor2_0( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses @@ -4542,6 +4570,238 @@ def __call__( return hidden_states +class IPAdapterXFormersAttnProcessor(torch.nn.Module): + r""" + Attention processor for IP-Adapter using xFormers. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + num_tokens=(4,), + scale=1.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if ip_hidden_states: + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate( + zip(ip_adapter_masks, self.scale, ip_hidden_states) + ): + if mask is None: + continue + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + mask = mask.to(torch.float16) + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + + _current_ip_hidden_states = xformers.ops.memory_efficient_attention( + query, ip_key, ip_value, op=self.attention_op + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + + current_ip_hidden_states = xformers.ops.memory_efficient_attention( + query, ip_key, ip_value, op=self.attention_op + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).