Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableDiffusion3PAGImg2Img Pipeline + Fix SD3 Unconditional PAG #9932

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

painebenjamin
Copy link
Contributor

What does this PR do?

This PR does two things:

  1. Adds StableDiffusion3PAGImg2ImgPipeline, tests and documentation.
  2. Fixes a bug with SD3 + PAG in general which caused unconditional PAG generation to fail (PAGJointAttentionProcessor2_0 was failing due to it receiving attention_mask as a keyword argument.)

Before submitting

Who can review?

Anyone who is interested, but particularly @yiyixuxu and @asomoza

Images

All are with SD3-Medium-Diffusers

Test output - PAG+CFG
test_pag_cfg

Test output - PAG only
test_pag_uncond

Docstring example output (PAG+CFG):
output

Copy link
Contributor

@jeongiin jeongiin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I found this work interesting while reading it and noticed what seemed to be a typo, so I removed it.

@yiyixuxu
Copy link
Collaborator

cc @rootonchair if you want to give a review!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@rootonchair rootonchair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just finished looking through the pipeline code. Just a small displacement. It seems like the style is failing. Could you run 'make style' and 'make quality' to fix it?

Good work. I will proceed reviewing the unittest later

@painebenjamin
Copy link
Contributor Author

I have just finished looking through the pipeline code. Just a small displacement. It seems like the style is failing. Could you run 'make style' and 'make quality' to fix it?

Good work. I will proceed reviewing the unittest later

Thank you @rootonchair, all set on the style/quality fix!

@@ -1171,6 +1171,7 @@ def __call__(
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we add this here? it is not used, no?

Copy link
Contributor Author

@painebenjamin painebenjamin Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the second part of what I wrote above - when using SD3+PAG and foregoing CFG (e.g. calling a PAG pipeline with guidance_scale=0,) PAGJointAttnProcessor2_0 is used instead of PAGCFGJointAttnProcessor2_0, and the following error is produced:

StableDiffusion3PAGImg2ImgPipelineIntegrationTests.test_pag_uncond 
__________________________________________________

self = <tests.pipelines.pag.test_pag_sd3_img2img.StableDiffusion3PAGImg2ImgPipelineIntegrationTests testMethod=test_pag_uncond>

    def test_pag_uncond(self):
        pipeline = AutoPipelineForImage2Image.from_pretrained(
            self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
        )
        pipeline.enable_model_cpu_offload()
        pipeline.set_progress_bar_config(disable=None)
    
        inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
>       image = pipeline(**inputs).images

tests/pipelines/pag/test_pag_sd3_img2img.py:261: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py:975: in __call__
    noise_pred = self.transformer(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/accelerate/hooks.py:170: in new_forward
    output = module._old_forward(*args, **kwargs)
src/diffusers/models/transformers/transformer_sd3.py:346: in forward
    encoder_hidden_states, hidden_states = block(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
src/diffusers/models/attention.py:208: in forward
    attn_output, context_attn_output = self.attn(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = Attention(
  (to_q): Linear(in_features=1536, out_features=1536, bias=True)
  (to_k): Linear(in_features=1536, out_fea...ue)
    (1): Dropout(p=0.0, inplace=False)
  )
  (to_add_out): Linear(in_features=1536, out_features=1536, bias=True)
)
hidden_states = tensor([[[-0.0430, -3.7031,  0.2078,  ...,  0.3115,  0.0703,  0.0383],
         [ 0.0179, -2.4727,  0.1594,  ..., -0.0...,
         [-0.0490, -0.3691,  0.2568,  ..., -1.0303, -0.0298,  0.5527]]],
       device='cuda:0', dtype=torch.float16)
encoder_hidden_states = tensor([[[-0.0503,  0.0515, -0.0623,  ..., -0.0044, -0.0186, -0.0752],
         [ 0.1860, -0.2595,  0.0835,  ...,  0.1...,
         [ 0.6958, -0.4875, -0.1246,  ...,  0.2664, -0.1700,  0.0030]]],
       device='cuda:0', dtype=torch.float16)
attention_mask = None, cross_attention_kwargs = {}, unused_kwargs = []

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **cross_attention_kwargs,
    ) -> torch.Tensor:
        r"""
        The forward method of the `Attention` class.
    
        Args:
            hidden_states (`torch.Tensor`):
                The hidden states of the query.
            encoder_hidden_states (`torch.Tensor`, *optional*):
                The hidden states of the encoder.
            attention_mask (`torch.Tensor`, *optional*):
                The attention mask to use. If `None`, no mask is applied.
            **cross_attention_kwargs:
                Additional keyword arguments to pass along to the cross attention.
    
        Returns:
            `torch.Tensor`: The output of the attention layer.
        """
        # The `Attention` class can call different attention processors / attention functions
        # here we simply pass along all tensors to the selected processor class
        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
    
        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
        quiet_attn_parameters = {"ip_adapter_masks"}
        unused_kwargs = [
            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
        ]
        if len(unused_kwargs) > 0:
            logger.warning(
                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
            )
        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
    
>       return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
E       TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'

src/diffusers/models/attention_processor.py:530: TypeError
======================================================================= short test summary info =======================================================================
FAILED tests/pipelines/pag/test_pag_sd3_img2img.py::StableDiffusion3PAGImg2ImgPipelineIntegrationTests::test_pag_uncond - TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'

An alternative to adding this particular keyword argument would be to catch all other keyword arguments with **kwargs, which there is precedent for in other attention processors, but I generally default to being more restrictive and not less. For whatever it's worth, PAGCFGJointAttnProcessor2_0 does both of those things; it captures attention_mask and does nothing with it, and also has *args and **kwargs.

If there is any particular way that you think is the most in-line with the rest of the codebase, I'll be happy to adjust.

@rootonchair
Copy link
Contributor

FAILED tests/pipelines/pag/test_pag_sd3_img2img.py::StableDiffusion3PAGImg2ImgPipelineFastTests::test_pag_inference - AssertionError: 0.0720449086320496 not less than or equal to 0.001

@painebenjamin could you update the expected values?

@painebenjamin
Copy link
Contributor Author

painebenjamin commented Nov 22, 2024

FAILED tests/pipelines/pag/test_pag_sd3_img2img.py::StableDiffusion3PAGImg2ImgPipelineFastTests::test_pag_inference - AssertionError: 0.0720449086320496 not less than or equal to 0.001

@painebenjamin could you update the expected values?

All set! Something was off in my environment, but restarting with a fresh conda env got my machine reproducing the same values as the failed run. I also re-checked the slow tests and the CFG one was off in the fresh environment so I fixed that one too, and re-ran make style and make quality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants