Skip to content

Commit

Permalink
Merge pull request #183 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes
  • Loading branch information
Skquark authored Oct 12, 2024
2 parents c7a05b6 + 0f8fb75 commit aca5451
Show file tree
Hide file tree
Showing 45 changed files with 3,482 additions and 97 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__

## StableDiffusionPAGImg2ImgPipeline
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
- all
- __call__

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
```

> [!TIP]
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.

Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ image

![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)

<Tip>

By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.

</Tip>

## Merge adapters

You can also merge different adapter checkpoints for inference to blend their styles together.
Expand Down
1 change: 1 addition & 0 deletions examples/cogvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
Expand Down
91 changes: 77 additions & 14 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer

import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
Expand Down Expand Up @@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
Expand Down Expand Up @@ -413,6 +423,7 @@ def __init__(
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
Expand All @@ -429,6 +440,7 @@ def __init__(
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
Expand Down Expand Up @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):

return instance_prompts, instance_videos

def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)

h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)

delta_h = h - image_size[0]
delta_w = w - image_size[1]

if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr

def _preprocess_data(self):
try:
import decord
Expand All @@ -542,15 +586,14 @@ def _preprocess_data(self):

decord.bridge.set_bridge("torch")

videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []

for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_reader = decord.VideoReader(uri=filename.as_posix())
video_num_frames = len(video_reader)

start_frame = min(self.skip_frames_start, video_num_frames)
Expand All @@ -576,10 +619,16 @@ def _preprocess_data(self):
assert (selected_num_frames - 1) % 4 == 0

# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
frames = (frames - 127.5) / 127.5
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)

progress_dataset_bar.close()
return videos


Expand Down Expand Up @@ -694,8 +743,13 @@ def log_validation(

videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])

image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)

videos.append(image_pil)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down Expand Up @@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
Expand All @@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
id_token=args.id_token,
)

def encode_video(video):
def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist

train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()

def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
Expand Down
Loading

0 comments on commit aca5451

Please sign in to comment.