Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge changes #191

Merged
merged 10 commits into from
Nov 28, 2024
9 changes: 6 additions & 3 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

There are three official CogVideoX checkpoints for text-to-video and video-to-video.

| checkpoints | recommended inference dtype |
|---|---|
|:---:|:---:|
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |

There are two official CogVideoX checkpoints available for image-to-video.

| checkpoints | recommended inference dtype |
|---|---|
|:---:|:---:|
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |

Expand All @@ -48,8 +50,9 @@ For the CogVideoX 1.5 series:
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.

There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).

| checkpoints | recommended inference dtype |
|---|---|
|:---:|:---:|
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |

Expand Down
12 changes: 12 additions & 0 deletions docs/source/en/api/pipelines/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxControlImg2ImgPipeline
- all
- __call__

## FluxPriorReduxPipeline

[[autodoc]] FluxPriorReduxPipeline
- all
- __call__

## FluxFillPipeline

[[autodoc]] FluxFillPipeline
- all
- __call__
128 changes: 114 additions & 14 deletions examples/community/README.md

Large diffs are not rendered by default.

79 changes: 63 additions & 16 deletions examples/community/regional_prompting_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

import torch
import torchvision.transforms.functional as FF
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND


try:
from compel import Compel
except ImportError:
Compel = None

KBASE = "ADDBASE"
KCOMM = "ADDCOMM"
KBRK = "BREAK"

Expand All @@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):

Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
rp_args["power"]: int (power for attention maps in prompt mode)
rp_args["base_ratio"]:
float (Sets the ratio of the base prompt)
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)

Pipeline for text-to-image generation using Stable Diffusion.

Expand Down Expand Up @@ -70,6 +75,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
Expand All @@ -80,6 +86,7 @@ def __init__(
scheduler,
safety_checker,
feature_extractor,
image_encoder,
requires_safety_checker,
)
self.register_modules(
Expand All @@ -90,6 +97,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

@torch.no_grad()
Expand All @@ -110,17 +118,40 @@ def __call__(
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
if negative_prompt is None:
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)

device = self._execution_device
regions = 0

self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
self.power = int(rp_args["power"]) if "power" in rp_args else 1

prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)

if use_base:
bases = prompts.copy()
n_bases = n_prompts.copy()

for i, prompt in enumerate(prompts):
parts = prompt.split(KBASE)
if len(parts) == 2:
bases[i], prompts[i] = parts
elif len(parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
for i, prompt in enumerate(n_prompts):
n_parts = prompt.split(KBASE)
if len(n_parts) == 2:
n_bases[i], n_prompts[i] = n_parts
elif len(n_parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")

all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)

all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)

Expand All @@ -137,8 +168,16 @@ def getcompelembs(prps):

conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
base_embs = getcompelembs(all_bases_cn) if use_base else None
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
embs = getcompelembs(prompts) if not use_base else base_embs
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs

if use_base and self.base_ratio > 0:
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
Expand All @@ -147,6 +186,18 @@ def getcompelembs(prps):
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)

if use_base and self.base_ratio > 0:
base_embs = self.encode_prompt(bases, device, 1, True)[0]
base_n_embs = (
self.encode_prompt(n_bases, device, 1, True)[0]
if equal
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
)

conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

embs = n_embs = None

if not active:
Expand Down Expand Up @@ -225,8 +276,6 @@ def forward(

residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -247,16 +296,15 @@ def forward(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -283,7 +331,7 @@ def forward(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
add = add.strip() + " "
prompts = [p.strip() for p in prompt.split(KBRK)]
out_p.append([add + p for i, p in enumerate(prompts)])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
Expand Down Expand Up @@ -449,7 +497,6 @@ def startend(cells, array):
add = []
startend(add, inratios[1:])
icells.append(add)

return ocells, icells, sum(len(cell) for cell in icells)


Expand Down
Loading
Loading