From dccf39f01ed5d22d3435e612121b27f2820b0f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?0x=E5=90=8D=E7=84=A1=E3=81=97?= Date: Tue, 15 Oct 2024 17:18:13 +0530 Subject: [PATCH 01/12] Dreambooth lora flux bug 3dtensor to 2dtensor (#9653) * fixed issue #9350, Tensor is deprecated * ran make style --- examples/dreambooth/train_dreambooth_lora_flux.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fcc11386abcf..11cba745cc4a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -985,7 +985,6 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( @@ -1007,8 +1006,7 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids From 92d2baf643b6198c2df08d9e908637ea235d84d1 Mon Sep 17 00:00:00 2001 From: Charchit Sharma Date: Tue, 15 Oct 2024 17:20:33 +0530 Subject: [PATCH 02/12] refactor image_processor.py file (#9608) * refactor image_processor file * changes as requested * +1 edits * quality fix * indent issue --------- Co-authored-by: Aryan Co-authored-by: YiYi Xu --- src/diffusers/image_processor.py | 309 +++++++++++++++++++++++++------ 1 file changed, 256 insertions(+), 53 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index d58bd9e3e375..0fffe67b0bdb 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -38,16 +38,44 @@ PipelineDepthInput = PipelineImageInput -def is_valid_image(image): +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): - # check if the image input is one of the supported formats for image and image list: - # it can be either one of below 3 - # (1) a 4d pytorch tensor or numpy array, - # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor - # (3) a list of valid image + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: return True elif is_valid_image(images): @@ -103,8 +131,16 @@ def __init__( @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: - """ + r""" Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images. """ if images.ndim == 3: images = images[None, ...] @@ -119,8 +155,16 @@ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: @staticmethod def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: - """ + r""" Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`PIL.Image.Image` or `List[PIL.Image.Image]`): + The PIL image or list of images to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. """ if not isinstance(images, list): images = [images] @@ -131,8 +175,16 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.Tensor: - """ + r""" Convert a NumPy image to a PyTorch tensor. + + Args: + images (`np.ndarray`): + The NumPy image array to convert to PyTorch format. + + Returns: + `torch.Tensor`: + A PyTorch tensor representation of the images. """ if images.ndim == 3: images = images[..., None] @@ -142,30 +194,62 @@ def numpy_to_pt(images: np.ndarray) -> torch.Tensor: @staticmethod def pt_to_numpy(images: torch.Tensor) -> np.ndarray: - """ + r""" Convert a PyTorch tensor to a NumPy image. + + Args: + images (`torch.Tensor`): + The PyTorch tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ + r""" Normalize an image array to [-1,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to normalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The normalized image array. """ return 2.0 * images - 1.0 @staticmethod def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ + r""" Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. """ return (images / 2 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: - """ + r""" Converts a PIL image to RGB format. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to RGB. + + Returns: + `PIL.Image.Image`: + The RGB-converted PIL image. """ image = image.convert("RGB") @@ -173,8 +257,16 @@ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: @staticmethod def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: - """ - Converts a PIL image to grayscale format. + r""" + Converts a given PIL image to grayscale. + + Args: + image (`PIL.Image.Image`): + The input image to convert. + + Returns: + `PIL.Image.Image`: + The image converted to grayscale. """ image = image.convert("L") @@ -182,8 +274,16 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: @staticmethod def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: - """ + r""" Applies Gaussian blur to an image. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to grayscale. + + Returns: + `PIL.Image.Image`: + The grayscale-converted PIL image. """ image = image.filter(ImageFilter.GaussianBlur(blur_factor)) @@ -191,7 +291,7 @@ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: @staticmethod def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): - """ + r""" Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. @@ -285,14 +385,21 @@ def _resize_and_fill( width: int, height: int, ) -> PIL.Image.Image: - """ + r""" Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. Args: - image: The image to resize. - width: The width to resize the image to. - height: The height to resize the image to. + image (`PIL.Image.Image`): + The image to resize and fill. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and filled image. """ ratio = width / height @@ -330,14 +437,21 @@ def _resize_and_crop( width: int, height: int, ) -> PIL.Image.Image: - """ + r""" Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Args: - image: The image to resize. - width: The width to resize the image to. - height: The height to resize the image to. + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. """ ratio = width / height src_ratio = image.width / image.height @@ -429,19 +543,23 @@ def get_default_height_width( height: Optional[int] = None, width: Optional[int] = None, ) -> Tuple[int, int]: - """ - This function return the height and width that are downscaled to the next integer multiple of - `vae_scale_factor`. + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. Args: - image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): - The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have - shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should - have shape `[batch, channel, height, width]`. - height (`int`, *optional*, defaults to `None`): - The height in preprocessed image. If `None`, will use the height of `image` input. - width (`int`, *optional*`, defaults to `None`): - The width in preprocessed. If `None`, will use the width of the `image` input. + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. """ if height is None: @@ -478,13 +596,13 @@ def preprocess( Preprocess the image input. Args: - image (`pipeline_image_input`): + image (`PipelineImageInput`): The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. - height (`int`, *optional*, defaults to `None`): + height (`int`, *optional*): The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. - width (`int`, *optional*`, defaults to `None`): + width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within @@ -496,6 +614,10 @@ def preprocess( supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) @@ -655,8 +777,22 @@ def apply_overlay( image: PIL.Image.Image, crop_coords: Optional[Tuple[int, int, int, int]] = None, ) -> PIL.Image.Image: - """ - overlay the inpaint output to the original image + r""" + Applies an overlay of the mask and the inpainted image on the original image. + + Args: + mask (`PIL.Image.Image`): + The mask image that highlights regions to overlay. + init_image (`PIL.Image.Image`): + The original image to which the overlay is applied. + image (`PIL.Image.Image`): + The image to overlay onto the original. + crop_coords (`Tuple[int, int, int, int]`, *optional*): + Coordinates to crop the image. If provided, the image will be cropped accordingly. + + Returns: + `PIL.Image.Image`: + The final image with the overlay applied. """ width, height = image.width, image.height @@ -713,8 +849,16 @@ def __init__( @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: - """ - Convert a NumPy image or a batch of images to a PIL image. + r""" + Convert a NumPy image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy array. """ if images.ndim == 3: images = images[None, ...] @@ -729,8 +873,16 @@ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: @staticmethod def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: - """ + r""" Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`Union[List[PIL.Image.Image], PIL.Image.Image]`): + The input image or list of images to be converted. + + Returns: + `np.ndarray`: + A NumPy array of the converted images. """ if not isinstance(images, list): images = [images] @@ -741,18 +893,30 @@ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> @staticmethod def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ - Args: - image: RGB-like depth image + r""" + Convert an RGB-like depth image to a depth map. - Returns: depth map + Args: + image (`Union[np.ndarray, torch.Tensor]`): + The RGB-like depth image to convert. + Returns: + `Union[np.ndarray, torch.Tensor]`: + The corresponding depth map. """ return image[:, :, 1] * 2**8 + image[:, :, 2] def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: - """ - Convert a NumPy depth image or a batch of images to a PIL image. + r""" + Convert a NumPy depth image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of depth images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy depth images. """ if images.ndim == 3: images = images[None, ...] @@ -833,8 +997,24 @@ def preprocess( width: Optional[int] = None, target_res: Optional[int] = None, ) -> torch.Tensor: - """ - Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. + r""" + Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors. + + Args: + rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The RGB input image, which can be a single image or a batch. + depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The depth input image, which can be a single image or a batch. + height (`Optional[int]`, *optional*, defaults to `None`): + The desired height of the processed image. If `None`, defaults to the height of the input image. + width (`Optional[int]`, *optional*, defaults to `None`): + The desired width of the processed image. If `None`, defaults to the width of the input image. + target_res (`Optional[int]`, *optional*, defaults to `None`): + Target resolution for resizing the images. If specified, overrides height and width. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing the processed RGB and depth images as PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) @@ -1072,7 +1252,17 @@ def __init__( @staticmethod def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: - """Returns binned height and width.""" + r""" + Returns the binned height and width based on the aspect ratio. + + Args: + height (`int`): The height of the image. + width (`int`): The width of the image. + ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width). + + Returns: + `Tuple[int, int]`: The closest binned height and width. + """ ar = float(height / width) closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) default_hw = ratios[closest_ratio] @@ -1080,6 +1270,19 @@ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[in @staticmethod def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + r""" + Resizes and crops a tensor of images to the specified dimensions. + + Args: + samples (`torch.Tensor`): + A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height, + and W is the width. + new_width (`int`): The desired width of the output images. + new_height (`int`): The desired height of the output images. + + Returns: + `torch.Tensor`: A tensor containing the resized and cropped images. + """ orig_height, orig_width = samples.shape[2], samples.shape[3] # Check if resizing is needed From 355bb641e384e5c01600a9ec55b5603dc824ce9a Mon Sep 17 00:00:00 2001 From: Jiwook Han <33192762+mreraser@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:17:52 +0900 Subject: [PATCH 03/12] [doc] Fix some docstrings in `src/diffusers/training_utils.py` (#9606) * refac: docstrings in training_utils.py * fix: manual edits * run make style * add docstring at cast_training_params --------- Co-authored-by: Sayak Paul --- src/diffusers/training_utils.py | 40 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 57bd9074870c..11a4e1cc8069 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -36,8 +36,9 @@ def set_seed(seed: int): """ - Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: seed (`int`): The seed to set. """ random.seed(seed) @@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ if not isinstance(model, list): model = [model] for m in model: @@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder( def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """Compute the density for sampling the timesteps when doing SD3 training. + """ + Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling( def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. + """ + Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def free_memory(): - """Runs garbage collection. Then clears the cache of the available accelerator.""" + """ + Runs garbage collection. Then clears the cache of the available accelerator. + """ gc.collect() if torch.cuda.is_available(): @@ -494,7 +506,8 @@ def pin_memory(self) -> None: self.shadow_params = [p.pin_memory() for p in self.shadow_params] def to(self, device=None, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. + r""" + Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` @@ -528,23 +541,25 @@ def state_dict(self) -> dict: def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" + Saves the current parameters for restoring later. + Args: - Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. + parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters + without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. + + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: @@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: def load_state_dict(self, state_dict: dict) -> None: r""" - Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. + + Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ From fff4be8e23b8b44e756cf54bc4f17e8bd637bbc5 Mon Sep 17 00:00:00 2001 From: wony617 <49024958+Jwaminju@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:20:12 +0900 Subject: [PATCH 04/12] [docs] refactoring docstrings in `community/hd_painter.py` (#9593) * [docs] refactoring docstrings in community/hd_painter.py * Update examples/community/hd_painter.py Co-authored-by: Aryan * make style --------- Co-authored-by: Aryan Co-authored-by: Aryan --- examples/community/hd_painter.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py index df41be9ef7b1..91ebe076104a 100644 --- a/examples/community/hd_painter.py +++ b/examples/community/hd_painter.py @@ -898,13 +898,16 @@ class GaussianSmoothing(nn.Module): Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. - Arguments: - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. - sigma (float, sequence): Standard deviation of the gaussian kernel. - dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). + + Args: + channels (`int` or `sequence`): + Number of channels of the input tensors. The output will have this number of channels as well. + kernel_size (`int` or `sequence`): + Size of the Gaussian kernel. + sigma (`float` or `sequence`): + Standard deviation of the Gaussian kernel. + dim (`int`, *optional*, defaults to `2`): + The number of dimensions of the data. Default is 2 (spatial dimensions). """ def __init__(self, channels, kernel_size, sigma, dim=2): @@ -944,10 +947,14 @@ def __init__(self, channels, kernel_size, sigma, dim=2): def forward(self, input): """ Apply gaussian filter to input. - Arguments: - input (torch.Tensor): Input to apply gaussian filter on. + + Args: + input (`torch.Tensor` of shape `(N, C, H, W)`): + Input to apply Gaussian filter on. + Returns: - filtered (torch.Tensor): Filtered output. + `torch.Tensor`: + The filtered output tensor with the same shape as the input. """ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same") From a3e8d3f7deed140f57a28d82dd0b5d965bd0fb09 Mon Sep 17 00:00:00 2001 From: wony617 <49024958+Jwaminju@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:45:14 +0900 Subject: [PATCH 05/12] [docs] refactoring docstrings in `models/embeddings_flax.py` (#9592) * [docs] refactoring docstrings in `models/embeddings_flax.py` * Update src/diffusers/models/embeddings_flax.py * make style --------- Co-authored-by: Aryan --- src/diffusers/models/embeddings_flax.py | 32 ++++++++++++++++++------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 8e343be0d3b7..92b5a6c35883 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -29,11 +29,21 @@ def get_sinusoidal_embeddings( """Returns the positional encoding (same as Tensor2Tensor). Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. + timesteps (`jnp.ndarray` of shape `(N,)`): + A 1-D array of N indices, one per batch element. These may be fractional. + embedding_dim (`int`): + The number of output channels. + freq_shift (`float`, *optional*, defaults to `1`): + Shift applied to the frequency scaling of the embeddings. + min_timescale (`float`, *optional*, defaults to `1`): + The smallest time unit used in the sinusoidal calculation (should probably be 0.0). + max_timescale (`float`, *optional*, defaults to `1.0e4`): + The largest time unit used in the sinusoidal calculation. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the order of sinusoidal components to cosine first. + scale (`float`, *optional*, defaults to `1.0`): + A scaling factor applied to the positional embeddings. + Returns: a Tensor of timing signals [N, num_channels] """ @@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module): Args: time_embed_dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Time step embedding dimension. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The data type for the embedding parameters. """ time_embed_dim: int = 32 @@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module): Args: dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension + Time step embedding dimension. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sinusoidal function from sine to cosine. + freq_shift (`float`, *optional*, defaults to `1`): + Frequency shift applied to the sinusoidal embeddings. """ dim: int = 32 From d40da7b68a3116bb18ee9725b8dc78d39c502473 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Wed, 16 Oct 2024 02:27:39 +0900 Subject: [PATCH 06/12] Fix some documentation in ./src/diffusers/models/adapter.py (#9591) * Fix some documentation in ./src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * run make style * make style & fix * make style : 0.1.5 version ruff * revert changes to examples --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Aryan --- src/diffusers/models/adapter.py | 106 ++++++++++++++++---------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 0f4b2ec03371..677a991f055e 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin): MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to user-assigned weighting. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading + or saving. - Parameters: + Args: adapters (`List[T2IAdapter]`, *optional*, defaults to None): A list of `T2IAdapter` model instances. """ @@ -77,11 +77,13 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non r""" Args: xs (`torch.Tensor`): - (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, - `channel` should equal to `num_adapter` * "number of channel of image". + A tensor of shape (batch, channel, height, width) representing input images for multiple adapter + models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to + `num_adapter` * number of channel per image. + adapter_weights (`List[float]`, *optional*, defaults to None): - List of floats representing the weight which will be multiply to each adapter's output before adding - them together. + A list of floats representing the weights which will be multiplied by each adapter's output before + summing them together. If `None`, equal weights will be used for all adapters. """ if adapter_weights is None: adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) @@ -109,24 +111,24 @@ def save_pretrained( variant: Optional[str] = None, ): """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the + Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the `[`~models.adapter.MultiAdapter.from_pretrained`]` class method. - Arguments: + Args: save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. + The directory where the model will be saved. If the directory does not exist, it will be created. + is_main_process (`bool`, optional, defaults=True): + Indicates whether current process is the main process or not. Useful for distributed training (e.g., + TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only + for the main process to avoid race conditions. save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace + `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment + variable. + safe_serialization (`bool`, optional, defaults=True): + If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`. variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. + If specified, weights are saved in the format `pytorch_model..bin`. """ idx = 0 model_path_to_save = save_directory @@ -145,19 +147,17 @@ def save_pretrained( @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models. + Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. + the model, set it back to training mode using `model.train()`. - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. + Warnings: + *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained + with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights + from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded. - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: + Args: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. @@ -175,20 +175,20 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each - GPU and the available CPU RAM if unset. + A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory + available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. + If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will + be ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from - `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is + installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`, + `safetensors` is not used. """ idx = 0 adapters = [] @@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin): and [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as + downloading or saving. - Parameters: - in_channels (`int`, *optional*, defaults to 3): - Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale - image as *control image*. + Args: + in_channels (`int`, *optional*, defaults to `3`): + The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale + image. channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will - also determine the number of downsample blocks in the Adapter. - num_res_blocks (`int`, *optional*, defaults to 2): + The number of channels in each downsample block's output hidden state. The `len(block_out_channels)` + determines the number of downsample blocks in the adapter. + num_res_blocks (`int`, *optional*, defaults to `2`): Number of ResNet blocks in each downsample block. - downscale_factor (`int`, *optional*, defaults to 8): + downscale_factor (`int`, *optional*, defaults to `8`): A factor that determines the total downscale factor of the Adapter. adapter_type (`str`, *optional*, defaults to `full_adapter`): - The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. + Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use. """ @register_to_config @@ -393,7 +393,7 @@ class AdapterBlock(nn.Module): An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and `FullAdapterXL` models. - Parameters: + Args: in_channels (`int`): Number of channels of AdapterBlock's input. out_channels (`int`): @@ -401,7 +401,7 @@ class AdapterBlock(nn.Module): num_res_blocks (`int`): Number of ResNet blocks in the AdapterBlock. down (`bool`, *optional*, defaults to `False`): - Whether to perform downsampling on AdapterBlock's input. + If `True`, perform downsampling on AdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): @@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module): r""" An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. - Parameters: + Args: channels (`int`): Number of channels of AdapterResnetBlock's input and output. """ @@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module): A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the `LightAdapter` model. - Parameters: + Args: in_channels (`int`): Number of channels of LightAdapterBlock's input. out_channels (`int`): @@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module): num_res_blocks (`int`): Number of LightAdapterResnetBlocks in the LightAdapterBlock. down (`bool`, *optional*, defaults to `False`): - Whether to perform downsampling on LightAdapterBlock's input. + If `True`, perform downsampling on LightAdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): @@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module): A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different architecture than `AdapterResnetBlock`. - Parameters: + Args: channels (`int`): Number of channels of LightAdapterResnetBlock's input and output. """ From 2ffbb88f1c0448f0bcfbf1db34d8ccd7b771c8dd Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 02:07:07 +0530 Subject: [PATCH 07/12] [training] CogVideoX-I2V LoRA (#9482) * update * update * update * update * update * add coauthor Co-Authored-By: yuan-shenghai <963658029@qq.com> * add coauthor Co-Authored-By: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> * update Co-Authored-By: yuan-shenghai <963658029@qq.com> * update --------- Co-authored-by: yuan-shenghai <963658029@qq.com> Co-authored-by: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> --- examples/cogvideo/README.md | 13 +- .../train_cogvideox_image_to_video_lora.py | 1621 +++++++++++++++++ examples/cogvideo/train_cogvideox_lora.py | 16 +- .../pipeline_cogvideox_image2video.py | 16 +- 4 files changed, 1656 insertions(+), 10 deletions(-) create mode 100644 examples/cogvideo/train_cogvideox_image_to_video_lora.py diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md index 8e57c2b19188..02887faeaa74 100644 --- a/examples/cogvideo/README.md +++ b/examples/cogvideo/README.md @@ -10,6 +10,11 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b). +> [!NOTE] +> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve. +> +> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team. + ## Data Preparation The training scripts accepts data in two formats. @@ -132,6 +137,8 @@ Assuming you are training on 50 videos of a similar concept, we have found 1500- - 1500 steps on 50 videos would correspond to `30` training epochs - 4000 steps on 100 videos would correspond to `40` training epochs +The following bash script launches training for text-to-video lora. + ```bash #!/bin/bash @@ -172,6 +179,8 @@ accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \ --report_to wandb ``` +For launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work. + To better track our training experiments, we're using the following flags in the command above: * `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. @@ -197,8 +206,6 @@ Note that setting the `` is not necessary. From some limited experimen > > Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data. - - ## Inference Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`. @@ -227,3 +234,5 @@ prompt = ( frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0] export_to_video(frames, "output.mp4", fps=8) ``` + +If you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation. diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py new file mode 100644 index 000000000000..0fdca2850784 --- /dev/null +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -0,0 +1,1621 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import random +import shutil +from datetime import timedelta +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer + +import diffusers +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.optimization import get_scheduler +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.training_utils import cast_training_params, free_memory +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + export_to_video, + is_wandb_available, + load_image, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-i2v-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--nccl_timeout", type=int, default=600, help="NCCL backend timeout in seconds.") + + return parser.parse_args() + + +class VideoDataset(Dataset): + def __init__( + self, + instance_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, + ) -> None: + super().__init__() + + self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.caption_column = caption_column + self.video_column = video_column + self.height = height + self.width = width + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.cache_dir = cache_dir + self.id_token = id_token or "" + + if dataset_name is not None: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() + else: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + + self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] + + self.num_instance_videos = len(self.instance_video_paths) + if self.num_instance_videos != len(self.instance_prompts): + raise ValueError( + f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.instance_videos = self._preprocess_data() + + def __len__(self): + return self.num_instance_videos + + def __getitem__(self, index): + return { + "instance_prompt": self.instance_prompts[index], + "instance_video": self.instance_videos[index], + } + + def _load_dataset_from_hub(self): + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_root instead." + ) + + # Downloading and loading a dataset from the hub. See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + self.dataset_name, + self.dataset_config_name, + cache_dir=self.cache_dir, + ) + column_names = dataset["train"].column_names + + if self.video_column is None: + video_column = column_names[0] + logger.info(f"`video_column` defaulting to {video_column}") + else: + video_column = self.video_column + if video_column not in column_names: + raise ValueError( + f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if self.caption_column is None: + caption_column = column_names[1] + logger.info(f"`caption_column` defaulting to {caption_column}") + else: + caption_column = self.caption_column + if self.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + instance_prompts = dataset["train"][caption_column] + instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] + + return instance_prompts, instance_videos + + def _load_dataset_from_local_path(self): + if not self.instance_data_root.exists(): + raise ValueError("Instance videos root folder does not exist") + + prompt_path = self.instance_data_root.joinpath(self.caption_column) + video_path = self.instance_data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + instance_videos = [ + self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 + ] + + if any(not path.is_file() for path in instance_videos): + raise ValueError( + "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return instance_prompts, instance_videos + + def _preprocess_data(self): + try: + import decord + except ImportError: + raise ImportError( + "The `decord` package is required for loading the video dataset. Install with `pip install decord`" + ) + + decord.bridge.set_bridge("torch") + + videos = [] + train_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), + ] + ) + + for filename in self.instance_video_paths: + video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_num_frames = len(video_reader) + + start_frame = min(self.skip_frames_start, video_num_frames) + end_frame = max(0, video_num_frames - self.skip_frames_end) + if end_frame <= start_frame: + frames = video_reader.get_batch([start_frame]) + elif end_frame - start_frame <= self.max_num_frames: + frames = video_reader.get_batch(list(range(start_frame, end_frame))) + else: + indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) + frames = video_reader.get_batch(indices) + + # Ensure that we don't go over the limit + frames = frames[: self.max_num_frames] + selected_num_frames = frames.shape[0] + + # Choose first (4k + 1) frames as this is how many is required by the VAE + remainder = (3 + (selected_num_frames % 4)) % 4 + if remainder != 0: + frames = frames[:-remainder] + selected_num_frames = frames.shape[0] + + assert (selected_num_frames - 1) % 4 == 0 + + # Training transforms + frames = frames.float() + frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + + return videos + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + video_path = f"final_video_{i}.mp4" + export_to_video(video, os.path.join(repo_folder, video_path, fps=fps)) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": video_path}}, + ) + + model_description = f""" +# CogVideoX LoRA - {repo_id} + + + +## Model description + +These are {repo_id} LoRA weights for {base_model}. + +The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_image_to_video_lora.py). + +Was LoRA for the text encoder enabled? No. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +import torch +from diffusers import CogVideoXImageToVideoPipeline +from diffusers.utils import load_image, export_to_video + +pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-i2v-lora"]) + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-i2v-lora"], [32 / 64]) + +image = load_image("/path/to/image") +video = pipe(image=image, "{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "image-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) + pipe = pipe.to(accelerator.device) + # pipe.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + del pipe + free_memory() + + return videos + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]: + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.float16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + CogVideoXImageToVideoPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) + + # Dataset and DataLoader + train_dataset = VideoDataset( + instance_data_root=args.instance_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + caption_column=args.caption_column, + video_column=args.video_column, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + ) + + def encode_video(video): + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image = video[:, :, :1].clone() + + latent_dist = vae.encode(video).latent_dist + + image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) + image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) + noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_image).latent_dist + + return latent_dist, image_latent_dist + + train_dataset.instance_prompts = [ + compute_prompt_embeddings( + tokenizer, + text_encoder, + [prompt], + transformer.config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + for prompt in train_dataset.instance_prompts + ] + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + def collate_fn(examples): + videos = [] + images = [] + for example in examples: + latent_dist, image_latent_dist = example["instance_video"] + + video_latents = latent_dist.sample() * vae.config.scaling_factor + image_latents = image_latent_dist.sample() * vae.config.scaling_factor + video_latents = video_latents.permute(0, 2, 1, 3, 4) + image_latents = image_latents.permute(0, 2, 1, 3, 4) + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + videos.append(video_latents) + images.append(image_latents) + + videos = torch.cat(videos) + images = torch.cat(images) + videos = videos.to(memory_format=torch.contiguous_format).float() + images = images.to(memory_format=torch.contiguous_format).float() + + prompts = [example["instance_prompt"] for example in examples] + prompts = torch.cat(prompts) + + return { + "videos": (videos, images), + "prompts": prompts, + } + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-i2v-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + + with accelerator.accumulate(models_to_accumulate): + video_latents, image_latents = batch["videos"] + prompt_embeds = batch["prompts"] + + video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W] + image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W] + + batch_size, num_frames, num_channels, height, width = video_latents.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device + ) + timesteps = timesteps.long() + + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=args.height, + width=args.width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + alphas_cumprod = scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = video_latents + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + # Create pipeline + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXImageToVideoPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + del transformer + free_memory() + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-i2v-lora") + pipe.set_adapters(["cogvideox-i2v-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + if args.push_to_hub: + validation_prompt = args.validation_prompt or "" + validation_prompt = validation_prompt.split(args.validation_prompt_separator)[0] + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 2fc05bf692bb..ece2228147e2 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -25,7 +25,7 @@ import torch import torchvision.transforms as TT import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -922,7 +922,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ) args.optimizer = "adamw" - if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): + if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]: logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" @@ -1211,7 +1211,7 @@ def load_model_hook(models, input_dir): ) use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None - and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config ) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) @@ -1255,6 +1255,7 @@ def collate_fn(examples): prompts = [example["instance_prompt"] for example in examples] videos = torch.cat(videos) + videos = videos.permute(0, 2, 1, 3, 4) videos = videos.to(memory_format=torch.contiguous_format).float() return { @@ -1376,7 +1377,7 @@ def collate_fn(examples): models_to_accumulate = [transformer] with accelerator.accumulate(models_to_accumulate): - model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] + model_input = batch["videos"].to(dtype=weight_dtype) # [B, F, C, H, W] prompts = batch["prompts"] # encode prompts @@ -1455,7 +1456,7 @@ def collate_fn(examples): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1495,7 +1496,6 @@ def collate_fn(examples): args.pretrained_model_name_or_path, transformer=unwrap_model(transformer), text_encoder=unwrap_model(text_encoder), - vae=unwrap_model(vae), scheduler=scheduler, revision=args.revision, variant=args.variant, @@ -1539,6 +1539,10 @@ def collate_fn(examples): transformer_lora_layers=transformer_lora_layers, ) + # Cleanup trained models to save memory + del transformer + free_memory() + # Final test inference pipe = CogVideoXPipeline.from_pretrained( args.pretrained_model_name_or_path, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index afc11bce00d5..975e8ed27db8 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL import torch @@ -23,6 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline @@ -152,7 +153,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXImageToVideoPipeline(DiffusionPipeline): +class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for image-to-video generation using CogVideoX. @@ -546,6 +547,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -572,6 +577,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -635,6 +641,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -679,6 +689,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -768,6 +779,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() From 3e9a28a8a19686b7b66701f7b93d3358d682a5ae Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 15 Oct 2024 12:10:45 -1000 Subject: [PATCH 08/12] [authored by @Anghellia) Add support of Xlabs Controlnets #9638 (#9687) * Add support of Xlabs Controlnets --------- Co-authored-by: Anzhella Pankratova --- src/diffusers/models/controlnet_flux.py | 22 ++++++- .../models/transformers/transformer_flux.py | 9 ++- .../flux/pipeline_flux_controlnet.py | 63 ++++++++++--------- 3 files changed, 62 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 88ad49d2b776..961e30155a3d 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -23,7 +23,7 @@ from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module +from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock @@ -55,6 +55,7 @@ def __init__( guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, + conditioning_embedding_channels: int = None, ): super().__init__() self.out_channels = in_channels @@ -106,7 +107,14 @@ def __init__( if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -269,6 +277,16 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6238ab8044bb..5d39a1bb5391 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -402,6 +402,7 @@ def forward( controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, + controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -508,7 +509,13 @@ def custom_forward(*inputs): if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 8770c231809f..f4018a82ad69 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -754,19 +754,22 @@ def __call__( ) height, width = control_image.shape[-2:] - # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -777,8 +780,9 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - - for control_image_ in control_image: + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -790,20 +794,20 @@ def __call__( ) height, width = control_image_.shape[-2:] - # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) control_images.append(control_image_) control_image = control_images @@ -927,6 +931,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # compute the previous noisy sample x_t -> x_t-1 From 0d935df67da3e76f7e4bdbba374aefc8c573b838 Mon Sep 17 00:00:00 2001 From: glide-the Date: Wed, 16 Oct 2024 08:41:56 +0800 Subject: [PATCH 09/12] Docs: CogVideoX (#9578) * CogVideoX docs --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 4 + docs/source/en/training/cogvideox.md | 291 ++++++++++++++++++ docs/source/en/using-diffusers/cogvideox.md | 120 ++++++++ .../source/en/using-diffusers/text-img2vid.md | 53 ++++ 4 files changed, 468 insertions(+) create mode 100644 docs/source/en/training/cogvideox.md create mode 100644 docs/source/en/using-diffusers/cogvideox.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 22613cb343ff..e218e9878599 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -75,6 +75,8 @@ title: Outpainting title: Advanced inference - sections: + - local: using-diffusers/cogvideox + title: CogVideoX - local: using-diffusers/sdxl title: Stable Diffusion XL - local: using-diffusers/sdxl_turbo @@ -129,6 +131,8 @@ title: T2I-Adapters - local: training/instructpix2pix title: InstructPix2Pix + - local: training/cogvideox + title: CogVideoX title: Models - isExpanded: false sections: diff --git a/docs/source/en/training/cogvideox.md b/docs/source/en/training/cogvideox.md new file mode 100644 index 000000000000..657e58bfd5eb --- /dev/null +++ b/docs/source/en/training/cogvideox.md @@ -0,0 +1,291 @@ + +# CogVideoX + +CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods. + +- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy. + +- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos. + +The actual test of the video instruction dimension found that CogVideoX has good effects on consistent theme, dynamic information, consistent background, object information, smooth motion, color, scene, appearance style, and temporal style but cannot achieve good results with human action, spatial relationship, and multiple objects. + +Finetuning with Diffusers can help make up for these poor results. + +## Data Preparation + +The training scripts accepts data in two formats. + +The first format is suited for small-scale training, and the second format uses a CSV format, which is more appropriate for streaming data for large-scale training. In the future, Diffusers will support the `