Skip to content

Commit

Permalink
setup separate parameters for memories in joint attention module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent d344532 commit 4b08c06
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
heads = 8,
dropout = 0.,
softclamp_value = 50.,
num_recurrent_memory_tokens = 0,
accept_memories = False,
learned_value_action_residual_mix = False,
rotary_emb: RotaryEmbedding | None = None
):
Expand All @@ -218,6 +218,14 @@ def __init__(
self.to_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_out = LinearNoBias(dim_inner, dim)

# maybe memory parameters

self.accept_memories = accept_memories

self.mem_rmsnorm = nn.RMSNorm(dim) if accept_memories else None
self.to_mem_qkv = LinearNoBias(dim, 3.* dim_inner) if accept_memories else None
self.to_mem_out = LinearNoBias(dim_inner, dim) if accept_memories else None

# action parameters

self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner)
Expand All @@ -232,22 +240,11 @@ def __init__(

self.softclamp_value = softclamp_value

# maybe recurrent memory parameters

has_recurrent_memories = num_recurrent_memory_tokens > 0
self.accepts_recurrent_memories = has_recurrent_memories
self.num_mem = num_recurrent_memory_tokens

if not has_recurrent_memories:
return

self.to_memories_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_memories_out = LinearNoBias(dim_inner, dim)

def forward_actions_with_cached_state(
self,
actions,
cached_state_keys_values: tuple[Tensor, Tensor],
memories: tuple[Tensor, Tensor] | None = None,
rotary_emb = None,
mask: Bool['b n'] | None = None,
actions_value_residual: Tensor | None = None,
Expand All @@ -265,8 +262,19 @@ def forward_actions_with_cached_state(
q = aq
mk, mv = cached_state_keys_values

# concat cache key / values with action key / values

k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av)))

# handle read, write memories

assert not (self.accept_memories ^ exists(memories))

if exists(memories):
_, write_memories = memories
write_memories = self.mem_rmsnorm(write_memories)
mqkv_write = self.to_mem_qkv(write_memories)

if exists(rotary_emb):
q = apply_rotary_emb(rotary_emb, q, freqs_seq_dim = -2)
k = apply_rotary_emb(rotary_emb, k)
Expand Down Expand Up @@ -354,6 +362,7 @@ def forward(
multimodal_seq,
actions,
rotary_emb = None,
memories: tuple[Tensor, Tensor] | None = None,
mask: Bool['b n'] | None = None,
actions_value_residual: Tensor | None = None,
return_keys_values = False,
Expand All @@ -377,6 +386,18 @@ def forward(

q, k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mq, mk, mv), (aq, ak, av)))

# handle read, write memories

assert not (self.accept_memories ^ exists(memories))

if exists(memories):
memories = self.mem_rmsnorm(memories)
memories, unpack_memories = pack_with_inverse(memories, 'b * d')
mqkv = self.to_mem_qkv(memories)
mqkv_read, mqkv_write = unpack_memories(mqkv, 'b * d')

# rotary embedding

if exists(rotary_emb):
q = apply_rotary_emb(rotary_emb, q)
k = apply_rotary_emb(rotary_emb, k)
Expand Down Expand Up @@ -630,7 +651,7 @@ def __init__(
is_first_block = i == 0

layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
Attention(dim = dim, dim_head = dim_head, heads = heads, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None
Expand Down

0 comments on commit 4b08c06

Please sign in to comment.