-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StableDiffusion3PAGImg2Img Pipeline + Fix SD3 Unconditional PAG #9932
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I found this work interesting while reading it and noticed what seemed to be a typo, so I removed it.
cc @rootonchair if you want to give a review! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just finished looking through the pipeline code. Just a small displacement. It seems like the style is failing. Could you run 'make style' and 'make quality' to fix it?
Good work. I will proceed reviewing the unittest later
Co-authored-by: Vinh H. Pham <[email protected]>
Co-authored-by: Vinh H. Pham <[email protected]>
Thank you @rootonchair, all set on the style/quality fix! |
@@ -1171,6 +1171,7 @@ def __call__( | |||
attn: Attention, | |||
hidden_states: torch.FloatTensor, | |||
encoder_hidden_states: torch.FloatTensor = None, | |||
attention_mask: Optional[torch.FloatTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we add this here? it is not used, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the second part of what I wrote above - when using SD3+PAG and foregoing CFG (e.g. calling a PAG pipeline with guidance_scale=0
,) PAGJointAttnProcessor2_0
is used instead of PAGCFGJointAttnProcessor2_0
, and the following error is produced:
StableDiffusion3PAGImg2ImgPipelineIntegrationTests.test_pag_uncond
__________________________________________________
self = <tests.pipelines.pag.test_pag_sd3_img2img.StableDiffusion3PAGImg2ImgPipelineIntegrationTests testMethod=test_pag_uncond>
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
> image = pipeline(**inputs).images
tests/pipelines/pag/test_pag_sd3_img2img.py:261:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/utils/_contextlib.py:116: in decorate_context
return func(*args, **kwargs)
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py:975: in __call__
noise_pred = self.transformer(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
return forward_call(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/accelerate/hooks.py:170: in new_forward
output = module._old_forward(*args, **kwargs)
src/diffusers/models/transformers/transformer_sd3.py:346: in forward
encoder_hidden_states, hidden_states = block(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
return forward_call(*args, **kwargs)
src/diffusers/models/attention.py:208: in forward
attn_output, context_attn_output = self.attn(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Attention(
(to_q): Linear(in_features=1536, out_features=1536, bias=True)
(to_k): Linear(in_features=1536, out_fea...ue)
(1): Dropout(p=0.0, inplace=False)
)
(to_add_out): Linear(in_features=1536, out_features=1536, bias=True)
)
hidden_states = tensor([[[-0.0430, -3.7031, 0.2078, ..., 0.3115, 0.0703, 0.0383],
[ 0.0179, -2.4727, 0.1594, ..., -0.0...,
[-0.0490, -0.3691, 0.2568, ..., -1.0303, -0.0298, 0.5527]]],
device='cuda:0', dtype=torch.float16)
encoder_hidden_states = tensor([[[-0.0503, 0.0515, -0.0623, ..., -0.0044, -0.0186, -0.0752],
[ 0.1860, -0.2595, 0.0835, ..., 0.1...,
[ 0.6958, -0.4875, -0.1246, ..., 0.2664, -0.1700, 0.0030]]],
device='cuda:0', dtype=torch.float16)
attention_mask = None, cross_attention_kwargs = {}, unused_kwargs = []
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
> return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
E TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'
src/diffusers/models/attention_processor.py:530: TypeError
======================================================================= short test summary info =======================================================================
FAILED tests/pipelines/pag/test_pag_sd3_img2img.py::StableDiffusion3PAGImg2ImgPipelineIntegrationTests::test_pag_uncond - TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'
An alternative to adding this particular keyword argument would be to catch all other keyword arguments with **kwargs
, which there is precedent for in other attention processors, but I generally default to being more restrictive and not less. For whatever it's worth, PAGCFGJointAttnProcessor2_0
does both of those things; it captures attention_mask
and does nothing with it, and also has *args
and **kwargs
.
If there is any particular way that you think is the most in-line with the rest of the codebase, I'll be happy to adjust.
@painebenjamin could you update the expected values? |
All set! Something was off in my environment, but restarting with a fresh conda env got my machine reproducing the same values as the failed run. I also re-checked the slow tests and the CFG one was off in the fresh environment so I fixed that one too, and re-ran |
What does this PR do?
This PR does two things:
StableDiffusion3PAGImg2ImgPipeline
, tests and documentation.PAGJointAttentionProcessor2_0
was failing due to it receivingattention_mask
as a keyword argument.)Before submitting
This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.Who can review?
Anyone who is interested, but particularly @yiyixuxu and @asomoza
Images
All are with SD3-Medium-Diffusers
Test output - PAG+CFG
Test output - PAG only
Docstring example output (PAG+CFG):