Skip to content

Commit

Permalink
[Community Pipeline] Add some feature for regional prompting pipeline (
Browse files Browse the repository at this point in the history
…huggingface#9874)

* [Fix] fix bugs of  regional_prompting pipeline

* [Feat] add base prompt feature

* [Fix] fix __init__ pipeline error

* [Fix] delete unused args

* [Fix] improve string handling

* [Docs] docs to use_base in regional_prompting

* make style

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
cjkangme and sayakpaul authored Nov 28, 2024
1 parent e44fc75 commit 69c83d6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
15 changes: 15 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```

### Use base prompt

You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.

```
2d animation style ADDBASE
masterpiece, high quality ADDCOMM
(blue sky)++ BREAK
green hair twintail BREAK
book shelf BREAK
messy desk BREAK
orange++ dress and sofa
```

### Negative prompt

Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
Expand Down Expand Up @@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args)
### Optional Parameters

- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`

The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.

Expand Down
79 changes: 63 additions & 16 deletions examples/community/regional_prompting_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

import torch
import torchvision.transforms.functional as FF
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND


try:
from compel import Compel
except ImportError:
Compel = None

KBASE = "ADDBASE"
KCOMM = "ADDCOMM"
KBRK = "BREAK"

Expand All @@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
rp_args["power"]: int (power for attention maps in prompt mode)
rp_args["base_ratio"]:
float (Sets the ratio of the base prompt)
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
Pipeline for text-to-image generation using Stable Diffusion.
Expand Down Expand Up @@ -70,6 +75,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
Expand All @@ -80,6 +86,7 @@ def __init__(
scheduler,
safety_checker,
feature_extractor,
image_encoder,
requires_safety_checker,
)
self.register_modules(
Expand All @@ -90,6 +97,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

@torch.no_grad()
Expand All @@ -110,17 +118,40 @@ def __call__(
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
if negative_prompt is None:
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)

device = self._execution_device
regions = 0

self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
self.power = int(rp_args["power"]) if "power" in rp_args else 1

prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)

if use_base:
bases = prompts.copy()
n_bases = n_prompts.copy()

for i, prompt in enumerate(prompts):
parts = prompt.split(KBASE)
if len(parts) == 2:
bases[i], prompts[i] = parts
elif len(parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
for i, prompt in enumerate(n_prompts):
n_parts = prompt.split(KBASE)
if len(n_parts) == 2:
n_bases[i], n_prompts[i] = n_parts
elif len(n_parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")

all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)

all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)

Expand All @@ -137,8 +168,16 @@ def getcompelembs(prps):

conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
base_embs = getcompelembs(all_bases_cn) if use_base else None
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
embs = getcompelembs(prompts) if not use_base else base_embs
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs

if use_base and self.base_ratio > 0:
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
Expand All @@ -147,6 +186,18 @@ def getcompelembs(prps):
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)

if use_base and self.base_ratio > 0:
base_embs = self.encode_prompt(bases, device, 1, True)[0]
base_n_embs = (
self.encode_prompt(n_bases, device, 1, True)[0]
if equal
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
)

conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

embs = n_embs = None

if not active:
Expand Down Expand Up @@ -225,8 +276,6 @@ def forward(

residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -247,16 +296,15 @@ def forward(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -283,7 +331,7 @@ def forward(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
add = add.strip() + " "
prompts = [p.strip() for p in prompt.split(KBRK)]
out_p.append([add + p for i, p in enumerate(prompts)])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
Expand Down Expand Up @@ -449,7 +497,6 @@ def startend(cells, array):
add = []
startend(add, inratios[1:])
icells.append(add)

return ocells, icells, sum(len(cell) for cell in icells)


Expand Down

0 comments on commit 69c83d6

Please sign in to comment.