From c1e6a32ae46594c6ba8cb1d4690f70755389aacb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=A5=87=E5=8B=8B?= <46553287+wangqixun@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:24:21 +0800 Subject: [PATCH 1/3] [Flux] Support Union ControlNet (#9175) * refactor --------- Co-authored-by: haofanwang --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/models/controlnet_flux.md | 45 ++++++ .../en/api/pipelines/controlnet_flux.md | 48 ++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 4 +- src/diffusers/models/controlnet_flux.py | 137 +++++++++++++++++- .../flux/pipeline_flux_controlnet.py | 75 +++++++++- src/diffusers/utils/dummy_pt_objects.py | 15 ++ 8 files changed, 320 insertions(+), 9 deletions(-) create mode 100644 docs/source/en/api/models/controlnet_flux.md create mode 100644 docs/source/en/api/pipelines/controlnet_flux.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 31eb5e44a76e..445b538dab9e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -226,6 +226,8 @@ - sections: - local: api/models/controlnet title: ControlNetModel + - local: api/models/controlnet_flux + title: FluxControlNetModel - local: api/models/controlnet_hunyuandit title: HunyuanDiT2DControlNetModel - local: api/models/controlnet_sd3 @@ -320,6 +322,8 @@ title: Consistency Models - local: api/pipelines/controlnet title: ControlNet + - local: api/pipelines/controlnet_flux + title: ControlNet with Flux.1 - local: api/pipelines/controlnet_hunyuandit title: ControlNet with Hunyuan-DiT - local: api/pipelines/controlnet_sd3 diff --git a/docs/source/en/api/models/controlnet_flux.md b/docs/source/en/api/models/controlnet_flux.md new file mode 100644 index 000000000000..422d066d95ff --- /dev/null +++ b/docs/source/en/api/models/controlnet_flux.md @@ -0,0 +1,45 @@ + + +# FluxControlNetModel + +FluxControlNetModel is an implementation of ControlNet for Flux.1. + +The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection. + +The abstract from the paper is: + +*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.* + +## Loading from the original format + +By default the [`FluxControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`]. + +```py +from diffusers import FluxControlNetPipeline +from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel + +controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny") +pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet) + +controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny") +controlnet = FluxMultiControlNetModel([controlnet]) +pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet) +``` + +## FluxControlNetModel + +[[autodoc]] FluxControlNetModel + +## FluxControlNetOutput + +[[autodoc]] models.controlnet_flux.FluxControlNetOutput \ No newline at end of file diff --git a/docs/source/en/api/pipelines/controlnet_flux.md b/docs/source/en/api/pipelines/controlnet_flux.md new file mode 100644 index 000000000000..f63885b4d42c --- /dev/null +++ b/docs/source/en/api/pipelines/controlnet_flux.md @@ -0,0 +1,48 @@ + + +# ControlNet with Flux.1 + +FluxControlNetPipeline is an implementation of ControlNet for Flux.1. + +ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. + +With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. + +The abstract from the paper is: + +*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.* + +This controlnet code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for Flux-ControlNet in the table below: + + +| ControlNet type | Developer | Link | +| -------- | ---------- | ---- | +| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny) | +| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Depth) | +| Union | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) | + + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + +## FluxControlNetPipeline +[[autodoc]] FluxControlNetPipeline + - all + - __call__ + + +## FluxPipelineOutput +[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 650542c124d5..4589edb7d6b3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -554,6 +554,7 @@ ControlNetXSAdapter, DiTTransformer2DModel, FluxControlNetModel, + FluxMultiControlNetModel, FluxTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4230c1a4887b..f0dd7248c117 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,7 +35,7 @@ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnet_flux"] = ["FluxControlNetModel"] + _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] @@ -88,7 +88,7 @@ VQModel, ) from .controlnet import ControlNetModel - from .controlnet_flux import FluxControlNetModel + from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_sparsectrl import SparseControlNetModel diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index b29930f81ea2..036e5654a98e 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -54,6 +54,7 @@ def __init__( pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, ): super().__init__() self.out_channels = in_channels @@ -101,6 +102,10 @@ def __init__( for _ in range(len(self.single_transformer_blocks)): self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + self.union = num_mode is not None + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -173,8 +178,8 @@ def _set_gradient_checkpointing(self, module, value=False): def from_transformer( cls, transformer, - num_layers=4, - num_single_layers=10, + num_layers: int = 4, + num_single_layers: int = 10, attention_head_dim: int = 128, num_attention_heads: int = 24, load_weights_from_transformer=True, @@ -205,6 +210,7 @@ def forward( self, hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, conditioning_scale: float = 1.0, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, @@ -221,6 +227,12 @@ def forward( Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected @@ -272,6 +284,15 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) + if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." @@ -367,7 +388,6 @@ def custom_forward(*inputs): controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] - # controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples controlnet_single_block_samples = ( None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples @@ -384,3 +404,114 @@ def custom_forward(*inputs): controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, ) + + +class FluxMultiControlNetModel(ModelMixin): + r""" + `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel + + This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be + compatible with `FluxControlNetModel`. + + Args: + controlnets (`List[FluxControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `FluxControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + controlnet_mode: List[torch.tensor], + conditioning_scale: List[float], + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[FluxControlNetOutput, Tuple]: + # ControlNet-Union with multiple conditions + # only load one ControlNet for saving memories + if len(self.nets) == 1 and self.nets[0].union: + controlnet = self.nets[0] + + for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + # Regular Multi-ControlNets + # load all ControlNets into memories + else: + for i, (image, mode, scale, controlnet) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) + ): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + return control_block_samples, control_single_block_samples diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 18c59414c302..cb573f3b19b5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -27,7 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -61,7 +61,7 @@ >>> from diffusers import FluxControlNetPipeline >>> from diffusers import FluxControlNetModel - >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha" + >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) >>> pipe = FluxControlNetPipeline.from_pretrained( ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 @@ -195,7 +195,9 @@ def __init__( text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, - controlnet: FluxControlNetModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], ): super().__init__() @@ -571,6 +573,7 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 7.0, control_image: PipelineImageInput = None, + control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -611,6 +614,20 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_mode (`int` or `List[int]`,, *optional*, defaults to None): + The control mode when applying ControlNet-Union. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -730,6 +747,55 @@ def __call__( width_control_image, ) + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + height, width = control_image_.shape[-2:] + + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( @@ -785,6 +851,7 @@ def __call__( controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, + controlnet_mode=control_mode, conditioning_scale=controlnet_conditioning_scale, timestep=timestep / 1000, guidance=guidance, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0827dea44edf..1ab946ce7257 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -197,6 +197,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FluxMultiControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FluxTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] From 1ca0a75567da1ca5a97681310c1b57e9f527a84a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 25 Aug 2024 11:57:12 -1000 Subject: [PATCH 2/3] refactor 3d rope for cogvideox (#9269) * refactor 3d rope * repeat -> expand --- src/diffusers/models/embeddings.py | 86 ++++++++----------- .../pipelines/cogvideo/pipeline_cogvideox.py | 1 - 2 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d1366654c448..dcb9528cb1a0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed( The size of the temporal dimension. theta (`float`): Scaling factor for frequency computation. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) # Compute dimensions for each axis @@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed( dim_w = embed_dim // 8 * 3 # Temporal frequencies - freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) - grid_t = torch.from_numpy(grid_t).float() - freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) - freqs_t = freqs_t.repeat_interleave(2, dim=-1) - + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) # Spatial frequencies for height and width - freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) - freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) - grid_h = torch.from_numpy(grid_h).float() - grid_w = torch.from_numpy(grid_w).float() - freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) - freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) - freqs_h = freqs_h.repeat_interleave(2, dim=-1) - freqs_w = freqs_w.repeat_interleave(2, dim=-1) - - # Broadcast and concatenate tensors along specified dimension - def broadcast(tensors, dim=-1): - num_tensors = len(tensors) - shape_lens = {len(t.shape) for t in tensors} - assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" - shape_len = list(shape_lens)[0] - dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*(list(t.shape) for t in tensors))) - expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] - assert all( - [*(len(set(t[1])) <= 2 for t in expandable_dims)] - ), "invalid dimensions for broadcastable concatenation" - max_dims = [(t[0], max(t[1])) for t in expandable_dims] - expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] - expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) - tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] - return torch.cat(tensors, dim=dim) - - freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - - t, h, w, d = freqs.shape - freqs = freqs.view(t * h * w, d) - - # Generate sine and cosine components - sin = freqs.sin() - cos = freqs.cos() - - if use_real: - return cos, sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - return freqs_cis + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index e100c1f11e20..11f491e49532 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -463,7 +463,6 @@ def _prepare_rotary_positional_embeddings( crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, - use_real=True, ) freqs_cos = freqs_cos.to(device=device) From c977966502b70f4758c83ee5a855b48398042b03 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:59:58 +0300 Subject: [PATCH 3/3] [Dreambooth flux] bug fix for dreambooth script (align with dreambooth lora) (#9257) * fix shape * fix prompt encoding * style * fix device * add comment --- examples/dreambooth/train_dreambooth_flux.py | 130 ++++++++++--------- 1 file changed, 72 insertions(+), 58 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index ece12e289e0c..da571cc46c57 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -842,7 +842,7 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): +def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", @@ -863,20 +863,26 @@ def _encode_prompt_with_t5( prompt=None, num_images_per_prompt=1, device=None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -896,22 +902,28 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel @@ -932,17 +944,19 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) dtype = text_encoders[0].dtype - + device = device if device is not None else text_encoders[1].device pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - device=device if device is not None else text_encoders[0].device, + device=device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -951,7 +965,8 @@ def encode_prompt( max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[1].device, + device=device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1499,7 +1514,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=prompts, + ) + else: + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=args.instance_prompt, + ) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1553,41 +1586,22 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): guidance = None # Predict the noise residual - if not args.train_text_encoder: - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - else: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=None, - prompt=None, - text_input_ids_list=[tokens_one, tokens_two], - ) - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2]), - width=int(model_input.shape[3]), + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), vae_scale_factor=vae_scale_factor, )