From 97ed762d6bf0010617d91f28cf1bc11d3170810e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 9 Dec 2024 13:21:06 -0800 Subject: [PATCH] complete join attention of memories at park --- pi_zero_pytorch/pi_zero.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index a1329c2..7f478e2 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -223,7 +223,7 @@ def __init__( self.accept_memories = accept_memories self.mem_rmsnorm = nn.RMSNorm(dim) if accept_memories else None - self.to_mem_qkv = LinearNoBias(dim, 3.* dim_inner) if accept_memories else None + self.to_mem_qkv = LinearNoBias(dim, 3 * dim_inner) if accept_memories else None self.to_mem_out = LinearNoBias(dim_inner, dim) if accept_memories else None # action parameters @@ -391,11 +391,17 @@ def forward( assert not (self.accept_memories ^ exists(memories)) if exists(memories): - memories = self.mem_rmsnorm(memories) memories, unpack_memories = pack_with_inverse(memories, 'b * d') + memories = self.mem_rmsnorm(memories) mqkv = self.to_mem_qkv(memories) mqkv_read, mqkv_write = unpack_memories(mqkv, 'b * d') + mqr, mkr, mvr, mqw, mkw, mvw = tuple(self.split_heads(t) for t in (*mqkv_read.chunk(3, dim = -1), *mqkv_write.chunk(3, dim = -1))) + + k = torch.cat((mkr, k, mkw), dim = -2) + v = torch.cat((mvr, v, mvw), dim = -2) + q, attn_output_unpack_memories = pack_with_inverse((mqr, q, mqw), 'b h * d') + # rotary embedding if exists(rotary_emb): @@ -436,6 +442,11 @@ def forward( out = out * gates + # split out memories + + if self.accept_memories: + mem_read_out, out, mem_write_out = attn_output_unpack_memories(out) + # merge attention heads out = self.merge_heads(out) @@ -651,7 +662,7 @@ def __init__( is_first_block = i == 0 layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), + Attention(dim = dim, dim_head = dim_head, heads = heads, accept_memories = self.has_recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None @@ -1013,6 +1024,8 @@ def forward( memory_tokens = (past_recurrent_memory_tokens, write_memory_tokens) + mem_length = past_recurrent_memory_tokens.shape[-2] + write_memory_tokens.shape[-2] + # pack into [action registers] [internal + joint states] [actions] action_tokens, inverse_pack_action_registers = pack_with_inverse([ @@ -1098,6 +1111,10 @@ def forward( mask = F.pad(language_mask, (state_length - command_length, action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later + # memory + + mask = F.pad(mask, (past_recurrent_memory_tokens.shape[-2], write_memory_tokens.shape[-2]), value = True) + # rotary embeddings seq = mask.float().cumsum(dim = -1) @@ -1159,7 +1176,8 @@ def forward( flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, - return_keys_values = True + return_keys_values = True, + memories = memory_tokens ) state_cached_keys_values.append((state_keys, state_values))