Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/diffusers into …
Browse files Browse the repository at this point in the history
…main
  • Loading branch information
Skquark committed Feb 26, 2024
2 parents 47fe103 + d603ccb commit 12a68f7
Show file tree
Hide file tree
Showing 21 changed files with 1,035 additions and 264 deletions.
6 changes: 0 additions & 6 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ An attention processor is a class for applying different types of attention mech
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0

## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

## LoRAAttnProcessor2_0
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0

## LoRAAttnAddedKVProcessor
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor

Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/optimization/fp16.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,9 @@ image = pipe(prompt).images[0]
Don't use [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.

</Tip>

## Distilled model

You could also use a distilled Stable Diffusion model and autoencoder to speed up inference. During distillation, many of the UNet's residual and attention blocks are shed to reduce the model size. The distilled model is faster and uses less memory while generating images of comparable quality to the full Stable Diffusion model.

Learn more about in the [Distilled Stable Diffusion inference](../using-diffusers/distilled_sd) guide!
3 changes: 3 additions & 0 deletions docs/source/en/optimization/torch2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ Compilation requires some time to complete, so it is best suited for situations

For more information and different options about `torch.compile`, refer to the [`torch_compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) tutorial.

> [!TIP]
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
## Benchmark

We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
Expand Down
60 changes: 37 additions & 23 deletions docs/source/en/training/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,36 +113,50 @@ The dataset preprocessing code and training loop are found in the [`main()`](htt

As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the LoRA relevant parts of the script.

The script begins by adding the [new LoRA weights](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L447) to the attention layers. This involves correctly configuring the weight size for each block in the UNet. You'll see the `rank` parameter is used to create the [`~models.attention_processor.LoRAAttnProcessor`]:
<hfoptions id="lora">
<hfoption id="UNet">

Diffusers uses [`~peft.LoraConfig`] from the [PEFT](https://hf.co/docs/peft) library to set up the parameters of the LoRA adapter such as the rank, alpha, and which modules to insert the LoRA weights into. The adapter is added to the UNet, and only the LoRA layers are filtered for optimization in `lora_layers`.

```py
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
```

</hfoption>
<hfoption id="text encoder">

Diffusers also supports finetuning the text encoder with LoRA from the [PEFT](https://hf.co/docs/peft) library when necessary such as finetuning Stable Diffusion XL (SDXL). The [`~peft.LoraConfig`] is used to configure the parameters of the LoRA adapter which are then added to the text encoder, and only the LoRA layers are filtered for training.

```py
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
```

The [optimizer](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L519) is initialized with the `lora_layers` because these are the only weights that'll be optimized:
</hfoption>
</hfoptions>

The [optimizer](https://github.com/huggingface/diffusers/blob/e4b8f173b97731686e290b2eb98e7f5df2b1b322/examples/text_to_image/train_text_to_image_lora.py#L529) is initialized with the `lora_layers` because these are the only weights that'll be optimized:

```py
optimizer = optimizer_cls(
lora_layers.parameters(),
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,9 @@ Check your image dimensions to see if they're correct:
images.shape
# (8, 1, 512, 512, 3)
```

## Resources

To learn more about how JAX works with Stable Diffusion, you may be interested in reading:

* [Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e](https://hf.co/blog/sdxl_jax)
7 changes: 5 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3561,14 +3561,17 @@ pipe.disable_style_aligned()

This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.

This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method.

```py
import torch
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image

model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)

image = load_image("snail.png")
output = pipe(
Expand Down
4 changes: 2 additions & 2 deletions examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,15 +1766,15 @@ def __call__(

# 4. Prepare timesteps
def denoising_value_valid(dnv):
return isinstance(self.denoising_end, float) and 0 < dnv < 1
return isinstance(dnv, float) and 0 < dnv < 1

timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
if image is not None:
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
)

# check that number of inference steps is not < 1 - as this doesn't make sense
Expand Down
53 changes: 48 additions & 5 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
Expand Down Expand Up @@ -382,6 +382,41 @@ def encode_image(self, image, device, num_images_per_prompt):
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
Expand Down Expand Up @@ -767,6 +802,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
conditioning_frames: Optional[List[PipelineImageInput]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -821,6 +857,9 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
are specified, images must be passed as a list such that each element of the list can be correctly
Expand Down Expand Up @@ -965,9 +1004,9 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
)

if isinstance(controlnet, ControlNetModel):
conditioning_frames = self.prepare_image(
Expand Down Expand Up @@ -1023,7 +1062,11 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)

# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
Expand Down
Loading

0 comments on commit 12a68f7

Please sign in to comment.