diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 02f063b6016e..c306e1eb99e7 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -253,6 +253,8 @@
title: PriorTransformer
- local: api/models/controlnet
title: ControlNetModel
+ - local: api/models/controlnet_sd3
+ title: SD3ControlNetModel
title: Models
- isExpanded: false
sections:
@@ -276,6 +278,8 @@
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
+ - local: api/pipelines/controlnet_sd3
+ title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnetxs
diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md
new file mode 100644
index 000000000000..59db64546fa2
--- /dev/null
+++ b/docs/source/en/api/models/controlnet_sd3.md
@@ -0,0 +1,42 @@
+
+
+# SD3ControlNetModel
+
+SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3.
+
+The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+## Loading from the original format
+
+By default the [`SD3ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
+
+```py
+from diffusers import StableDiffusion3ControlNetPipeline
+from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
+
+controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
+pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet)
+```
+
+## SD3ControlNetModel
+
+[[autodoc]] SD3ControlNetModel
+
+## SD3ControlNetOutput
+
+[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
+
diff --git a/docs/source/en/api/pipelines/controlnet_sd3.md b/docs/source/en/api/pipelines/controlnet_sd3.md
new file mode 100644
index 000000000000..31dd21f1dd36
--- /dev/null
+++ b/docs/source/en/api/pipelines/controlnet_sd3.md
@@ -0,0 +1,39 @@
+
+
+# ControlNet with Stable Diffusion 3
+
+StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3.
+
+ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
+
+With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusion3ControlNetPipeline
+[[autodoc]] StableDiffusion3ControlNetPipeline
+ - all
+ - __call__
+
+## StableDiffusion3PipelineOutput
+[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index b667b5cea7d0..cad0ca544026 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -91,6 +91,8 @@
"MultiAdapter",
"PixArtTransformer2DModel",
"PriorTransformer",
+ "SD3ControlNetModel",
+ "SD3MultiControlNetModel",
"SD3Transformer2DModel",
"StableCascadeUNet",
"T2IAdapter",
@@ -278,6 +280,7 @@
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
+ "StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
@@ -501,6 +504,8 @@
MultiAdapter,
PixArtTransformer2DModel,
PriorTransformer,
+ SD3ControlNetModel,
+ SD3MultiControlNetModel,
SD3Transformer2DModel,
T2IAdapter,
T5FilmDecoder,
@@ -666,6 +671,7 @@
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
+ StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index b28fc537d99d..e863c8010a3d 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -33,6 +33,7 @@
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -74,6 +75,7 @@
VQModel,
)
from .controlnet import ControlNetModel
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
new file mode 100644
index 000000000000..d32b662b463b
--- /dev/null
+++ b/src/diffusers/models/controlnet_sd3.py
@@ -0,0 +1,418 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ..models.attention import JointTransformerBlock
+from ..models.attention_processor import Attention, AttentionProcessor
+from ..models.modeling_utils import ModelMixin
+from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from .controlnet import BaseOutput, zero_module
+from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from .transformers.transformer_2d import Transformer2DModelOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class SD3ControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+
+
+class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 18,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 18,
+ joint_attention_dim: int = 4096,
+ caption_projection_dim: int = 1152,
+ pooled_projection_dim: int = 2048,
+ out_channels: int = 16,
+ pos_embed_max_size: int = 96,
+ ):
+ super().__init__()
+ default_out_channels = in_channels
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_max_size=pos_embed_max_size,
+ )
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
+
+ # `attention_head_dim` is doubled to account for the mixing.
+ # It needs to crafted when we get the actual checkpoints.
+ self.transformer_blocks = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=self.inner_dim,
+ context_pre_only=False,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_blocks.append(controlnet_block)
+ pos_embed_input = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_type=None,
+ )
+ self.pos_embed_input = zero_module(pos_embed_input)
+
+ self.gradient_checkpointing = False
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
+ config = transformer.config
+ config["num_layers"] = num_layers or config.num_layers
+ controlnet = cls(**config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
+
+ controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`SD3Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ height, width = hidden_states.shape[-2:]
+
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
+
+ block_res_samples = ()
+
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+
+ block_res_samples = block_res_samples + (hidden_states,)
+
+ controlnet_block_res_samples = ()
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
+ block_res_sample = controlnet_block(block_res_sample)
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
+
+ # 6. scaling
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_res_samples,)
+
+ return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
+
+
+class SD3MultiControlNetModel(ModelMixin):
+ r"""
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
+
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
+ compatible with `SD3ControlNetModel`.
+
+ Args:
+ controlnets (`List[SD3ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `SD3ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ pooled_projections: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[SD3ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ block_samples = controlnet(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ pooled_projections=pooled_projections,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
+ ]
+ control_block_samples = (tuple(control_block_samples),)
+
+ return control_block_samples
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index 740b19bb5377..677ed4e28700 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
@@ -245,6 +245,7 @@ def forward(
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
+ block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
@@ -260,6 +261,8 @@ def forward(
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -293,7 +296,7 @@ def forward(
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
- for block in self.transformer_blocks:
+ for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -319,6 +322,11 @@ def custom_forward(*inputs):
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
+ # controlnet residual
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
+
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 8b2c8a1b2119..a2041d6ea07b 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -20,6 +20,7 @@
_dummy_objects = {}
_import_structure = {
"controlnet": [],
+ "controlnet_sd3": [],
"controlnet_xs": [],
"deprecated": [],
"latent_diffusion": [],
@@ -142,6 +143,11 @@
"StableDiffusionXLControlNetXSPipeline",
]
)
+ _import_structure["controlnet_sd3"].extend(
+ [
+ "StableDiffusion3ControlNetPipeline",
+ ]
+ )
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -394,6 +400,9 @@
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
+ from .controlnet_sd3 import (
+ StableDiffusion3ControlNetPipeline,
+ )
from .controlnet_xs import (
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
diff --git a/src/diffusers/pipelines/controlnet_sd3/__init__.py b/src/diffusers/pipelines/controlnet_sd3/__init__.py
new file mode 100644
index 000000000000..65a3855a0adc
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet_sd3/__init__.py
@@ -0,0 +1,53 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_flax_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
+
+ try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
new file mode 100644
index 000000000000..188408d4cb2e
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -0,0 +1,1062 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
+from ...models.transformers import SD3Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusion3ControlNetPipeline
+ >>> from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
+ >>> from diffusers.utils import load_image
+
+ >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
+ >>> prompt = "A girl holding a sign that says InstantX"
+ >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0]
+ >>> image.save("sd3.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ Args:
+ transformer ([`SD3Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
+ as its dimension.
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ text_encoder_3 ([`T5EncoderModel`]):
+ Frozen text-encoder. Stable Diffusion 3 uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_3 (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: SD3Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder_3: T5EncoderModel,
+ tokenizer_3: T5TokenizerFast,
+ controlnet: Union[
+ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
+ ],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ text_encoder_3=text_encoder_3,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ tokenizer_3=tokenizer_3,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if self.text_encoder_3 is None:
+ return torch.zeros(
+ (
+ batch_size * num_images_per_prompt,
+ self.tokenizer_max_length,
+ self.transformer.config.joint_attention_dim,
+ ),
+ device=device,
+ dtype=dtype,
+ )
+
+ text_inputs = self.tokenizer_3(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
+
+ dtype = self.text_encoder_3.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ clip_skip: Optional[int] = None,
+ clip_model_index: int = 0,
+ ):
+ device = device or self._execution_device
+
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
+
+ tokenizer = clip_tokenizers[clip_model_index]
+ text_encoder = clip_text_encoders[clip_model_index]
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ prompt_3: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=0,
+ )
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ prompt=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=1,
+ )
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+ negative_prompt_3 = (
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
+ negative_prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=0,
+ )
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ negative_prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=1,
+ )
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
+
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=negative_prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
+ negative_clip_prompt_embeds,
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
+ )
+
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ negative_prompt_3=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_3 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ controlnet_pooled_projections: Optional[torch.FloatTensor] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
+ Embeddings projected from the embeddings of controlnet input conditions.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 3. Prepare control image
+ if isinstance(self.controlnet, SD3ControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=False,
+ )
+ height, width = control_image.shape[-2:]
+
+ control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = control_image * self.vae.config.scaling_factor
+
+ elif isinstance(self.controlnet, SD3MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=False,
+ )
+
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = control_image_ * self.vae.config.scaling_factor
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ if controlnet_pooled_projections is None:
+ controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
+ else:
+ controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet(s) inference
+ control_block_samples = self.controlnet(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=controlnet_pooled_projections,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ return_dict=False,
+ )[0]
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ block_controlnet_hidden_states=control_block_samples,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusion3PipelineOutput(images=image)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 64c642a1d130..007570d81dae 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -242,6 +242,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SD3ControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SD3MultiControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SD3Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 9b3512a82519..b5d17f3fce65 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -902,6 +902,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusion3ControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/controlnet_sd3/__init__.py b/tests/pipelines/controlnet_sd3/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
new file mode 100644
index 000000000000..824c1de1b9a5
--- /dev/null
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -0,0 +1,348 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc and The InstantX Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3ControlNetPipeline,
+)
+from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
+from diffusers.utils import load_image
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = StableDiffusion3ControlNetPipeline
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SD3Transformer2DModel(
+ sample_size=32,
+ patch_size=1,
+ in_channels=8,
+ num_layers=4,
+ attention_head_dim=8,
+ num_attention_heads=4,
+ joint_attention_dim=32,
+ caption_projection_dim=32,
+ pooled_projection_dim=64,
+ out_channels=8,
+ )
+
+ torch.manual_seed(0)
+ controlnet = SD3ControlNetModel(
+ sample_size=32,
+ patch_size=1,
+ in_channels=8,
+ num_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=4,
+ joint_attention_dim=32,
+ caption_projection_dim=32,
+ pooled_projection_dim=64,
+ out_channels=8,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=8,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "text_encoder_3": text_encoder_3,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "tokenizer_3": tokenizer_3,
+ "transformer": transformer,
+ "vae": vae,
+ "controlnet": controlnet,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ control_image = randn_tensor(
+ (1, 3, 32, 32),
+ generator=generator,
+ device=torch.device(device),
+ dtype=torch.float16,
+ )
+
+ controlnet_conditioning_scale = 0.5
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "control_image": control_image,
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
+ }
+
+ return inputs
+
+ def test_controlnet_sd3(self):
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusion3ControlNetPipeline(**components)
+ sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = sd_pipe(**inputs)
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+ assert image.shape == (1, 32, 32, 3)
+
+ expected_slice = np.array(
+ [0.5761719, 0.71777344, 0.59228516, 0.578125, 0.6020508, 0.39453125, 0.46728516, 0.51708984, 0.58984375]
+ )
+
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+
+
+@slow
+@require_torch_gpu
+class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
+ pipeline_class = StableDiffusion3ControlNetPipeline
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_canny(self):
+ controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Anime style illustration of a girl wearing a suit. A moon in sky. In the background we see a big rain approaching. text 'InstantX' on image"
+ n_prompt = "NSFW, nude, naked, porn, ugly"
+ control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
+
+ output = pipe(
+ prompt,
+ negative_prompt=n_prompt,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ guidance_scale=5.0,
+ num_inference_steps=2,
+ output_type="np",
+ generator=generator,
+ )
+ image = output.images[0]
+
+ assert image.shape == (1024, 1024, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+
+ expected_image = np.array(
+ [0.20947266, 0.1574707, 0.19897461, 0.15063477, 0.1418457, 0.17285156, 0.14160156, 0.13989258, 0.30810547]
+ )
+
+ assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+
+ def test_pose(self):
+ controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Pose", torch_dtype=torch.float16)
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = 'Anime style illustration of a girl wearing a suit. A moon in sky. In the background we see a big rain approaching. text "InstantX" on image'
+ n_prompt = "NSFW, nude, naked, porn, ugly"
+ control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Pose/resolve/main/pose.jpg")
+
+ output = pipe(
+ prompt,
+ negative_prompt=n_prompt,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ guidance_scale=5.0,
+ num_inference_steps=2,
+ output_type="np",
+ generator=generator,
+ )
+ image = output.images[0]
+
+ assert image.shape == (1024, 1024, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+
+ expected_image = np.array(
+ [0.8671875, 0.86621094, 0.91015625, 0.8491211, 0.87890625, 0.9140625, 0.8300781, 0.8334961, 0.8623047]
+ )
+
+ assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+
+ def test_tile(self):
+ controlnet = SD3ControlNetModel.from_pretrained("InstantX//SD3-Controlnet-Tile", torch_dtype=torch.float16)
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = 'Anime style illustration of a girl wearing a suit. A moon in sky. In the background we see a big rain approaching. text "InstantX" on image'
+ n_prompt = "NSFW, nude, naked, porn, ugly"
+ control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Tile/resolve/main/tile.jpg")
+
+ output = pipe(
+ prompt,
+ negative_prompt=n_prompt,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ guidance_scale=5.0,
+ num_inference_steps=2,
+ output_type="np",
+ generator=generator,
+ )
+ image = output.images[0]
+
+ assert image.shape == (1024, 1024, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+
+ expected_image = np.array(
+ [0.6982422, 0.7011719, 0.65771484, 0.6904297, 0.7416992, 0.6904297, 0.6977539, 0.7080078, 0.6386719]
+ )
+
+ assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+
+ def test_multi_controlnet(self):
+ controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
+ controlnet = SD3MultiControlNetModel([controlnet, controlnet])
+
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Anime style illustration of a girl wearing a suit. A moon in sky. In the background we see a big rain approaching. text 'InstantX' on image"
+ n_prompt = "NSFW, nude, naked, porn, ugly"
+ control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
+
+ output = pipe(
+ prompt,
+ negative_prompt=n_prompt,
+ control_image=[control_image, control_image],
+ controlnet_conditioning_scale=[0.25, 0.25],
+ guidance_scale=5.0,
+ num_inference_steps=2,
+ output_type="np",
+ generator=generator,
+ )
+ image = output.images[0]
+
+ assert image.shape == (1024, 1024, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array(
+ [0.7451172, 0.7416992, 0.7158203, 0.7792969, 0.7607422, 0.7089844, 0.6855469, 0.71777344, 0.7314453]
+ )
+
+ assert np.abs(original_image.flatten() - expected_image).max() < 1e-2