diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index acca4b148c58..9647d92754dc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -228,6 +228,8 @@ title: UNet3DConditionModel - local: api/models/unet-motion title: UNetMotionModel + - local: api/models/uvit2d + title: UViT2DModel - local: api/models/vq title: VQModel - local: api/models/autoencoderkl diff --git a/docs/source/en/api/loaders/single_file.md b/docs/source/en/api/loaders/single_file.md index 52e44606455b..62dbc21067c5 100644 --- a/docs/source/en/api/loaders/single_file.md +++ b/docs/source/en/api/loaders/single_file.md @@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta ## FromOriginalVAEMixin -[[autodoc]] loaders.single_file.FromOriginalVAEMixin +[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin ## FromOriginalControlnetMixin -[[autodoc]] loaders.single_file.FromOriginalControlnetMixin \ No newline at end of file +[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin \ No newline at end of file diff --git a/docs/source/en/api/models/unet-motion.md b/docs/source/en/api/models/unet-motion.md index cbc8c30ff64f..af967924dfb3 100644 --- a/docs/source/en/api/models/unet-motion.md +++ b/docs/source/en/api/models/unet-motion.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNetMotionModel ## UNet3DConditionOutput -[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput +[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput diff --git a/docs/source/en/api/models/unet.md b/docs/source/en/api/models/unet.md index 66508b469a60..7e6324952b28 100644 --- a/docs/source/en/api/models/unet.md +++ b/docs/source/en/api/models/unet.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet1DModel ## UNet1DOutput -[[autodoc]] models.unet_1d.UNet1DOutput +[[autodoc]] models.unets.unet_1d.UNet1DOutput diff --git a/docs/source/en/api/models/unet2d-cond.md b/docs/source/en/api/models/unet2d-cond.md index ea385ff92426..ec9dbae8f25e 100644 --- a/docs/source/en/api/models/unet2d-cond.md +++ b/docs/source/en/api/models/unet2d-cond.md @@ -22,10 +22,10 @@ The abstract from the paper is: [[autodoc]] UNet2DConditionModel ## UNet2DConditionOutput -[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput +[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput ## FlaxUNet2DConditionModel -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel +[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel ## FlaxUNet2DConditionOutput -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput +[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput diff --git a/docs/source/en/api/models/unet2d.md b/docs/source/en/api/models/unet2d.md index 7669d4a5d75a..d317d14ce744 100644 --- a/docs/source/en/api/models/unet2d.md +++ b/docs/source/en/api/models/unet2d.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet2DModel ## UNet2DOutput -[[autodoc]] models.unet_2d.UNet2DOutput +[[autodoc]] models.unets.unet_2d.UNet2DOutput diff --git a/docs/source/en/api/models/unet3d-cond.md b/docs/source/en/api/models/unet3d-cond.md index 4eea0a6d1cd2..1dc01234dabe 100644 --- a/docs/source/en/api/models/unet3d-cond.md +++ b/docs/source/en/api/models/unet3d-cond.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet3DConditionModel ## UNet3DConditionOutput -[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput +[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput diff --git a/docs/source/en/api/models/uvit2d.md b/docs/source/en/api/models/uvit2d.md new file mode 100644 index 000000000000..abea0fdc38c3 --- /dev/null +++ b/docs/source/en/api/models/uvit2d.md @@ -0,0 +1,39 @@ + + +# UVit2DModel + +The [U-ViT](https://hf.co/papers/2301.11093) model is a vision transformer (ViT) based UNet. This model incorporates elements from ViT (considers all inputs such as time, conditions and noisy image patches as tokens) and a UNet (long skip connections between the shallow and deep layers). The skip connection is important for predicting pixel-level features. An additional 3x3 convolutional block is applied prior to the final output to improve image quality. + +The abstract from the paper is: + +*Currently, applying diffusion models in pixel space of high resolution images is difficult. Instead, existing approaches focus on diffusion in lower dimensional spaces (latent diffusion), or have multiple super-resolution levels of generation referred to as cascades. The downside is that these approaches add additional complexity to the diffusion framework. This paper aims to improve denoising diffusion for high resolution images while keeping the model as simple as possible. The paper is centered around the research question: How can one train a standard denoising diffusion models on high resolution images, and still obtain performance comparable to these alternate approaches? The four main findings are: 1) the noise schedule should be adjusted for high resolution images, 2) It is sufficient to scale only a particular part of the architecture, 3) dropout should be added at specific locations in the architecture, and 4) downsampling is an effective strategy to avoid high resolution feature maps. Combining these simple yet effective techniques, we achieve state-of-the-art on image generation among diffusion models without sampling modifiers on ImageNet.* + +## UVit2DModel + +[[autodoc]] UVit2DModel + +## UVit2DConvEmbed + +[[autodoc]] models.unets.uvit_2d.UVit2DConvEmbed + +## UVitBlock + +[[autodoc]] models.unets.uvit_2d.UVitBlock + +## ConvNextBlock + +[[autodoc]] models.unets.uvit_2d.ConvNextBlock + +## ConvMlmLayer + +[[autodoc]] models.unets.uvit_2d.ConvMlmLayer diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 4e1670df7717..e00446857fc5 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -25,6 +25,7 @@ The abstract of the paper is the following: | Pipeline | Tasks | Demo |---|---|:---:| | [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* | +| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* | ## Available checkpoints @@ -32,6 +33,8 @@ Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/gu ## Usage example +### AnimateDiffPipeline + AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet. The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5. @@ -98,6 +101,114 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you +### AnimateDiffVideoToVideoPipeline + +AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities. + +```python +import imageio +import requests +import torch +from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter +from diffusers.utils import export_to_gif +from io import BytesIO +from PIL import Image + +# Load the motion adapter +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) +# load SD 1.5 based finetuned model +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16) +scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + clip_sample=False, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, +) +pipe.scheduler = scheduler + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_model_cpu_offload() + +# helper function to load videos +def load_video(file_path: str): + images = [] + + if file_path.startswith(('http://', 'https://')): + # If the file_path is a URL + response = requests.get(file_path) + response.raise_for_status() + content = BytesIO(response.content) + vid = imageio.get_reader(content) + else: + # Assuming it's a local file path + vid = imageio.get_reader(file_path) + + for frame in vid: + pil_image = Image.fromarray(frame) + images.append(pil_image) + + return images + +video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif") + +output = pipe( + video = video, + prompt="panda playing a guitar, on a boat, in the ocean, high quality", + negative_prompt="bad quality, worse quality", + guidance_scale=7.5, + num_inference_steps=25, + strength=0.5, + generator=torch.Generator("cpu").manual_seed(42), +) +frames = output.frames[0] +export_to_gif(frames, "animation.gif") +``` + +Here are some sample outputs: + + + + + + + + + + + + + + +
Source VideoOutput Video
+ raccoon playing a guitar +
+ racoon playing a guitar +
+ panda playing a guitar +
+ panda playing a guitar +
+ closeup of margot robbie, fireworks in the background, high quality +
+ closeup of margot robbie, fireworks in the background, high quality +
+ closeup of tony stark, robert downey jr, fireworks +
+ closeup of tony stark, robert downey jr, fireworks +
+ ## Using Motion LoRAs Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations. diff --git a/docs/source/ja/_toctree.yml b/docs/source/ja/_toctree.yml index 7609e42b150c..000809baf65e 100644 --- a/docs/source/ja/_toctree.yml +++ b/docs/source/ja/_toctree.yml @@ -11,4 +11,6 @@ - sections: - local: tutorials/tutorial_overview title: 概要 + - local: tutorials/autopipeline + title: AutoPipeline title: チュートリアル \ No newline at end of file diff --git a/docs/source/ja/tutorials/autopipeline.md b/docs/source/ja/tutorials/autopipeline.md new file mode 100644 index 000000000000..d8d861cbdd96 --- /dev/null +++ b/docs/source/ja/tutorials/autopipeline.md @@ -0,0 +1,168 @@ + + +# AutoPipeline + +Diffusersは様々なタスクをこなすことができ、テキストから画像、画像から画像、画像の修復など、複数のタスクに対して同じように事前学習された重みを再利用することができます。しかし、ライブラリや拡散モデルに慣れていない場合、どのタスクにどのパイプラインを使えばいいのかがわかりにくいかもしれません。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントをテキストから画像に変換するために使用している場合、それぞれ[`StableDiffusionImg2ImgPipeline`]クラスと[`StableDiffusionInpaintPipeline`]クラスでチェックポイントをロードすることで、画像から画像や画像の修復にも使えることを知らない可能性もあります。 + +`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。 + + + +どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。 + + + +このチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。 + +## タスクに合わせてAutoPipeline を選択する +まずはチェックポイントを選ぶことから始めましょう。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントでテキストから画像への変換したいなら、[`AutoPipelineForText2Image`]を使います: + +```py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True +).to("cuda") +prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune" + +image = pipeline(prompt, num_inference_steps=25).images[0] +image +``` + +
+ generated image of peasant fighting dragon in wood cutting style +
+ +[`AutoPipelineForText2Image`] を具体的に見ていきましょう: + +1. [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) ファイルから `"stable-diffusion"` クラスを自動的に検出します。 +2. `"stable-diffusion"` のクラス名に基づいて、テキストから画像へ変換する [`StableDiffusionPipeline`] を読み込みます。 + +同様に、画像から画像へ変換する場合、[`AutoPipelineForImage2Image`] は `model_index.json` ファイルから `"stable-diffusion"` チェックポイントを検出し、対応する [`StableDiffusionImg2ImgPipeline`] を読み込みます。また、入力画像にノイズの量やバリエーションの追加を決めるための強さなど、パイプラインクラスに固有の追加引数を渡すこともできます: + +```py +from diffusers import AutoPipelineForImage2Image +import torch +import requests +from PIL import Image +from io import BytesIO + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, + use_safetensors=True, +).to("cuda") +prompt = "a portrait of a dog wearing a pearl earring" + +url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg" + +response = requests.get(url) +image = Image.open(BytesIO(response.content)).convert("RGB") +image.thumbnail((768, 768)) + +image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0] +image +``` + +
+ generated image of a vermeer portrait of a dog wearing a pearl earring +
+ +また、画像の修復を行いたい場合は、 [`AutoPipelineForInpainting`] が、同様にベースとなる[`StableDiffusionInpaintPipeline`]クラスを読み込みます: + +```py +from diffusers import AutoPipelineForInpainting +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForInpainting.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True +).to("cuda") + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = load_image(img_url).convert("RGB") +mask_image = load_image(mask_url).convert("RGB") + +prompt = "A majestic tiger sitting on a bench" +image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0] +image +``` + +
+ generated image of a tiger sitting on a bench +
+ +サポートされていないチェックポイントを読み込もうとすると、エラーになります: + +```py +from diffusers import AutoPipelineForImage2Image +import torch + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True +) +"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None" +``` + +## 複数のパイプラインを使用する + +いくつかのワークフローや多くのパイプラインを読み込む場合、不要なメモリを使ってしまう再読み込みをするよりも、チェックポイントから同じコンポーネントを再利用する方がメモリ効率が良いです。たとえば、テキストから画像への変換にチェックポイントを使い、画像から画像への変換にまたチェックポイントを使いたい場合、[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドを使用します。このメソッドは、以前読み込まれたパイプラインのコンポーネントを使うことで追加のメモリを消費することなく、新しいパイプラインを作成します。 + +[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドは、元のパイプラインクラスを検出し、実行したいタスクに対応する新しいパイプラインクラスにマッピングします。例えば、テキストから画像への`"stable-diffusion"` クラスのパイプラインを読み込む場合: + +```py +from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image +import torch + +pipeline_text2img = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True +) +print(type(pipeline_text2img)) +"" +``` + +そして、[from_pipe()] (https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe)は、もとの`"stable-diffusion"` パイプラインのクラスである [`StableDiffusionImg2ImgPipeline`] にマップします: + +```py +pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img) +print(type(pipeline_img2img)) +"" +``` +元のパイプラインにオプションとして引数(セーフティチェッカーの無効化など)を渡した場合、この引数も新しいパイプラインに渡されます: + +```py +from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image +import torch + +pipeline_text2img = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, + use_safetensors=True, + requires_safety_checker=False, +).to("cuda") + +pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img) +print(pipeline_img2img.config.requires_safety_checker) +"False" +``` + +新しいパイプラインの動作を変更したい場合は、元のパイプラインの引数や設定を上書きすることができます。例えば、セーフティチェッカーをオンに戻し、`strength` 引数を追加します: + +```py +pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3) +print(pipeline_img2img.config.requires_safety_checker) +"True" +``` diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py new file mode 100644 index 000000000000..385144b133a6 --- /dev/null +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -0,0 +1,1956 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import re +import shutil +import warnings +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + +# imports of the TokenEmbeddingsHandler class +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import load_file, save_file +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import ( + check_min_version, + convert_all_state_dict_to_peft, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + is_wandb_available, +) +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + train_text_encoder_ti=False, + token_abstraction_dict=None, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, +): + img_str = "widget:\n" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + output: + url: + "image_{i}.png" + """ + if not images: + img_str += f""" + - text: '{instance_prompt}' + """ + embeddings_filename = f"{repo_folder}_emb" + instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1)) + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + if instance_prompt_webui != embeddings_filename: + instance_prompt_sentence = f"For example, `{instance_prompt_webui}`" + else: + instance_prompt_sentence = "" + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" + webui_example_pivotal = "" + if train_text_encoder_ti: + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + """ + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") +state_dict = load_file(embedding_path) +pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) +pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) + """ + webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**. + - Place it on it on your `embeddings` folder + - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence} + (you need both the LoRA and the embeddings as they were trained together for this LoRA) + """ + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" +to trigger concept `{key}` → use `{tokens}` in your prompt \n +""" + + yaml = f"""--- +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- lora +- template:sd-lora +{img_str} +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ +--- +""" + + model_card = f""" +# SD1.5 LoRA DreamBooth - {repo_id} + + + +## Model description + +### These are {repo_id} LoRA adaption weights for {base_model}. + +## Download model + +### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke + +- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**. + - Place it on your `models/Lora` folder. + - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/). +{webui_example_pivotal} + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## Trigger words + +{trigger_str} + +## Details +All [Files & versions](/{repo_id}/tree/main). + +The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py). + +LoRA for the text encoder was enabled. {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + +Special VAE used for training: {vae_path}. + +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a " + "datasets ImageFolder, containing both the images and the corresponding caption for each image. see: " + "https://huggingface.co/docs/datasets/image_dataset for more information" + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example " + "if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as " + "None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help="A path to local folder containing the training data of instance images. Specify this arg instead of " + "--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify " + "--dataset_name instead.", + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--token_abstraction", + type=str, + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " + "'TOK,TOK2,TOK3' etc.", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + type=int, + default=2, + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) + + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform text encoder tuning"), + ) + + parser.add_argument( + "--optimizer", + type=str, + default="adamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + + print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") + + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + inu = torch.ones((len(tokenizer),), dtype=torch.bool) + inu[self.train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = inu + + print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 + idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} + for idx, text_encoder in enumerate(self.text_encoders): + assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( + self.tokenizers[0] + ), "Tokenizers should be the same." + new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for + # text_encoder 1) to keep compatible with the ecosystem. + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + dataset_name, + dataset_config_name, + cache_dir, + image_column, + caption_column, + train_text_encoder_ti, + class_data_root=None, + class_num=None, + token_abstraction_dict=None, # token mapping for textual inversion + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + self.token_abstraction_dict = token_abstraction_dict + self.train_text_encoder_ti = train_text_encoder_ti + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + dataset_name, + dataset_config_name, + cache_dir=cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images[index % self.num_instance_images] + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + if self.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, add_special_tokens=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + return prompt_embeds[0] + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None + if args.push_to_hub: + repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + variant=args.variant, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + vae_scaling_factor = vae.config.scaling_factor + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(token_abstraction_list): + token_abstraction_dict[token] = [ + f"" for j in range(args.num_new_tokens_per_abstraction) + ] + token_idx += args.num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + embedding_handler = TokenEmbeddingsHandler([text_encoder_one], [tokenizer_one]) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + + # The VAE is always in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + if args.train_text_encoder: + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + StableDiffusionPipeline.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors") + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} + LoraLoaderMixin.load_lora_into_text_encoder( + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if not freeze_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + ] + else: + params_to_optimize = [unet_lora_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + cache_dir=args.cache_dir, + image_column=args.image_column, + train_text_encoder_ti=args.train_text_encoder_ti, + caption_column=args.caption_column, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if not args.train_text_encoder: + tokenizers = [tokenizer_one] + text_encoders = [text_encoder_one] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + return prompt_embeds + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if freeze_text_encoder: + class_prompt_hidden_states = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + + # Clear the memory here + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + gc.collect() + torch.cuda.empty_cache() + + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + add_special_tokens = True if args.train_text_encoder_ti else False + + if not train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = instance_prompt_hidden_states + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + + if args.train_text_encoder_ti and args.validation_prompt: + # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc. + for token_abs, token_replacement in train_dataset.token_abstraction_dict.items(): + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + print("validation prompt:", args.validation_prompt) + + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=torch.float32 + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if not freeze_text_encoder: + unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora-sd-15", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + + for epoch in range(first_epoch, args.num_train_epochs): + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + print("PIVOT HALFWAY", epoch) + # stopping optimization of text_encoder params + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + + else: + # still optimizng the text encoder + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + if args.train_text_encoder: + text_encoder_one.text_model.embeddings.requires_grad_(True) + + unet.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + prompts = batch["prompts"] + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) + + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens) + + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + + model_input = model_input * vae_scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + + else: + elems_to_repeat_text_embeds = 1 + + # Predict the noise residual + if freeze_text_encoder: + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample + else: + prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one], + ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + + if args.with_prior_preservation: + # if we're using prior preservation, we calc snr for instance loss only - + # and hence only need timesteps corresponding to instance images + snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0) + else: + snr_timesteps = timesteps + + snr = compute_snr(noise_scheduler, snr_timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one) + if (args.train_text_encoder or args.train_text_encoder_ti) + else unet_lora_parameters + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + for idx, text_encoder in enumerate(text_encoders): + embedding_handler.retract_embeddings() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + if freeze_text_encoder: + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, + ) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + pipeline_args = {"prompt": args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + if args.train_text_encoder: + text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + ) + else: + text_encoder_lora_layers = None + + StableDiffusionPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/{args.output_dir}_emb.safetensors", + ) + + # Conver to WebUI format + lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") + peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) + kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) + save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") + + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/community/README.md b/examples/community/README.md index c8e45b504fae..6204c0359380 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -55,12 +55,13 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) | | AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) | | DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) | +| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) | | Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) | | Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) | | StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | | AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | - -| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | + IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | + InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -3228,6 +3229,43 @@ output_image.save("./output.png") ``` +### Instaflow Pipeline +InstaFlow is an ultra-fast, one-step image generator that achieves image quality close to Stable Diffusion, significantly reducing the demand of computational resources. This efficiency is made possible through a recent [Rectified Flow](https://github.com/gnobitab/RectifiedFlow) technique, which trains probability flows with straight trajectories, hence inherently requiring only a single step for fast inference. + +```python +from diffusers import DiffusionPipeline +import torch + + +pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step") +pipe.to("cuda") ### if GPU is not available, comment this line +prompt = "A hyper-realistic photo of a cute cat." + +images = pipe(prompt=prompt, + num_inference_steps=1, + guidance_scale=0.0).images +images[0].save("./image.png") +``` +![image1](https://huggingface.co/datasets/ayushtues/instaflow_images/resolve/main/instaflow_cat.png) + +You can also combine it with LORA out of the box, like https://huggingface.co/artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5, to unlock cool use cases in single step! + +```python +from diffusers import DiffusionPipeline +import torch + + +pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step") +pipe.to("cuda") ### if GPU is not available, comment this line +pipe.load_lora_weights("artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5") +prompt = "logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ," +images = pipe(prompt=prompt, + num_inference_steps=1, + guidance_scale=0.0).images +images[0].save("./image.png") +``` +![image0](https://huggingface.co/datasets/ayushtues/instaflow_images/resolve/main/instaflow_logo.png) + ### Null-Text Inversion pipeline This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline. @@ -3267,6 +3305,7 @@ pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_step ``` ### Rerender_A_Video +``` This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender_A_Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`: ```py @@ -3494,3 +3533,73 @@ images = pipeline( for i in range(num_images): images[i].save(f"c{i}.png") ``` + +### InstantID Pipeline + +InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID). + +```py +# !pip install opencv-python transformers accelerate insightface +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel + +import cv2 +import torch +import numpy as np +from PIL import Image + +from insightface.app import FaceAnalysis +from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + +# prepare 'antelopev2' under ./models +# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304 +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) + +# prepare models under ./checkpoints +# https://huggingface.co/InstantX/InstantID +from huggingface_hub import hf_hub_download +hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints") +hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints") +hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints") + +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# load IdentityNet +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + +base_model = 'wangqixun/YamerMIX_v8' +pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + base_model, + controlnet=controlnet, + torch_dtype=torch.float16 +) +pipe.cuda() + +# load adapter +pipe.load_ip_adapter_instantid(face_adapter) + +# load an image +face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +# prepare face emb +face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) +face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face +face_emb = face_info['embedding'] +face_kps = draw_kps(face_image, face_info['kps']) + +# prompt +prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic" +negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful" + +# generate image +pipe.set_ip_adapter_scale(0.8) +image = pipe( + prompt, + image_embeds=face_emb, + image=face_kps, + controlnet_conditioning_scale=0.8, +).images[0] +``` diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py new file mode 100644 index 000000000000..5fe285c5675f --- /dev/null +++ b/examples/community/instaflow_one_step.py @@ -0,0 +1,707 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + logging, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class InstaFlowPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Rectified Flow and Euler discretization. + This customized pipeline is based on StableDiffusionPipeline from the official Diffusers library (0.21.4) + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def merge_dW_to_unet(pipe, dW_dict, alpha=1.0): + _tmp_sd = pipe.unet.state_dict() + for key in dW_dict.keys(): + _tmp_sd[key] += dW_dict[key] * alpha + pipe.unet.load_state_dict(_tmp_sd, strict=False) + return pipe + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps = [(1.0 - i / num_inference_steps) * 1000.0 for i in range(num_inference_steps)] + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + dt = 1.0 / num_inference_steps + + # 7. Denoising loop of Euler discretization from t = 0 to t = 1 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t + + v_pred = self.unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + v_pred_neg, v_pred_text = v_pred.chunk(2) + v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg) + + latents = latents + dt * v_pred + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index cf0c66bb50d0..b700a6c86b93 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -26,7 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter +from diffusers.models.unets.unet_motion_model import MotionAdapter from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( diff --git a/examples/community/pipeline_stable_diffusion_xl_instantid.py b/examples/community/pipeline_stable_diffusion_xl_instantid.py new file mode 100644 index 000000000000..bd03e3159dc3 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_xl_instantid.py @@ -0,0 +1,1058 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL.Image +import torch +import torch.nn as nn + +from diffusers import StableDiffusionXLControlNetPipeline +from diffusers.image_processor import PipelineImageInput +from diffusers.models import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.utils import ( + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version + + +try: + import xformers + import xformers.ops + + xformers_available = True +except Exception: + xformers_available = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + if xformers_available: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + if xformers_available: + ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) + else: + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + return hidden_states + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate insightface + >>> import diffusers + >>> from diffusers.utils import load_image + >>> from diffusers.models import ControlNetModel + + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from insightface.app import FaceAnalysis + >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + + >>> # download 'antelopev2' under ./models + >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + + >>> # download models under ./checkpoints + >>> face_adapter = f'./checkpoints/ip-adapter.bin' + >>> controlnet_path = f'./checkpoints/ControlNetModel' + + >>> # load IdentityNet + >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.cuda() + + >>> # load adapter + >>> pipe.load_ip_adapter_instantid(face_adapter) + + >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + >>> # load an image + >>> image = load_image("your-example.jpg") + + >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] + >>> face_emb = face_info['embedding'] + >>> face_kps = draw_kps(face_image, face_info['kps']) + + >>> pipe.set_ip_adapter_scale(0.8) + + >>> # generate image + >>> image = pipe( + ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 + ... ).images[0] + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): + def cuda(self, dtype=torch.float16, use_xformers=False): + self.to("cuda", dtype) + + if hasattr(self, "image_proj_model"): + self.image_proj_model.to(self.unet.device).to(self.unet.dtype) + + if use_xformers: + if is_xformers_available(): + import xformers + from packaging import version + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): + self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) + self.set_ip_adapter(model_ckpt, num_tokens, scale) + + def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.unet.config.cross_attention_dim, + ff_mult=4, + ) + + image_proj_model.eval() + + self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) + state_dict = torch.load(model_ckpt, map_location="cpu") + if "image_proj" in state_dict: + state_dict = state_dict["image_proj"] + self.image_proj_model.load_state_dict(state_dict) + + self.image_proj_model_in_features = image_emb_dim + + def set_ip_adapter(self, model_ckpt, num_tokens, scale): + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=scale, + num_tokens=num_tokens, + ).to(unet.device, dtype=unet.dtype) + unet.set_attn_processor(attn_procs) + + state_dict = torch.load(model_ckpt, map_location="cpu") + ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + if "ip_adapter" in state_dict: + state_dict = state_dict["ip_adapter"] + ip_layers.load_state_dict(state_dict) + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance): + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) + prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) + + if do_classifier_free_guidance: + prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) + else: + prompt_image_emb = torch.cat([prompt_image_emb], dim=0) + + prompt_image_emb = self.image_proj_model(prompt_image_emb) + return prompt_image_emb + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode image prompt + prompt_image_emb = self._encode_prompt_image_emb( + image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=prompt_image_emb, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index 358fc1c6dc67..16f7f589b70b 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -8,7 +8,7 @@ from diffusers import StableDiffusionControlNetPipeline from diffusers.models import ControlNetModel from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 505470574a0b..88a7febae650 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -7,7 +7,7 @@ from diffusers import StableDiffusionPipeline from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.utils import PIL_INTERPOLATION, logging diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 5d2b1c771128..fbfb6bdd6160 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -8,7 +8,7 @@ from diffusers import StableDiffusionXLPipeline from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py index 20c8d0fdf0f1..027a853764f8 100644 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ b/examples/research_projects/controlnetxs/controlnetxs.py @@ -26,7 +26,7 @@ from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.lora import LoRACompatibleConv from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, @@ -36,7 +36,7 @@ UpBlock2D, Upsample2D, ) -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput, logging diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py index be888d7e1145..ed45b3bb5a1b 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -1041,11 +1041,6 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py index fdddbef7cd65..21be29dfdb99 100644 --- a/scripts/convert_amused.py +++ b/scripts/convert_amused.py @@ -10,7 +10,7 @@ from diffusers import VQModel from diffusers.models.attention_processor import AttnProcessor -from diffusers.models.uvit_2d import UVit2DModel +from diffusers.models.unets.uvit_2d import UVit2DModel from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline from diffusers.schedulers import AmusedScheduler diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py index 3319f4c4665e..0cb5fc50dd60 100644 --- a/scripts/convert_consistency_decoder.py +++ b/scripts/convert_consistency_decoder.py @@ -14,7 +14,7 @@ from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel from diffusers.models.autoencoders.vae import Encoder from diffusers.models.embeddings import TimestepEmbedding -from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D +from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D args = ArgumentParser() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5527c0116b14..931ccc8a4a9d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -208,6 +208,7 @@ "AmusedInpaintPipeline", "AmusedPipeline", "AnimateDiffPipeline", + "AnimateDiffVideoToVideoPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", @@ -382,7 +383,7 @@ else: _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( @@ -569,6 +570,7 @@ AmusedInpaintPipeline, AmusedPipeline, AnimateDiffPipeline, + AnimateDiffVideoToVideoPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, @@ -711,7 +713,7 @@ else: from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py index af5ee2102163..da78f3b55605 100644 --- a/src/diffusers/experimental/rl/value_guided_sampling.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -16,7 +16,7 @@ import torch import tqdm -from ...models.unet_1d import UNet1DModel +from ...models.unets.unet_1d import UNet1DModel from ...pipelines import DiffusionPipeline from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.torch_utils import randn_tensor diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index d7855206a287..4da047435d8e 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -54,12 +54,13 @@ def text_encoder_attn_modules(text_encoder): _import_structure = {} if is_torch_available(): - _import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"] + _import_structure["autoencoder"] = ["FromOriginalVAEMixin"] + + _import_structure["controlnet"] = ["FromOriginalControlNetMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] - if is_transformers_available(): - _import_structure["single_file"].extend(["FromSingleFileMixin"]) + _import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -69,7 +70,8 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin + from .autoencoder import FromOriginalVAEMixin + from .controlnet import FromOriginalControlNetMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/autoencoder.py b/src/diffusers/loaders/autoencoder.py new file mode 100644 index 000000000000..6e65bd1c0070 --- /dev/null +++ b/src/diffusers/loaders/autoencoder.py @@ -0,0 +1,126 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_vae_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalVAEMixin: + """ + Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading + a VAE from SDXL or a Stable Diffusion v2 model or higher. + + + + Examples: + + ```py + from diffusers import AutoencoderKL + + url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file + model = AutoencoderKL.from_single_file(url) + ``` + """ + + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + image_size = kwargs.pop("image_size", None) + component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size) + vae = component["vae"] + if torch_dtype is not None: + vae = vae.to(torch_dtype) + + return vae diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py new file mode 100644 index 000000000000..527a77109aae --- /dev/null +++ b/src/diffusers/loaders/controlnet.py @@ -0,0 +1,127 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_controlnet_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalControlNetMixin: + """ + Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + upcast_attention = kwargs.pop("upcast_attention", False) + image_size = kwargs.pop("image_size", None) + + component = create_diffusers_controlnet_model_from_ldm( + class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size + ) + controlnet = component["controlnet"] + if torch_dtype is not None: + controlnet = controlnet.to(torch_dtype) + + return controlnet diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 4086b1a2a8e8..034271aaba33 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -11,26 +11,125 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from io import BytesIO -from pathlib import Path - -import requests -import torch -import yaml -from huggingface_hub import hf_hub_download + from huggingface_hub.utils import validate_hf_hub_args -from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging +from ..utils import is_transformers_available, logging +from .single_file_utils import ( + create_diffusers_unet_model_from_ldm, + create_diffusers_vae_model_from_ldm, + create_scheduler_from_ldm, + create_text_encoders_and_tokenizers_from_ldm, + fetch_ldm_config_and_checkpoint, + infer_model_type, +) + + +logger = logging.get_logger(__name__) +# Pipelines that support the SDXL Refiner checkpoint +REFINER_PIPELINES = [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", +] if is_transformers_available(): - pass + from transformers import AutoFeatureExtractor + + +def build_sub_model_components( + pipeline_components, + pipeline_class_name, + component_name, + original_config, + checkpoint, + local_files_only=False, + load_safety_checker=False, + model_type=None, + image_size=None, + **kwargs, +): + if component_name in pipeline_components: + return {} + + if component_name == "unet": + num_in_channels = kwargs.pop("num_in_channels", None) + unet_components = create_diffusers_unet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size + ) + return unet_components -if is_accelerate_available(): - from accelerate import init_empty_weights + if component_name == "vae": + vae_components = create_diffusers_vae_model_from_ldm( + pipeline_class_name, original_config, checkpoint, image_size + ) + return vae_components -logger = logging.get_logger(__name__) + if component_name == "scheduler": + scheduler_type = kwargs.get("scheduler_type", "ddim") + prediction_type = kwargs.get("prediction_type", None) + + scheduler_components = create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + scheduler_type=scheduler_type, + prediction_type=prediction_type, + model_type=model_type, + ) + + return scheduler_components + + if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: + text_encoder_components = create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=model_type, + local_files_only=local_files_only, + ) + return text_encoder_components + + if component_name == "safety_checker": + if load_safety_checker: + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + else: + safety_checker = None + return {"safety_checker": safety_checker} + + if component_name == "feature_extractor": + if load_safety_checker: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + else: + feature_extractor = None + return {"feature_extractor": feature_extractor} + + return + + +def set_additional_components( + pipeline_class_name, + original_config, + model_type=None, +): + components = {} + if pipeline_class_name in REFINER_PIPELINES: + model_type = infer_model_type(original_config, model_type=model_type) + is_refiner = model_type == "SDXL-Refiner" + components.update( + { + "requires_aesthetics_score": is_refiner, + "force_zeros_for_empty_prompt": False if is_refiner else True, + } + ) + + return components class FromSingleFileMixin: @@ -38,12 +137,6 @@ class FromSingleFileMixin: Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ - @classmethod - def from_ckpt(cls, *args, **kwargs): - deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." - deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) - return cls.from_single_file(*args, **kwargs) - @classmethod @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path, **kwargs): @@ -58,8 +151,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -85,42 +177,6 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. - extract_ema (`bool`, *optional*, defaults to `False`): - Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield - higher quality images for inference. Non-EMA weights are usually better for continuing finetuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and - the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to `None`): - The number of input channels. If `None`, it is automatically inferred. - scheduler_type (`str`, *optional*, defaults to `"pndm"`): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. - text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`): - An instance of `CLIPTextModel` to use, specifically the - [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this - parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed. - vae (`AutoencoderKL`, *optional*, defaults to `None`): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If - this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`): - An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance - of `CLIPTokenizer` by itself if needed. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be - automatically inferred by looking for a key that only exists in SD2.0 models. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - Examples: ```py @@ -143,484 +199,80 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): >>> pipeline.to("cuda") ``` """ - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt - original_config_file = kwargs.pop("original_config_file", None) - config_files = kwargs.pop("config_files", None) - cache_dir = kwargs.pop("cache_dir", None) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - extract_ema = kwargs.pop("extract_ema", False) - image_size = kwargs.pop("image_size", None) - scheduler_type = kwargs.pop("scheduler_type", "pndm") - num_in_channels = kwargs.pop("num_in_channels", None) - upcast_attention = kwargs.pop("upcast_attention", None) - load_safety_checker = kwargs.pop("load_safety_checker", True) - prediction_type = kwargs.pop("prediction_type", None) - text_encoder = kwargs.pop("text_encoder", None) - text_encoder_2 = kwargs.pop("text_encoder_2", None) - vae = kwargs.pop("vae", None) - controlnet = kwargs.pop("controlnet", None) - adapter = kwargs.pop("adapter", None) - tokenizer = kwargs.pop("tokenizer", None) - tokenizer_2 = kwargs.pop("tokenizer_2", None) - torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) - use_safetensors = kwargs.pop("use_safetensors", None) - - pipeline_name = cls.__name__ - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # TODO: For now we only support stable diffusion - stable_unclip = None - model_type = None - - if pipeline_name in [ - "StableDiffusionControlNetPipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - ]: - from ..models.controlnet import ControlNetModel - from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel - - # list/tuple or a single instance of ControlNetModel or MultiControlNetModel - if not ( - isinstance(controlnet, (ControlNetModel, MultiControlNetModel)) - or isinstance(controlnet, (list, tuple)) - and isinstance(controlnet[0], ControlNetModel) - ): - raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.") - elif "StableDiffusion" in pipeline_name: - # Model type will be inferred from the checkpoint. - pass - elif pipeline_name == "StableUnCLIPPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "txt2img" - elif pipeline_name == "StableUnCLIPImg2ImgPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "img2img" - elif pipeline_name == "PaintByExamplePipeline": - model_type = "PaintByExample" - elif pipeline_name == "LDMTextToImagePipeline": - model_type = "LDMTextToImage" - else: - raise ValueError(f"Unhandled pipeline class: {pipeline_name}") - - # remove huggingface url - has_valid_url_prefix = False - valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] - for prefix in valid_url_prefixes: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - has_valid_url_prefix = True - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - if not has_valid_url_prefix: - raise ValueError( - f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}" - ) - - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) + class_name = cls.__name__ - pipe = download_from_original_stable_diffusion_ckpt( - pretrained_model_link_or_path, - pipeline_class=cls, - model_type=model_type, - stable_unclip=stable_unclip, - controlnet=controlnet, - adapter=adapter, - from_safetensors=from_safetensors, - extract_ema=extract_ema, - image_size=image_size, - scheduler_type=scheduler_type, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - load_safety_checker=load_safety_checker, - prediction_type=prediction_type, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - vae=vae, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, original_config_file=original_config_file, - config_files=config_files, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, ) - if torch_dtype is not None: - pipe.to(dtype=torch_dtype) - - return pipe - - -class FromOriginalVAEMixin: - """ - Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the safetensors weights are downloaded if they're available **and** if the - safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors - weights. If set to `False`, safetensors weights are not loaded. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z - = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution - Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - - - Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading - a VAE from SDXL or a Stable Diffusion v2 model or higher. - - - - Examples: - - ```py - from diffusers import AutoencoderKL - - url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file - model = AutoencoderKL.from_single_file(url) - ``` - """ - from ..models import AutoencoderKL + from ..pipelines.pipeline_utils import _get_pipeline_class - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import ( - convert_ldm_vae_checkpoint, - create_vae_diffusers_config, + pipeline_class = _get_pipeline_class( + cls, + config=None, + cache_dir=cache_dir, ) - config_file = kwargs.pop("config_file", None) - cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - image_size = kwargs.pop("image_size", None) - scaling_factor = kwargs.pop("scaling_factor", None) - kwargs.pop("upcast_attention", None) - - torch_dtype = kwargs.pop("torch_dtype", None) - - use_safetensors = kwargs.pop("use_safetensors", None) - - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # remove huggingface url - for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) - - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu") - - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - if config_file is None: - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - config_file = BytesIO(requests.get(config_url).content) - - original_config = yaml.safe_load(config_file) - - # default to sd-v1-5 - image_size = image_size or 512 - - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if scaling_factor is None: - if ( - "model" in original_config - and "params" in original_config["model"] - and "scale_factor" in original_config["model"]["params"] - ): - vae_scaling_factor = original_config["model"]["params"]["scale_factor"] - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - from ..models.modeling_utils import load_model_dict_into_meta - - load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu") - else: - vae.load_state_dict(converted_vae_checkpoint) - - if torch_dtype is not None: - vae.to(dtype=torch_dtype) - - return vae - - -class FromOriginalControlnetMixin: - """ - Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the safetensors weights are downloaded if they're available **and** if the - safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors - weights. If set to `False`, safetensors weights are not loaded. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - Examples: - - ```py - from diffusers import StableDiffusionControlNetPipeline, ControlNetModel - - url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path - model = ControlNetModel.from_single_file(url) - - url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path - pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) - ``` - """ - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - config_file = kwargs.pop("config_file", None) - cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - num_in_channels = kwargs.pop("num_in_channels", None) - use_linear_projection = kwargs.pop("use_linear_projection", None) - revision = kwargs.pop("revision", None) - extract_ema = kwargs.pop("extract_ema", False) + model_type = kwargs.pop("model_type", None) image_size = kwargs.pop("image_size", None) - upcast_attention = kwargs.pop("upcast_attention", None) + load_safety_checker = (kwargs.pop("load_safety_checker", False)) or ( + passed_class_obj.get("safety_checker", None) is not None + ) - torch_dtype = kwargs.pop("torch_dtype", None) + init_kwargs = {} + for name in expected_modules: + if name in passed_class_obj: + init_kwargs[name] = passed_class_obj[name] + else: + components = build_sub_model_components( + init_kwargs, + class_name, + name, + original_config, + checkpoint, + model_type=model_type, + image_size=image_size, + load_safety_checker=load_safety_checker, + local_files_only=local_files_only, + **kwargs, + ) + if not components: + continue + init_kwargs.update(components) - use_safetensors = kwargs.pop("use_safetensors", None) - - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # remove huggingface url - for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) + additional_components = set_additional_components(class_name, original_config, model_type=model_type) + if additional_components: + init_kwargs.update(additional_components) - if config_file is None: - config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml" - config_file = BytesIO(requests.get(config_url).content) - - image_size = image_size or 512 - - controlnet = download_controlnet_from_original_ckpt( - pretrained_model_link_or_path, - original_config_file=config_file, - image_size=image_size, - extract_ema=extract_ema, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - from_safetensors=from_safetensors, - use_linear_projection=use_linear_projection, - ) + init_kwargs.update(passed_pipe_kwargs) + pipe = pipeline_class(**init_kwargs) if torch_dtype is not None: - controlnet.to(dtype=torch_dtype) + pipe.to(dtype=torch_dtype) - return controlnet + return pipe diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py new file mode 100644 index 000000000000..fb1ad14fd3e2 --- /dev/null +++ b/src/diffusers/loaders/single_file_utils.py @@ -0,0 +1,1387 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import os +import re +from contextlib import nullcontext +from io import BytesIO +from urllib.parse import urlparse + +import requests +import yaml + +from ..models.modeling_utils import load_state_dict +from ..schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ..utils import is_accelerate_available, is_transformers_available, logging +from ..utils.hub_utils import _get_model_file + + +if is_transformers_available(): + from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CONFIG_URLS = { + "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", + "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml", + "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml", + "xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml", + "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml", + "controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml", +} + +CHECKPOINT_KEY_NAMES = { + "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", + "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", +} + +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + +DIFFUSERS_TO_LDM_MAPPING = { + "unet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "conv_norm_out.weight": "out.0.weight", + "conv_norm_out.bias": "out.0.bias", + "conv_out.weight": "out.2.weight", + "conv_out.bias": "out.2.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "controlnet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", + "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", + "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "vae": { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.norm_out.bias", + "quant_conv.weight": "quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + }, + "openclip": { + "layers": { + "text_model.embeddings.position_embedding.weight": "positional_embedding", + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.final_layer_norm.weight": "ln_final.weight", + "text_model.final_layer_norm.bias": "ln_final.bias", + "text_projection.weight": "text_projection", + }, + "transformer": { + "text_model.encoder.layers.": "resblocks.", + "layer_norm1": "ln_1", + "layer_norm2": "ln_2", + ".fc1.": ".c_fc.", + ".fc2.": ".c_proj.", + ".self_attn": ".attn", + "transformer.text_model.final_layer_norm.": "ln_final.", + "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", + }, + }, +} + +LDM_VAE_KEY = "first_stage_model." +LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." +LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 + +SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", + "cond_stage_model.model.text_projection", +] + + +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + + +def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" + weights_name = None + repo_id = (None,) + for prefix in VALID_URL_PREFIXES: + pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") + match = re.match(pattern, pretrained_model_name_or_path) + if not match: + return repo_id, weights_name + + repo_id = f"{match.group(1)}/{match.group(2)}" + weights_name = match.group(3) + + return repo_id, weights_name + + +def fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path, + class_name, + original_config_file=None, + resume_download=False, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, + use_safetensors=True, +): + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + if os.path.isfile(pretrained_model_link_or_path): + checkpoint = load_state_dict(pretrained_model_link_or_path) + + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + checkpoint_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + checkpoint = load_state_dict(checkpoint_path) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = fetch_original_config(class_name, checkpoint, original_config_file) + + return original_config, checkpoint + + +def infer_original_config_file(class_name, checkpoint): + if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + config_url = CONFIG_URLS["v2"] + + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + config_url = CONFIG_URLS["xl"] + + elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: + config_url = CONFIG_URLS["xl_refiner"] + + elif class_name == "StableDiffusionUpscalePipeline": + config_url = CONFIG_URLS["upscale"] + + elif class_name == "ControlNetModel": + config_url = CONFIG_URLS["controlnet"] + + else: + config_url = CONFIG_URLS["v1"] + + original_config_file = BytesIO(requests.get(config_url).content) + + return original_config_file + + +def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None): + def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + return False + + if original_config_file is None: + original_config_file = infer_original_config_file(pipeline_class_name, checkpoint) + + elif os.path.isfile(original_config_file): + with open(original_config_file, "r") as fp: + original_config_file = fp.read() + + elif is_valid_url(original_config_file): + original_config_file = BytesIO(requests.get(original_config_file).content) + + else: + raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") + + original_config = yaml.safe_load(original_config_file) + + return original_config + + +def infer_model_type(original_config, model_type=None): + if model_type is not None: + return model_type + + has_cond_stage_config = ( + "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ) + has_network_config = ( + "network_config" in original_config["model"]["params"] + and original_config["model"]["params"]["network_config"] is not None + ) + + if has_cond_stage_config: + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + + elif has_network_config: + context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"] + if context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + else: + raise ValueError("Unable to infer model type from config") + + logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}") + + return model_type + + +def get_default_scheduler_config(): + return SCHEDULER_DEFAULT_CONFIG + + +def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None): + if image_size: + return image_size + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + model_type = infer_model_type(original_config, model_type) + + if pipeline_class_name == "StableDiffusionUpscalePipeline": + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + return image_size + + elif model_type in ["SDXL", "SDXL-Refiner"]: + image_size = 1024 + return image_size + + elif ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + return image_size + + else: + image_size = 512 + return image_size + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_controlnet_diffusers_config(original_config, image_size: int): + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + + controlnet_config = { + "conditioning_channels": unet_params["hint_channels"], + "in_channels": diffusers_unet_config["in_channels"], + "down_block_types": diffusers_unet_config["down_block_types"], + "block_out_channels": diffusers_unet_config["block_out_channels"], + "layers_per_block": diffusers_unet_config["layers_per_block"], + "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], + "attention_head_dim": diffusers_unet_config["attention_head_dim"], + "use_linear_projection": diffusers_unet_config["use_linear_projection"], + "class_embed_type": diffusers_unet_config["class_embed_type"], + "addition_embed_type": diffusers_unet_config["addition_embed_type"], + "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], + "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], + "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], + } + + return controlnet_config + + +def create_vae_diffusers_config(original_config, image_size, scaling_factor=0.18125): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": scaling_factor, + } + + return config + + +def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): + for ldm_key in ldm_keys: + diffusers_key = ( + ldm_key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + if mapping: + diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): + for ldm_key in ldm_keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + unet_key = LDM_UNET_KEY + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning("Checkpoint has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] + for diffusers_key, ldm_key in ldm_unet_keys.items(): + if ldm_key not in unet_state_dict: + continue + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): + class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] + for diffusers_key, ldm_key in class_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): + addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] + for diffusers_key, ldm_key in addition_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # Mid blocks + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + update_unet_resnet_ldm_to_diffusers( + resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"} + ) + update_unet_resnet_ldm_to_diffusers( + resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"} + ) + update_unet_attention_ldm_to_diffusers( + attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"} + ) + + # Up Blocks + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + + resnets = [ + key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key + ] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.bias" + ] + if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.bias" + ] + + return new_checkpoint + + +def convert_controlnet_checkpoint( + checkpoint, + config, +): + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + if middle_blocks: + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + update_unet_resnet_ldm_to_diffusers( + resnet_0, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}, + ) + update_unet_resnet_ldm_to_diffusers( + resnet_1, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}, + ) + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.bias" + ) + + return new_checkpoint + + +def create_diffusers_controlnet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None +): + # import here to avoid circular imports + from ..models import ControlNetModel + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + + diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size) + diffusers_config["upcast_attention"] = upcast_attention + + diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**diffusers_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_controlnet_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(diffusers_format_controlnet_checkpoint) + + return {"controlnet": controlnet} + + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] + elif len(shape) == 4: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] + for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): + if ldm_key not in vae_state_dict: + continue + new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len(config["down_block_types"]) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, + ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len(config["up_block_types"]) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, + ) + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint + + +def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False): + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + text_model_dict = {} + text_proj_key = prefix + "text_projection" + text_proj_dim = ( + int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM + ) + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + keys = list(checkpoint.keys()) + keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE + + openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] + for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): + ldm_key = prefix + ldm_key + if ldm_key not in checkpoint: + continue + if ldm_key in keys_to_ignore: + continue + if ldm_key.endswith("text_projection"): + text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() + else: + text_model_dict[diffusers_key] = checkpoint[ldm_key] + + for key in keys: + if key in keys_to_ignore: + continue + + if not key.startswith(prefix + "transformer."): + continue + + diffusers_key = key.replace(prefix + "transformer.", "") + transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] + for new_key, old_key in transformer_diffusers_to_ldm_map.items(): + diffusers_key = ( + diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") + ) + + if key.endswith(".in_proj_weight"): + weight_value = checkpoint[key] + + text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :] + text_model_dict[diffusers_key + ".k_proj.weight"] = weight_value[text_proj_dim : text_proj_dim * 2, :] + text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :] + + elif key.endswith(".in_proj_bias"): + weight_value = checkpoint[key] + text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim] + text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2] + text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :] + + else: + text_model_dict[diffusers_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def create_diffusers_unet_model_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + num_in_channels=None, + upcast_attention=False, + extract_ema=False, + image_size=None, +): + from ..models import UNet2DConditionModel + + if num_in_channels is None: + if pipeline_class_name in [ + "StableDiffusionInpaintPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + ]: + num_in_channels = 9 + + elif pipeline_class_name == "StableDiffusionUpscalePipeline": + num_in_channels = 7 + + else: + num_in_channels = 4 + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["in_channels"] = num_in_channels + unet_config["upcast_attention"] = upcast_attention + + diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(diffusers_format_unet_checkpoint) + + return {"unet": unet} + + +def create_diffusers_vae_model_from_ldm( + pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125 +): + # import here to avoid circular imports + from ..models import AutoencoderKL + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + + vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor) + diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(diffusers_format_vae_checkpoint) + + return {"vae": vae} + + +def create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=None, + local_files_only=False, +): + model_type = infer_model_type(original_config, model_type=model_type) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + try: + text_encoder = create_text_encoder_from_open_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only, **config_kwargs + ) + tokenizer = CLIPTokenizer.from_pretrained( + config_name, subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "FrozenCLIPEmbedder": + try: + config_name = "openai/clip-vit-large-patch14" + text_encoder = create_text_encoder_from_ldm_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only + ) + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "SDXL-Refiner": + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." + + try: + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + else: + return { + "text_encoder": None, + "tokenizer": None, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + elif model_type == "SDXL": + try: + config_name = "openai/clip-vit-large-patch14" + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + text_encoder = create_text_encoder_from_ldm_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only + ) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + try: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.1.model." + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + return { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + return + + +def create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + prediction_type=None, + scheduler_type="ddim", + model_type=None, +): + scheduler_config = get_default_scheduler_config() + model_type = infer_model_type(original_config, model_type=model_type) + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 + scheduler_config["num_train_timesteps"] = num_train_timesteps + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + + else: + prediction_type = prediction_type or "epsilon" + + scheduler_config["prediction_type"] = prediction_type + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_type = "euler" + + else: + beta_start = original_config["model"]["params"].get("linear_start", 0.02) + beta_end = original_config["model"]["params"].get("linear_end", 0.085) + scheduler_config["beta_start"] = beta_start + scheduler_config["beta_end"] = beta_end + scheduler_config["beta_schedule"] = "scaled_linear" + scheduler_config["clip_sample"] = False + scheduler_config["set_alpha_to_one"] = False + + if scheduler_type == "pndm": + scheduler_config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(scheduler_config) + + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + + elif scheduler_type == "ddim": + scheduler = DDIMScheduler.from_config(scheduler_config) + + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class_name == "StableDiffusionUpscalePipeline": + scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler") + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + return { + "scheduler": scheduler, + "low_res_scheduler": low_res_scheduler, + } + + return {"scheduler": scheduler} diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 36dbe14c5053..02c94ddbf1de 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,19 +39,19 @@ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] _import_structure["transformer_temporal"] = ["TransformerTemporalModel"] - _import_structure["unet_1d"] = ["UNet1DModel"] - _import_structure["unet_2d"] = ["UNet2DModel"] - _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] - _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] - _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] - _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] - _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] - _import_structure["uvit_2d"] = ["UVit2DModel"] + _import_structure["unets.unet_1d"] = ["UNet1DModel"] + _import_structure["unets.unet_2d"] = ["UNet2DModel"] + _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] + _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] + _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] + _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["unets.uvit_2d"] = ["UVit2DModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] @@ -73,19 +73,22 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel - from .unet_1d import UNet1DModel - from .unet_2d import UNet2DModel - from .unet_2d_condition import UNet2DConditionModel - from .unet_3d_condition import UNet3DConditionModel - from .unet_kandinsky3 import Kandinsky3UNet - from .unet_motion_model import MotionAdapter, UNetMotionModel - from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel - from .uvit_2d import UVit2DModel + from .unets import ( + Kandinsky3UNet, + MotionAdapter, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + ) from .vq_model import VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .unets import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL else: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 10a3ae58de9f..a0b23b896d13 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -157,7 +157,7 @@ def disable_slicing(self): self.use_slicing = False @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -181,7 +181,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -216,7 +216,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -448,7 +448,7 @@ def forward( return DecoderOutput(sample=dec) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, @@ -472,7 +472,7 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index dbafb4571d4a..68d5a31e43c7 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -17,13 +17,12 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalVAEMixin from ...utils import is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -162,7 +161,7 @@ def custom_forward(*inputs): return sample -class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin): +class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -242,7 +241,7 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -266,7 +265,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index ca670fec4b28..0013521f4cbb 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -31,7 +31,7 @@ AttnProcessor, ) from ..modeling_utils import ModelMixin -from ..unet_2d import UNet2DModel +from ..unets.unet_2d import UNet2DModel from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -187,7 +187,7 @@ def disable_slicing(self): self.use_slicing = False @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -211,7 +211,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -246,7 +246,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 3f1643bc50ef..3c56f15117ba 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -22,7 +22,7 @@ from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm -from ..unet_2d_blocks import ( +from ..unets.unet_2d_blocks import ( AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 1102f4f9d36d..60ddc999e96f 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin +from ..loaders import FromOriginalControlNetMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -30,8 +30,14 @@ ) from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block -from .unet_2d_condition import UNet2DConditionModel +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from .unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -102,7 +108,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): """ A ControlNet model. @@ -509,7 +515,7 @@ def from_unet( return controlnet @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -533,7 +539,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -568,7 +574,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -584,7 +590,7 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 34aaac549f8c..1a140cfb94d3 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -23,7 +23,7 @@ from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin -from .unet_2d_blocks_flax import ( +from .unets.unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, @@ -329,14 +329,14 @@ def __call__( controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index 02568298409c..21b135c2eb86 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -120,7 +120,7 @@ def forward( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 445c3ca71caf..90a700d9443f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,6 +32,7 @@ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, + SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, @@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: - if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): - return torch.load(checkpoint_file, map_location="cpu") - else: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") + else: + return torch.load(checkpoint_file, map_location="cpu") except Exception as e: try: with open(checkpoint_file) as f: diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6b52ea344d41..081d66991faf 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -167,7 +167,7 @@ def __init__( self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -191,7 +191,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -226,7 +226,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 128395cc161a..3b219b4f0b37 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -286,7 +286,7 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 26e899a9b908..a18671776baf 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -149,7 +149,7 @@ def forward( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 5bb5b0818245..06ff51b17d0d 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -12,244 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union +from ..utils import deprecate +from .unets.unet_1d import UNet1DModel, UNet1DOutput -import torch -import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block +class UNet1DOutput(UNet1DOutput): + deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead." + deprecate("UNet1DOutput", "0.29", deprecation_message) -@dataclass -class UNet1DOutput(BaseOutput): - """ - The output of [`UNet1DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): - The hidden states output from the last layer of the model. - """ - - sample: torch.FloatTensor - - -class UNet1DModel(ModelMixin, ConfigMixin): - r""" - A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. - in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. - extra_in_channels (`int`, *optional*, defaults to 0): - Number of additional channels to be added to the input of the first down block. Useful for cases where the - input data has more channels than what the model was initially designed for. - time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip sin to cos for Fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): - Tuple of block output channels. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. - out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. - act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. - norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. - layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. - downsample_each_block (`int`, *optional*, defaults to `False`): - Experimental feature for using a UNet without upsampling. - """ - - @register_to_config - def __init__( - self, - sample_size: int = 65536, - sample_rate: Optional[int] = None, - in_channels: int = 2, - out_channels: int = 2, - extra_in_channels: int = 0, - time_embedding_type: str = "fourier", - flip_sin_to_cos: bool = True, - use_timestep_embedding: bool = False, - freq_shift: float = 0.0, - down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), - up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), - mid_block_type: Tuple[str] = "UNetMidBlock1D", - out_block_type: str = None, - block_out_channels: Tuple[int] = (32, 32, 64), - act_fn: str = None, - norm_num_groups: int = 8, - layers_per_block: int = 1, - downsample_each_block: bool = False, - ): - super().__init__() - self.sample_size = sample_size - - # time - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection( - embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift - ) - timestep_input_dim = block_out_channels[0] - - if use_timestep_embedding: - time_embed_dim = block_out_channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=timestep_input_dim, - time_embed_dim=time_embed_dim, - act_fn=act_fn, - out_dim=block_out_channels[0], - ) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - self.out_block = None - - # down - output_channel = in_channels - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - - if i == 0: - input_channel += extra_in_channels - - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_downsample=not is_final_block or downsample_each_block, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = get_mid_block( - mid_block_type, - in_channels=block_out_channels[-1], - mid_channels=block_out_channels[-1], - out_channels=block_out_channels[-1], - embed_dim=block_out_channels[0], - num_layers=layers_per_block, - add_downsample=downsample_each_block, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - if out_block_type is None: - final_upsample_channels = out_channels - else: - final_upsample_channels = block_out_channels[0] - - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = ( - reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels - ) - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block, - in_channels=prev_output_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_upsample=not is_final_block, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.out_block = get_out_block( - out_block_type=out_block_type, - num_groups_out=num_groups_out, - embed_dim=block_out_channels[0], - out_channels=out_channels, - act_fn=act_fn, - fc_dim=block_out_channels[-1] // 4, - ) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - return_dict: bool = True, - ) -> Union[UNet1DOutput, Tuple]: - r""" - The [`UNet1DModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_1d.UNet1DOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is - returned where the first element is the sample tensor. - """ - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - timestep_embed = self.time_proj(timesteps) - if self.config.use_timestep_embedding: - timestep_embed = self.time_mlp(timestep_embed) - else: - timestep_embed = timestep_embed[..., None] - timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) - timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) - - # 2. down - down_block_res_samples = () - for downsample_block in self.down_blocks: - sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) - down_block_res_samples += res_samples - - # 3. mid - if self.mid_block: - sample = self.mid_block(sample, timestep_embed) - - # 4. up - for i, upsample_block in enumerate(self.up_blocks): - res_samples = down_block_res_samples[-1:] - down_block_res_samples = down_block_res_samples[:-1] - sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) - - # 5. post-process - if self.out_block: - sample = self.out_block(sample, timestep_embed) - - if not return_dict: - return (sample,) - - return UNet1DOutput(sample=sample) +class UNet1DModel(UNet1DModel): + deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead." + deprecate("UNet1DModel", "0.29", deprecation_message) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 74a2f1681ead..772d7f6cfbe4 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -11,616 +11,112 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn - -from .activations import get_activation -from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims - - -class DownResnetBlock1D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - num_layers: int = 1, - conv_shortcut: bool = False, - temb_channels: int = 32, - groups: int = 32, - groups_out: Optional[int] = None, - non_linearity: Optional[str] = None, - time_embedding_norm: str = "default", - output_scale_factor: float = 1.0, - add_downsample: bool = True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.add_downsample = add_downsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity is None: - self.nonlinearity = None - else: - self.nonlinearity = get_activation(non_linearity) - - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) - - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - output_states = () - - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - - if self.downsample is not None: - hidden_states = self.downsample(hidden_states) - - return hidden_states, output_states - - -class UpResnetBlock1D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - num_layers: int = 1, - temb_channels: int = 32, - groups: int = 32, - groups_out: Optional[int] = None, - non_linearity: Optional[str] = None, - time_embedding_norm: str = "default", - output_scale_factor: float = 1.0, - add_upsample: bool = True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.time_embedding_norm = time_embedding_norm - self.add_upsample = add_upsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity is None: - self.nonlinearity = None - else: - self.nonlinearity = get_activation(non_linearity) - - self.upsample = None - if add_upsample: - self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, - temb: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - if res_hidden_states_tuple is not None: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) - - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - hidden_states = self.upsample(hidden_states) - - return hidden_states - - -class ValueFunctionMidBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, embed_dim: int): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.embed_dim = embed_dim - - self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) - self.down1 = Downsample1D(out_channels // 2, use_conv=True) - self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) - self.down2 = Downsample1D(out_channels // 4, use_conv=True) - - def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - x = self.res1(x, temb) - x = self.down1(x) - x = self.res2(x, temb) - x = self.down2(x) - return x - - -class MidResTemporalBlock1D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - embed_dim: int, - num_layers: int = 1, - add_downsample: bool = False, - add_upsample: bool = False, - non_linearity: Optional[str] = None, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.add_downsample = add_downsample - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity is None: - self.nonlinearity = None - else: - self.nonlinearity = get_activation(non_linearity) - - self.upsample = None - if add_upsample: - self.upsample = Downsample1D(out_channels, use_conv=True) - - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True) - - if self.upsample and self.downsample: - raise ValueError("Block cannot downsample and upsample") - - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - if self.upsample: - hidden_states = self.upsample(hidden_states) - if self.downsample: - self.downsample = self.downsample(hidden_states) - - return hidden_states - - -class OutConv1DBlock(nn.Module): - def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): - super().__init__() - self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - self.final_conv1d_act = get_activation(act_fn) - self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) - - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = self.final_conv1d_1(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_gn(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_act(hidden_states) - hidden_states = self.final_conv1d_2(hidden_states) - return hidden_states - - -class OutValueFunctionBlock(nn.Module): - def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): - super().__init__() - self.final_block = nn.ModuleList( - [ - nn.Linear(fc_dim + embed_dim, fc_dim // 2), - get_activation(act_fn), - nn.Linear(fc_dim // 2, 1), - ] - ) - - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = hidden_states.view(hidden_states.shape[0], -1) - hidden_states = torch.cat((hidden_states, temb), dim=-1) - for layer in self.final_block: - hidden_states = layer(hidden_states) - - return hidden_states - - -_kernels = { - "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], - "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], - "lanczos3": [ - 0.003689131001010537, - 0.015056144446134567, - -0.03399861603975296, - -0.066637322306633, - 0.13550527393817902, - 0.44638532400131226, - 0.44638532400131226, - 0.13550527393817902, - -0.066637322306633, - -0.03399861603975296, - 0.015056144446134567, - 0.003689131001010537, - ], -} - - -class Downsample1d(nn.Module): - def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) - weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) - weight[indices, indices] = kernel - return F.conv1d(hidden_states, weight, stride=2) - - -class Upsample1d(nn.Module): - def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) * 2 - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) - weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) - weight[indices, indices] = kernel - return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) - - -class SelfAttention1d(nn.Module): - def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): - super().__init__() - self.channels = in_channels - self.group_norm = nn.GroupNorm(1, num_channels=in_channels) - self.num_heads = n_head - - self.query = nn.Linear(self.channels, self.channels) - self.key = nn.Linear(self.channels, self.channels) - self.value = nn.Linear(self.channels, self.channels) - - self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) - - self.dropout = nn.Dropout(dropout_rate, inplace=True) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - residual = hidden_states - batch, channel_dim, seq = hidden_states.shape - - hidden_states = self.group_norm(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) - - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) - attention_probs = torch.softmax(attention_scores, dim=-1) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - hidden_states = self.dropout(hidden_states) - - output = hidden_states + residual - - return output - - -class ResConvBlock(nn.Module): - def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): - super().__init__() - self.is_last = is_last - self.has_conv_skip = in_channels != out_channels - - if self.has_conv_skip: - self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) - - self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) - self.group_norm_1 = nn.GroupNorm(1, mid_channels) - self.gelu_1 = nn.GELU() - self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) - - if not self.is_last: - self.group_norm_2 = nn.GroupNorm(1, out_channels) - self.gelu_2 = nn.GELU() - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states - hidden_states = self.conv_1(hidden_states) - hidden_states = self.group_norm_1(hidden_states) - hidden_states = self.gelu_1(hidden_states) - hidden_states = self.conv_2(hidden_states) +from ..utils import deprecate +from .unets.unet_1d_blocks import ( + AttnDownBlock1D, + AttnUpBlock1D, + DownBlock1D, + DownBlock1DNoSkip, + DownResnetBlock1D, + Downsample1d, + MidResTemporalBlock1D, + OutConv1DBlock, + OutValueFunctionBlock, + ResConvBlock, + SelfAttention1d, + UNetMidBlock1D, + UpBlock1D, + UpBlock1DNoSkip, + UpResnetBlock1D, + Upsample1d, + ValueFunctionMidBlock1D, +) - if not self.is_last: - hidden_states = self.group_norm_2(hidden_states) - hidden_states = self.gelu_2(hidden_states) - output = hidden_states + residual - return output +class DownResnetBlock1D(DownResnetBlock1D): + deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead." + deprecate("DownResnetBlock1D", "0.29", deprecation_message) -class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): - super().__init__() +class UpResnetBlock1D(UpResnetBlock1D): + deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead." + deprecate("UpResnetBlock1D", "0.29", deprecation_message) - out_channels = in_channels if out_channels is None else out_channels - # there is always at least one resnet - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.up = Upsample1d(kernel="cubic") +class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D): + deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead." + deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = self.down(hidden_states) - for attn, resnet in zip(self.attentions, self.resnets): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) +class OutConv1DBlock(OutConv1DBlock): + deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead." + deprecate("OutConv1DBlock", "0.29", deprecation_message) - hidden_states = self.up(hidden_states) - return hidden_states +class OutValueFunctionBlock(OutValueFunctionBlock): + deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead." + deprecate("OutValueFunctionBlock", "0.29", deprecation_message) -class AttnDownBlock1D(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] +class Downsample1d(Downsample1d): + deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead." + deprecate("Downsample1d", "0.29", deprecation_message) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = self.down(hidden_states) +class Upsample1d(Upsample1d): + deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead." + deprecate("Upsample1d", "0.29", deprecation_message) - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - - return hidden_states, (hidden_states,) - - -class DownBlock1D(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = self.down(hidden_states) - - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states, (hidden_states,) - - -class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = torch.cat([hidden_states, temb], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states, (hidden_states,) - - -class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels +class SelfAttention1d(SelfAttention1d): + deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead." + deprecate("SelfAttention1d", "0.29", deprecation_message) - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") +class ResConvBlock(ResConvBlock): + deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead." + deprecate("ResConvBlock", "0.29", deprecation_message) - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) +class UNetMidBlock1D(UNetMidBlock1D): + deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead." + deprecate("UNetMidBlock1D", "0.29", deprecation_message) - hidden_states = self.up(hidden_states) - return hidden_states +class AttnDownBlock1D(AttnDownBlock1D): + deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead." + deprecate("AttnDownBlock1D", "0.29", deprecation_message) -class UpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels +class DownBlock1D(DownBlock1D): + deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead." + deprecate("DownBlock1D", "0.29", deprecation_message) - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") +class DownBlock1DNoSkip(DownBlock1DNoSkip): + deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead." + deprecate("DownBlock1DNoSkip", "0.29", deprecation_message) - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) +class AttnUpBlock1D(AttnUpBlock1D): + deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead." + deprecate("AttnUpBlock1D", "0.29", deprecation_message) - hidden_states = self.up(hidden_states) - return hidden_states +class UpBlock1D(UpBlock1D): + deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead." + deprecate("UpBlock1D", "0.29", deprecation_message) -class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels +class UpBlock1DNoSkip(UpBlock1DNoSkip): + deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead." + deprecate("UpBlock1DNoSkip", "0.29", deprecation_message) - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), - ] - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states - - -DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] -MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] -OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] -UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] +class MidResTemporalBlock1D(MidResTemporalBlock1D): + deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead." + deprecate("MidResTemporalBlock1D", "0.29", deprecation_message) def get_down_block( @@ -630,42 +126,38 @@ def get_down_block( out_channels: int, temb_channels: int, add_downsample: bool, -) -> DownBlockType: - if down_block_type == "DownResnetBlock1D": - return DownResnetBlock1D( - in_channels=in_channels, - num_layers=num_layers, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - ) - elif down_block_type == "DownBlock1D": - return DownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "AttnDownBlock1D": - return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "DownBlock1DNoSkip": - return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) - raise ValueError(f"{down_block_type} does not exist.") +): + deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead." + deprecate("get_down_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_down_block + + return get_down_block( + down_block_type=down_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool -) -> UpBlockType: - if up_block_type == "UpResnetBlock1D": - return UpResnetBlock1D( - in_channels=in_channels, - num_layers=num_layers, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - ) - elif up_block_type == "UpBlock1D": - return UpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "AttnUpBlock1D": - return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "UpBlock1DNoSkip": - return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) - raise ValueError(f"{up_block_type} does not exist.") +): + deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead." + deprecate("get_up_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_up_block + + return get_up_block( + up_block_type=up_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) def get_mid_block( @@ -676,27 +168,36 @@ def get_mid_block( out_channels: int, embed_dim: int, add_downsample: bool, -) -> MidBlockType: - if mid_block_type == "MidResTemporalBlock1D": - return MidResTemporalBlock1D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - embed_dim=embed_dim, - add_downsample=add_downsample, - ) - elif mid_block_type == "ValueFunctionMidBlock1D": - return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) - elif mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) - raise ValueError(f"{mid_block_type} does not exist.") +): + deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead." + deprecate("get_mid_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_mid_block + + return get_mid_block( + mid_block_type=mid_block_type, + num_layers=num_layers, + in_channels=in_channels, + mid_channels=mid_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) def get_out_block( *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int -) -> Optional[OutBlockType]: - if out_block_type == "OutConv1DBlock": - return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) - elif out_block_type == "ValueFunction": - return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) - return None +): + deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead." + deprecate("get_out_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_out_block + + return get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=embed_dim, + out_channels=out_channels, + act_fn=act_fn, + fc_dim=fc_dim, + ) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 0531d8aae783..006bf4721856 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -11,336 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union -import torch -import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from ..utils import deprecate +from .unets.unet_2d import UNet2DModel, UNet2DOutput -@dataclass -class UNet2DOutput(BaseOutput): - """ - The output of [`UNet2DModel`]. +class UNet2DOutput(UNet2DOutput): + deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead." + deprecate("UNet2DOutput", "0.29", deprecation_message) - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output from the last layer of the model. - """ - sample: torch.FloatTensor - - -class UNet2DModel(ModelMixin, ConfigMixin): - r""" - A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - - 1)`. - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): - Whether to flip sin to cos for Fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): - Tuple of downsample block types. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): - Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. - mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. - downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. - downsample_type (`str`, *optional*, defaults to `conv`): - The downsample type for downsampling layers. Choose between "conv" and "resnet" - upsample_type (`str`, *optional*, defaults to `conv`): - The upsample type for upsampling layers. Choose between "conv" and "resnet" - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. - norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. - attn_norm_num_groups (`int`, *optional*, defaults to `None`): - If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the - given number of groups. If left as `None`, the group norm layer will only be created if - `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. - norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, or `"identity"`. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class - conditioning with `class_embed_type` equal to `None`. - """ - - @register_to_config - def __init__( - self, - sample_size: Optional[Union[int, Tuple[int, int]]] = None, - in_channels: int = 3, - out_channels: int = 3, - center_input_sample: bool = False, - time_embedding_type: str = "positional", - freq_shift: int = 0, - flip_sin_to_cos: bool = True, - down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), - up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), - block_out_channels: Tuple[int] = (224, 448, 672, 896), - layers_per_block: int = 2, - mid_block_scale_factor: float = 1, - downsample_padding: int = 1, - downsample_type: str = "conv", - upsample_type: str = "conv", - dropout: float = 0.0, - act_fn: str = "silu", - attention_head_dim: Optional[int] = 8, - norm_num_groups: int = 32, - attn_norm_num_groups: Optional[int] = None, - norm_eps: float = 1e-5, - resnet_time_scale_shift: str = "default", - add_attention: bool = True, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - num_train_timesteps: Optional[int] = None, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - elif time_embedding_type == "learned": - self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - downsample_type=downsample_type, - dropout=dropout, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], - resnet_groups=norm_num_groups, - attn_groups=attn_norm_num_groups, - add_attention=add_attention, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, - resnet_time_scale_shift=resnet_time_scale_shift, - upsample_type=upsample_type, - dropout=dropout, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DOutput, Tuple]: - r""" - The [`UNet2DModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is - returned where the first element is the sample tensor. - """ - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when doing class conditioning") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - elif self.class_embedding is None and class_labels is not None: - raise ValueError("class_embedding needs to be initialized in order to use class conditioning") - - # 2. pre-process - skip_sample = sample - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "skip_conv"): - sample, res_samples, skip_sample = downsample_block( - hidden_states=sample, temb=emb, skip_sample=skip_sample - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb) - - # 5. up - skip_sample = None - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) - else: - sample = upsample_block(sample, res_samples, emb) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if skip_sample is not None: - sample += skip_sample - - if self.config.time_embedding_type == "fourier": - timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) - sample = sample / timesteps - - if not return_dict: - return (sample,) - - return UNet2DOutput(sample=sample) +class UNet2DModel(UNet2DModel): + deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead." + deprecate("UNet2DModel", "0.29", deprecation_message) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 553f6aaa990c..497eabfc607b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,33 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from ..utils import is_torch_version, logging -from ..utils.torch_utils import apply_freeu -from .activations import get_activation -from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 -from .dual_transformer_2d import DualTransformer2DModel -from .normalization import AdaGroupNorm -from .resnet import ( - Downsample2D, - FirDownsample2D, - FirUpsample2D, - KDownsample2D, - KUpsample2D, - ResnetBlock2D, - ResnetBlockCondNorm2D, - Upsample2D, -) -from .transformer_2d import Transformer2DModel - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from typing import Optional + +from ..utils import deprecate +from .unets.unet_2d_blocks import ( + AttnDownBlock2D, + AttnDownEncoderBlock2D, + AttnSkipDownBlock2D, + AttnSkipUpBlock2D, + AttnUpBlock2D, + AttnUpDecoderBlock2D, + AutoencoderTinyBlock, + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + KAttentionBlock, + KCrossAttnDownBlock2D, + KCrossAttnUpBlock2D, + KDownBlock2D, + KUpBlock2D, + ResnetDownsampleBlock2D, + ResnetUpsampleBlock2D, + SimpleCrossAttnDownBlock2D, + SimpleCrossAttnUpBlock2D, + SkipDownBlock2D, + SkipUpBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + UpDecoderBlock2D, +) def get_down_block( @@ -67,186 +72,38 @@ def get_down_block( downsample_type: Optional[str] = None, dropout: float = 0.0, ): - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - ) - elif down_block_type == "AttnDownBlock2D": - if add_downsample is False: - downsample_type = None - else: - downsample_type = downsample_type or "conv" # default to 'conv' - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - downsample_type=downsample_type, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "KDownBlock2D": - return KDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif down_block_type == "KCrossAttnDownBlock2D": - return KCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - add_self_attention=True if not add_downsample else False, - ) - raise ValueError(f"{down_block_type} does not exist.") + deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_down_block`, instead." + deprecate("get_down_block", "0.29", deprecation_message) + + from .unets.unet_2d_blocks import get_down_block + + return get_down_block( + down_block_type=down_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim, + downsample_type=downsample_type, + dropout=dropout, + ) def get_mid_block( @@ -351,3316 +208,168 @@ def get_up_block( attention_head_dim: Optional[int] = None, upsample_type: Optional[str] = None, dropout: float = 0.0, -) -> nn.Module: - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads +): + deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_up_block`, instead." + deprecate("get_up_block", "0.29", deprecation_message) - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif up_block_type == "AttnUpBlock2D": - if add_upsample is False: - upsample_type = None - else: - upsample_type = upsample_type or "conv" # default to 'conv' - - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - upsample_type=upsample_type, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - temb_channels=temb_channels, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - temb_channels=temb_channels, - ) - elif up_block_type == "KUpBlock2D": - return KUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "KCrossAttnUpBlock2D": - return KCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - ) + from .unets.unet_2d_blocks import get_up_block - raise ValueError(f"{up_block_type} does not exist.") - - -class AutoencoderTinyBlock(nn.Module): - """ - Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU - blocks. - - Args: - in_channels (`int`): The number of input channels. - out_channels (`int`): The number of output channels. - act_fn (`str`): - ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. - - Returns: - `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to - `out_channels`. - """ - - def __init__(self, in_channels: int, out_channels: int, act_fn: str): - super().__init__() - act_fn = get_activation(act_fn) - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - act_fn, - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - act_fn, - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - ) - self.skip = ( - nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) - if in_channels != out_channels - else nn.Identity() - ) - self.fuse = nn.ReLU() - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - return self.fuse(self.conv(x) + self.skip(x)) - - -class UNetMidBlock2D(nn.Module): - """ - A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. - - Args: - in_channels (`int`): The number of input channels. - temb_channels (`int`): The number of temporal embedding channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_time_scale_shift (`str`, *optional*, defaults to `default`): - The type of normalization to apply to the time embeddings. This can help to improve the performance of the - model on tasks with long-range temporal dependencies. - resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. - resnet_pre_norm (`bool`, *optional*, defaults to `True`): - Whether to use pre-normalization for the resnet blocks. - add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. - attention_head_dim (`int`, *optional*, defaults to 1): - Dimension of a single attention head. The number of attention heads is determined based on this value and - the number of input channels. - output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - attn_groups: Optional[int] = None, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - if attn_groups is None: - attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None - - # there is always at least one resnet - if resnet_time_scale_shift == "spatial": - resnets = [ - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ] - else: - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." - ) - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=attn_groups, - spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - if resnet_time_scale_shift == "spatial": - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - else: - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for i in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - - return hidden_states - - -class UNetMidBlock2DSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - skip_time_act: bool = False, - only_cross_attention: bool = False, - cross_attention_norm: Optional[str] = None, - ): - super().__init__() - - self.has_cross_attention = True - - self.attention_head_dim = attention_head_dim - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - self.num_heads = in_channels // self.attention_head_dim - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ] - attentions = [] - - for _ in range(num_layers): - processor = ( - AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() - ) - - attentions.append( - Attention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=self.attention_head_dim, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - processor=processor, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) - - if attention_mask is None: - # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. - mask = None if encoder_hidden_states is None else encoder_attention_mask - else: - # when attention_mask is defined: we don't even check for encoder_attention_mask. - # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. - # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. - # then we can simplify this whole if/else block to: - # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask - mask = attention_mask - - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=mask, - **cross_attention_kwargs, - ) - - # resnet - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - downsample_type: str = "conv", - ): - super().__init__() - resnets = [] - attentions = [] - self.downsample_type = downsample_type - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if downsample_type == "conv": - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - elif downsample_type == "resnet": - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - lora_scale = cross_attention_kwargs.get("scale", 1.0) - - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - cross_attention_kwargs.update({"scale": lora_scale}) - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn(hidden_states, **cross_attention_kwargs) - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) - else: - hidden_states = downsampler(hidden_states, scale=lora_scale) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - add_downsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - additional_residuals: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - blocks = list(zip(self.resnets, self.attentions)) - - for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - if resnet_time_scale_shift == "spatial": - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - else: - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None, scale=scale) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) - - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - resnets = [] - attentions = [] - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - if resnet_time_scale_shift == "spatial": - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - else: - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) - - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = np.sqrt(2.0), - add_downsample: bool = True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=32, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale=scale) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor: float = np.sqrt(2.0), - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb, scale) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class ResnetDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - skip_time_act: bool = False, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale) - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class SimpleCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - skip_time_act: bool = False, - only_cross_attention: bool = False, - cross_attention_norm: Optional[str] = None, - ): - super().__init__() - - self.has_cross_attention = True - - resnets = [] - attentions = [] - - self.attention_head_dim = attention_head_dim - self.num_heads = out_channels // self.attention_head_dim - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ) - - processor = ( - AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() - ) - - attentions.append( - Attention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attention_head_dim, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - processor=processor, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - lora_scale = cross_attention_kwargs.get("scale", 1.0) - - if attention_mask is None: - # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. - mask = None if encoder_hidden_states is None else encoder_attention_mask - else: - # when attention_mask is defined: we don't even check for encoder_attention_mask. - # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. - # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. - # then we can simplify this whole if/else block to: - # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask - mask = attention_mask - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=mask, - **cross_attention_kwargs, - ) - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=mask, - **cross_attention_kwargs, - ) - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale=lora_scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class KDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: int = 32, - add_downsample: bool = False, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - # YiYi's comments- might be able to use FirDownsample2D, look into details later - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states, output_states - - -class KCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - cross_attention_dim: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_group_size: int = 32, - add_downsample: bool = True, - attention_head_dim: int = 64, - add_self_attention: bool = False, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - out_channels, - out_channels // attention_head_dim, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm="layer_norm", - group_size=resnet_group_size, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_downsample: - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - - if self.downsamplers is None: - output_states += (None,) - else: - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states, output_states - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: int = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - upsample_type: str = "conv", - ): - super().__init__() - resnets = [] - attentions = [] - - self.upsample_type = upsample_type - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if upsample_type == "conv": - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - elif upsample_type == "resnet": - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb, scale=scale) - else: - hidden_states = upsampler(hidden_states, scale=scale) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) + return get_up_block( + up_block_type=up_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resolution_idx=resolution_idx, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim, + upsample_type=upsample_type, + dropout=dropout, + ) - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - if resnet_time_scale_shift == "spatial": - resnets.append( - ResnetBlockCondNorm2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - else: - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.resolution_idx = resolution_idx - - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> torch.FloatTensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb, scale=scale) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - ): - super().__init__() - resnets = [] - attentions = [] - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - if resnet_time_scale_shift == "spatial": - resnets.append( - ResnetBlockCondNorm2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm="spatial", - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - else: - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, - spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.resolution_idx = resolution_idx - - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> torch.FloatTensor: - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, scale=scale) - - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = np.sqrt(2.0), - add_upsample: bool = True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." - ) - attention_head_dim = out_channels - - self.attentions.append( - Attention( - out_channels, - heads=out_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=32, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) +class AutoencoderTinyBlock(AutoencoderTinyBlock): + deprecation_message = "Importing `AutoencoderTinyBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AutoencoderTinyBlock`, instead." + deprecate("AutoencoderTinyBlock", "0.29", deprecation_message) - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - skip_sample=None, - scale: float = 1.0, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb, scale=scale) - - cross_attention_kwargs = {"scale": scale} - hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) - - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor: float = np.sqrt(2.0), - add_upsample: bool = True, - upsample_padding: int = 1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - skip_sample=None, - scale: float = 1.0, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb, scale=scale) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) - - return hidden_states, skip_sample - - -class ResnetUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - skip_time_act: bool = False, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=scale) - - return hidden_states - - -class SimpleCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attention_head_dim: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - skip_time_act: bool = False, - only_cross_attention: bool = False, - cross_attention_norm: Optional[str] = None, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attention_head_dim = attention_head_dim - - self.num_heads = out_channels // self.attention_head_dim - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - ) - ) - - processor = ( - AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() - ) - - attentions.append( - Attention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=self.attention_head_dim, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - processor=processor, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - skip_time_act=skip_time_act, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - lora_scale = cross_attention_kwargs.get("scale", 1.0) - if attention_mask is None: - # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. - mask = None if encoder_hidden_states is None else encoder_attention_mask - else: - # when attention_mask is defined: we don't even check for encoder_attention_mask. - # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. - # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. - # then we can simplify this whole if/else block to: - # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask - mask = attention_mask - - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=mask, - **cross_attention_kwargs, - ) - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=mask, - **cross_attention_kwargs, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=lora_scale) - - return hidden_states - - -class KUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - resolution_idx: int, - dropout: float = 0.0, - num_layers: int = 5, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: Optional[int] = 32, - add_upsample: bool = True, - ): - super().__init__() - resnets = [] - k_in_channels = 2 * out_channels - k_out_channels = in_channels - num_layers = num_layers - 1 - - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=k_out_channels if (i == num_layers - 1) else out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - upsample_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class KCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - resolution_idx: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: int = 32, - attention_head_dim: int = 1, # attention dim_head - cross_attention_dim: int = 768, - add_upsample: bool = True, - upcast_attention: bool = False, - ): - super().__init__() - resnets = [] - attentions = [] - - is_first_block = in_channels == out_channels == temb_channels - is_middle_block = in_channels != out_channels - add_self_attention = True if is_first_block else False - - self.has_cross_attention = True - self.attention_head_dim = attention_head_dim - - # in_channels, and out_channels for the block (k-unet) - k_in_channels = out_channels if is_first_block else 2 * out_channels - k_out_channels = in_channels - - num_layers = num_layers - 1 - - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - if is_middle_block and (i == num_layers - 1): - conv_2d_out_channels = k_out_channels - else: - conv_2d_out_channels = None - - resnets.append( - ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - conv_2d_out_channels=conv_2d_out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - k_out_channels if (i == num_layers - 1) else out_channels, - k_out_channels // attention_head_dim - if (i == num_layers - 1) - else out_channels // attention_head_dim, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm="layer_norm", - upcast_attention=upcast_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states +class UNetMidBlock2D(UNetMidBlock2D): + deprecation_message = "Importing `UNetMidBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D`, instead." + deprecate("UNetMidBlock2D", "0.29", deprecation_message) -# can potentially later be renamed to `No-feed-forward` attention -class KAttentionBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - attention_bias (`bool`, *optional*, defaults to `False`): - Configure if the attention layers should contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to upcast the attention computation to `float32`. - temb_channels (`int`, *optional*, defaults to 768): - The number of channels in the token embedding. - add_self_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to add self-attention to the block. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - group_size (`int`, *optional*, defaults to 32): - The number of groups to separate the channels into for group normalization. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - upcast_attention: bool = False, - temb_channels: int = 768, # for ada_group_norm - add_self_attention: bool = False, - cross_attention_norm: Optional[str] = None, - group_size: int = 32, - ): - super().__init__() - self.add_self_attention = add_self_attention - - # 1. Self-Attn - if add_self_attention: - self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=None, - cross_attention_norm=None, - ) - - # 2. Cross-Attn - self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - cross_attention_norm=cross_attention_norm, - ) - def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: - return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) - - def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: - return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - # TODO: mark emb as non-optional (self.norm2 requires it). - # requires assessing impact of change to positional param interface. - emb: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - # 1. Self-Attention - if self.add_self_attention: - norm_hidden_states = self.norm1(hidden_states, emb) - - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - attn_output = self._to_4d(attn_output, height, weight) - - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention/None - norm_hidden_states = self.norm2(hidden_states, emb) - - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, - **cross_attention_kwargs, - ) - attn_output = self._to_4d(attn_output, height, weight) +class UNetMidBlock2DCrossAttn(UNetMidBlock2DCrossAttn): + deprecation_message = "Importing `UNetMidBlock2DCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn`, instead." + deprecate("UNetMidBlock2DCrossAttn", "0.29", deprecation_message) + + +class UNetMidBlock2DSimpleCrossAttn(UNetMidBlock2DSimpleCrossAttn): + deprecation_message = "Importing `UNetMidBlock2DSimpleCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn`, instead." + deprecate("UNetMidBlock2DSimpleCrossAttn", "0.29", deprecation_message) + + +class AttnDownBlock2D(AttnDownBlock2D): + deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownBlock2D`, instead." + deprecate("AttnDownBlock2D", "0.29", deprecation_message) + + +class CrossAttnDownBlock2D(CrossAttnDownBlock2D): + deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D`, instead." + deprecate("CrossAttnDownBlock2D", "0.29", deprecation_message) + + +class DownBlock2D(DownBlock2D): + deprecation_message = "Importing `DownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import DownBlock2D`, instead." + deprecate("DownBlock2D", "0.29", deprecation_message) + + +class AttnDownEncoderBlock2D(AttnDownEncoderBlock2D): + deprecation_message = "Importing `AttnDownEncoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownEncoderBlock2D`, instead." + deprecate("AttnDownEncoderBlock2D", "0.29", deprecation_message) + + +class AttnSkipDownBlock2D(AttnSkipDownBlock2D): + deprecation_message = "Importing `AttnSkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipDownBlock2D`, instead." + deprecate("AttnSkipDownBlock2D", "0.29", deprecation_message) + - hidden_states = attn_output + hidden_states +class SkipDownBlock2D(SkipDownBlock2D): + deprecation_message = "Importing `SkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipDownBlock2D`, instead." + deprecate("SkipDownBlock2D", "0.29", deprecation_message) - return hidden_states + +class ResnetDownsampleBlock2D(ResnetDownsampleBlock2D): + deprecation_message = "Importing `ResnetDownsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D`, instead." + deprecate("ResnetDownsampleBlock2D", "0.29", deprecation_message) + + +class SimpleCrossAttnDownBlock2D(SimpleCrossAttnDownBlock2D): + deprecation_message = "Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead." + deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message) + + +class KDownBlock2D(KDownBlock2D): + deprecation_message = "Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead." + deprecate("KDownBlock2D", "0.29", deprecation_message) + + +class KCrossAttnDownBlock2D(KCrossAttnDownBlock2D): + deprecation_message = "Importing `KCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnDownBlock2D`, instead." + deprecate("KCrossAttnDownBlock2D", "0.29", deprecation_message) + + +class AttnUpBlock2D(AttnUpBlock2D): + deprecation_message = "Importing `AttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpBlock2D`, instead." + deprecate("AttnUpBlock2D", "0.29", deprecation_message) + + +class CrossAttnUpBlock2D(CrossAttnUpBlock2D): + deprecation_message = "Importing `CrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D`, instead." + deprecate("CrossAttnUpBlock2D", "0.29", deprecation_message) + + +class UpBlock2D(UpBlock2D): + deprecation_message = "Importing `UpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpBlock2D`, instead." + deprecate("UpBlock2D", "0.29", deprecation_message) + + +class UpDecoderBlock2D(UpDecoderBlock2D): + deprecation_message = "Importing `UpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D`, instead." + deprecate("UpDecoderBlock2D", "0.29", deprecation_message) + + +class AttnUpDecoderBlock2D(AttnUpDecoderBlock2D): + deprecation_message = "Importing `AttnUpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpDecoderBlock2D`, instead." + deprecate("AttnUpDecoderBlock2D", "0.29", deprecation_message) + + +class AttnSkipUpBlock2D(AttnSkipUpBlock2D): + deprecation_message = "Importing `AttnSkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipUpBlock2D`, instead." + deprecate("AttnSkipUpBlock2D", "0.29", deprecation_message) + + +class SkipUpBlock2D(SkipUpBlock2D): + deprecation_message = "Importing `SkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipUpBlock2D`, instead." + deprecate("SkipUpBlock2D", "0.29", deprecation_message) + + +class ResnetUpsampleBlock2D(ResnetUpsampleBlock2D): + deprecation_message = "Importing `ResnetUpsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetUpsampleBlock2D`, instead." + deprecate("ResnetUpsampleBlock2D", "0.29", deprecation_message) + + +class SimpleCrossAttnUpBlock2D(SimpleCrossAttnUpBlock2D): + deprecation_message = "Importing `SimpleCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D`, instead." + deprecate("SimpleCrossAttnUpBlock2D", "0.29", deprecation_message) + + +class KUpBlock2D(KUpBlock2D): + deprecation_message = "Importing `KUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KUpBlock2D`, instead." + deprecate("KUpBlock2D", "0.29", deprecation_message) + + +class KCrossAttnUpBlock2D(KCrossAttnUpBlock2D): + deprecation_message = "Importing `KCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnUpBlock2D`, instead." + deprecate("KCrossAttnUpBlock2D", "0.29", deprecation_message) + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(KAttentionBlock): + deprecation_message = "Importing `KAttentionBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KAttentionBlock`, instead." + deprecate("KAttentionBlock", "0.29", deprecation_message) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5e16fdd93992..cc619dd17c4c 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -11,1302 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from ..utils import deprecate +from .unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput -import torch -import torch.nn as nn -import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers -from .activations import get_activation -from .attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - Attention, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) -from .embeddings import ( - GaussianFourierProjection, - GLIGENTextBoundingboxProjection, - ImageHintTimeEmbedding, - ImageProjection, - ImageTimeEmbedding, - TextImageProjection, - TextImageTimeEmbedding, - TextTimeEmbedding, - TimestepEmbedding, - Timesteps, -) -from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - get_down_block, - get_mid_block, - get_up_block, -) +class UNet2DConditionOutput(UNet2DConditionOutput): + deprecation_message = "Importing `UNet2DConditionOutput` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput`, instead." + deprecate("UNet2DConditionOutput", "0.29", deprecation_message) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or - `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling - blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - addition_time_embed_dim: (`int`, *optional*, defaults to `None`): - Dimension for the timestep embeddings. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, - *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, - *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): - Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the - `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` - otherwise. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - dropout: float = 0.0, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: int = 1.0, - time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, - time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - attention_type: str = "default", - class_embeddings_concat: bool = False, - mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, - ): - super().__init__() - - self.sample_size = sample_size - - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - self._check_config( - down_block_types=down_block_types, - up_block_types=up_block_types, - only_cross_attention=only_cross_attention, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - ) - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - time_embed_dim, timestep_input_dim = self._set_time_proj( - time_embedding_type, - block_out_channels=block_out_channels, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - time_embedding_dim=time_embedding_dim, - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - self._set_encoder_hid_proj( - encoder_hid_dim_type, - cross_attention_dim=cross_attention_dim, - encoder_hid_dim=encoder_hid_dim, - ) - - # class embedding - self._set_class_embedding( - class_embed_type, - act_fn=act_fn, - num_class_embeds=num_class_embeds, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - time_embed_dim=time_embed_dim, - timestep_input_dim=timestep_input_dim, - ) - - self._set_add_embedding( - addition_embed_type, - addition_embed_type_num_heads=addition_embed_type_num_heads, - addition_time_embed_dim=addition_time_embed_dim, - cross_attention_dim=cross_attention_dim, - encoder_hid_dim=encoder_hid_dim, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - time_embed_dim=time_embed_dim, - ) - - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - # set or unroll configs - if isinstance(only_cross_attention, bool): - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = only_cross_attention - - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = False - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = get_mid_block( - mid_block_type, - temb_channels=blocks_time_embed_dim, - in_channels=block_out_channels[-1], - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1], - num_attention_heads=num_attention_heads[-1], - cross_attention_dim=cross_attention_dim[-1], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - mid_block_only_cross_attention=mid_block_only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[-1], - dropout=dropout, - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = ( - list(reversed(transformer_layers_per_block)) - if reverse_transformer_layers_per_block is None - else reverse_transformer_layers_per_block - ) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resolution_idx=i, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - self.conv_act = get_activation(act_fn) - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) - - def _check_config( - self, - down_block_types: Tuple[str], - up_block_types: Tuple[str], - only_cross_attention: Union[bool, Tuple[bool]], - block_out_channels: Tuple[int], - layers_per_block: [int, Tuple[int]], - cross_attention_dim: Union[int, Tuple[int]], - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]], - reverse_transformer_layers_per_block: bool, - attention_head_dim: int, - num_attention_heads: Optional[Union[int, Tuple[int]]], - ): - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: - for layer_number_per_block in transformer_layers_per_block: - if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") - - def _set_time_proj( - self, - time_embedding_type: str, - block_out_channels: int, - flip_sin_to_cos: bool, - freq_shift: float, - time_embedding_dim: int, - ) -> Tuple[int, int]: - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) - - return time_embed_dim, timestep_input_dim - - def _set_encoder_hid_proj( - self, - encoder_hid_dim_type: Optional[str], - cross_attention_dim: Union[int, Tuple[int]], - encoder_hid_dim: Optional[int], - ): - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - def _set_class_embedding( - self, - class_embed_type: Optional[str], - act_fn: str, - num_class_embeds: Optional[int], - projection_class_embeddings_input_dim: Optional[int], - time_embed_dim: int, - timestep_input_dim: int, - ): - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - def _set_add_embedding( - self, - addition_embed_type: str, - addition_embed_type_num_heads: int, - addition_time_embed_dim: Optional[int], - flip_sin_to_cos: bool, - freq_shift: float, - cross_attention_dim: Optional[int], - encoder_hid_dim: Optional[int], - projection_class_embeddings_input_dim: Optional[int], - time_embed_dim: int, - ): - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): - if attention_type in ["gated", "gated-text-image"]: - positive_len = 768 - if isinstance(cross_attention_dim, int): - positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): - positive_len = cross_attention_dim[0] - - feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type - ) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def enable_freeu(self, s1, s2, b1, b2): - r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stage blocks where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that - are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate the "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate the "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - for i, upsample_block in enumerate(self.up_blocks): - setattr(upsample_block, "s1", s1) - setattr(upsample_block, "s2", s2) - setattr(upsample_block, "b1", b1) - setattr(upsample_block, "b2", b2) - - def disable_freeu(self): - """Disables the FreeU mechanism.""" - freeu_keys = {"s1", "s2", "b1", "b2"} - for i, upsample_block in enumerate(self.up_blocks): - for k in freeu_keys: - if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: - setattr(upsample_block, k, None) - - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, - key, value) are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def unload_lora(self): - """Unloads LoRA weights.""" - deprecate( - "unload_lora", - "0.28.0", - "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", - ) - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - - def get_time_embed( - self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] - ) -> Optional[torch.Tensor]: - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - return t_emb - - def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - class_emb = None - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - return class_emb - - def get_aug_embed( - self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict - ) -> Optional[torch.Tensor]: - aug_emb = None - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "text_image": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - - image_embs = added_cond_kwargs.get("image_embeds") - text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) - aug_emb = self.add_embedding(text_embs, image_embs) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - elif self.config.addition_embed_type == "image": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - aug_emb = self.add_embedding(image_embs) - elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - hint = added_cond_kwargs.get("hint") - aug_emb = self.add_embedding(image_embs, hint) - return aug_emb - - def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor: - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) - encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) - return encoder_hidden_states - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - The [`UNet2DConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): - Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed - through the `self.time_embedding` layer to obtain the timestep embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): - A tuple of tensors that if specified are added to the residuals of down unet blocks. - mid_block_additional_residual: (`torch.Tensor`, *optional*): - A tensor that if specified is added to the residual of the middle unet block. - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks for - example from ControlNet side model(s) - mid_block_additional_residual (`torch.Tensor`, *optional*): - additional residual to be added to UNet mid block output, for example from ControlNet side model - down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - for dim in sample.shape[-2:]: - if dim % default_overall_up_factor != 0: - # Forward upsample size to force interpolation output size. - forward_upsample_size = True - break - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - t_emb = self.get_time_embed(sample=sample, timestep=timestep) - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) - if class_emb is not None: - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb - - aug_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - if self.config.addition_embed_type == "image_hint": - aug_emb, hint = aug_emb - sample = torch.cat([sample, hint], dim=1) - emb = emb + aug_emb if aug_emb is not None else emb - - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - - # 2. pre-process - sample = self.conv_in(sample) - - # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - - # 3. down - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets - is_adapter = down_intrablock_additional_residuals is not None - # maintain backward compatibility for legacy usage, where - # T2I-Adapter and ControlNet both use down_block_additional_residuals arg - # but can only use one or the other - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate( - "T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ - and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ - for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False, - ) - down_intrablock_additional_residuals = down_block_additional_residuals - is_adapter = True - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) - - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) - - down_block_res_samples += res_samples - - if is_controlnet: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = self.mid_block(sample, emb) - - # To support T2I-Adapter-XL - if ( - is_adapter - and len(down_intrablock_additional_residuals) > 0 - and sample.shape == down_intrablock_additional_residuals[0].shape - ): - sample += down_intrablock_additional_residuals.pop(0) - - if is_controlnet: - sample = sample + mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - scale=lora_scale, - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) +class UNet2DConditionModel(UNet2DConditionModel): + deprecation_message = "Importing `UNet2DConditionModel` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel`, instead." + deprecate("UNet2DConditionModel", "0.29", deprecation_message) diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py new file mode 100644 index 000000000000..5b1418a608f0 --- /dev/null +++ b/src/diffusers/models/unets/__init__.py @@ -0,0 +1,16 @@ +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_1d import UNet1DModel + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel + from .unet_kandinsky3 import Kandinsky3UNet + from .unet_motion_model import MotionAdapter, UNetMotionModel + from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .uvit_2d import UVit2DModel + + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py new file mode 100644 index 000000000000..131f05f735cd --- /dev/null +++ b/src/diffusers/models/unets/unet_1d.py @@ -0,0 +1,255 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + The output of [`UNet1DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model was initially designed for. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet1DOutput, Tuple]: + r""" + The [`UNet1DModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/src/diffusers/models/unets/unet_1d_blocks.py b/src/diffusers/models/unets/unet_1d_blocks.py new file mode 100644 index 000000000000..3e128bf727c0 --- /dev/null +++ b/src/diffusers/models/unets/unet_1d_blocks.py @@ -0,0 +1,702 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..activations import get_activation +from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + conv_shortcut: bool = False, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_downsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, embed_dim: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + embed_dim: int, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity: Optional[str] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + self.final_conv1d_act = get_activation(act_fn) + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + get_activation(act_fn), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] +MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] +OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] +UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +) -> DownBlockType: + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +) -> UpBlockType: + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +) -> MidBlockType: + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +) -> Optional[OutBlockType]: + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) + return None diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py new file mode 100644 index 000000000000..0a4ede51a7fd --- /dev/null +++ b/src/diffusers/models/unets/unet_2d.py @@ -0,0 +1,346 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + The output of [`UNet2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + Tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + attn_norm_num_groups (`int`, *optional*, defaults to `None`): + If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the + given number of groups. If left as `None`, the group norm layer will only be created if + `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 32, + attn_norm_num_groups: Optional[int] = None, + norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + num_train_timesteps: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + elif time_embedding_type == "learned": + self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + r""" + The [`UNet2DModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + elif self.class_embedding is None and class_labels is not None: + raise ValueError("class_embedding needs to be initialized in order to use class conditioning") + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py new file mode 100644 index 000000000000..d933691d89d3 --- /dev/null +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -0,0 +1,3591 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...utils import is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..activations import get_activation +from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from ..dual_transformer_2d import DualTransformer2DModel +from ..normalization import AdaGroupNorm +from ..resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from ..transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + cross_attention_kwargs.update({"scale": lora_scale}) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) + else: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None, scale=scale) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb, scale=scale) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, scale) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb, scale) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb, scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb, scale=scale) + else: + hidden_states = upsampler(hidden_states, scale=scale) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, scale=scale) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + + cross_attention_kwargs = {"scale": scale} + hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb, scale=scale) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py similarity index 99% rename from src/diffusers/models/unet_2d_blocks_flax.py rename to src/diffusers/models/unets/unet_2d_blocks_flax.py index 8cf2f8eb24b4..447efcd8c138 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py @@ -15,8 +15,8 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxTransformer2DModel -from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D +from ..attention_flax import FlaxTransformer2DModel +from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py new file mode 100644 index 000000000000..87297b5b5d0b --- /dev/null +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -0,0 +1,1218 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ..activations import get_activation +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py similarity index 97% rename from src/diffusers/models/unet_2d_condition_flax.py rename to src/diffusers/models/unets/unet_2d_condition_flax.py index 13f53e16e7ac..0c17777f1a51 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -19,10 +19,10 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .modeling_flax_utils import FlaxModelMixin +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, @@ -342,14 +342,14 @@ def __call__( mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py similarity index 99% rename from src/diffusers/models/unet_3d_blocks.py rename to src/diffusers/models/unets/unet_3d_blocks.py index e9c505c347b0..6c20b1175349 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,19 +17,19 @@ import torch from torch import nn -from ..utils import is_torch_version -from ..utils.torch_utils import apply_freeu -from .attention import Attention -from .dual_transformer_2d import DualTransformer2DModel -from .resnet import ( +from ...utils import is_torch_version +from ...utils.torch_utils import apply_freeu +from ..attention import Attention +from ..dual_transformer_2d import DualTransformer2DModel +from ..resnet import ( Downsample2D, ResnetBlock2D, SpatioTemporalResBlock, TemporalConvLayer, Upsample2D, ) -from .transformer_2d import Transformer2DModel -from .transformer_temporal import ( +from ..transformer_2d import Transformer2DModel +from ..transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py similarity index 96% rename from src/diffusers/models/unet_3d_condition.py rename to src/diffusers/models/unets/unet_3d_condition.py index fc8695e064b5..b29e2c270ba9 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -20,20 +20,20 @@ import torch.nn as nn import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, deprecate, logging -from .activations import get_activation -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, deprecate, logging +from ..activations import get_activation +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTemporalModel +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, @@ -284,7 +284,7 @@ def __init__( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -308,7 +308,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. @@ -374,7 +374,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -449,7 +449,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -469,7 +469,7 @@ def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -494,7 +494,7 @@ def enable_freeu(self, s1, s2, b1, b2): setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} @@ -503,7 +503,7 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora def unload_lora(self): """Unloads LoRA weights.""" deprecate( diff --git a/src/diffusers/models/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py similarity index 98% rename from src/diffusers/models/unet_kandinsky3.py rename to src/diffusers/models/unets/unet_kandinsky3.py index eef3287e5d99..b52aace419f0 100644 --- a/src/diffusers/models/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -19,11 +19,11 @@ import torch.utils.checkpoint from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from .attention_processor import Attention, AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py similarity index 97% rename from src/diffusers/models/unet_motion_model.py rename to src/diffusers/models/unets/unet_motion_model.py index b5f0302b4a43..9654ae508215 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -17,19 +17,19 @@ import torch.nn as nn import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import logging -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTemporalModel +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformer_temporal import TransformerTemporalModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( @@ -524,7 +524,7 @@ def save_motion_modules( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -548,7 +548,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -583,7 +583,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward @@ -613,7 +613,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): @@ -625,7 +625,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ Disables custom attention processors and sets the default attention implementation. @@ -645,7 +645,7 @@ def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -670,7 +670,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self) -> None: """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} diff --git a/src/diffusers/models/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py similarity index 97% rename from src/diffusers/models/unet_spatio_temporal_condition.py rename to src/diffusers/models/unets/unet_spatio_temporal_condition.py index 8d0d3e61d879..39a8009d5af9 100644 --- a/src/diffusers/models/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -4,12 +4,12 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging -from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block @@ -323,7 +323,7 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py similarity index 95% rename from src/diffusers/models/uvit_2d.py rename to src/diffusers/models/unets/uvit_2d.py index c0e224562cf2..492c41e4cad4 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -20,20 +20,20 @@ from torch import nn from torch.utils.checkpoint import checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin -from .attention import BasicTransformerBlock, SkipFFTransformerBlock -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ..attention import BasicTransformerBlock, SkipFFTransformerBlock +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, get_timestep_embedding -from .modeling_utils import ModelMixin -from .normalization import GlobalResponseNorm, RMSNorm -from .resnet import Downsample2D, Upsample2D +from ..embeddings import TimestepEmbedding, get_timestep_embedding +from ..modeling_utils import ModelMixin +from ..normalization import GlobalResponseNorm, RMSNorm +from ..resnet import Downsample2D, Upsample2D class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -213,7 +213,7 @@ def layer_(*args): return logits @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -237,7 +237,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -272,7 +272,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 13163f96e059..9d11f78ceee2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -109,7 +109,10 @@ ] ) _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] - _import_structure["animatediff"] = ["AnimateDiffPipeline"] + _import_structure["animatediff"] = [ + "AnimateDiffPipeline", + "AnimateDiffVideoToVideoPipeline", + ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -341,7 +344,7 @@ from ..utils.dummy_torch_and_transformers_objects import * else: from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline - from .animatediff import AnimateDiffPipeline + from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline from .audioldm import AudioLDMPipeline from .audioldm2 import ( AudioLDM2Pipeline, diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py index 503352fec865..35b99a76fd21 100644 --- a/src/diffusers/pipelines/animatediff/__init__.py +++ b/src/diffusers/pipelines/animatediff/__init__.py @@ -11,7 +11,7 @@ _dummy_objects = {} -_import_structure = {} +_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -21,7 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"] + _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] + _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -31,7 +32,9 @@ from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput + from .pipeline_animatediff import AnimateDiffPipeline + from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline + from .pipeline_output import AnimateDiffPipelineOutput else: import sys diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 0fb4637dab7f..ee1062ee81ff 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -14,7 +14,6 @@ import inspect import math -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -26,7 +25,7 @@ from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder -from ...models.unet_motion_model import MotionAdapter +from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -37,7 +36,6 @@ ) from ...utils import ( USE_PEFT_BACKEND, - BaseOutput, deprecate, logging, replace_example_docstring, @@ -46,6 +44,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -67,10 +66,7 @@ """ -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -79,6 +75,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + return outputs @@ -147,11 +152,6 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor return x_mixed -@dataclass -class AnimateDiffPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - - class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -805,11 +805,7 @@ def _retrieve_video_frames(self, latents, output_type, return_dict): return AnimateDiffPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py new file mode 100644 index 000000000000..3d01009cbac7 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -0,0 +1,969 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import imageio + >>> import requests + >>> import torch + >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter + >>> from diffusers.utils import export_to_gif + >>> from io import BytesIO + >>> from PIL import Image + + >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) + >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda") + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace") + + >>> def load_video(file_path: str): + ... images = [] + ... + ... if file_path.startswith(('http://', 'https://')): + ... # If the file_path is a URL + ... response = requests.get(file_path) + ... response.raise_for_status() + ... content = BytesIO(response.content) + ... vid = imageio.get_reader(content) + ... else: + ... # Assuming it's a local file path + ... vid = imageio.get_reader(file_path) + ... + ... for frame in vid: + ... pil_image = Image.fromarray(frame) + ... images.append(pil_image) + ... + ... return images + + >>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif") + >>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5) + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): + r""" + Pipeline for video-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + video, + height, + width, + num_channels_latents, + batch_size, + timestep, + dtype, + device, + generator, + latents=None, + ): + # video must be a list of list of images + # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video + # as a list of images + if not isinstance(video[0], list): + video = [video] + if latents is None: + video = torch.cat( + [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0 + ) + video = video.to(device=device, dtype=dtype) + num_frames = video.shape[1] + else: + num_frames = latents.shape[2] + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video + ] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + latents = latents.to(device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + video: List[List[PipelineImageInput]] = None, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or + `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + latents=latents, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + # 9. Post-processing + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/animatediff/pipeline_output.py b/src/diffusers/pipelines/animatediff/pipeline_output.py new file mode 100644 index 000000000000..79a399fbc3a6 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_output.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class AnimateDiffPipelineOutput(BaseOutput): + r""" + Output class for AnimateDiff pipelines. + + Args: + frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`): + List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape + `(batch_size, num_frames, height, width, num_channels)`. + """ + + frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray] diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index d39b2c99ddd0..147dd7a58e7b 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -36,8 +36,8 @@ from ...models.modeling_utils import ModelMixin from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from ...models.transformer_2d import Transformer2DModel -from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D -from ...models.unet_2d_condition import UNet2DConditionOutput +from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D +from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...utils import BaseOutput, is_torch_version, logging @@ -513,7 +513,7 @@ def __init__( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -537,7 +537,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -572,7 +572,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -588,7 +588,7 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -654,7 +654,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -687,7 +687,7 @@ def forward( `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. @@ -700,8 +700,8 @@ def forward( which adds large negative values to the attention scores corresponding to "discard" tokens. Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 02e515c0ff55..78793c2866f4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1404,11 +1404,6 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 379f4a036f69..20884a15da4d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -33,7 +33,7 @@ ) from ....models.resnet import ResnetBlockCondNorm2D from ....models.transformer_2d import Transformer2DModel -from ....models.unet_2d_condition import UNet2DConditionOutput +from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu @@ -268,6 +268,7 @@ def forward( return objs +# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample @@ -1095,7 +1096,7 @@ def forward( `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. @@ -1111,8 +1112,8 @@ def forward( additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. @@ -1785,7 +1786,7 @@ def custom_forward(*inputs): return hidden_states, output_states -# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class UpBlockFlat(nn.Module): def __init__( self, @@ -1896,7 +1897,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class CrossAttnUpBlockFlat(nn.Module): def __init__( self, @@ -2070,7 +2071,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlat(nn.Module): """ A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. @@ -2226,7 +2227,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, @@ -2373,7 +2374,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index de5dea679ee9..4d0bc8b13a92 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -351,7 +351,7 @@ def get_class_obj_and_candidates( def _get_pipeline_class( class_obj, - config, + config=None, load_connected_pipeline=False, custom_pipeline=None, repo_id=None, @@ -389,7 +389,12 @@ def _get_pipeline_class( return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) - class_name = config["_class_name"] + class_name = class_name or config["_class_name"] + if not class_name: + raise ValueError( + "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." + ) + class_name = class_name[4:] if class_name.startswith("Flax") else class_name pipeline_cls = getattr(diffusers_module, class_name) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 56f72691303d..9bfced85955b 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -40,10 +40,8 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) if output_type == "np": - return np.stack(outputs) + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") return outputs diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index ab5286a5e5b4..6e5db85c9e66 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -58,22 +59,26 @@ """ -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -122,6 +127,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -717,11 +723,7 @@ def __call__( return TextToVideoSDPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index b19ccee660e2..c781e490caae 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -93,22 +94,26 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs def preprocess_video(video): @@ -198,6 +203,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -812,12 +818,11 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - video_tensor = self.decode_latents(latents) + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor) + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 6e97e0279350..561d8344e746 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -752,7 +752,7 @@ def forward( cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 8b494fa32476..c752cba606a4 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -66,7 +66,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro self.set_default_attn_processor() @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -90,7 +90,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -125,7 +125,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 468476e0c748..667f1fe5e2fd 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -28,6 +28,7 @@ MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, + SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND, WEIGHTS_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 8850da073e95..a397e8cf86d3 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -31,6 +31,7 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +SAFETENSORS_FILE_EXTENSION = "safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2eb9599658d9..e0d5c77d0e8c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AudioLDM2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index d762f015a7bc..6f311b957abf 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -244,15 +244,15 @@ def _get_model_file( pretrained_model_name_or_path: Union[str, Path], *, weights_name: str, - subfolder: Optional[str], - cache_dir: Optional[str], - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Optional[str], - user_agent: Union[Dict, str, None], - revision: Optional[str], + subfolder: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + resume_download: bool = False, + local_files_only: bool = False, + token: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + revision: Optional[str] = None, commit_hash: Optional[str] = None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) diff --git a/tests/models/test_unet_2d_blocks.py b/tests/models/test_unet_2d_blocks.py index d714b9384860..ef77df8abdfb 100644 --- a/tests/models/test_unet_2d_blocks.py +++ b/tests/models/test_unet_2d_blocks.py @@ -14,7 +14,7 @@ # limitations under the License. import unittest -from diffusers.models.unet_2d_blocks import * # noqa F403 +from diffusers.models.unets.unet_2d_blocks import * # noqa F403 from diffusers.utils.testing_utils import torch_device from .test_unet_blocks_common import UNetBlockTesterMixin diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 44cb730a9501..80a8fd19f5a0 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -262,7 +262,7 @@ def test_free_init(self): sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() self.assertGreater( - sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results" + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" ) self.assertLess( max_diff_disabled, diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py new file mode 100644 index 000000000000..3226bdb3ca6e --- /dev/null +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -0,0 +1,269 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AnimateDiffVideoToVideoPipeline, + AutoencoderKL, + DDIMScheduler, + MotionAdapter, + UNet2DConditionModel, + UNetMotionModel, +) +from diffusers.utils import is_xformers_available, logging +from diffusers.utils.testing_utils import torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AnimateDiffVideoToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + motion_adapter = MotionAdapter( + block_out_channels=(32, 64), + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video_height = 32 + video_width = 32 + video_num_frames = 2 + video = [Image.new("RGB", (video_width, video_height))] * video_num_frames + + inputs = { + "video": video, + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 7.5, + "output_type": "pt", + } + return inputs + + def test_motion_unet_loading(self): + components = self.get_dummy_components() + pipe = AnimateDiffVideoToVideoPipeline(**components) + + assert isinstance(pipe.unet, UNetMotionModel) + + @unittest.skip("Attention slicing is not enabled in this pipeline") + def test_attention_slicing_forward_pass(self): + pass + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs) + output_batch = pipe(**batched_inputs) + + assert output_batch[0].shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device) + pipe(**inputs) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs).frames[0] + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) + + pipe.enable_xformers_memory_efficient_attention() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs).frames[0] + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index c034a9b68bd8..05f3ade5089f 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -37,6 +37,7 @@ enable_full_determinism, load_image, load_numpy, + numpy_cosine_similarity_distance, require_python39_or_higher, require_torch_2, require_torch_gpu, @@ -1022,39 +1023,49 @@ def test_v11_shuffle_global_pool_conditions(self): def test_load_local(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") - pipe_1 = StableDiffusionControlNetPipeline.from_pretrained( + pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() controlnet = ControlNetModel.from_single_file( "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" ) - pipe_2 = StableDiffusionControlNetPipeline.from_single_file( + pipe_sf = StableDiffusionControlNetPipeline.from_single_file( "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) - pipes = [pipe_1, pipe_2] - images = [] - - for pipe in pipes: - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) + pipe_sf.unet.set_default_attn_processor() + pipe_sf.enable_model_cpu_offload() - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + prompt = "bird" - output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) - images.append(output.images[0]) + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe( + prompt, + image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - del pipe - gc.collect() - torch.cuda.empty_cache() + generator = torch.Generator(device="cpu").manual_seed(0) + output_sf = pipe_sf( + prompt, + image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) + assert max_diff < 1e-3 @slow diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index b4b67e6476f6..939eb34ef0c6 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -39,6 +39,7 @@ enable_full_determinism, floats_tensor, load_numpy, + numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device, @@ -421,46 +422,53 @@ def test_canny(self): def test_load_local(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") - pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() controlnet = ControlNetModel.from_single_file( "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" ) - pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file( + pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file( "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) + pipe_sf.unet.set_default_attn_processor() + pipe_sf.enable_model_cpu_offload() + control_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" ).resize((512, 512)) image = load_image( "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" ).resize((512, 512)) + prompt = "bird" - pipes = [pipe_1, pipe_2] - images = [] - for pipe in pipes: - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - output = pipe( - prompt, - image=image, - control_image=control_image, - strength=0.9, - generator=generator, - output_type="np", - num_inference_steps=3, - ) - images.append(output.images[0]) + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe( + prompt, + image=image, + control_image=control_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - del pipe - gc.collect() - torch.cuda.empty_cache() + generator = torch.Generator(device="cpu").manual_seed(0) + output_sf = pipe_sf( + prompt, + image=image, + control_image=control_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) + assert max_diff < 1e-3 diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 7c3371c197d4..7db336df9448 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -569,6 +569,7 @@ def test_load_local(self): "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) control_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" @@ -605,4 +606,5 @@ def test_load_local(self): gc.collect() torch.cuda.empty_cache() - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten()) + assert max_diff < 1e-3 diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index ba129e763c22..7ccf724361d3 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -28,10 +28,17 @@ StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) -from diffusers.models.unet_2d_blocks import UNetMidBlock2D +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import ( @@ -819,6 +826,41 @@ def test_depth(self): expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853]) assert np.allclose(original_image, expected_image, atol=1e-04) + def test_download_ckpt_diff_format_is_same(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16) + single_file_url = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file( + single_file_url, controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe_single_file.unet.set_default_attn_processor() + pipe_single_file.enable_model_cpu_offload() + pipe_single_file.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + single_file_images = pipe_single_file( + prompt, image=image, generator=generator, output_type="np", num_inference_steps=2 + ).images + + generator = torch.Generator(device="cpu").manual_seed(0) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images + + assert images[0].shape == (512, 512, 3) + assert single_file_images[0].shape == (512, 512, 3) + + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten()) + assert max_diff < 5e-2 + class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests): def test_controlnet_sdxl_guess(self): diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index f5be787656c7..4e2c4dcdd9cb 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -125,3 +125,5 @@ TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) + +VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"]) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 8854b482dec7..bfdcec09f211 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -836,7 +836,10 @@ def test_stable_diffusion_lms(self): def test_stable_diffusion_dpm(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) - sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config( + sd_pipe.scheduler.config, + final_sigmas_type="sigma_min", + ) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -1243,9 +1246,12 @@ def test_download_from_hub(self): assert image_out.shape == (512, 512, 3) def test_download_local(self): - filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt") + ckpt_filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt") + config_filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-inference.yaml") - pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16) + pipe = StableDiffusionPipeline.from_single_file( + ckpt_filename, config_files={"v1": config_filename}, torch_dtype=torch.float16 + ) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.to("cuda") @@ -1256,13 +1262,13 @@ def test_download_local(self): def test_download_ckpt_diff_format_is_same(self): ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt" - pipe = StableDiffusionPipeline.from_single_file(ckpt_path) - pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) - pipe.unet.set_attn_processor(AttnProcessor()) - pipe.to("cuda") + sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path) + sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config) + sf_pipe.unet.set_attn_processor(AttnProcessor()) + sf_pipe.to("cuda") generator = torch.Generator(device="cpu").manual_seed(0) - image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -1272,7 +1278,7 @@ def test_download_ckpt_diff_format_is_same(self): generator = torch.Generator(device="cpu").manual_seed(0) image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] - max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) assert max_diff < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index fe664b21e271..7ec6964b0688 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -43,6 +43,7 @@ load_image, load_numpy, nightly, + numpy_cosine_similarity_distance, require_python39_or_higher, require_torch_2, require_torch_gpu, @@ -771,7 +772,9 @@ def test_download_ckpt_diff_format_is_same(self): inputs["num_inference_steps"] = 5 image = pipe(**inputs).images[0] - assert np.max(np.abs(image - image_ckpt)) < 5e-4 + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + + assert max_diff < 1e-4 @slow diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 81e85efe953c..b486bd69d9d3 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -627,7 +627,9 @@ def test_stable_diffusion_euler(self): def test_stable_diffusion_dpm(self): sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device) - sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config( + sd_pipe.scheduler.config, final_sigmas_type="sigma_min" + ) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_inputs(torch_device) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 09034789c61c..0305fa52426a 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -323,7 +323,9 @@ def test_stable_diffusion_v_pred_dpm(self): TODO: update this test after making DPM compatible with V-prediction! """ scheduler = DPMSolverMultistepScheduler.from_pretrained( - "stabilityai/stable-diffusion-2", subfolder="scheduler" + "stabilityai/stable-diffusion-2", + subfolder="scheduler", + final_sigmas_type="sigma_min", ) sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler) sd_pipe = sd_pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 80bff3663a98..d5ad5ee0d72b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import tempfile import unittest @@ -1024,6 +1025,11 @@ def callback_on_step_end(pipe, i, t, callback_kwargs): @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + def test_stable_diffusion_lcm(self): torch.manual_seed(0) unet = UNet2DConditionModel.from_pretrained( @@ -1049,3 +1055,30 @@ def test_stable_diffusion_lcm(self): max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten()) assert max_diff < 1e-2 + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + + pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + + pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + + assert max_diff < 6e-3 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 2296e092e7bd..1513525df3b5 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -699,3 +699,40 @@ def test_canny_lora(self): image_slice = images[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226]) assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16) + prompt = "toy" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png" + ) + pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file( + ckpt_path, + adapter=adapter, + torch_dtype=torch.float16, + ) + pipe_single_file.enable_model_cpu_offload() + pipe_single_file.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + images_single_file = pipe_single_file( + prompt, image=image, generator=generator, output_type="np", num_inference_steps=3 + ).images + + generator = torch.Generator(device="cpu").manual_seed(0) + pipe = StableDiffusionXLAdapterPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + adapter=adapter, + torch_dtype=torch.float16, + ) + pipe.enable_model_cpu_offload() + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images_single_file[0].shape == (768, 512, 3) + assert images[0].shape == (768, 512, 3) + + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images_single_file[0].flatten()) + assert max_diff < 5e-3 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 0a7d4d0de4ca..e505630cf6e1 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import random import unittest @@ -31,15 +32,19 @@ from diffusers import ( AutoencoderKL, AutoencoderTiny, + DDIMScheduler, EulerDiscreteScheduler, LCMScheduler, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, ) +from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + numpy_cosine_similarity_distance, require_torch_gpu, + slow, torch_device, ) @@ -763,3 +768,44 @@ def test_inference_batch_single_identical(self): def test_save_load_optional_components(self): self._test_save_load_optional_components() + + +@slow +class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors" + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_img2img/sketch-mountains-input.png" + ) + + pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe( + prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np" + ).images[0] + + pipe_single_file = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16) + pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config) + pipe_single_file.unet.set_default_attn_processor() + pipe_single_file.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image_single_file = pipe_single_file( + prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np" + ).images[0] + + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) + + assert max_diff < 5e-2 diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index e9f435239c92..2f48dc5c318a 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, load_numpy, + numpy_cosine_similarity_distance, require_torch_gpu, skip_mps, slow, @@ -141,10 +142,11 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) - expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0]) + image_slice = frames[0][0][-3:, -3:, -1] + + assert frames[0][0].shape == (32, 32, 3) + expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -183,7 +185,7 @@ def test_progress_bar(self): class TextToVideoSDPipelineSlowTests(unittest.TestCase): def test_two_step_model(self): expected_video = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy" ) pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") @@ -192,10 +194,8 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" generator = torch.Generator(device="cpu").manual_seed(0) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - - assert np.abs(expected_video - video).mean() < 5e-2 + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4 def test_two_step_model_with_freeu(self): expected_video = [] @@ -207,10 +207,9 @@ def test_two_step_model_with_freeu(self): generator = torch.Generator(device="cpu").manual_seed(0) pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - video = video[0, 0, -3:, -3:, -1].flatten() + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + video = video_frames[0, 0, -3:, -3:, -1].flatten() - expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177] + expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475] assert np.abs(expected_video - video).mean() < 5e-2 diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py index 1785eb967f16..07d48eba5574 100644 --- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py @@ -157,10 +157,10 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] + image_slice = frames[0][0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) - expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0]) + assert frames[0][0].shape == (32, 32, 3) + expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -214,9 +214,11 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" - video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames - - expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242]) - output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:] + generator = torch.Generator(device="cpu").manual_seed(0) + video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames - assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2 + expected_array = np.array( + [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422] + ) + output_array = video_frames[0, 0, :3, :3, 0].flatten() + assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3