Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge changes #183

Merged
merged 20 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
99f6082
[sd3] make sure height and size are divisible by `16` (#9573)
yiyixuxu Oct 3, 2024
3159e60
fix xlabs FLUX lora conversion typo (#9581)
Clement-Lelievre Oct 7, 2024
31010ec
[Chore] add a note on the versions in Flux LoRA integration tests (#9…
sayakpaul Oct 7, 2024
2cb383f
fix vae dtype when accelerate config using --mixed_precision="fp16" (…
xduzhangjiayu Oct 7, 2024
a80f689
refac: docstrings in import_utils.py (#9583)
yijun-lee Oct 7, 2024
1287822
Fix for use_safetensors parameters, allow use of parameter on loading…
elismasilva Oct 7, 2024
63a5c87
Update distributed_inference.md to include `transformer.device_map` (…
sayakpaul Oct 8, 2024
66eef9a
fix: CogVideox train dataset _preprocess_data crop video (#9574)
glide-the Oct 8, 2024
02eeb8e
[LoRA] Handle DoRA better (#9547)
sayakpaul Oct 8, 2024
86bd991
Fixed noise_pred_text referenced before assignment. (#9537)
LagPixelLOL Oct 8, 2024
acd6d2c
Fix the bug that `joint_attention_kwargs` is not passed to the FLUX's…
HorizonWind2004 Oct 8, 2024
ec9e526
refac/pipeline_output (#9582)
yijun-lee Oct 9, 2024
31058cd
[LoRA] allow loras to be loaded with low_cpu_mem_usage. (#9510)
sayakpaul Oct 9, 2024
af28ae2
add PAG support for SD Img2Img (#9463)
SahilCarterr Oct 9, 2024
07bd2fa
make controlnet support interrupt (#9620)
pureexe Oct 9, 2024
e16fd93
[LoRA] fix dora test to catch the warning properly. (#9627)
sayakpaul Oct 10, 2024
38a3e4d
flux controlnet control_guidance_start and control_guidance_end imple…
ighoshsubho Oct 10, 2024
164ec9f
fix IsADirectoryError when running the training code for sd3_dreamboo…
alaister123 Oct 11, 2024
3033f08
Add Differential Diffusion to Kolors (#9423)
saqlain2204 Oct 11, 2024
0f8fb75
FluxMultiControlNetModel (#9647)
hlky Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__

## StableDiffusionPAGImg2ImgPipeline
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
- all
- __call__

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
```

> [!TIP]
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.

Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.

Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ image

![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)

<Tip>

By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.

</Tip>

## Merge adapters

You can also merge different adapter checkpoints for inference to blend their styles together.
Expand Down
1 change: 1 addition & 0 deletions examples/cogvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.

> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
Expand Down
91 changes: 77 additions & 14 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer

import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
Expand Down Expand Up @@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
Expand Down Expand Up @@ -413,6 +423,7 @@ def __init__(
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
Expand All @@ -429,6 +440,7 @@ def __init__(
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
Expand Down Expand Up @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):

return instance_prompts, instance_videos

def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)

h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)

delta_h = h - image_size[0]
delta_w = w - image_size[1]

if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr

def _preprocess_data(self):
try:
import decord
Expand All @@ -542,15 +586,14 @@ def _preprocess_data(self):

decord.bridge.set_bridge("torch")

videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []

for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)

start_frame = min(self.skip_frames_start, video_num_frames)
Expand All @@ -576,10 +619,16 @@ def _preprocess_data(self):
assert (selected_num_frames - 1) % 4 == 0

# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
frames = (frames - 127.5) / 127.5
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)

progress_dataset_bar.close()
return videos


Expand Down Expand Up @@ -694,8 +743,13 @@ def log_validation(

videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])

image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)

videos.append(image_pil)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down Expand Up @@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
Expand All @@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
id_token=args.id_token,
)

def encode_video(video):
def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist

train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()

def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
Expand Down
Loading
Loading