diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 0d2f5f74d325..6538522756f9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -62,6 +62,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks """ _supports_gradient_checkpointing = True @@ -87,6 +90,7 @@ def __init__( force_upcast: float = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, ): super().__init__() @@ -100,6 +104,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, + mid_block_add_attention=mid_block_add_attention, ) # pass init params to Decoder @@ -111,6 +116,7 @@ def __init__( layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None