diff --git a/scripts/convert_sd3_controlnet_to_diffusers.py b/scripts/convert_sd3_controlnet_to_diffusers.py new file mode 100644 index 000000000000..171f40a7aa06 --- /dev/null +++ b/scripts/convert_sd3_controlnet_to_diffusers.py @@ -0,0 +1,185 @@ +""" +A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format. + +Example: + Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file: + ```bash + python scripts/convert_sd3_controlnet_to_diffusers.py \ + --checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \ + --output_path "output/sd35-controlnet-canny" \ + --dtype "fp16" # optional, defaults to fp32 + ``` + + Or download and convert from HuggingFace repository: + ```bash + python scripts/convert_sd3_controlnet_to_diffusers.py \ + --original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \ + --filename "sd3.5_large_controlnet_canny.safetensors" \ + --output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \ + --dtype "fp32" # optional, defaults to fp32 + ``` + +Note: + The script supports the following ControlNet types from SD3.5: + - Canny edge detection + - Depth estimation + - Blur detection + + The checkpoint files can be downloaded from: + https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets +""" + +import argparse + +import safetensors.torch +import torch +from huggingface_hub import hf_hub_download + +from diffusers import SD3ControlNetModel + + +parser = argparse.ArgumentParser() +parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file") +parser.add_argument( + "--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint" +) +parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo") +parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") +parser.add_argument( + "--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)" +) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + if args.filename is None: + raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified") + print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}") + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + print(f"Loading checkpoint from local path: {args.checkpoint_path}") + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict): + converted_state_dict = {} + + # Direct mappings for controlnet blocks + for i in range(19): # 19 controlnet blocks + converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"] + converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"] + + # Positional embeddings + converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"] + converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"] + + # Time and text embeddings + time_text_mappings = { + "time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight", + "time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias", + "time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight", + "time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias", + "time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight", + "time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias", + "time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight", + "time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias", + } + + for new_key, old_key in time_text_mappings.items(): + if old_key in original_state_dict: + converted_state_dict[new_key] = original_state_dict[old_key] + + # Transformer blocks + for i in range(19): + # Split QKV into separate Q, K, V + qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"] + qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"] + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) + + block_mappings = { + f"transformer_blocks.{i}.attn.to_q.weight": q, + f"transformer_blocks.{i}.attn.to_q.bias": q_bias, + f"transformer_blocks.{i}.attn.to_k.weight": k, + f"transformer_blocks.{i}.attn.to_k.bias": k_bias, + f"transformer_blocks.{i}.attn.to_v.weight": v, + f"transformer_blocks.{i}.attn.to_v.bias": v_bias, + # Output projections + f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[ + f"transformer_blocks.{i}.attn.proj.weight" + ], + f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[ + f"transformer_blocks.{i}.attn.proj.bias" + ], + # Feed forward + f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[ + f"transformer_blocks.{i}.mlp.fc1.weight" + ], + f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"], + f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"], + f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"], + # Norms + f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[ + f"transformer_blocks.{i}.adaLN_modulation.1.weight" + ], + f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[ + f"transformer_blocks.{i}.adaLN_modulation.1.bias" + ], + } + converted_state_dict.update(block_mappings) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + original_dtype = next(iter(original_ckpt.values())).dtype + + # Initialize dtype with fp32 as default + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32") + + if dtype != original_dtype: + print( + f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution." + ) + + converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt) + + controlnet = SD3ControlNetModel( + patch_size=2, + in_channels=16, + num_layers=19, + attention_head_dim=64, + num_attention_heads=38, + joint_attention_dim=None, + caption_projection_dim=2048, + pooled_projection_dim=2048, + out_channels=16, + pos_embed_max_size=None, + pos_embed_type=None, + use_pos_embed=False, + force_zeros_for_pooled_projection=False, + ) + + controlnet.load_state_dict(converted_controlnet_state_dict, strict=True) + + print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.") + controlnet.to(dtype).save_pretrained(args.output_path) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 118e8630ec8e..2a5fcf35498e 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -27,6 +27,7 @@ from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..transformers.transformer_sd3 import SD3SingleTransformerBlock from .controlnet import BaseOutput, zero_module @@ -58,40 +59,60 @@ def __init__( extra_conditioning_channels: int = 0, dual_attention_layers: Tuple[int, ...] = (), qk_norm: Optional[str] = None, + pos_embed_type: Optional[str] = "sincos", + use_pos_embed: bool = True, + force_zeros_for_pooled_projection: bool = True, ): super().__init__() default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=self.inner_dim, - pos_embed_max_size=pos_embed_max_size, - ) + if use_pos_embed: + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + pos_embed_type=pos_embed_type, + ) + else: + self.pos_embed = None self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) - self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) - - # `attention_head_dim` is doubled to account for the mixing. - # It needs to crafted when we get the actual checkpoints. - self.transformer_blocks = nn.ModuleList( - [ - JointTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - context_pre_only=False, - qk_norm=qk_norm, - use_dual_attention=True if i in dual_attention_layers else False, - ) - for i in range(num_layers) - ] - ) + if joint_attention_dim is not None: + self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=False, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, + ) + for i in range(num_layers) + ] + ) + else: + self.context_embedder = None + self.transformer_blocks = nn.ModuleList( + [ + SD3SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(num_layers) + ] + ) # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) @@ -318,9 +339,27 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + if self.pos_embed is not None and hidden_states.ndim != 4: + raise ValueError("hidden_states must be 4D when pos_embed is used") + + # SD3.5 8b controlnet does not have a `pos_embed`, + # it use the `pos_embed` from the transformer to process input before passing to controlnet + elif self.pos_embed is None and hidden_states.ndim != 3: + raise ValueError("hidden_states must be 3D when pos_embed is not used") + + if self.context_embedder is not None and encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be provided when context_embedder is used") + # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states` + elif self.context_embedder is None and encoder_hidden_states is not None: + raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used") + + if self.pos_embed is not None: + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.context_embedder is not None: + encoder_hidden_states = self.context_embedder(encoder_hidden_states) # add hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) @@ -349,9 +388,13 @@ def custom_forward(*inputs): ) else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) + if self.context_embedder is not None: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + else: + # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` + hidden_states = block(hidden_states, temb) block_res_samples = block_res_samples + (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 7777d7c42d94..a1ce9a2412c5 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,14 +18,21 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.attention import JointTransformerBlock -from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ...models.attention import FeedForward, JointTransformerBlock +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + FusedJointAttnProcessor2_0, + JointAttnProcessor2_0, +) from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous +from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -33,6 +40,72 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class SD3SingleTransformerBlock(nn.Module): + r""" + A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + + self.attn = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + # Attention. + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + return hidden_states + + class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Stable Diffusion 3. diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index a589821c1f98..b92dafffc715 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -858,6 +858,12 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + controlnet_config = ( + self.controlnet.config + if isinstance(self.controlnet, SD3ControlNetModel) + else self.controlnet.nets[0].config + ) + # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] @@ -932,6 +938,11 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Prepare control image + if controlnet_config.force_zeros_for_pooled_projection: + # instantx sd3 controlnet does not apply shift factor + vae_shift_factor = 0 + else: + vae_shift_factor = self.vae.config.shift_factor if isinstance(self.controlnet, SD3ControlNetModel): control_image = self.prepare_image( image=control_image, @@ -947,8 +958,7 @@ def __call__( height, width = control_image.shape[-2:] control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = control_image * self.vae.config.scaling_factor - + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor elif isinstance(self.controlnet, SD3MultiControlNetModel): control_images = [] @@ -966,7 +976,7 @@ def __call__( ) control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = control_image_ * self.vae.config.scaling_factor + control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor control_images.append(control_image_) @@ -974,11 +984,6 @@ def __call__( else: assert False - if controlnet_pooled_projections is None: - controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) - else: - controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds - # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1006,6 +1011,18 @@ def __call__( ] controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + if controlnet_config.force_zeros_for_pooled_projection: + # instantx sd3 controlnet used zero pooled projection + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + if controlnet_config.joint_attention_dim is not None: + controlnet_encoder_hidden_states = prompt_embeds + else: + # SD35 official 8b controlnet does not use encoder_hidden_states + controlnet_encoder_hidden_states = None + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1025,11 +1042,17 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + if controlnet_config.use_pos_embed is False: + # sd35 (offical) 8b controlnet + controlnet_model_input = self.transformer.pos_embed(latent_model_input) + else: + controlnet_model_input = latent_model_input + # controlnet(s) inference control_block_samples = self.controlnet( - hidden_states=latent_model_input, + hidden_states=controlnet_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=controlnet_encoder_hidden_states, pooled_projections=controlnet_pooled_projections, joint_attention_kwargs=self.joint_attention_kwargs, controlnet_cond=control_image,