Skip to content

Commit

Permalink
prepare to bring in recurrence, using the RMT design
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 25, 2024
1 parent 84219e3 commit 0049a90
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,15 @@ That's it
}
```

```bibtex
@article{Bulatov2022RecurrentMT,
title = {Recurrent Memory Transformer},
author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
journal = {ArXiv},
year = {2022},
volume = {abs/2207.06881},
url = {https://api.semanticscholar.org/CorpusID:250526424}
}
```

[*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4)
18 changes: 17 additions & 1 deletion pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
heads = 8,
dropout = 0.,
softclamp_value = 50.,
recurrent_memory_params = False,
learned_value_action_residual_mix = False,
rotary_emb: RotaryEmbedding | None = None
):
Expand All @@ -215,9 +216,13 @@ def __init__(

self.rmsnorm = nn.RMSNorm(dim)

# state parameters

self.to_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_out = LinearNoBias(dim_inner, dim)

# action parameters

self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner)

self.to_action_value_residual_mix = nn.Sequential(
Expand All @@ -230,6 +235,16 @@ def __init__(

self.softclamp_value = softclamp_value

# maybe recurrent memory parameters

self.accepts_recurrent_memories = recurrent_memory_params

if not recurrent_memory_params:
return

self.to_memories_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_memories_out = LinearNoBias(dim_inner, dim)

def forward_actions_with_cached_state(
self,
actions,
Expand Down Expand Up @@ -542,6 +557,7 @@ def __init__(
flow_loss_weight = 1.,
immiscible_flow = False, # https://arxiv.org/abs/2406.12303
reward_tokens_dropout_prob = 0.,
recurrent_memories = False,
odeint_kwargs: dict = dict(
atol = 1e-5,
rtol = 1e-5,
Expand Down Expand Up @@ -611,7 +627,7 @@ def __init__(
is_first_block = i == 0

layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
Attention(dim = dim, dim_head = dim_head, heads = heads, recurrent_memory_params = recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs)
]))
Expand Down

0 comments on commit 0049a90

Please sign in to comment.