Skip to content

Commit

Permalink
complete non-inference path for recurrent memories
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent 97ed762 commit 8ad66fd
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,13 @@ def forward(

output = self.to_out(mout), self.to_actions_out(aout)

if self.accept_memories:
mem_out, unpack_memories = pack_with_inverse((mem_read_out, mem_write_out), 'b h * d')
mem_out = self.merge_heads(mem_out)
mem_out = self.to_mem_out(mem_out)

output = (*output, unpack_memories(mem_out, 'b * d'))

if not return_keys_values:
return output

Expand Down Expand Up @@ -1169,7 +1176,7 @@ def forward(

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(
(state_attn_out, actions_attn_out, *maybe_mem_out), (state_keys, state_values, action_keys, action_values) = attn(
state_tokens,
action_tokens,
rotary_emb = rotary_emb,
Expand All @@ -1187,6 +1194,12 @@ def forward(
state_tokens = state_tokens + state_attn_out
action_tokens = action_tokens + attn_ada_layerscale(actions_attn_out, time_cond)

if self.has_recurrent_memories:
(read_mem_attn_out, write_mem_attn_out), = maybe_mem_out
read_mem, write_mem = memory_tokens

memory_tokens = (read_mem + read_mem_attn_out, write_mem + write_mem_attn_out)

state_tokens = state_ff(state_tokens) + state_tokens

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)
Expand Down

0 comments on commit 8ad66fd

Please sign in to comment.