diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml
index 7eb208505e75..6c5d0f1c248b 100644
--- a/.github/workflows/pr_test_fetcher.yml
+++ b/.github/workflows/pr_test_fetcher.yml
@@ -35,14 +35,15 @@ jobs:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
- fetch-depth: 2
+ fetch-depth: 0
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
- python -m pip install -e .
+ python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
+ echo $(git --version)
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
@@ -110,7 +111,7 @@ jobs:
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
- cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
+ cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 7cb7f1a0ced2..d24b049d3b39 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
-Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
+Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
diff --git a/Makefile b/Makefile
index 1b81f551d36d..70bfced8c7b4 100644
--- a/Makefile
+++ b/Makefile
@@ -41,7 +41,7 @@ repo-consistency:
quality:
ruff check $(check_dirs) setup.py
- ruff format --check $(check_dirs) setup.py
+ ruff format --check $(check_dirs) setup.py
python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing
diff --git a/docs/source/en/api/pipelines/value_guided_sampling.md b/docs/source/en/api/pipelines/value_guided_sampling.md
index 01b7717f49f8..3c7e4977a68a 100644
--- a/docs/source/en/api/pipelines/value_guided_sampling.md
+++ b/docs/source/en/api/pipelines/value_guided_sampling.md
@@ -24,7 +24,7 @@ The abstract from the paper is:
*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
-You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
+You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).
The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md
index dc942a24c42e..d2b45cac7362 100644
--- a/docs/source/en/conceptual/contribution.md
+++ b/docs/source/en/conceptual/contribution.md
@@ -297,17 +297,37 @@ if you don't know yet what specific component you would like to add:
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
-Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
-we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
-as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
-open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
-pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
+Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
+as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
-Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
-original author directly on the PR so that they can follow the progress and potentially help with questions.
+Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
+#### Copied from mechanism
+
+A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.
+
+For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.
+
+```py
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
+class AltDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Alt Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ nsfw_content_detected (`List[bool]`)
+ List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
+ `None` if safety checking could not be performed.
+ """
+```
+
+To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.
+
## How to write a good issue
**The better your issue is written, the higher the chances that it will be quickly resolved.**
diff --git a/docs/source/en/using-diffusers/kandinsky.md b/docs/source/en/using-diffusers/kandinsky.md
index 05be2e1ee289..0fbec32a5296 100644
--- a/docs/source/en/using-diffusers/kandinsky.md
+++ b/docs/source/en/using-diffusers/kandinsky.md
@@ -20,6 +20,8 @@ The Kandinsky models are a series of multilingual text-to-image generation model
[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
+[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.
+
This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
Before you begin, make sure you have the following libraries installed:
@@ -33,6 +35,10 @@ Before you begin, make sure you have the following libraries installed:
Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
+
+
+Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
+
## Text-to-image
@@ -91,6 +97,23 @@ image
+
+
+
+Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:
+
+```py
+from diffusers import Kandinsky3Pipeline
+import torch
+
+pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
+image = pipeline(prompt).images[0]
+image
+```
+
@@ -161,6 +184,20 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
```
+
+
+
+Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:
+
+```py
+from diffusers import Kandinsky3Img2ImgPipeline
+from diffusers.utils import load_image
+import torch
+
+pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+```
+
@@ -218,6 +255,14 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
+
+
+
+```py
+image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]
+image
+```
+
diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md
index 4fdb2608aa76..7fd29284cbd0 100644
--- a/docs/source/en/using-diffusers/svd.md
+++ b/docs/source/en/using-diffusers/svd.md
@@ -53,8 +53,9 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
-
-
+
+
+
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index f032634a11f0..df5477d0d643 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -54,7 +54,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
-from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
+from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
@@ -62,16 +62,51 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
+# TODO: This function should be removed once training scripts are rewritten in PEFT
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ def text_encoder_attn_modules(text_encoder):
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
+
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+
+ return attn_modules
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
+ train_text_encoder_ti=False,
+ token_abstraction_dict=None,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
@@ -83,10 +118,33 @@ def save_model_card(
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
- url: >-
+ url:
"image_{i}.png"
"""
+ trigger_str = f"You should use {instance_prompt} to trigger the image generation."
+ diffusers_imports_pivotal = ""
+ diffusers_example_pivotal = ""
+ if train_text_encoder_ti:
+ trigger_str = (
+ "To trigger image generation of trained concept(or concepts) replace each concept identifier "
+ "in you prompt with the new inserted tokens:\n"
+ )
+ diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+ """
+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
+state_dict = load_file(embedding_path)
+pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
+pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
+ """
+ if token_abstraction_dict:
+ for key, value in token_abstraction_dict.items():
+ tokens = "".join(value)
+ trigger_str += f"""
+to trigger concept `{key}` โ use `{tokens}` in your prompt \n
+"""
+
yaml = f"""
---
tags:
@@ -96,9 +154,7 @@ def save_model_card(
- diffusers
- lora
- template:sd-lora
-widget:
{img_str}
----
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
@@ -112,16 +168,35 @@ def save_model_card(
## Model description
-These are {repo_id} LoRA adaption weights for {base_model}.
+### These are {repo_id} LoRA adaption weights for {base_model}.
+
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Pivotal tuning was enabled: {train_text_encoder_ti}.
+
Special VAE used for training: {vae_path}.
## Trigger words
-You should use {instance_prompt} to trigger the image generation.
+{trigger_str}
-## Download model
+## Use it with the [๐งจ diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+{diffusers_imports_pivotal}
+pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+{diffusers_example_pivotal}
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
Weights for this model are available in Safetensors format.
@@ -174,6 +249,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
parser.add_argument(
"--dataset_name",
type=str,
@@ -181,20 +262,26 @@ def parse_args(input_args=None):
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that ๐ค Datasets can understand."
+ " or to a folder containing files that ๐ค Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
+ "datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
+ "https://huggingface.co/docs/datasets/image_dataset for more information"
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
+ help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
+ "if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
+ "None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
- help=("A folder containing the training data. "),
+ help="A path to local folder containing the training data of instance images. Specify this arg instead of "
+ "--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
+ "--dataset_name instead.",
)
parser.add_argument(
@@ -237,15 +324,18 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--token_abstraction",
+ type=str,
default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
- "captions - e.g. TOK",
+ "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
+ "'TOK,TOK2,TOK3' etc.",
)
parser.add_argument(
"--num_new_tokens_per_abstraction",
+ type=int,
default=2,
- help="number of new tokens inserted to the tokenizers per token_abstraction value when "
+ help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
"tokens - ",
)
@@ -455,7 +545,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--train_text_encoder_frac",
type=float,
- default=0.5,
+ default=1.0,
help=("The percentage of epochs to perform text encoder tuning"),
)
@@ -488,7 +578,7 @@ def parse_args(input_args=None):
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
- "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
)
parser.add_argument(
@@ -596,17 +686,6 @@ def parse_args(input_args=None):
"inversion training check `--train_text_encoder_ti`"
)
- if args.train_text_encoder_ti:
- if isinstance(args.token_abstraction, str):
- args.token_abstraction = [args.token_abstraction]
- elif isinstance(args.token_abstraction, List):
- args.token_abstraction = args.token_abstraction
- else:
- raise ValueError(
- f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
- f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
- )
-
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -679,12 +758,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
+ # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
+ idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
- tensors[f"text_encoders_{idx}"] = new_token_embeddings
+
+ # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
+ # text_encoder 1) to keep compatible with the ecosystem.
+ # Note: When loading with diffusers, any name can work - simply specify in inference
+ tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
+ # tensors[f"text_encoders_{idx}"] = new_token_embeddings
save_file(tensors, file_path)
@@ -696,19 +782,6 @@ def dtype(self):
def device(self):
return self.text_encoders[0].device
- # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
- # # Assuming new tokens are of the format
- # self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])]
- # special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
- # tokenizer.add_special_tokens(special_tokens_dict)
- # text_encoder.resize_token_embeddings(len(tokenizer))
- #
- # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
- # assert self.train_ids is not None, "New tokens could not be converted to IDs."
- # text_encoder.text_model.embeddings.token_embedding.weight.data[
- # self.train_ids
- # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
-
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
@@ -730,15 +803,6 @@ def retract_embeddings(self):
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
- # def load_embeddings(self, file_path: str):
- # with safe_open(file_path, framework="pt", device=self.device.type) as f:
- # for idx in range(len(self.text_encoders)):
- # text_encoder = self.text_encoders[idx]
- # tokenizer = self.tokenizers[idx]
- #
- # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
- # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
-
class DreamBoothDataset(Dataset):
"""
@@ -751,6 +815,12 @@ def __init__(
instance_data_root,
instance_prompt,
class_prompt,
+ dataset_name,
+ dataset_config_name,
+ cache_dir,
+ image_column,
+ caption_column,
+ train_text_encoder_ti,
class_data_root=None,
class_num=None,
token_abstraction_dict=None, # token mapping for textual inversion
@@ -765,10 +835,10 @@ def __init__(
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict
-
+ self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
- if args.dataset_name is not None:
+ if dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
@@ -781,26 +851,25 @@ def __init__(
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
- args.dataset_name,
- args.dataset_config_name,
- cache_dir=args.cache_dir,
+ dataset_name,
+ dataset_config_name,
+ cache_dir=cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
- if args.image_column is None:
+ if image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
- image_column = args.image_column
if image_column not in column_names:
raise ValueError(
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
- if args.caption_column is None:
+ if caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
@@ -808,11 +877,11 @@ def __init__(
)
self.custom_instance_prompts = None
else:
- if args.caption_column not in column_names:
+ if caption_column not in column_names:
raise ValueError(
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
- custom_instance_prompts = dataset["train"][args.caption_column]
+ custom_instance_prompts = dataset["train"][caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
@@ -867,7 +936,7 @@ def __getitem__(self, index):
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
- if args.train_text_encoder_ti:
+ if self.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement))
@@ -1021,6 +1090,7 @@ def main(args):
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
+ variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -1052,17 +1122,25 @@ def main(args):
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
+ model_id = args.hub_model_id or Path(args.output_dir).name
+ repo_id = None
if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
+ repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ variant=args.variant,
+ use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ variant=args.variant,
+ use_fast=False,
)
# import correct text encoder classes
@@ -1076,10 +1154,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -1087,16 +1165,24 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
- vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.train_text_encoder_ti:
+ # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
+ # TOK2" -> ["TOK", "TOK2"] etc.
+ token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
+ logger.info(f"list of token identifiers: {token_abstraction_list}")
+
token_abstraction_dict = {}
token_idx = 0
- for i, token in enumerate(args.token_abstraction):
+ for i, token in enumerate(token_abstraction_list):
token_abstraction_dict[token] = [
f"" for j in range(args.num_new_tokens_per_abstraction)
]
@@ -1216,6 +1302,8 @@ def main(args):
text_lora_parameters_one = []
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1223,6 +1311,8 @@ def main(args):
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1309,12 +1399,16 @@ def load_model_hook(models, input_dir):
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
- "weight_decay": args.adam_weight_decay_text_encoder,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
- "weight_decay": args.adam_weight_decay_text_encoder,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
@@ -1399,6 +1493,12 @@ def load_model_hook(models, input_dir):
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
+ dataset_name=args.dataset_name,
+ dataset_config_name=args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ image_column=args.image_column,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images,
@@ -1494,6 +1594,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ if args.train_text_encoder_ti and args.validation_prompt:
+ # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
+ args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
+ print("validation prompt:", args.validation_prompt)
+
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1593,27 +1699,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
- params_to_optimize = params_to_optimize[:1]
- # reinitializing the optimizer to optimize only on unet params
- if args.optimizer.lower() == "prodigy":
- optimizer = optimizer_class(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- beta3=args.prodigy_beta3,
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- decouple=args.prodigy_decouple,
- use_bias_correction=args.prodigy_use_bias_correction,
- safeguard_warmup=args.prodigy_safeguard_warmup,
- )
- else: # AdamW or 8-bit-AdamW
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
+ # re setting the optimizer to optimize only on unet params
+ optimizer.param_groups[1]["lr"] = 0.0
+ optimizer.param_groups[2]["lr"] = 0.0
+
else:
# still optimizng the text encoder
text_encoder_one.train()
@@ -1628,7 +1717,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
- print(prompts)
+ # print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
@@ -1801,12 +1890,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
f" {args.validation_prompt}."
)
# create pipeline
- if not args.train_text_encoder:
+ if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1815,6 +1910,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
+ variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1892,10 +1988,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
+ variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
- args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -1938,21 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
}
)
- if args.push_to_hub:
- if args.train_text_encoder_ti:
- embedding_handler.save_embeddings(
- f"{args.output_dir}/embeddings.safetensors",
- )
- save_model_card(
- repo_id,
- images=images,
- base_model=args.pretrained_model_name_or_path,
- train_text_encoder=args.train_text_encoder,
- instance_prompt=args.instance_prompt,
- validation_prompt=args.validation_prompt,
- repo_folder=args.output_dir,
- vae_path=args.pretrained_vae_model_name_or_path,
+ if args.train_text_encoder_ti:
+ embedding_handler.save_embeddings(
+ f"{args.output_dir}/embeddings.safetensors",
)
+ save_model_card(
+ model_id if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ token_abstraction_dict=train_dataset.token_abstraction_dict,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
diff --git a/examples/community/README.md b/examples/community/README.md
index aee6ffee09c7..98780edeccf7 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -48,8 +48,9 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
+| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
-|
+| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -77,6 +78,7 @@ from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"longlian/lmd_plus",
custom_pipeline="llm_grounded_diffusion",
+ custom_revision="main",
variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
@@ -2524,6 +2526,181 @@ images[0].save("controlnet_and_adapter_inpaint.png")
```
+### Regional Prompting Pipeline
+This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers.
+This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
+
+![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png)
+
+### Usage
+### Sample Code
+```
+from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
+pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
+
+rp_args = {
+ "mode":"rows",
+ "div": "1;1;1"
+}
+
+prompt ="""
+green hair twintail BREAK
+red blouse BREAK
+blue skirt
+"""
+
+images = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=7.5,
+ height = 768,
+ width = 512,
+ num_inference_steps =20,
+ num_images_per_prompt = 1,
+ rp_args = rp_args
+ ).images
+
+time = time.strftime(r"%Y%m%d%H%M%S")
+i = 1
+for image in images:
+ i += 1
+ fileName = f'img-{time}-{i+1}.png'
+ image.save(fileName)
+```
+### Cols, Rows mode
+In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
+
+In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
+![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png)
+```
+green hair twintail BREAK
+red blouse BREAK
+blue skirt
+```
+
+### 2-Dimentional division
+The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
+```
+rp_args = {
+ "mode":"rows",
+ "div": "1,2,1,1;2,4,6"
+}
+
+prompt ="""
+blue sky BREAK
+green hair BREAK
+book shelf BREAK
+terrarium on desk BREAK
+orange dress and sofa
+"""
+```
+![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png)
+
+### Prompt Mode
+There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
+For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
+### syntax
+```
+baseprompt target1 target2 BREAK
+effect1, target1 BREAK
+effect2 ,target2
+```
+
+First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
+
+```
+target2 baseprompt target1 BREAK
+effect1, target1 BREAK
+effect2 ,target2
+```
+is also effective.
+
+### Sample
+In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
+```
+rp_args = {
+ "mode":"prompt-ex",
+ "save_mask":True,
+ "th": "0.4,0.6,0.6",
+}
+
+prompt ="""
+a girl in street with shirt, tie, skirt BREAK
+red, shirt BREAK
+green, tie BREAK
+blue , skirt
+"""
+```
+![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png)
+### threshold
+The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
+
+```
+a lady ,hair, face BREAK
+red, hair BREAK
+tanned ,face
+```
+`threshold : 0.4,0.6`
+If only one input is given for multiple regions, they are all assumed to be the same value.
+
+### Prompt and Prompt-EX
+The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
+
+### Accuracy
+In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
+```
+girl hair twintail frills,ribbons, dress, face BREAK
+girl, ,face
+```
+
+### Mask
+When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
+
+
+### Use common prompt
+You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows:
+```
+best quality, 3persons in garden, ADDCOMM
+a girl white dress BREAK
+a boy blue shirt BREAK
+an old man red suit
+```
+If common is enabled, this prompt is converted to the following:
+```
+best quality, 3persons in garden, a girl white dress BREAK
+best quality, 3persons in garden, a boy blue shirt BREAK
+best quality, 3persons in garden, an old man red suit
+```
+### Negative prompt
+Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
+
+### Parameters
+To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type.
+
+### Input Parameters
+Parameters are specified through the `rp_arg`(dictionary type).
+
+```
+rp_args = {
+ "mode":"rows",
+ "div": "1;1;1"
+}
+
+pipe(prompt =prompt, rp_args = rp_args)
+```
+
+
+
+### Required Parameters
+- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive.
+- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
+- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
+
+### Optional Parameters
+- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
+
+The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
+
## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@@ -2665,3 +2842,86 @@ images[0].save("controlnet_and_adapter_inpaint.png")
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
+
+### DemoFusion
+This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
+The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
+- `view_batch_size` (`int`, defaults to 16):
+ The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.
+
+- `stride` (`int`, defaults to 64):
+ The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.
+
+- `cosine_scale_1` (`float`, defaults to 3):
+ Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `cosine_scale_2` (`float`, defaults to 1):
+ Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `cosine_scale_3` (`float`, defaults to 1):
+ Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `sigma` (`float`, defaults to 1):
+ The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.
+
+- `multi_decoder` (`bool`, defaults to True):
+ Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.
+
+- `show_image` (`bool`, defaults to False):
+ Determine whether to show intermediate results during generation.
+```
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_demofusion_sdxl",
+ custom_revision="main",
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
+negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
+
+images = pipe(
+ prompt,
+ negative_prompt=negative_prompt,
+ height=3072,
+ width=3072,
+ view_batch_size=16,
+ stride=64,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ cosine_scale_1=3,
+ cosine_scale_2=1,
+ cosine_scale_3=1,
+ sigma=0.8,
+ multi_decoder=True,
+ show_image=True
+)
+```
+You can display and save the generated images as:
+```
+def image_grid(imgs, save_path=None):
+
+ w = 0
+ for i, img in enumerate(imgs):
+ h_, w_ = imgs[i].size
+ w += w_
+ h = h_
+ grid = Image.new('RGB', size=(w, h))
+ grid_w, grid_h = grid.size
+
+ w = 0
+ for i, img in enumerate(imgs):
+ h_, w_ = imgs[i].size
+ grid.paste(img, box=(w, h - h_))
+ if save_path != None:
+ img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024))
+ w += w_
+
+ return grid
+
+image_grid(images, save_path="./outputs/")
+```
+ ![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png)
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index d47c99bb2990..14f4deabcea7 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -16,6 +16,7 @@
import ast
import gc
+import inspect
import math
import warnings
from collections.abc import Iterable
@@ -23,16 +24,29 @@
import torch
import torch.nn.functional as F
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense
from diffusers.models.attention_processor import AttnProcessor2_0
-from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import logging, replace_example_docstring
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
EXAMPLE_DOC_STRING = """
@@ -44,6 +58,7 @@
>>> pipe = DiffusionPipeline.from_pretrained(
... "longlian/lmd_plus",
... custom_pipeline="llm_grounded_diffusion",
+ ... custom_revision="main",
... variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
@@ -96,7 +111,12 @@
# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
-DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
+DEFAULT_GUIDANCE_ATTN_KEYS = [
+ ("mid", 0, 0, 0),
+ ("up", 1, 0, 0),
+ ("up", 1, 1, 0),
+ ("up", 1, 2, 0),
+]
def convert_attn_keys(key):
@@ -126,7 +146,15 @@ def scale_proportion(obj_box, H, W):
# Adapted from the parent class `AttnProcessor2_0`
class AttnProcessorWithHook(AttnProcessor2_0):
- def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
+ def __init__(
+ self,
+ attn_processor_key,
+ hidden_size,
+ cross_attention_dim,
+ hook=None,
+ fast_attn=True,
+ enabled=True,
+ ):
super().__init__()
self.attn_processor_key = attn_processor_key
self.hidden_size = hidden_size
@@ -165,15 +193,16 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states, scale=scale)
- value = attn.to_v(encoder_hidden_states, scale=scale)
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -186,7 +215,13 @@ def __call__(
if self.hook is not None and self.enabled:
# Call the hook with query, key, value, and attention maps
- self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
+ self.hook(
+ self.attn_processor_key,
+ query_batch_dim,
+ key_batch_dim,
+ value_batch_dim,
+ attention_probs,
+ )
if self.fast_attn:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -202,7 +237,12 @@ def __call__(
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -211,7 +251,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -226,7 +266,9 @@ def __call__(
return hidden_states
-class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
+class LLMGroundedDiffusionPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
r"""
Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
@@ -257,6 +299,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
Whether a safety checker is needed for this pipeline.
"""
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
objects_text = "Objects: "
bg_prompt_text = "Background prompt: "
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
@@ -272,12 +319,91 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
- super().__init__(
- vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ # This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ # Initialize the attention hooks for LLM-grounded Diffusion
self.register_attn_hooks(unet)
self._saved_attn = None
@@ -464,7 +590,14 @@ def get_token_map(self, prompt, padding="do_not_pad", verbose=False):
return token_map
- def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
+ def get_phrase_indices(
+ self,
+ prompt,
+ phrases,
+ token_map=None,
+ add_suffix_if_not_found=False,
+ verbose=False,
+ ):
for obj in phrases:
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
if obj not in prompt:
@@ -485,7 +618,14 @@ def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_
phrase_token_map_str = " ".join(phrase_token_map)
if verbose:
- logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
+ logger.info(
+ "Full str:",
+ token_map_str,
+ "Substr:",
+ phrase_token_map_str,
+ "Phrase:",
+ phrases,
+ )
# Count the number of token before substr
# The substring comes with a trailing space that needs to be removed by minus one in the index.
@@ -552,7 +692,15 @@ def add_ca_loss_per_attn_map_to_loss(
return loss
- def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
+ def compute_ca_loss(
+ self,
+ saved_attn,
+ bboxes,
+ phrase_indices,
+ guidance_attn_keys,
+ verbose=False,
+ **kwargs,
+ ):
"""
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
@@ -605,6 +753,7 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -662,6 +811,7 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -724,9 +874,10 @@ def __call__(
phrase_indices = []
prompt_parsed = []
for prompt_item in prompt:
- phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
- prompt_item, add_suffix_if_not_found=True
- )
+ (
+ phrase_indices_parsed_item,
+ prompt_parsed_item,
+ ) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
phrase_indices.append(phrase_indices_parsed_item)
prompt_parsed.append(prompt_parsed_item)
prompt = prompt_parsed
@@ -759,6 +910,11 @@ def __call__(
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if ip_adapter_image is not None:
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -801,7 +957,10 @@ def __call__(
if n_objs:
cond_boxes[:n_objs] = torch.tensor(boxes)
text_embeddings = torch.zeros(
- max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
+ max_objs,
+ self.unet.config.cross_attention_dim,
+ device=device,
+ dtype=self.text_encoder.dtype,
)
if n_objs:
text_embeddings[:n_objs] = _text_embeddings
@@ -833,6 +992,9 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
loss_attn = torch.tensor(10000.0)
# 7. Denoising loop
@@ -869,6 +1031,7 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
@@ -1013,3 +1176,438 @@ def latent_lmd_guidance(
self.enable_attn_hook(enabled=False)
return latents, loss
+
+ # Below are methods copied from StableDiffusionPipeline
+ # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py
new file mode 100644
index 000000000000..98508b7ff89c
--- /dev/null
+++ b/examples/community/pipeline_demofusion_sdxl.py
@@ -0,0 +1,1414 @@
+import inspect
+import os
+import random
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ LoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_invisible_watermark_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import (
+ StableDiffusionXLWatermarker,
+ )
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
+ x_coord = torch.arange(kernel_size)
+ gaussian_1d = torch.exp(-((x_coord - (kernel_size - 1) / 2) ** 2) / (2 * sigma**2))
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
+
+ return kernel
+
+
+def gaussian_filter(latents, kernel_size=3, sigma=1.0):
+ channels = latents.shape[1]
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size // 2, groups=channels)
+
+ return blurred_latents
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class DemoFusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ num_images_per_prompt=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # DemoFusion specific checks
+ if max(height, width) % 1024 != 0:
+ raise ValueError(
+ f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}."
+ )
+
+ if num_images_per_prompt != 1:
+ warnings.warn("num_images_per_prompt != 1 is not supported by DemoFusion and will be ignored.")
+ num_images_per_prompt = 1
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def get_views(self, height, width, window_size=128, stride=64, random_jitter=False):
+ height //= self.vae_scale_factor
+ width //= self.vae_scale_factor
+ num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1
+ num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * stride)
+ h_end = h_start + window_size
+ w_start = int((i % num_blocks_width) * stride)
+ w_end = w_start + window_size
+
+ if h_end > height:
+ h_start = int(h_start + height - h_end)
+ h_end = int(height)
+ if w_end > width:
+ w_start = int(w_start + width - w_end)
+ w_end = int(width)
+ if h_start < 0:
+ h_end = int(h_end - h_start)
+ h_start = 0
+ if w_start < 0:
+ w_end = int(w_end - w_start)
+ w_start = 0
+
+ if random_jitter:
+ jitter_range = (window_size - stride) // 4
+ w_jitter = 0
+ h_jitter = 0
+ if (w_start != 0) and (w_end != width):
+ w_jitter = random.randint(-jitter_range, jitter_range)
+ elif (w_start == 0) and (w_end != width):
+ w_jitter = random.randint(-jitter_range, 0)
+ elif (w_start != 0) and (w_end == width):
+ w_jitter = random.randint(0, jitter_range)
+ if (h_start != 0) and (h_end != height):
+ h_jitter = random.randint(-jitter_range, jitter_range)
+ elif (h_start == 0) and (h_end != height):
+ h_jitter = random.randint(-jitter_range, 0)
+ elif (h_start != 0) and (h_end == height):
+ h_jitter = random.randint(0, jitter_range)
+ h_start += h_jitter + jitter_range
+ h_end += h_jitter + jitter_range
+ w_start += w_jitter + jitter_range
+ w_end += w_jitter + jitter_range
+
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+ def tiled_decode(self, latents, current_height, current_width):
+ core_size = self.unet.config.sample_size // 4
+ core_stride = core_size
+ pad_size = self.unet.config.sample_size // 4 * 3
+ decoder_view_batch_size = 1
+
+ views = self.get_views(current_height, current_width, stride=core_stride, window_size=core_size)
+ views_batch = [views[i : i + decoder_view_batch_size] for i in range(0, len(views), decoder_view_batch_size)]
+ latents_ = F.pad(latents, (pad_size, pad_size, pad_size, pad_size), "constant", 0)
+ image = torch.zeros(latents.size(0), 3, current_height, current_width).to(latents.device)
+ count = torch.zeros_like(image).to(latents.device)
+ # get the latents corresponding to the current view coordinates
+ with self.progress_bar(total=len(views_batch)) as progress_bar:
+ for j, batch_view in enumerate(views_batch):
+ len(batch_view)
+ latents_for_view = torch.cat(
+ [
+ latents_[:, :, h_start : h_end + pad_size * 2, w_start : w_end + pad_size * 2]
+ for h_start, h_end, w_start, w_end in batch_view
+ ]
+ )
+ image_patch = self.vae.decode(latents_for_view / self.vae.config.scaling_factor, return_dict=False)[0]
+ h_start, h_end, w_start, w_end = views[j]
+ h_start, h_end, w_start, w_end = (
+ h_start * self.vae_scale_factor,
+ h_end * self.vae_scale_factor,
+ w_start * self.vae_scale_factor,
+ w_end * self.vae_scale_factor,
+ )
+ p_h_start, p_h_end, p_w_start, p_w_end = (
+ pad_size * self.vae_scale_factor,
+ image_patch.size(2) - pad_size * self.vae_scale_factor,
+ pad_size * self.vae_scale_factor,
+ image_patch.size(3) - pad_size * self.vae_scale_factor,
+ )
+ image[:, :, h_start:h_end, w_start:w_end] += image_patch[:, :, p_h_start:p_h_end, p_w_start:p_w_end]
+ count[:, :, h_start:h_end, w_start:w_end] += 1
+ progress_bar.update()
+ image = image / count
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = False,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ ################### DemoFusion specific parameters ####################
+ view_batch_size: int = 16,
+ multi_decoder: bool = True,
+ stride: Optional[int] = 64,
+ cosine_scale_1: Optional[float] = 3.0,
+ cosine_scale_2: Optional[float] = 1.0,
+ cosine_scale_3: Optional[float] = 1.0,
+ sigma: Optional[float] = 0.8,
+ show_image: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ ################### DemoFusion specific parameters ####################
+ view_batch_size (`int`, defaults to 16):
+ The batch size for multiple denoising paths. Typically, a larger batch size can result in higher
+ efficiency but comes with increased GPU memory requirements.
+ multi_decoder (`bool`, defaults to True):
+ Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072,
+ a tiled decoder becomes necessary.
+ stride (`int`, defaults to 64):
+ The stride of moving local patches. A smaller stride is better for alleviating seam issues,
+ but it also introduces additional computational overhead and inference time.
+ cosine_scale_1 (`float`, defaults to 3):
+ Control the strength of skip-residual. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ cosine_scale_2 (`float`, defaults to 1):
+ Control the strength of dilated sampling. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ cosine_scale_3 (`float`, defaults to 1):
+ Control the strength of the gaussion filter. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ sigma (`float`, defaults to 1):
+ The standerd value of the gaussian filter.
+ show_image (`bool`, defaults to False):
+ Determine whether to show intermediate results during generation.
+
+ Examples:
+
+ Returns:
+ a `list` with the generated images at each phase.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ x1_size = self.default_sample_size * self.vae_scale_factor
+
+ height_scale = height / x1_size
+ width_scale = width / x1_size
+ scale_num = int(max(height_scale, width_scale))
+ aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ num_images_per_prompt,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height // scale_num,
+ width // scale_num,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ output_images = []
+
+ ############################################################### Phase 1 #################################################################
+
+ print("### Phase 1 Denoising ###")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latents_for_view = latents
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ anchor_mean = latents.mean()
+ anchor_std = latents.std()
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ print("### Phase 1 Decoding ###")
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ if show_image:
+ plt.figure(figsize=(10, 10))
+ plt.imshow(image[0])
+ plt.axis("off") # Turn off axis numbers and ticks
+ plt.show()
+ output_images.append(image[0])
+
+ ####################################################### Phase 2+ #####################################################
+
+ for current_scale_num in range(2, scale_num + 1):
+ print("### Phase {} Denoising ###".format(current_scale_num))
+ current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
+ current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
+ if height > width:
+ current_width = int(current_width * aspect_ratio)
+ else:
+ current_height = int(current_height * aspect_ratio)
+
+ latents = F.interpolate(
+ latents,
+ size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)),
+ mode="bicubic",
+ )
+
+ noise_latents = []
+ noise = torch.randn_like(latents)
+ for timestep in timesteps:
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
+ noise_latents.append(noise_latent)
+ latents = noise_latents[0]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ count = torch.zeros_like(latents)
+ value = torch.zeros_like(latents)
+ cosine_factor = (
+ 0.5
+ * (
+ 1
+ + torch.cos(
+ torch.pi
+ * (self.scheduler.config.num_train_timesteps - t)
+ / self.scheduler.config.num_train_timesteps
+ )
+ ).cpu()
+ )
+
+ c1 = cosine_factor**cosine_scale_1
+ latents = latents * (1 - c1) + noise_latents[i] * c1
+
+ ############################################# MultiDiffusion #############################################
+
+ views = self.get_views(
+ current_height,
+ current_width,
+ stride=stride,
+ window_size=self.unet.config.sample_size,
+ random_jitter=True,
+ )
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
+
+ jitter_range = (self.unet.config.sample_size - stride) // 4
+ latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), "constant", 0)
+
+ count_local = torch.zeros_like(latents_)
+ value_local = torch.zeros_like(latents_)
+
+ for j, batch_view in enumerate(views_batch):
+ vb_size = len(batch_view)
+
+ # get the latents corresponding to the current view coordinates
+ latents_for_view = torch.cat(
+ [
+ latents_[:, :, h_start:h_end, w_start:w_end]
+ for h_start, h_end, w_start, w_end in batch_view
+ ]
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents_for_view
+ latent_model_input = (
+ latent_model_input.repeat_interleave(2, dim=0)
+ if do_classifier_free_guidance
+ else latent_model_input
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
+ add_time_ids_input = []
+ for h_start, h_end, w_start, w_end in batch_view:
+ add_time_ids_ = add_time_ids.clone()
+ add_time_ids_[:, 2] = h_start * self.vae_scale_factor
+ add_time_ids_[:, 3] = w_start * self.vae_scale_factor
+ add_time_ids_input.append(add_time_ids_)
+ add_time_ids_input = torch.cat(add_time_ids_input)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_input,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ self.scheduler._init_step_index(t)
+ latents_denoised_batch = self.scheduler.step(
+ noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # extract value from batch
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
+ latents_denoised_batch.chunk(vb_size), batch_view
+ ):
+ value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
+ count_local[:, :, h_start:h_end, w_start:w_end] += 1
+
+ value_local = value_local[
+ :,
+ :,
+ jitter_range : jitter_range + current_height // self.vae_scale_factor,
+ jitter_range : jitter_range + current_width // self.vae_scale_factor,
+ ]
+ count_local = count_local[
+ :,
+ :,
+ jitter_range : jitter_range + current_height // self.vae_scale_factor,
+ jitter_range : jitter_range + current_width // self.vae_scale_factor,
+ ]
+
+ c2 = cosine_factor**cosine_scale_2
+
+ value += value_local / count_local * (1 - c2)
+ count += torch.ones_like(value_local) * (1 - c2)
+
+ ############################################# Dilated Sampling #############################################
+
+ views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
+
+ h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num
+ w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num
+ latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), "constant", 0)
+
+ count_global = torch.zeros_like(latents_)
+ value_global = torch.zeros_like(latents_)
+
+ c3 = 0.99 * cosine_factor**cosine_scale_3 + 1e-2
+ std_, mean_ = latents_.std(), latents_.mean()
+ latents_gaussian = gaussian_filter(
+ latents_, kernel_size=(2 * current_scale_num - 1), sigma=sigma * c3
+ )
+ latents_gaussian = (
+ latents_gaussian - latents_gaussian.mean()
+ ) / latents_gaussian.std() * std_ + mean_
+
+ for j, batch_view in enumerate(views_batch):
+ latents_for_view = torch.cat(
+ [latents_[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
+ )
+ latents_for_view_gaussian = torch.cat(
+ [latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
+ )
+
+ vb_size = latents_for_view.size(0)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents_for_view_gaussian
+ latent_model_input = (
+ latent_model_input.repeat_interleave(2, dim=0)
+ if do_classifier_free_guidance
+ else latent_model_input
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
+ add_time_ids_input = torch.cat([add_time_ids] * vb_size)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_input,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ self.scheduler._init_step_index(t)
+ latents_denoised_batch = self.scheduler.step(
+ noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # extract value from batch
+ for latents_view_denoised, (h, w) in zip(latents_denoised_batch.chunk(vb_size), batch_view):
+ value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
+ count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
+
+ c2 = cosine_factor**cosine_scale_2
+
+ value_global = value_global[:, :, h_pad:, w_pad:]
+
+ value += value_global * c2
+ count += torch.ones_like(value_global) * c2
+
+ ###########################################################
+
+ latents = torch.where(count > 0, value / count, value)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ #########################################################################################################################################
+
+ latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ print("### Phase {} Decoding ###".format(current_scale_num))
+ if multi_decoder:
+ image = self.tiled_decode(latents, current_height, current_width)
+ else:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ if show_image:
+ plt.figure(figsize=(10, 10))
+ plt.imshow(image[0])
+ plt.axis("off") # Turn off axis numbers and ticks
+ plt.show()
+ output_images.append(image[0])
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return output_images
+
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+
+ # Remove any existing hooks.
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+ else:
+ raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
+
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+ recursive = False
+ for _, component in self.components.items():
+ if isinstance(component, torch.nn.Module):
+ if hasattr(component, "_hf_hook"):
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ recursive = is_sequential_cpu_offload
+ remove_hook_from_module(component, recurse=recursive)
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ self.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ self.enable_sequential_cpu_offload()
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
+ )
+
+ if unet_lora_layers:
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py
new file mode 100644
index 000000000000..525e75bc68b9
--- /dev/null
+++ b/examples/community/regional_prompting_stable_diffusion.py
@@ -0,0 +1,589 @@
+import math
+from typing import Dict, Optional
+
+import torch
+import torchvision.transforms.functional as FF
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import StableDiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import USE_PEFT_BACKEND
+
+
+try:
+ from compel import Compel
+except ImportError:
+ Compel = None
+
+KCOMM = "ADDCOMM"
+KBRK = "BREAK"
+
+
+class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
+ r"""
+ Args for Regional Prompting Pipeline:
+ rp_args:dict
+ Required
+ rp_args["mode"]: cols, rows, prompt, prompt-ex
+ for cols, rows mode
+ rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
+ for prompt, prompt-ex mode
+ rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
+
+ Optional
+ rp_args["save_mask"]: True/False (save masks in prompt mode)
+
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: str,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: str = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ rp_args: Dict[str, str] = None,
+ ):
+ active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
+ if negative_prompt is None:
+ negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
+
+ device = self._execution_device
+ regions = 0
+
+ self.power = int(rp_args["power"]) if "power" in rp_args else 1
+
+ prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
+ n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
+ self.batch = batch = num_images_per_prompt * len(prompts)
+ all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
+ all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
+
+ cn = len(all_prompts_cn) == len(all_n_prompts_cn)
+
+ if Compel:
+ compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
+
+ def getcompelembs(prps):
+ embl = []
+ for prp in prps:
+ embl.append(compel.build_conditioning_tensor(prp))
+ return torch.cat(embl)
+
+ conds = getcompelembs(all_prompts_cn)
+ unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
+ embs = getcompelembs(prompts)
+ n_embs = getcompelembs(n_prompts)
+ prompt = negative_prompt = None
+ else:
+ conds = self.encode_prompt(prompts, device, 1, True)[0]
+ unconds = (
+ self.encode_prompt(n_prompts, device, 1, True)[0]
+ if cn
+ else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
+ )
+ embs = n_embs = None
+
+ if not active:
+ pcallback = None
+ mode = None
+ else:
+ if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
+ mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
+ ocells, icells, regions = make_cells(rp_args["div"])
+
+ elif "PRO" in rp_args["mode"].upper():
+ regions = len(all_prompts_p[0])
+ mode = "PROMPT"
+ reset_attnmaps(self)
+ self.ex = "EX" in rp_args["mode"].upper()
+ self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
+ thresholds = [float(x) for x in rp_args["th"].split(",")]
+
+ orig_hw = (height, width)
+ revers = True
+
+ def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
+ if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
+ self.step = step
+
+ if len(self.attnmaps_sizes) > 3:
+ self.history[step] = self.attnmaps.copy()
+ for hw in self.attnmaps_sizes:
+ allmasks = []
+ basemasks = [None] * batch
+ for tt, th in zip(target_tokens, thresholds):
+ for b in range(batch):
+ key = f"{tt}-{b}"
+ _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
+ mask = mask.unsqueeze(0).unsqueeze(-1)
+ if self.ex:
+ allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
+ allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
+ allmasks.append(mask)
+ basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
+ basemasks = [1 - mask for mask in basemasks]
+ basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
+ allmasks = basemasks + allmasks
+
+ self.attnmasks[hw] = torch.cat(allmasks)
+ self.maskready = True
+ return latents
+
+ def hook_forward(module):
+ # diffusers==0.23.2
+ def forward(
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ) -> torch.Tensor:
+ attn = module
+ xshape = hidden_states.shape
+ self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
+
+ if revers:
+ nx, px = hidden_states.chunk(2)
+ else:
+ px, nx = hidden_states.chunk(2)
+
+ if cn:
+ hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
+ encoder_hidden_states = torch.cat([conds] + [unconds])
+ else:
+ hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
+ encoder_hidden_states = torch.cat([conds] + [unconds])
+
+ residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ getattn="PRO" in mode,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ #### Regional Prompting Col/Row mode
+ if any(x in mode for x in ["COL", "ROW"]):
+ reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
+ center = reshaped.shape[0] // 2
+ px = reshaped[0:center] if cn else reshaped[0:-batch]
+ nx = reshaped[center:] if cn else reshaped[-batch:]
+ outs = [px, nx] if cn else [px]
+ for out in outs:
+ c = 0
+ for i, ocell in enumerate(ocells):
+ for icell in icells[i]:
+ if "ROW" in mode:
+ out[
+ 0:batch,
+ int(h * ocell[0]) : int(h * ocell[1]),
+ int(w * icell[0]) : int(w * icell[1]),
+ :,
+ ] = out[
+ c * batch : (c + 1) * batch,
+ int(h * ocell[0]) : int(h * ocell[1]),
+ int(w * icell[0]) : int(w * icell[1]),
+ :,
+ ]
+ else:
+ out[
+ 0:batch,
+ int(h * icell[0]) : int(h * icell[1]),
+ int(w * ocell[0]) : int(w * ocell[1]),
+ :,
+ ] = out[
+ c * batch : (c + 1) * batch,
+ int(h * icell[0]) : int(h * icell[1]),
+ int(w * ocell[0]) : int(w * ocell[1]),
+ :,
+ ]
+ c += 1
+ px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
+ hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
+ hidden_states = hidden_states.reshape(xshape)
+
+ #### Regional Prompting Prompt mode
+ elif "PRO" in mode:
+ center = reshaped.shape[0] // 2
+ px = reshaped[0:center] if cn else reshaped[0:-batch]
+ nx = reshaped[center:] if cn else reshaped[-batch:]
+
+ if (h, w) in self.attnmasks and self.maskready:
+
+ def mask(input):
+ out = torch.multiply(input, self.attnmasks[(h, w)])
+ for b in range(batch):
+ for r in range(1, regions):
+ out[b] = out[b] + out[r * batch + b]
+ return out
+
+ px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
+ px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
+ hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
+ return hidden_states
+
+ return forward
+
+ def hook_forwards(root_module: torch.nn.Module):
+ for name, module in root_module.named_modules():
+ if "attn2" in name and module.__class__.__name__ == "Attention":
+ module.forward = hook_forward(module)
+
+ hook_forwards(self.unet)
+
+ output = StableDiffusionPipeline(**self.components)(
+ prompt=prompt,
+ prompt_embeds=embs,
+ negative_prompt=negative_prompt,
+ negative_prompt_embeds=n_embs,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback_on_step_end=pcallback,
+ )
+
+ if "save_mask" in rp_args:
+ save_mask = rp_args["save_mask"]
+ else:
+ save_mask = False
+
+ if mode == "PROMPT" and save_mask:
+ saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
+
+ return output
+
+
+### Make prompt list for each regions
+def promptsmaker(prompts, batch):
+ out_p = []
+ plen = len(prompts)
+ for prompt in prompts:
+ add = ""
+ if KCOMM in prompt:
+ add, prompt = prompt.split(KCOMM)
+ add = add + " "
+ prompts = prompt.split(KBRK)
+ out_p.append([add + p for p in prompts])
+ out = [None] * batch * len(out_p[0]) * len(out_p)
+ for p, prs in enumerate(out_p): # inputs prompts
+ for r, pr in enumerate(prs): # prompts for regions
+ start = (p + r * plen) * batch
+ out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
+ return out, out_p
+
+
+### make regions from ratios
+### ";" makes outercells, "," makes inner cells
+def make_cells(ratios):
+ if ";" not in ratios and "," in ratios:
+ ratios = ratios.replace(",", ";")
+ ratios = ratios.split(";")
+ ratios = [inratios.split(",") for inratios in ratios]
+
+ icells = []
+ ocells = []
+
+ def startend(cells, array):
+ current_start = 0
+ array = [float(x) for x in array]
+ for value in array:
+ end = current_start + (value / sum(array))
+ cells.append([current_start, end])
+ current_start = end
+
+ startend(ocells, [r[0] for r in ratios])
+
+ for inratios in ratios:
+ if 2 > len(inratios):
+ icells.append([[0, 1]])
+ else:
+ add = []
+ startend(add, inratios[1:])
+ icells.append(add)
+
+ return ocells, icells, sum(len(cell) for cell in icells)
+
+
+def make_emblist(self, prompts):
+ with torch.no_grad():
+ tokens = self.tokenizer(
+ prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ ).input_ids.to(self.device)
+ embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
+ return embs
+
+
+def split_dims(xs, height, width):
+ xs = xs
+
+ def repeat_div(x, y):
+ while y > 0:
+ x = math.ceil(x / 2)
+ y = y - 1
+ return x
+
+ scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
+ dsh = repeat_div(height, scale)
+ dsw = repeat_div(width, scale)
+ return dsh, dsw
+
+
+##### for prompt mode
+def get_attn_maps(self, attn):
+ height, width = self.hw
+ target_tokens = self.target_tokens
+ if (height, width) not in self.attnmaps_sizes:
+ self.attnmaps_sizes.append((height, width))
+
+ for b in range(self.batch):
+ for t in target_tokens:
+ power = self.power
+ add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
+ add = torch.sum(add, dim=2)
+ key = f"{t}-{b}"
+ if key not in self.attnmaps:
+ self.attnmaps[key] = add
+ else:
+ if self.attnmaps[key].shape[1] != add.shape[1]:
+ add = add.view(8, height, width)
+ add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
+ add = add.reshape_as(self.attnmaps[key])
+
+ self.attnmaps[key] = self.attnmaps[key] + add
+
+
+def reset_attnmaps(self): # init parameters in every batch
+ self.step = 0
+ self.attnmaps = {} # maked from attention maps
+ self.attnmaps_sizes = [] # height,width set of u-net blocks
+ self.attnmasks = {} # maked from attnmaps for regions
+ self.maskready = False
+ self.history = {}
+
+
+def saveattnmaps(self, output, h, w, th, step, regions):
+ masks = []
+ for i, mask in enumerate(self.history[step].values()):
+ img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
+ if self.ex:
+ masks = [x - mask for x in masks]
+ masks.append(mask)
+ if len(masks) == regions - 1:
+ output.images.extend([FF.to_pil_image(mask) for mask in masks])
+ masks = []
+ else:
+ output.images.append(img)
+
+
+def makepmask(
+ self, mask, h, w, th, step
+): # make masks from attention cache return [for preview, for attention, for Latent]
+ th = th - step * 0.005
+ if 0.05 >= th:
+ th = 0.05
+ mask = torch.mean(mask, dim=0)
+ mask = mask / mask.max().item()
+ mask = torch.where(mask > th, 1, 0)
+ mask = mask.float()
+ mask = mask.view(1, *self.attnmaps_sizes[0])
+ img = FF.to_pil_image(mask)
+ img = img.resize((w, h))
+ mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
+ lmask = mask
+ mask = mask.reshape(h * w)
+ mask = torch.where(mask > 0.1, 1, 0)
+ return img, mask, lmask
+
+
+def tokendealer(self, all_prompts):
+ for prompts in all_prompts:
+ targets = [p.split(",")[-1] for p in prompts[1:]]
+ tt = []
+
+ for target in targets:
+ ptokens = (
+ self.tokenizer(
+ prompts,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+ )[0]
+ ttokens = (
+ self.tokenizer(
+ target,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+ )[0]
+
+ tlist = []
+
+ for t in range(ttokens.shape[0] - 2):
+ for p in range(ptokens.shape[0]):
+ if ttokens[t + 1] == ptokens[p]:
+ tlist.append(p)
+ if tlist != []:
+ tt.append(tlist)
+
+ return tt
+
+
+def scaled_dot_product_attention(
+ self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
+) -> torch.Tensor:
+ # Efficient implementation equivalent to the following:
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
+ if is_causal:
+ assert attn_mask is None
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(query.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ if getattn:
+ get_attn_maps(self, attn_weight)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+ return attn_weight @ value
diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py
index 041cf3a12dbd..507177791f5e 100755
--- a/examples/community/stable_diffusion_tensorrt_img2img.py
+++ b/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
-from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -709,6 +709,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -724,7 +725,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
- vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py
index 71fa1b0a5f11..b4e16c76159c 100755
--- a/examples/community/stable_diffusion_tensorrt_inpaint.py
+++ b/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
-from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -710,6 +710,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -725,7 +726,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
- vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py
index b51f3176b958..c38261463384 100755
--- a/examples/community/stable_diffusion_tensorrt_txt2img.py
+++ b/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -40,7 +40,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
-from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -624,6 +624,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
@@ -639,7 +640,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
- vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 6fa8d2c57832..00fd1910a657 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -71,7 +71,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.18.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index 25faedf714b9..f63333696861 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.18.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 4c4ad984fc31..d1eda7776223 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -70,7 +70,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.18.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index 920950d0f6e6..884b2755942a 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -71,7 +71,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.18.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index 8dee7c33eac6..2ba6565abf81 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index b658f689358d..b3c09325fc4d 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -59,7 +59,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 41a29c3945ab..70d9c52eefea 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index c619a46dd99d..623ce3704af6 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -62,7 +62,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index c3f19efbdc38..c37ee87d8899 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index 680c9dffdfcb..6244fa1c1a9c 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 3ba775b543d8..b2905a05cef3 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -65,7 +65,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index bbe8dab731e9..98b5f3bc5694 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 2766e4c99086..ae0f216b97cd 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index 9b57b5eb08f9..e764aa63bd4d 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index bc0a64b42e4b..a4017c85e1b5 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 2f968aa8b8b3..90cf540c6425 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 317e4178c04c..b64986ecf5ae 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 0e6d06074012..a6855abcee75 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index f8e58bdb80fa..b1c554a7a2ed 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 9a5482054939..c5371c95469e 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -53,7 +53,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index aad29d1f565c..c0d2f639883c 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -33,7 +33,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 7d731c994bdd..c030c59693c3 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -48,7 +48,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index b69a85e4f463..f1306ea5b9de 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index ee15e6f7def6..5c024f4080ae 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 7fea4fdb6440..50bcc992064d 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -79,7 +79,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 5de1a8d7c325..bce70223eb3b 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 6e552c9b3dde..a721c605fdea 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index bca018d8df23..1e67f05abe7a 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -50,7 +50,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
index 62450679f201..1ae5092ad10c 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
+check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/setup.py b/setup.py
index 1c645d769a5c..ddb2afd64c51 100644
--- a/setup.py
+++ b/setup.py
@@ -118,9 +118,10 @@
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
- "ruff>=0.1.5,<=0.2",
+ "ruff==0.1.5",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
+ "GitPython<3.1.19",
"scipy",
"onnx",
"regex!=2019.12.17",
@@ -206,6 +207,7 @@ def run(self):
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
"compel",
+ "GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
@@ -249,13 +251,13 @@ def run(self):
setup(
name="diffusers",
- version="0.24.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
- license="Apache",
- author="The HuggingFace team",
+ license="Apache 2.0 License",
+ author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)",
author_email="patrick@huggingface.co",
url="https://github.com/huggingface/diffusers",
package_dir={"": "src"},
@@ -279,24 +281,3 @@ def run(self):
+ [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
)
-
-
-# Release checklist
-# 1. Change the version in __init__.py and setup.py.
-# 2. Commit these changes with the message: "Release: Release"
-# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for PyPI'"
-# Push the tag to git: git push --tags origin main
-# 4. Run the following commands in the top-level directory:
-# python setup.py bdist_wheel
-# python setup.py sdist
-# 5. Upload the package to the PyPI test server first:
-# twine upload dist/* -r pypitest
-# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
-# 6. Check that you can install it in a virtualenv by running:
-# pip install -i https://testpypi.python.org/pypi diffusers
-# diffusers env
-# diffusers test
-# 7. Upload the final version to the actual PyPI:
-# twine upload dist/* -r pypi
-# 8. Add release notes to the tag in GitHub once everything is looking hunky-dory.
-# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to main.
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 574082c30362..c262ea42a6c5 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.24.0.dev0"
+__version__ = "0.25.0.dev0"
from typing import TYPE_CHECKING
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 7ec2e2cf6d5c..7891984b0c5d 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -30,9 +30,10 @@
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
- "ruff": "ruff>=0.1.5,<=0.2",
+ "ruff": "ruff==0.1.5",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+ "GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"regex": "regex!=2019.12.17",
diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py
index dfb27587d7d5..f46d3ac98b17 100644
--- a/src/diffusers/experimental/rl/value_guided_sampling.py
+++ b/src/diffusers/experimental/rl/value_guided_sampling.py
@@ -113,7 +113,7 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: verify deprecation of this kwarg
- x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
# apply conditions to the trajectory (set the initial state)
x = self.reset_x0(x, conditions, self.action_dim)
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index 3eb7569967a7..dde717959f8e 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -391,6 +391,10 @@ def load_lora_into_unet(
# their prefixes.
keys = list(state_dict.keys())
+ if all(key.startswith("unet.unet") for key in keys):
+ deprecation_message = "Keys starting with 'unet.unet' are deprecated."
+ deprecate("unet.unet keys", "0.27", deprecation_message)
+
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
@@ -407,8 +411,9 @@ def load_lora_into_unet(
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
- logger.warn(warn_message)
+ if not USE_PEFT_BACKEND:
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
+ logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -675,8 +680,7 @@ def _remove_text_encoder_monkey_patch(self):
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
- if version.parse(__version__) > version.parse("0.23"):
- deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
+ deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -704,8 +708,7 @@ def _modify_text_encoder(
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
- if version.parse(__version__) > version.parse("0.23"):
- deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
+ deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -802,29 +805,21 @@ def save_lora_weights(
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
- # Create a flat dictionary.
state_dict = {}
- # Populate the dictionary.
- if unet_lora_layers is not None:
- weights = (
- unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
- )
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
- unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
- state_dict.update(unet_lora_state_dict)
+ if not (unet_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
- if text_encoder_lora_layers is not None:
- weights = (
- text_encoder_lora_layers.state_dict()
- if isinstance(text_encoder_lora_layers, torch.nn.Module)
- else text_encoder_lora_layers
- )
+ if unet_lora_layers:
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
- text_encoder_lora_state_dict = {
- f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
- }
- state_dict.update(text_encoder_lora_state_dict)
+ if text_encoder_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
# Save the model
cls.write_lora_layers(
@@ -948,8 +943,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
module.merge()
else:
- if version.parse(__version__) > version.parse("0.23"):
- deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
+ deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -1006,8 +1000,7 @@ def unfuse_text_encoder_lora(text_encoder):
module.unmerge()
else:
- if version.parse(__version__) > version.parse("0.23"):
- deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
+ deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index 8c63c4cf59a5..bf100e7f2c81 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -282,7 +282,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
)
if torch_dtype is not None:
- pipe.to(torch_dtype=torch_dtype)
+ pipe.to(dtype=torch_dtype)
return pipe
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 6c805672c9cd..9d559a4b4af8 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from collections import defaultdict
+from collections import OrderedDict, defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -21,7 +21,7 @@
import torch.nn.functional as F
from torch import nn
-from ..models.embeddings import ImageProjection
+from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
@@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0,
)
+ if "proj.weight" in state_dict["image_proj"]:
+ # IP-Adapter
+ num_image_text_embeds = 4
+ else:
+ # IP-Adapter Plus
+ num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
+
+ # Set encoder_hid_proj after loading ip_adapter weights,
+ # because `Resampler` also has `attn_processors`.
+ self.encoder_hid_proj = None
+
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
@@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
value_dict = {}
@@ -708,26 +722,76 @@ def _load_ip_adapter_weights(self, state_dict):
self.set_attn_processor(attn_procs)
# create image projection layers.
- clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
- cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
+ if "proj.weight" in state_dict["image_proj"]:
+ # IP-Adapter
+ clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
+ cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
+
+ image_projection = ImageProjection(
+ cross_attention_dim=cross_attention_dim,
+ image_embed_dim=clip_embeddings_dim,
+ num_image_text_embeds=num_image_text_embeds,
+ )
+ image_projection.to(dtype=self.dtype, device=self.device)
- image_projection = ImageProjection(
- cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
- )
- image_projection.to(dtype=self.dtype, device=self.device)
-
- # load image projection layer weights
- image_proj_state_dict = {}
- image_proj_state_dict.update(
- {
- "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
- "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
- "norm.weight": state_dict["image_proj"]["norm.weight"],
- "norm.bias": state_dict["image_proj"]["norm.bias"],
- }
- )
+ # load image projection layer weights
+ image_proj_state_dict = {}
+ image_proj_state_dict.update(
+ {
+ "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
+ "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
+ "norm.weight": state_dict["image_proj"]["norm.weight"],
+ "norm.bias": state_dict["image_proj"]["norm.bias"],
+ }
+ )
+
+ image_projection.load_state_dict(image_proj_state_dict)
+
+ else:
+ # IP-Adapter Plus
+ embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
+ output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
+ hidden_dims = state_dict["image_proj"]["latents"].shape[2]
+ heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
+
+ image_projection = Resampler(
+ embed_dims=embed_dims,
+ output_dims=output_dims,
+ hidden_dims=hidden_dims,
+ heads=heads,
+ num_queries=num_image_text_embeds,
+ )
+
+ image_proj_state_dict = state_dict["image_proj"]
+
+ new_sd = OrderedDict()
+ for k, v in image_proj_state_dict.items():
+ if "0.to" in k:
+ k = k.replace("0.to", "2.to")
+ elif "1.0.weight" in k:
+ k = k.replace("1.0.weight", "3.0.weight")
+ elif "1.0.bias" in k:
+ k = k.replace("1.0.bias", "3.0.bias")
+ elif "1.1.weight" in k:
+ k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
+ elif "1.3.weight" in k:
+ k = k.replace("1.3.weight", "3.1.net.2.weight")
+
+ if "norm1" in k:
+ new_sd[k.replace("0.norm1", "0")] = v
+ elif "norm2" in k:
+ new_sd[k.replace("0.norm2", "1")] = v
+ elif "to_kv" in k:
+ v_chunk = v.chunk(2, dim=0)
+ new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
+ new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
+ elif "to_out" in k:
+ new_sd[k.replace("to_out", "to_out.0")] = v
+ else:
+ new_sd[k] = v
- image_projection.load_state_dict(image_proj_state_dict)
+ image_projection.load_state_dict(new_sd)
+ del image_proj_state_dict
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 839045001bb0..49ee3ee6af6b 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -34,6 +34,7 @@
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
+ _import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
@@ -42,7 +43,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
- _import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
+ _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
@@ -63,6 +64,7 @@
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
+ from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
@@ -72,7 +74,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
- from .unet_kandi3 import Kandinsky3UNet
+ from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py
index 8b75162ba597..47570eca8443 100644
--- a/src/diffusers/models/activations.py
+++ b/src/diffusers/models/activations.py
@@ -55,11 +55,12 @@ class GELU(nn.Module):
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
- def __init__(self, dim_in: int, dim_out: int):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
- self.proj = linear_cls(dim_in, dim_out * 2)
+ self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
- def __init__(self, dim_in: int, dim_out: int):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index f02b5e249eee..08faaaf3e5bf 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -501,6 +501,7 @@ class FeedForward(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
@@ -511,6 +512,7 @@ def __init__(
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
+ bias: bool = True,
):
super().__init__()
inner_dim = int(dim * mult)
@@ -518,13 +520,13 @@ def __init__(
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
- act_fn = GELU(dim, inner_dim)
+ act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
- act_fn = GELU(dim, inner_dim, approximate="tanh")
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
- act_fn = GEGLU(dim, inner_dim)
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
- act_fn = ApproximateGELU(dim, inner_dim)
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
@@ -532,7 +534,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
- self.net.append(linear_cls(inner_dim, dim_out))
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 21eb3a32dc09..40a335527ace 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -16,7 +16,7 @@
import torch
import torch.nn.functional as F
-from torch import einsum, nn
+from torch import nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
@@ -109,15 +109,17 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
):
super().__init__()
- self.inner_dim = dim_head * heads
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -126,7 +128,7 @@ def __init__(
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
- self.heads = heads
+ self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
@@ -193,7 +195,7 @@ def __init__(
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
- self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
@@ -2219,44 +2221,6 @@ def __call__(
return hidden_states
-# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
-# this way torch.compile and co. will work as well
-class Kandi3AttnProcessor:
- r"""
- Default kandinsky3 proccesor for performing attention-related computations.
- """
-
- @staticmethod
- def _reshape(hid_states, h):
- b, n, f = hid_states.shape
- d = f // h
- return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
-
- def __call__(
- self,
- attn,
- x,
- context,
- context_mask=None,
- ):
- query = self._reshape(attn.to_q(x), h=attn.num_heads)
- key = self._reshape(attn.to_k(context), h=attn.num_heads)
- value = self._reshape(attn.to_v(context), h=attn.num_heads)
-
- attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
-
- if context_mask is not None:
- max_neg_value = -torch.finfo(attention_matrix.dtype).max
- context_mask = context_mask.unsqueeze(1).unsqueeze(1)
- attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
- attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
-
- out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
- out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
- out = attn.to_out[0](out)
- return out
-
-
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -2282,7 +2246,6 @@ def __call__(
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
- Kandi3AttnProcessor,
)
AttentionProcessor = Union[
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index a377ae267411..bdd2930d20f9 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -20,6 +20,7 @@
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
+from .attention_processor import Attention
from .lora import LoRACompatibleLinear
@@ -790,3 +791,91 @@ def forward(self, caption, force_drop_ids=None):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
+
+
+class Resampler(nn.Module):
+ """Resampler of IP-Adapter Plus.
+
+ Args:
+ ----
+ embed_dims (int): The feature dimension. Defaults to 768.
+ output_dims (int): The number of output channels, that is the same
+ number of the channels in the
+ `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int): The number of hidden channels. Defaults to 1280.
+ depth (int): The number of blocks. Defaults to 8.
+ dim_head (int): The number of head channels. Defaults to 64.
+ heads (int): Parallel attention heads. Defaults to 16.
+ num_queries (int): The number of queries. Defaults to 8.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 1024,
+ hidden_dims: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward # Lazy import to avoid circular import
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
+
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
+
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ nn.LayerNorm(hidden_dims),
+ nn.LayerNorm(hidden_dims),
+ Attention(
+ query_dim=hidden_dims,
+ dim_head=dim_head,
+ heads=heads,
+ out_bias=False,
+ ),
+ nn.Sequential(
+ nn.LayerNorm(hidden_dims),
+ FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
+ ),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ ----
+ x (torch.Tensor): Input Tensor.
+
+ Returns:
+ -------
+ torch.Tensor: Output Tensor.
+ """
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for ln0, ln1, attn, ff in self.layers:
+ residual = latents
+
+ encoder_hidden_states = ln0(x)
+ latents = ln1(latents)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
+ latents = attn(latents, encoder_hidden_states) + residual
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandinsky3.py
similarity index 69%
rename from src/diffusers/models/unet_kandi3.py
rename to src/diffusers/models/unet_kandinsky3.py
index 42e25a942f7d..eef3287e5d99 100644
--- a/src/diffusers/models/unet_kandi3.py
+++ b/src/diffusers/models/unet_kandinsky3.py
@@ -1,16 +1,28 @@
-import math
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
-import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
-from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
-from .embeddings import TimestepEmbedding
+from .attention_processor import Attention, AttentionProcessor, AttnProcessor
+from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
@@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput):
sample: torch.FloatTensor = None
-# TODO(Yiyi): This class needs to be removed
-def set_default_item(condition, item_1, item_2=None):
- if condition:
- return item_1
- else:
- return item_2
-
-
-# TODO(Yiyi): This class needs to be removed
-def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
- if condition:
- return layer_1(*args_1, **kwargs_1)
- else:
- return layer_2(*args_2, **kwargs_2)
-
-
-# TODO(Yiyi): This class should be removed and be replaced by Timesteps
-class SinusoidalPosEmb(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.dim = dim
-
- def forward(self, x, type_tensor=None):
- half_dim = self.dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
- emb = x[:, None] * emb[None, :]
- return torch.cat((emb.sin(), emb.cos()), dim=-1)
-
-
class Kandinsky3EncoderProj(nn.Module):
def __init__(self, encoder_hid_dim, cross_attention_dim):
super().__init__()
@@ -87,9 +69,7 @@ def __init__(
out_channels = in_channels
init_channels = block_out_channels[0] // 2
- # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
- # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
- self.time_proj = SinusoidalPosEmb(init_channels)
+ self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
self.time_embedding = TimestepEmbedding(
init_channels,
@@ -106,7 +86,7 @@ def __init__(
hidden_dims = [init_channels] + list(block_out_channels)
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
- text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
+ text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention]
num_blocks = len(block_out_channels) * [layers_per_block]
layer_params = [num_blocks, text_dims, add_self_attention]
rev_layer_params = map(reversed, layer_params)
@@ -118,7 +98,7 @@ def __init__(
zip(in_out_dims, *layer_params)
):
down_sample = level != (self.num_levels - 1)
- cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
+ cat_dims.append(out_dim if level != (self.num_levels - 1) else 0)
self.down_blocks.append(
Kandinsky3DownSampleBlock(
in_dim,
@@ -223,18 +203,16 @@ def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
- self.set_attn_processor(Kandi3AttnProcessor())
+ self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
- # TODO(Yiyi): Clean up the following variables - these names should not be used
- # but instead only the ones that we pass to forward
- x = sample
- context_mask = encoder_attention_mask
- context = encoder_hidden_states
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if not torch.is_tensor(timestep):
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
@@ -244,33 +222,33 @@ def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attentio
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(sample.shape[0])
- time_embed_input = self.time_proj(timestep).to(x.dtype)
+ time_embed_input = self.time_proj(timestep).to(sample.dtype)
time_embed = self.time_embedding(time_embed_input)
- context = self.encoder_hid_proj(context)
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
- if context is not None:
- time_embed = self.add_time_condition(time_embed, context, context_mask)
+ if encoder_hidden_states is not None:
+ time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
hidden_states = []
- x = self.conv_in(x)
+ sample = self.conv_in(sample)
for level, down_sample in enumerate(self.down_blocks):
- x = down_sample(x, time_embed, context, context_mask)
+ sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
if level != self.num_levels - 1:
- hidden_states.append(x)
+ hidden_states.append(sample)
for level, up_sample in enumerate(self.up_blocks):
if level != 0:
- x = torch.cat([x, hidden_states.pop()], dim=1)
- x = up_sample(x, time_embed, context, context_mask)
+ sample = torch.cat([sample, hidden_states.pop()], dim=1)
+ sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
- x = self.conv_norm_out(x)
- x = self.conv_act_out(x)
- x = self.conv_out(x)
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act_out(sample)
+ sample = self.conv_out(sample)
if not return_dict:
- return (x,)
- return Kandinsky3UNetOutput(sample=x)
+ return (sample,)
+ return Kandinsky3UNetOutput(sample=sample)
class Kandinsky3UpSampleBlock(nn.Module):
@@ -290,7 +268,7 @@ def __init__(
self_attention=True,
):
super().__init__()
- up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
+ up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
hidden_channels = (
[(in_channels + cat_dim, in_channels)]
+ [(in_channels, in_channels)] * (num_blocks - 2)
@@ -303,27 +281,27 @@ def __init__(
self.self_attention = self_attention
self.context_dim = context_dim
- attentions.append(
- set_default_layer(
- self_attention,
- Kandinsky3AttentionBlock,
- (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
- layer_2=nn.Identity,
+ if self_attention:
+ attentions.append(
+ Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
- )
+ else:
+ attentions.append(nn.Identity())
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
- attentions.append(
- set_default_layer(
- context_dim is not None,
- Kandinsky3AttentionBlock,
- (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
- layer_2=nn.Identity,
+
+ if context_dim is not None:
+ attentions.append(
+ Kandinsky3AttentionBlock(
+ in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
+ )
)
- )
+ else:
+ attentions.append(nn.Identity())
+
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
@@ -367,29 +345,29 @@ def __init__(
self.self_attention = self_attention
self.context_dim = context_dim
- attentions.append(
- set_default_layer(
- self_attention,
- Kandinsky3AttentionBlock,
- (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
- layer_2=nn.Identity,
+ if self_attention:
+ attentions.append(
+ Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
- )
+ else:
+ attentions.append(nn.Identity())
- up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
- attentions.append(
- set_default_layer(
- context_dim is not None,
- Kandinsky3AttentionBlock,
- (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
- layer_2=nn.Identity,
+
+ if context_dim is not None:
+ attentions.append(
+ Kandinsky3AttentionBlock(
+ out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
+ )
)
- )
+ else:
+ attentions.append(nn.Identity())
+
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
@@ -431,68 +409,23 @@ def forward(self, x, context):
return x
-# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
-# sure we can delete it and instead just pass an attention_mask
-class Attention(nn.Module):
- def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
- super().__init__()
- assert out_channels % head_dim == 0
- self.num_heads = out_channels // head_dim
- self.scale = head_dim**-0.5
-
- # to_q
- self.to_q = nn.Linear(in_channels, out_channels, bias=False)
- # to_k
- self.to_k = nn.Linear(context_dim, out_channels, bias=False)
- # to_v
- self.to_v = nn.Linear(context_dim, out_channels, bias=False)
- processor = Kandi3AttnProcessor()
- self.set_processor(processor)
- # to_out
- self.to_out = nn.ModuleList([])
- self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
-
- def set_processor(self, processor: "AttnProcessor"): # noqa: F821
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
- # pop `processor` from `self._modules`
- if (
- hasattr(self, "processor")
- and isinstance(self.processor, torch.nn.Module)
- and not isinstance(processor, torch.nn.Module)
- ):
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
- self._modules.pop("processor")
-
- self.processor = processor
-
- def forward(self, x, context, context_mask=None, image_mask=None):
- return self.processor(
- self,
- x,
- context=context,
- context_mask=context_mask,
- )
-
-
class Kandinsky3Block(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
super().__init__()
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
self.activation = nn.SiLU()
- self.up_sample = set_default_layer(
- up_resolution is not None and up_resolution,
- nn.ConvTranspose2d,
- (in_channels, in_channels),
- {"kernel_size": 2, "stride": 2},
- )
+ if up_resolution is not None and up_resolution:
+ self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
+ else:
+ self.up_sample = nn.Identity()
+
padding = int(kernel_size > 1)
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
- self.down_sample = set_default_layer(
- up_resolution is not None and not up_resolution,
- nn.Conv2d,
- (out_channels, out_channels),
- {"kernel_size": 2, "stride": 2},
- )
+
+ if up_resolution is not None and not up_resolution:
+ self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
+ else:
+ self.down_sample = nn.Identity()
def forward(self, x, time_embed):
x = self.group_norm(x, time_embed)
@@ -521,14 +454,18 @@ def __init__(
)
]
)
- self.shortcut_up_sample = set_default_layer(
- True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
+ self.shortcut_up_sample = (
+ nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
+ if True in up_resolutions
+ else nn.Identity()
)
- self.shortcut_projection = set_default_layer(
- in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
+ self.shortcut_projection = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
)
- self.shortcut_down_sample = set_default_layer(
- False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
+ self.shortcut_down_sample = (
+ nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
+ if False in up_resolutions
+ else nn.Identity()
)
def forward(self, x, time_embed):
@@ -546,9 +483,16 @@ def forward(self, x, time_embed):
class Kandinsky3AttentionPooling(nn.Module):
def __init__(self, num_channels, context_dim, head_dim=64):
super().__init__()
- self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
+ self.attention = Attention(
+ context_dim,
+ context_dim,
+ dim_head=head_dim,
+ out_dim=num_channels,
+ out_bias=False,
+ )
def forward(self, x, context, context_mask=None):
+ context_mask = context_mask.to(dtype=context.dtype)
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
return x + context.squeeze(1)
@@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module):
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
super().__init__()
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
- self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
+ self.attention = Attention(
+ num_channels,
+ context_dim or num_channels,
+ dim_head=head_dim,
+ out_dim=num_channels,
+ out_bias=False,
+ )
hidden_channels = expansion_ratio * num_channels
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
@@ -572,14 +522,10 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non
out = self.in_norm(x, time_embed)
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
context = context if context is not None else out
+ if context_mask is not None:
+ context_mask = context_mask.to(dtype=context.dtype)
- if image_mask is not None:
- mask_height, mask_width = image_mask.shape[-2:]
- kernel_size = (mask_height // height, mask_width // width)
- image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
- image_mask = image_mask.reshape(image_mask.shape[0], -1)
-
- out = self.attention(out, context, context_mask, image_mask)
+ out = self.attention(out, context, context_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
index b5c7aee4b4de..2121e9b81509 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -22,7 +22,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -494,18 +494,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -875,7 +886,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 4272fa124755..401e6aef82b1 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -24,7 +24,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -505,18 +505,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -919,7 +930,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 28dc220545dc..32a08a0264bc 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -22,7 +22,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...schedulers import (
@@ -320,18 +320,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
@@ -651,7 +662,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
+ )
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 1e19678b221d..bf6ef2125446 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -24,7 +24,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -479,18 +479,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -1067,7 +1078,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 72c2250dd5ac..71e237ce4e02 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -25,7 +25,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -597,18 +597,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -1284,7 +1295,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 4696781dce0c..8c8399809228 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -37,7 +37,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -489,18 +489,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -1169,7 +1180,10 @@ def __call__(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/kandinsky3/__init__.py b/src/diffusers/pipelines/kandinsky3/__init__.py
index 4da3a83c0448..e8a3063141b5 100644
--- a/src/diffusers/pipelines/kandinsky3/__init__.py
+++ b/src/diffusers/pipelines/kandinsky3/__init__.py
@@ -21,8 +21,8 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
- _import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
+ _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]
+ _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -33,8 +33,8 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
- from .kandinsky3_pipeline import Kandinsky3Pipeline
- from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
+ from .pipeline_kandinsky3 import Kandinsky3Pipeline
+ from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
else:
import sys
diff --git a/tests/convert_kandinsky3_unet.py b/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
similarity index 100%
rename from tests/convert_kandinsky3_unet.py
rename to src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
similarity index 70%
rename from src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py
rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
index f116fb7894f0..fcf7ddcb9966 100644
--- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
@@ -1,4 +1,4 @@
-from typing import Callable, List, Optional, Union
+from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
@@ -7,8 +7,10 @@
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ deprecate,
is_accelerate_available,
logging,
+ replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -16,6 +18,23 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import AutoPipelineForText2Image
+ >>> import torch
+
+ >>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(0)
+ >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
+ ```
+
+"""
+
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
@@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8):
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "negative_attention_mask",
+ "attention_mask",
+ ]
def __init__(
self,
@@ -50,7 +76,7 @@ def remove_all_hooks(self):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
- for model in [self.text_encoder, self.unet]:
+ for model in [self.text_encoder, self.unet, self.movq]:
if model is not None:
remove_hook_from_module(model, recurse=True)
@@ -77,12 +103,14 @@ def encode_prompt(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
@@ -101,6 +129,10 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -228,14 +260,21 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ attention_mask=None,
+ negative_attention_mask=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -262,8 +301,42 @@ def check_inputs(
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
+ if negative_prompt_embeds is not None and negative_attention_mask is None:
+ raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
+
+ if negative_prompt_embeds is not None and negative_attention_mask is not None:
+ if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
+ raise ValueError(
+ "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
+ f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
+ f" {negative_attention_mask.shape}."
+ )
+
+ if prompt_embeds is not None and attention_mask is None:
+ raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
+
+ if prompt_embeds is not None and attention_mask is not None:
+ if prompt_embeds.shape[:2] != attention_mask.shape:
+ raise ValueError(
+ "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
+ f" {attention_mask.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -276,11 +349,14 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
latents=None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
):
"""
Function invoked when calling the pipeline for generation.
@@ -289,7 +365,7 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
- num_inference_steps (`int`, *optional*, defaults to 50):
+ num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
@@ -324,6 +400,10 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -343,12 +423,53 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+
"""
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
cut_context = True
device = self._execution_device
# 1. Check inputs. Raise error if not correct
- self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+ self.check_inputs(
+ prompt,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -357,24 +478,21 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
+ attention_mask=attention_mask,
+ negative_attention_mask=negative_attention_mask,
)
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
# 4. Prepare timesteps
@@ -397,11 +515,11 @@ def __call__(
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
- # TODO(Yiyi): Correct the following line and use correctly
- # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
@@ -412,7 +530,7 @@ def __call__(
return_dict=False,
)[0]
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
@@ -425,26 +543,45 @@ def __call__(
latents,
generator=generator,
).prev_sample
- progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- step_idx = i // getattr(self.scheduler, "order", 1)
- callback(step_idx, t, latents)
- # post-processing
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ attention_mask = callback_outputs.pop("attention_mask", attention_mask)
+ negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
- if output_type not in ["pt", "np", "pil"]:
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # post-processing
+ if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
- f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
+ f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ if not output_type == "latent":
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+ else:
+ image = latents
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
similarity index 59%
rename from src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py
rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index b043110cf1d7..7f4164a04d1e 100644
--- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -1,5 +1,5 @@
import inspect
-from typing import Callable, List, Optional, Union
+from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL
@@ -11,8 +11,10 @@
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ deprecate,
is_accelerate_available,
logging,
+ replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -20,6 +22,24 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import AutoPipelineForImage2Image
+ >>> from diffusers.utils import load_image
+ >>> import torch
+
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A painting of the inside of a subway train with tiny raccoons."
+ >>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(0)
+ >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
+ ```
+"""
+
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
@@ -40,7 +60,14 @@ def prepare_image(pil_image):
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
- model_cpu_offload_seq = "text_encoder->unet->movq"
+ model_cpu_offload_seq = "text_encoder->movq->unet->movq"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "negative_attention_mask",
+ "attention_mask",
+ ]
def __init__(
self,
@@ -99,6 +126,8 @@ def encode_prompt(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -123,6 +152,10 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -299,15 +332,23 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ attention_mask=None,
+ negative_attention_mask=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -334,7 +375,42 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
+ if negative_prompt_embeds is not None and negative_attention_mask is None:
+ raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
+
+ if negative_prompt_embeds is not None and negative_attention_mask is not None:
+ if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
+ raise ValueError(
+ "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
+ f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
+ f" {negative_attention_mask.shape}."
+ )
+
+ if prompt_embeds is not None and attention_mask is None:
+ raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
+
+ if prompt_embeds is not None and attention_mask is not None:
+ if prompt_embeds.shape[:2] != attention_mask.shape:
+ raise ValueError(
+ "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
+ f" {attention_mask.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -347,15 +423,117 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
- latents=None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 3.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
cut_context = True
# 1. Check inputs. Raise error if not correct
- self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+ self.check_inputs(
+ prompt,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -366,24 +544,21 @@ def __call__(
device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
+ attention_mask=attention_mask,
+ negative_attention_mask=negative_attention_mask,
)
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
if not isinstance(image, list):
@@ -409,11 +584,11 @@ def __call__(
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
- # TODO(Yiyi): Correct the following line and use correctly
- # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
@@ -422,7 +597,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask,
)[0]
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
@@ -434,25 +609,44 @@ def __call__(
latents,
generator=generator,
).prev_sample
- progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- step_idx = i // getattr(self.scheduler, "order", 1)
- callback(step_idx, t, latents)
- # post-processing
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type not in ["pt", "np", "pil"]:
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ attention_mask = callback_outputs.pop("attention_mask", attention_mask)
+ negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # post-processing
+ if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
- f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
+ f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
+ if not output_type == "latent":
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+ else:
+ image = latents
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index b84344fab85e..7d889f3afa4c 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -758,10 +758,10 @@ def to(self, *args, **kwargs):
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
- deprecate("torch_dtype", "0.25.0", "")
+ deprecate("torch_dtype", "0.27.0", "")
torch_device = kwargs.pop("torch_device", None)
if torch_device is not None:
- deprecate("torch_device", "0.25.0", "")
+ deprecate("torch_device", "0.27.0", "")
dtype_kwarg = kwargs.pop("dtype", None)
device_kwarg = kwargs.pop("device", None)
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 32e36aaddc53..090b66915dd0 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -134,6 +134,51 @@
}
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
@@ -783,8 +828,7 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index 5706298a281a..d81831082e2f 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -34,7 +34,6 @@
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
- _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 35466f008f54..6960ba6c4516 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -446,7 +446,7 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
- if "num_class_embeds" in config:
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
@@ -1480,9 +1480,12 @@ def download_from_original_stable_diffusion_ckpt(
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
- text_model = convert_open_clip_checkpoint(
- checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
- )
+ if text_encoder is None:
+ text_model = convert_open_clip_checkpoint(
+ checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
+ )
+ else:
+ text_model = text_encoder
try:
tokenizer = CLIPTokenizer.from_pretrained(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 57875a413f92..f949990c7ab4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -26,7 +26,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -494,18 +494,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -876,7 +887,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index e3a1a0ed3660..c80178152a6e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -24,7 +24,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -503,18 +503,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -923,7 +934,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 4ba2872b1d30..62139381a8e2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -24,7 +24,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
+from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -575,18 +575,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -1104,7 +1115,10 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 40c981a46d48..12d52aa076d4 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -31,7 +31,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -524,18 +524,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -1087,7 +1098,10 @@ def __call__(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 436d816e5eb3..729924ec2e20 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -32,7 +32,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -741,18 +741,29 @@ def prepare_latents(
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
def _get_add_time_ids(
self,
@@ -1259,7 +1270,10 @@ def denoising_value_valid(dnv):
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index f54b680dfd7c..7195b5f2521a 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -33,7 +33,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -462,18 +462,29 @@ def disable_vae_tiling(self):
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -1568,7 +1579,10 @@ def denoising_value_valid(dnv):
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index a940cec5e46a..0a2f1ca17cb0 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -50,6 +50,9 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
num_attention_heads,
+ transformer_layers_per_block,
+ attention_type,
+ attention_head_dim,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
@@ -113,6 +116,10 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
num_attention_heads,
+ transformer_layers_per_block,
+ resolution_idx,
+ attention_type,
+ attention_head_dim,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 6aa994676577..7cf6a9b33b37 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -162,6 +162,7 @@ def __init__(
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 4b638547b38a..beab985e3350 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -189,6 +189,7 @@ def __init__(
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index e762c0ec8bba..61d6810ce286 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -184,6 +184,7 @@ def __init__(
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 2c0be3b842cc..0f1175472f3e 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -172,6 +172,7 @@ def __init__(
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index d778f37ec059..1d58ab5259ef 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -175,6 +175,7 @@ def __init__(
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index fdc83237f9f9..992ae7d1b194 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
- lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
+ lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 6050f314c008..7945db333cab 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -213,7 +213,7 @@ def remove_handler(handler: logging.Handler) -> None:
_configure_library_root_logger()
- assert handler is not None and handler not in _get_library_root_logger().handlers
+ assert handler is not None and handler in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 2998f7dc429e..14b89c3cd3b9 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -17,7 +17,7 @@
from distutils.util import strtobool
from io import BytesIO, StringIO
from pathlib import Path
-from typing import List, Optional, Union
+from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
@@ -58,6 +58,17 @@
if is_torch_available():
import torch
+ # Set a backend environment variable for any extra module import required for a custom accelerator
+ if "DIFFUSERS_TEST_BACKEND" in os.environ:
+ backend = os.environ["DIFFUSERS_TEST_BACKEND"]
+ try:
+ _ = importlib.import_module(backend)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
+ to enable a specified backend.):\n{e}"
+ ) from e
+
if "DIFFUSERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
try:
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
)
+# These decorators are for accelerator-specific behaviours that are not GPU-specific
+def require_torch_accelerator(test_case):
+ """Decorator marking a test that requires an accelerator backend and PyTorch."""
+ return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
+ test_case
+ )
+
+
+def require_torch_accelerator_with_fp16(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP16 data type."""
+ return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
+ test_case
+ )
+
+
+def require_torch_accelerator_with_fp64(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP64 data type."""
+ return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
+ test_case
+ )
+
+
+def require_torch_accelerator_with_training(test_case):
+ """Decorator marking a test that requires an accelerator with support for training."""
+ return unittest.skipUnless(
+ is_torch_available() and backend_supports_training(torch_device),
+ "test requires accelerator with training support",
+ )(test_case)
+
+
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
@@ -766,3 +807,139 @@ def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)
+
+
+# Utils for custom and alternative accelerator devices
+def _is_torch_fp16_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ device = torch.device(device)
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float16).to(device)
+ _ = x @ x
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+def _is_torch_fp64_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float64).to(device)
+ _ = x @ x
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
+if is_torch_available():
+ # Behaviour flags
+ BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
+
+ # Function definitions
+ BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
+ BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
+ BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
+
+
+# This dispatches a defined function according to the accelerator from the function definitions.
+def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
+ if device not in dispatch_table:
+ return dispatch_table["default"](*args, **kwargs)
+
+ fn = dispatch_table[device]
+
+ # Some device agnostic functions return values. Need to guard against 'None' instead at
+ # user level
+ if fn is None:
+ return None
+
+ return fn(*args, **kwargs)
+
+
+# These are callables which automatically dispatch the function specific to the accelerator
+def backend_manual_seed(device: str, seed: int):
+ return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_empty_cache(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+# These are callables which return boolean behaviour flags and can be used to specify some
+# device agnostic alternative where the feature is unsupported.
+def backend_supports_training(device: str):
+ if not is_torch_available():
+ return False
+
+ if device not in BACKEND_SUPPORTS_TRAINING:
+ device = "default"
+
+ return BACKEND_SUPPORTS_TRAINING[device]
+
+
+# Guard for when Torch is not available
+if is_torch_available():
+ # Update device function dict mapping
+ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
+ try:
+ # Try to import the function directly
+ spec_fn = getattr(device_spec_module, attribute_name)
+ device_fn_dict[torch_device] = spec_fn
+ except AttributeError as e:
+ # If the function doesn't exist, and there is no default, throw an error
+ if "default" not in device_fn_dict:
+ raise AttributeError(
+ f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
+ ) from e
+
+ if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
+ device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
+ if not Path(device_spec_path).is_file():
+ raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
+
+ try:
+ import_name = device_spec_path[: device_spec_path.index(".py")]
+ except ValueError as e:
+ raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
+
+ device_spec_module = importlib.import_module(import_name)
+
+ try:
+ device_name = device_spec_module.DEVICE_NAME
+ except AttributeError:
+ raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
+ msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
+ msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
+ raise ValueError(msg)
+
+ torch_device = device_name
+
+ # Add one entry here for each `BACKEND_*` dictionary.
+ update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
+ update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
+ update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
+ update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py
index 9d45d810f653..c6e3e19d4cc3 100644
--- a/tests/models/test_layers_utils.py
+++ b/tests/models/test_layers_utils.py
@@ -25,7 +25,11 @@
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import (
+ backend_manual_seed,
+ require_torch_accelerator_with_fp64,
+ torch_device,
+)
class EmbeddingsTests(unittest.TestCase):
@@ -315,8 +319,7 @@ def test_restnet_with_kernel_sde_vp(self):
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
@@ -339,8 +342,7 @@ def test_spatial_transformer_default(self):
def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
@@ -363,8 +365,7 @@ def test_spatial_transformer_cross_attention_dim(self):
def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
num_embeds_ada_norm = 5
@@ -401,8 +402,7 @@ def test_spatial_transformer_timestep(self):
def test_spatial_transformer_dropout(self):
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = (
@@ -427,11 +427,10 @@ def test_spatial_transformer_dropout(self):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
- @unittest.skipIf(torch_device == "mps", "MPS does not support float64")
+ @require_torch_accelerator_with_fp64
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
num_embed = 5
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 961147839461..5ea0d910f3a3 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -35,6 +35,7 @@
CaptureLogger,
require_python39_or_higher,
require_torch_2,
+ require_torch_accelerator_with_training,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
@@ -536,7 +537,7 @@ def test_model_from_pretrained(self):
self.assertEqual(output_1.shape, output_2.shape)
- @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
+ @require_torch_accelerator_with_training
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -553,7 +554,7 @@ def test_training(self):
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
- @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
+ @require_torch_accelerator_with_training
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -624,7 +625,7 @@ def recursive_check(tuple_object, dict_object):
recursive_check(outputs_tuple, outputs_dict)
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
+ @require_torch_accelerator_with_training
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
diff --git a/tests/models/test_models_prior.py b/tests/models/test_models_prior.py
index 9b02de463ecd..ca27b6ff057f 100644
--- a/tests/models/test_models_prior.py
+++ b/tests/models/test_models_prior.py
@@ -21,7 +21,14 @@
from parameterized import parameterized
from diffusers import PriorTransformer
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ slow,
+ torch_all_close,
+ torch_device,
+)
from .test_modeling_common import ModelTesterMixin
@@ -157,7 +164,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
@parameterized.expand(
[
diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py
index 5803e5bfda2a..aad496416508 100644
--- a/tests/models/test_models_unet_1d.py
+++ b/tests/models/test_models_unet_1d.py
@@ -18,7 +18,12 @@
import torch
from diffusers import UNet1DModel
-from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_manual_seed,
+ floats_tensor,
+ slow,
+ torch_device,
+)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -103,8 +108,7 @@ def test_from_pretrained_hub(self):
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
num_features = model.config.in_channels
seq_len = 16
@@ -244,8 +248,7 @@ def test_output_pretrained(self):
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
num_features = value_function.config.in_channels
seq_len = 14
diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py
index 4fd991b3fc46..2be343e9d627 100644
--- a/tests/models/test_models_unet_2d.py
+++ b/tests/models/test_models_unet_2d.py
@@ -24,6 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
+ require_torch_accelerator,
slow,
torch_all_close,
torch_device,
@@ -153,7 +154,7 @@ def test_from_pretrained_hub(self):
assert image is not None, "Make sure output is not None"
- @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
+ @require_torch_accelerator
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
@@ -161,7 +162,7 @@ def test_from_pretrained_accelerate(self):
assert image is not None, "Make sure output is not None"
- @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
+ @require_torch_accelerator
def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py
index 06bf2685560d..80f59734b5ce 100644
--- a/tests/models/test_models_unet_2d_condition.py
+++ b/tests/models/test_models_unet_2d_condition.py
@@ -18,6 +18,7 @@
import os
import tempfile
import unittest
+from collections import OrderedDict
import torch
from parameterized import parameterized
@@ -25,14 +26,19 @@
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
-from diffusers.models.embeddings import ImageProjection
+from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
+ require_torch_accelerator,
+ require_torch_accelerator_with_fp16,
+ require_torch_accelerator_with_training,
require_torch_gpu,
+ skip_mps,
slow,
torch_all_close,
torch_device,
@@ -97,6 +103,85 @@ def create_ip_adapter_state_dict(model):
return ip_state_dict
+def create_ip_adapter_plus_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ ip_cross_attn_state_dict = {}
+ key_id = 1
+
+ for name in model.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = model.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(model.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = model.config.block_out_channels[block_id]
+ if cross_attention_dim is not None:
+ sd = IPAdapterAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
+ }
+ )
+
+ key_id += 2
+
+ # "image_proj" (ImageProjection layer weights)
+ cross_attention_dim = model.config["cross_attention_dim"]
+ image_projection = Resampler(
+ embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
+ )
+
+ ip_image_projection_state_dict = OrderedDict()
+ for k, v in image_projection.state_dict().items():
+ if "2.to" in k:
+ k = k.replace("2.to", "0.to")
+ elif "3.0.weight" in k:
+ k = k.replace("3.0.weight", "1.0.weight")
+ elif "3.0.bias" in k:
+ k = k.replace("3.0.bias", "1.0.bias")
+ elif "3.0.weight" in k:
+ k = k.replace("3.0.weight", "1.0.weight")
+ elif "3.1.net.0.proj.weight" in k:
+ k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
+ elif "3.net.2.weight" in k:
+ k = k.replace("3.net.2.weight", "1.3.weight")
+ elif "layers.0.0" in k:
+ k = k.replace("layers.0.0", "layers.0.0.norm1")
+ elif "layers.0.1" in k:
+ k = k.replace("layers.0.1", "layers.0.0.norm2")
+ elif "layers.1.0" in k:
+ k = k.replace("layers.1.0", "layers.1.0.norm1")
+ elif "layers.1.1" in k:
+ k = k.replace("layers.1.1", "layers.1.0.norm2")
+ elif "layers.2.0" in k:
+ k = k.replace("layers.2.0", "layers.2.0.norm1")
+ elif "layers.2.1" in k:
+ k = k.replace("layers.2.1", "layers.2.0.norm2")
+
+ if "norm_cross" in k:
+ ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
+ elif "layer_norm" in k:
+ ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
+ elif "to_k" in k:
+ ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
+ elif "to_v" in k:
+ continue
+ elif "to_out.0" in k:
+ ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
+ else:
+ ip_image_projection_state_dict[k] = v
+
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True
@@ -200,7 +285,7 @@ def test_xformers_enable_works(self):
== "XFormersAttnProcessor"
), "xformers is not enabled"
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
+ @require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -724,6 +809,56 @@ def test_ip_adapter(self):
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
+ def test_ip_adapter_plus(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["attention_head_dim"] = (8, 16)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ # forward pass without ip-adapter
+ with torch.no_grad():
+ sample1 = model(**inputs_dict).sample
+
+ # update inputs_dict for ip-adapter
+ batch_size = inputs_dict["encoder_hidden_states"].shape[0]
+ image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
+
+ # make ip_adapter_1 and ip_adapter_2
+ ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
+
+ image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
+ cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
+ ip_adapter_2 = {}
+ ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
+
+ # forward pass ip_adapter_1
+ model._load_ip_adapter_weights(ip_adapter_1)
+ assert model.config.encoder_hid_dim_type == "ip_image_proj"
+ assert model.encoder_hid_proj is not None
+ assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
+ "IPAdapterAttnProcessor",
+ "IPAdapterAttnProcessor2_0",
+ )
+ with torch.no_grad():
+ sample2 = model(**inputs_dict).sample
+
+ # forward pass with ip_adapter_2
+ model._load_ip_adapter_weights(ip_adapter_2)
+ with torch.no_grad():
+ sample3 = model(**inputs_dict).sample
+
+ # forward pass with ip_adapter_1 again
+ model._load_ip_adapter_weights(ip_adapter_1)
+ with torch.no_grad():
+ sample4 = model(**inputs_dict).sample
+
+ assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
+ assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
+ assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
+
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
@@ -734,7 +869,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
@@ -752,6 +887,7 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
return model
+ @require_torch_gpu
def test_set_attention_slice_auto(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -771,6 +907,7 @@ def test_set_attention_slice_auto(self):
assert mem_bytes < 5 * 10**9
+ @require_torch_gpu
def test_set_attention_slice_max(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -790,6 +927,7 @@ def test_set_attention_slice_max(self):
assert mem_bytes < 5 * 10**9
+ @require_torch_gpu
def test_set_attention_slice_int(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -809,6 +947,7 @@ def test_set_attention_slice_int(self):
assert mem_bytes < 5 * 10**9
+ @require_torch_gpu
def test_set_attention_slice_list(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -845,7 +984,7 @@ def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
@@ -873,7 +1012,7 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
@@ -901,7 +1040,8 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator
+ @skip_mps
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
@@ -929,7 +1069,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
latents = self.get_latents(seed, fp16=True)
@@ -957,7 +1097,8 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator
+ @skip_mps
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
@@ -985,7 +1126,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
@@ -1013,7 +1154,7 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py
index aa755e387b61..df34a48da3aa 100644
--- a/tests/models/test_models_vae.py
+++ b/tests/models/test_models_vae.py
@@ -31,10 +31,15 @@
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
+ require_torch_accelerator,
+ require_torch_accelerator_with_fp16,
+ require_torch_accelerator_with_training,
require_torch_gpu,
+ skip_mps,
slow,
torch_all_close,
torch_device,
@@ -157,7 +162,7 @@ def test_forward_signature(self):
def test_training(self):
pass
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
+ @require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -213,10 +218,12 @@ def test_output_pretrained(self):
model = model.to(torch_device)
model.eval()
- if torch_device == "mps":
- generator = torch.manual_seed(0)
+ # Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
+ generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
+ if torch_device != "mps":
+ generator = torch.Generator(device=generator_device).manual_seed(0)
else:
- generator = torch.Generator(device=torch_device).manual_seed(0)
+ generator = torch.manual_seed(0)
image = torch.randn(
1,
@@ -247,7 +254,7 @@ def test_output_pretrained(self):
-9.8644e-03,
]
)
- elif torch_device == "cpu":
+ elif generator_device == "cpu":
expected_output_slice = torch.tensor(
[
-0.1352,
@@ -478,7 +485,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -558,7 +565,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
@@ -580,9 +587,10 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
return model
def get_generator(self, seed=0):
- if torch_device == "mps":
- return torch.manual_seed(seed)
- return torch.Generator(device=torch_device).manual_seed(seed)
+ generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
+ if torch_device != "mps":
+ return torch.Generator(device=generator_device).manual_seed(seed)
+ return torch.manual_seed(seed)
@parameterized.expand(
[
@@ -623,7 +631,7 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_stable_diffusion_fp16(self, seed, expected_slice):
model = self.get_sd_vae_model(fp16=True)
image = self.get_sd_image(seed, fp16=True)
@@ -677,7 +685,8 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator
+ @skip_mps
def test_stable_diffusion_decode(self, seed, expected_slice):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -700,7 +709,7 @@ def test_stable_diffusion_decode(self, seed, expected_slice):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator_with_fp16
def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -811,7 +820,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
@@ -832,9 +841,10 @@ def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x
return model
def get_generator(self, seed=0):
- if torch_device == "mps":
- return torch.manual_seed(seed)
- return torch.Generator(device=torch_device).manual_seed(seed)
+ generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
+ if torch_device != "mps":
+ return torch.Generator(device=generator_device).manual_seed(seed)
+ return torch.manual_seed(seed)
@parameterized.expand(
[
@@ -905,7 +915,8 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
# fmt: on
]
)
- @require_torch_gpu
+ @require_torch_accelerator
+ @skip_mps
def test_stable_diffusion_decode(self, seed, expected_slice):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py
index c7b9363b5d5f..a5a9288d6462 100644
--- a/tests/models/test_models_vq.py
+++ b/tests/models/test_models_vq.py
@@ -18,7 +18,12 @@
import torch
from diffusers import VQModel
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from diffusers.utils.testing_utils import (
+ backend_manual_seed,
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -80,8 +85,7 @@ def test_output_pretrained(self):
model.to(torch_device).eval()
torch.manual_seed(0)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(0)
+ backend_manual_seed(torch_device, 0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
diff --git a/tests/models/test_unet_blocks_common.py b/tests/models/test_unet_blocks_common.py
index 4c399fdb74fa..9d1ddc2457e3 100644
--- a/tests/models/test_unet_blocks_common.py
+++ b/tests/models/test_unet_blocks_common.py
@@ -12,12 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from typing import Tuple
import torch
-from diffusers.utils.testing_utils import floats_tensor, require_torch, torch_all_close, torch_device
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_torch,
+ require_torch_accelerator_with_training,
+ torch_all_close,
+ torch_device,
+)
from diffusers.utils.torch_utils import randn_tensor
@@ -104,7 +109,7 @@ def test_output(self, expected_slice):
expected_slice = torch.tensor(expected_slice).to(torch_device)
assert torch_all_close(output_slice.flatten(), expected_slice, atol=5e-3)
- @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
+ @require_torch_accelerator_with_training
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.block_class(**init_dict)
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index 57eb49013c1f..7c6349ce2600 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -116,7 +116,17 @@ def test_text_to_image(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0])
+ expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
+
+ inputs = self.get_dummy_inputs()
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@@ -132,7 +142,17 @@ def test_image_to_image(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935])
+ expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
+
+ inputs = self.get_dummy_inputs(for_image_to_image=True)
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@@ -148,7 +168,17 @@ def test_inpainting(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997])
+ expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
+
+ inputs = self.get_dummy_inputs(for_inpainting=True)
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@@ -173,7 +203,30 @@ def test_text_to_image_sdxl(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923])
+ expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ torch_dtype=self.dtype,
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter-plus_sdxl_vit-h.bin",
+ )
+
+ inputs = self.get_dummy_inputs()
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@@ -194,7 +247,31 @@ def test_image_to_image_sdxl(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752])
+ expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
+
+ pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ torch_dtype=self.dtype,
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter-plus_sdxl_vit-h.bin",
+ )
+
+ inputs = self.get_dummy_inputs(for_image_to_image=True)
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@@ -216,6 +293,31 @@ def test_inpainting_sdxl(self):
image_slice = images[0, :3, :3, -1].flatten()
image_slice.tolist()
- expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516])
+ expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517])
+
+ assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
+
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
+
+ pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ torch_dtype=self.dtype,
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter-plus_sdxl_vit-h.bin",
+ )
+
+ inputs = self.get_dummy_inputs(for_inpainting=True)
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ image_slice.tolist()
+
+ expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index 65297a36b157..c163fe3102c4 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -165,10 +165,6 @@ def test_float16_inference(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
- def test_model_cpu_offload_forward_pass(self):
- # TODO(Yiyi) - this test should work, skipped for time reasons for now
- pass
-
@slow
@require_torch_gpu
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
new file mode 100644
index 000000000000..581251a81639
--- /dev/null
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -0,0 +1,225 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoPipelineForImage2Image,
+ Kandinsky3Img2ImgPipeline,
+ Kandinsky3UNet,
+ VQModel,
+)
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ require_torch_gpu,
+ slow,
+)
+
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+ TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky3Img2ImgPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
+ test_xformers_attention = False
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "num_images_per_prompt",
+ "generator",
+ "output_type",
+ "return_dict",
+ ]
+ )
+
+ @property
+ def dummy_movq_kwargs(self):
+ return {
+ "block_out_channels": [32, 64],
+ "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 1,
+ "norm_num_groups": 8,
+ "norm_type": "spatial",
+ "num_vq_embeddings": 12,
+ "out_channels": 3,
+ "up_block_types": [
+ "AttnUpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "vq_embed_dim": 4,
+ }
+
+ @property
+ def dummy_movq(self):
+ torch.manual_seed(0)
+ model = VQModel(**self.dummy_movq_kwargs)
+ return model
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = Kandinsky3UNet(
+ in_channels=4,
+ time_embedding_dim=4,
+ groups=2,
+ attention_head_dim=4,
+ layers_per_block=3,
+ block_out_channels=(32, 64),
+ cross_attention_dim=4,
+ encoder_hid_dim=32,
+ )
+ scheduler = DDPMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="squaredcos_cap_v2",
+ clip_sample=True,
+ thresholding=False,
+ )
+ torch.manual_seed(0)
+ movq = self.dummy_movq
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "movq": movq,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ # create init_image
+ image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB")
+
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": init_image,
+ "generator": generator,
+ "strength": 0.75,
+ "num_inference_steps": 10,
+ "guidance_scale": 6.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_kandinsky3_img2img(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+
+ pipe.set_progress_bar_config(disable=None)
+
+ output = pipe(**self.get_dummy_inputs(device))
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array(
+ [0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
+ )
+
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+
+ def test_float16_inference(self):
+ super().test_float16_inference(expected_max_diff=1e-1)
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=1e-2)
+
+
+@slow
+@require_torch_gpu
+class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_kandinskyV3_img2img(self):
+ pipe = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
+ )
+ w, h = 512, 512
+ image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
+ prompt = "A painting of the inside of a subway train with tiny raccoons."
+
+ image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
+
+ assert image.size == (512, 512)
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png"
+ )
+
+ image_processor = VaeImageProcessor()
+
+ image_np = image_processor.pil_to_numpy(image)
+ expected_image_np = image_processor.pil_to_numpy(expected_image)
+
+ self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))
diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py
index eced49e04261..361bacc298e9 100644
--- a/tests/pipelines/pixart/test_pixart.py
+++ b/tests/pipelines/pixart/test_pixart.py
@@ -329,40 +329,6 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
- def test_pixart_1024_fast(self):
- generator = torch.manual_seed(0)
-
- pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
- prompt = self.prompt
-
- image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
-
- image_slice = image[0, -3:, -3:, -1]
-
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
-
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_pixart_512_fast(self):
- generator = torch.manual_seed(0)
-
- pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
- prompt = self.prompt
-
- image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
-
- image_slice = image[0, -3:, -3:, -1]
-
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
-
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
def test_pixart_1024(self):
generator = torch.manual_seed(0)
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
index ed295f792f99..7459d5a6b617 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
@@ -34,11 +34,14 @@
)
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
+ require_torch_accelerator,
require_torch_gpu,
+ skip_mps,
slow,
torch_device,
)
@@ -128,10 +131,12 @@ def get_dummy_components(self):
return components
def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
+ generator_device = "cpu" if not device.startswith("cuda") else "cuda"
+ if not str(device).startswith("mps"):
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
else:
- generator = torch.Generator(device=device).manual_seed(seed)
+ generator = torch.manual_seed(seed)
+
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
@@ -299,15 +304,21 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
+@skip_mps
class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
+ _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
+ if not str(device).startswith("mps"):
+ generator = torch.Generator(device=_generator_device).manual_seed(seed)
+ else:
+ generator = torch.manual_seed(seed)
+
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
@@ -361,6 +372,7 @@ def test_stable_diffusion_k_lms(self):
expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006])
assert np.abs(image_slice - expected_slice).max() < 3e-3
+ @require_torch_gpu
def test_stable_diffusion_attention_slicing(self):
torch.cuda.reset_peak_memory_stats()
pipe = StableDiffusionPipeline.from_pretrained(
@@ -432,6 +444,7 @@ def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
+ @require_torch_gpu
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -452,6 +465,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
+ @require_torch_gpu
def test_stable_diffusion_pipeline_with_model_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@@ -511,15 +525,21 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
+@skip_mps
class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
+ _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
+ if not str(device).startswith("mps"):
+ generator = torch.Generator(device=_generator_device).manual_seed(seed)
+ else:
+ generator = torch.manual_seed(seed)
+
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index e11175921184..cac5ee442ae6 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -377,6 +377,10 @@ def test_save_load_local(self, expected_max_difference=5e-4):
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index 365310f415a2..dfa8f90837c5 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -66,12 +66,16 @@
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
PATH_TO_TESTS = PATH_TO_REPO / "tests"
-# List here the pipelines to always test.
+# Ignore fixtures in tests folder
+# Ignore lora since they are always tested
+MODULES_TO_IGNORE = ["fixtures", "lora"]
+
IMPORTANT_PIPELINES = [
"controlnet",
"stable_diffusion",
"stable_diffusion_2",
"stable_diffusion_xl",
+ "stable_video_diffusion",
"deepfloyd_if",
"kandinsky",
"kandinsky2_2",
@@ -79,10 +83,6 @@
"wuerstchen",
]
-# Ignore fixtures in tests folder
-# Ignore lora since they are always tested
-MODULES_TO_IGNORE = ["fixtures", "lora"]
-
@contextmanager
def checkout_commit(repo: Repo, commit_id: str):
@@ -289,10 +289,13 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
repo = Repo(PATH_TO_REPO)
if not diff_with_last_commit:
- print(f"main is at {repo.refs.main.commit}")
+ # Need to fetch refs for main using remotes when running with github actions.
+ upstream_main = repo.remotes.origin.refs.main
+
+ print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")
- branching_commits = repo.merge_base(repo.refs.main, repo.head)
+ branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
return get_diff(repo, repo.head.commit, branching_commits)
@@ -415,10 +418,11 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
test_files_to_run = [] # noqa
if not diff_with_last_commit:
- print(f"main is at {repo.refs.main.commit}")
+ upstream_main = repo.remotes.origin.refs.main
+ print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")
- branching_commits = repo.merge_base(repo.refs.main, repo.head)
+ branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
@@ -432,7 +436,7 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
all_test_files_to_run = get_all_doctest_files()
# Add to the test files to run any removed entry from "utils/not_doctested.txt".
- new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit)
+ new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
test_files_to_run = list(set(test_files_to_run + new_test_files))
# Do not run slow doctest tests on CircleCI
@@ -774,9 +778,7 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
return reverse_map
-def create_module_to_test_map(
- reverse_map: Dict[str, List[str]] = None, filter_models: bool = False
-) -> Dict[str, List[str]]:
+def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
"""
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
@@ -784,8 +786,8 @@ def create_module_to_test_map(
reverse_map (`Dict[str, List[str]]`, *optional*):
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
that function if not provided.
- filter_models (`bool`, *optional*, defaults to `False`):
- Whether or not to filter model tests to only include core models if a file impacts a lot of models.
+ filter_pipelines (`bool`, *optional*, defaults to `False`):
+ Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
Returns:
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
@@ -804,21 +806,7 @@ def is_test(fname):
# Build the test map
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
- if not filter_models:
- return test_map
-
- # Now we deal with the filtering if `filter_models` is True.
- num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
-
- def has_many_models(tests):
- # We filter to core models when a given file impacts more than half the model tests.
- model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
- return len(model_tests) > num_model_tests // 2
-
- def filter_tests(tests):
- return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_PIPELINES]
-
- return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
+ return test_map
def check_imports_all_exist():
@@ -844,7 +832,39 @@ def _print_list(l) -> str:
return "\n".join([f"- {f}" for f in l])
-def create_json_map(test_files_to_run: List[str], json_output_file: str):
+def update_test_map_with_core_pipelines(json_output_file: str):
+ print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
+ with open(json_output_file, "rb") as fp:
+ test_map = json.load(fp)
+
+ # Add core pipelines as their own test group
+ test_map["core_pipelines"] = " ".join(
+ sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
+ )
+
+ # If there are no existing pipeline tests save the map
+ if "pipelines" not in test_map:
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
+ json.dump(test_map, fp, ensure_ascii=False)
+
+ pipeline_tests = test_map.pop("pipelines")
+ pipeline_tests = pipeline_tests.split(" ")
+
+ # Remove core pipeline tests from the fetched pipeline tests
+ updated_pipeline_tests = []
+ for pipe in pipeline_tests:
+ if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
+ continue
+ updated_pipeline_tests.append(pipe)
+
+ if len(updated_pipeline_tests) > 0:
+ test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))
+
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
+ json.dump(test_map, fp, ensure_ascii=False)
+
+
+def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
"""
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
@@ -881,6 +901,7 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
# sort the keys & values
keys = sorted(test_map.keys())
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
+
with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)
@@ -888,7 +909,6 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
def infer_tests_to_run(
output_file: str,
diff_with_last_commit: bool = False,
- filter_models: bool = True,
json_output_file: Optional[str] = None,
):
"""
@@ -929,8 +949,9 @@ def infer_tests_to_run(
# Grab the corresponding test files:
if any(x in modified_files for x in ["setup.py"]):
test_files_to_run = ["tests", "examples"]
+
# in order to trigger pipeline tests even if no code change at all
- elif "tests/utils/tiny_model_summary.json" in modified_files:
+ if "tests/utils/tiny_model_summary.json" in modified_files:
test_files_to_run = ["tests"]
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
else:
@@ -939,7 +960,7 @@ def infer_tests_to_run(
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
]
# Then we grab the corresponding test files.
- test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
+ test_map = create_module_to_test_map(reverse_map=reverse_map)
for f in modified_files:
if f in test_map:
test_files_to_run.extend(test_map[f])
@@ -1064,8 +1085,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args = parser.parse_args()
if args.print_dependencies_of is not None:
print_tree_deps_of(args.print_dependencies_of)
- elif args.filter_tests:
- filter_tests(args.output_file, ["pipelines", "repo_utils"])
else:
repo = Repo(PATH_TO_REPO)
commit_message = repo.head.commit.message
@@ -1089,9 +1108,10 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args.output_file,
diff_with_last_commit=diff_with_last_commit,
json_output_file=args.json_output_file,
- filter_models=not commit_flags["no_filter"],
)
filter_tests(args.output_file, ["repo_utils"])
+ update_test_map_with_core_pipelines(json_output_file=args.json_output_file)
+
except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
commit_flags["test_all"] = True
@@ -1105,3 +1125,4 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
test_files_to_run = get_all_tests()
create_json_map(test_files_to_run, args.json_output_file)
+ update_test_map_with_core_pipelines(json_output_file=args.json_output_file)