Skip to content

Commit

Permalink
nevermind, discovered the warts of the technique
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 3, 2024
1 parent 88617f1 commit cf77236
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
19 changes: 1 addition & 18 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def __init__(
heads = 8,
dropout = 0.,
softclamp_value = 50.,
laser = False,
num_recurrent_memory_tokens = 0,
learned_value_action_residual_mix = False,
rotary_emb: RotaryEmbedding | None = None
Expand All @@ -212,10 +211,6 @@ def __init__(

self.rmsnorm = nn.RMSNorm(dim)

# laser attention

self.laser = laser

# state parameters

self.to_qkv = LinearNoBias(dim, 3 * dim_inner)
Expand Down Expand Up @@ -277,12 +272,6 @@ def forward_actions_with_cached_state(
elif exists(self.rotary_emb):
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)

# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True).detach()
v = (v - v_max).exp()

# attention

if exists(flex_attn_fn):
Expand All @@ -301,11 +290,6 @@ def forward_actions_with_cached_state(

out = einsum(attn, v, 'b h i j, b h j d -> b h i d')

# maybe laser

if self.laser:
out = log(out) + v_max

# gate

out = out * ag.sigmoid()
Expand Down Expand Up @@ -551,7 +535,6 @@ def __init__(
use_flex_attn = False,
ff_expand_factor = 4.,
attn_softclamp_value = 50.,
attn_laser = False,
final_norm_softclamp_value = 30.,
vit: Module | None = None,
vit_dim = None,
Expand Down Expand Up @@ -642,7 +625,7 @@ def __init__(
is_first_block = i == 0

layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, laser = attn_laser, **attn_kwargs),
Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs)
]))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.38"
version = "0.0.39"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit cf77236

Please sign in to comment.