From 0049a90b562f6df0ae22cd7487de97d30250fcd0 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 25 Nov 2024 08:52:56 -0800 Subject: [PATCH] prepare to bring in recurrence, using the RMT design --- README.md | 11 +++++++++++ pi_zero_pytorch/pi_zero.py | 18 +++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7c58153..bf6ee75 100644 --- a/README.md +++ b/README.md @@ -101,4 +101,15 @@ That's it } ``` +```bibtex +@article{Bulatov2022RecurrentMT, + title = {Recurrent Memory Transformer}, + author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2207.06881}, + url = {https://api.semanticscholar.org/CorpusID:250526424} +} +``` + [*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index ce23c95..cab5a43 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -201,6 +201,7 @@ def __init__( heads = 8, dropout = 0., softclamp_value = 50., + recurrent_memory_params = False, learned_value_action_residual_mix = False, rotary_emb: RotaryEmbedding | None = None ): @@ -215,9 +216,13 @@ def __init__( self.rmsnorm = nn.RMSNorm(dim) + # state parameters + self.to_qkv = LinearNoBias(dim, 3 * dim_inner) self.to_out = LinearNoBias(dim_inner, dim) + # action parameters + self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner) self.to_action_value_residual_mix = nn.Sequential( @@ -230,6 +235,16 @@ def __init__( self.softclamp_value = softclamp_value + # maybe recurrent memory parameters + + self.accepts_recurrent_memories = recurrent_memory_params + + if not recurrent_memory_params: + return + + self.to_memories_qkv = LinearNoBias(dim, 3 * dim_inner) + self.to_memories_out = LinearNoBias(dim_inner, dim) + def forward_actions_with_cached_state( self, actions, @@ -542,6 +557,7 @@ def __init__( flow_loss_weight = 1., immiscible_flow = False, # https://arxiv.org/abs/2406.12303 reward_tokens_dropout_prob = 0., + recurrent_memories = False, odeint_kwargs: dict = dict( atol = 1e-5, rtol = 1e-5, @@ -611,7 +627,7 @@ def __init__( is_first_block = i == 0 layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), + Attention(dim = dim, dim_head = dim_head, heads = heads, recurrent_memory_params = recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) ]))