From 1ca0a75567da1ca5a97681310c1b57e9f527a84a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 25 Aug 2024 11:57:12 -1000 Subject: [PATCH] refactor 3d rope for cogvideox (#9269) * refactor 3d rope * repeat -> expand --- src/diffusers/models/embeddings.py | 86 ++++++++----------- .../pipelines/cogvideo/pipeline_cogvideox.py | 1 - 2 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d1366654c448..dcb9528cb1a0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed( The size of the temporal dimension. theta (`float`): Scaling factor for frequency computation. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) # Compute dimensions for each axis @@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed( dim_w = embed_dim // 8 * 3 # Temporal frequencies - freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) - grid_t = torch.from_numpy(grid_t).float() - freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) - freqs_t = freqs_t.repeat_interleave(2, dim=-1) - + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) # Spatial frequencies for height and width - freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) - freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) - grid_h = torch.from_numpy(grid_h).float() - grid_w = torch.from_numpy(grid_w).float() - freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) - freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) - freqs_h = freqs_h.repeat_interleave(2, dim=-1) - freqs_w = freqs_w.repeat_interleave(2, dim=-1) - - # Broadcast and concatenate tensors along specified dimension - def broadcast(tensors, dim=-1): - num_tensors = len(tensors) - shape_lens = {len(t.shape) for t in tensors} - assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" - shape_len = list(shape_lens)[0] - dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*(list(t.shape) for t in tensors))) - expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] - assert all( - [*(len(set(t[1])) <= 2 for t in expandable_dims)] - ), "invalid dimensions for broadcastable concatenation" - max_dims = [(t[0], max(t[1])) for t in expandable_dims] - expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] - expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) - tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] - return torch.cat(tensors, dim=dim) - - freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - - t, h, w, d = freqs.shape - freqs = freqs.view(t * h * w, d) - - # Generate sine and cosine components - sin = freqs.sin() - cos = freqs.cos() - - if use_real: - return cos, sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - return freqs_cis + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index e100c1f11e20..11f491e49532 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -463,7 +463,6 @@ def _prepare_rotary_positional_embeddings( crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, - use_real=True, ) freqs_cos = freqs_cos.to(device=device)