diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dcb9528cb1a0..1f29622bdf20 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed( t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: + # stable audio freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis @@ -590,11 +593,11 @@ def apply_rotary_emb( cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: - # Use for example in Lumina + # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: - # Use for example in Stable Audio + # Used for Stable Audio x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: @@ -604,6 +607,7 @@ def apply_rotary_emb( return out else: + # used for lumina x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)