From cf7723695f515f12b0eb6fbd1df1f96a84d67832 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 3 Dec 2024 05:36:50 -0800 Subject: [PATCH] nevermind, discovered the warts of the technique --- pi_zero_pytorch/pi_zero.py | 19 +------------------ pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 3444a73..de04cde 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -196,7 +196,6 @@ def __init__( heads = 8, dropout = 0., softclamp_value = 50., - laser = False, num_recurrent_memory_tokens = 0, learned_value_action_residual_mix = False, rotary_emb: RotaryEmbedding | None = None @@ -212,10 +211,6 @@ def __init__( self.rmsnorm = nn.RMSNorm(dim) - # laser attention - - self.laser = laser - # state parameters self.to_qkv = LinearNoBias(dim, 3 * dim_inner) @@ -277,12 +272,6 @@ def forward_actions_with_cached_state( elif exists(self.rotary_emb): q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) - # maybe laser - - if self.laser: - v_max = v.amax(dim = -2, keepdim = True).detach() - v = (v - v_max).exp() - # attention if exists(flex_attn_fn): @@ -301,11 +290,6 @@ def forward_actions_with_cached_state( out = einsum(attn, v, 'b h i j, b h j d -> b h i d') - # maybe laser - - if self.laser: - out = log(out) + v_max - # gate out = out * ag.sigmoid() @@ -551,7 +535,6 @@ def __init__( use_flex_attn = False, ff_expand_factor = 4., attn_softclamp_value = 50., - attn_laser = False, final_norm_softclamp_value = 30., vit: Module | None = None, vit_dim = None, @@ -642,7 +625,7 @@ def __init__( is_first_block = i == 0 layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, laser = attn_laser, **attn_kwargs), + Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs), SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) ])) diff --git a/pyproject.toml b/pyproject.toml index 60f2188..b79bc1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.38" +version = "0.0.39" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }