From dd4459ad795230ee6d70dc270484bb9b523db9d4 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 9 Jan 2024 21:38:05 -1000 Subject: [PATCH] [Refactor] spliting`ResnetBlock2D` into multiple blocks (#6166) --------- Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- src/diffusers/models/resnet.py | 207 ++++++++++++-- src/diffusers/models/unet_2d_blocks.py | 268 ++++++++++++------ .../versatile_diffusion/modeling_text_unet.py | 85 ++++-- 3 files changed, 418 insertions(+), 142 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index bbfb71ca3fbf..3b8718f39940 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -42,6 +42,156 @@ ) +class ResnetBlockCondNorm2D(nn.Module): + r""" + A Resnet block that use normalization layer that incorporate conditioning information. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): + The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". + kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "ada_group", # ada_group, spatial + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": # spatial + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.dropout = torch.nn.Dropout(dropout) + + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = conv_cls( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + temb: torch.FloatTensor, + scale: float = 1.0, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, scale=scale) + hidden_states = self.upsample(hidden_states, scale=scale) + + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, scale=scale) + hidden_states = self.downsample(hidden_states, scale=scale) + + hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = ( + self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) + ) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + class ResnetBlock2D(nn.Module): r""" A Resnet block. @@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module): eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" + for a stronger conditioning with scale and shift. kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. @@ -87,7 +237,7 @@ def __init__( eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, - time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial + time_embedding_norm: str = "default", # default, scale_shift, kernel: Optional[torch.FloatTensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, @@ -97,7 +247,15 @@ def __init__( conv_2d_out_channels: Optional[int] = None, ): super().__init__() - self.pre_norm = pre_norm + if time_embedding_norm == "ada_group": + raise ValueError( + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + raise ValueError( + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", + ) + self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -115,12 +273,7 @@ def __init__( if groups_out is None: groups_out = groups - if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm1 = SpatialNorm(in_channels, temb_channels) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -129,19 +282,12 @@ def __init__( self.time_emb_proj = linear_cls(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None - if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm2 = SpatialNorm(out_channels, temb_channels) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels @@ -188,11 +334,7 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm1(hidden_states, temb) - else: - hidden_states = self.norm1(hidden_states) - + hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -233,17 +375,20 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm2(hidden_states, temb) - else: + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + scale) + shift + else: + hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e404cef224ff..470a021165ac 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -24,7 +24,16 @@ from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .normalization import AdaGroupNorm -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from .resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) from .transformer_2d import Transformer2DModel @@ -557,20 +566,35 @@ def __init__( attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] attentions = [] if attention_head_dim is None: @@ -599,20 +623,35 @@ def __init__( else: attentions.append(None) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1290,20 +1329,35 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.resnets = nn.ModuleList(resnets) @@ -1358,20 +1412,35 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) attentions.append( Attention( out_channels, @@ -1889,7 +1958,7 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -1975,7 +2044,7 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -2500,20 +2569,35 @@ def __init__( for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.resnets = nn.ModuleList(resnets) @@ -2568,20 +2652,36 @@ def __init__( for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) ) - ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( Attention( out_channels, @@ -3159,7 +3259,7 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, @@ -3267,7 +3367,7 @@ def __init__( conv_2d_out_channels = None resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 86e3cfefaf4f..504dfa40bac8 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -31,6 +31,7 @@ TimestepEmbedding, Timesteps, ) +from ....models.resnet import ResnetBlockCondNorm2D from ....models.transformer_2d import Transformer2DModel from ....models.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -2126,20 +2127,35 @@ def __init__( attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] attentions = [] if attention_head_dim is None: @@ -2168,20 +2184,35 @@ def __init__( else: attentions.append(None) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets)