Skip to content

Commit

Permalink
choose register tokens route for addressing a pathology in transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 13, 2024
1 parent b181d10 commit 1395440
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,13 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) #
}
```

```bibtex
@inproceedings{Darcet2023VisionTN,
title = {Vision Transformers Need Registers},
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:263134283}
}
```

[*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4)
36 changes: 29 additions & 7 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from pi_zero_pytorch.tensor_typing import Float, Int, Bool


import tqdm

# ein notation
Expand Down Expand Up @@ -91,6 +90,16 @@ def softclamp(t, value):

return (t / value).tanh() * value

def pack_with_inverse(t, pattern):
packed, packed_shape = pack(t, pattern)

def inverse(out, inv_pattern = None):
inv_pattern = default(inv_pattern, pattern)
out = unpack(out, packed_shape, inv_pattern)
return out

return packed, inverse

# losses

def direction_loss(pred, target, dim = -1):
Expand Down Expand Up @@ -354,6 +363,7 @@ def __init__(
final_norm_softclamp_value = 30.,
vit: Module | None = None,
vit_dim = None,
num_action_register_tokens = 4,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict(),
lm_loss_weight = 1.,
Expand Down Expand Up @@ -389,6 +399,10 @@ def __init__(
self.to_joint_state_tokens = nn.Linear(dim_joint_state, dim)

self.dim_action_input = dim_action_input

self.action_register_tokens = nn.Parameter(torch.zeros(num_action_register_tokens, dim))
nn.init.normal_(self.action_register_tokens, std = 0.02)

self.to_action_tokens = nn.Linear(dim_action_input, dim)

self.to_time_cond = nn.Sequential(
Expand Down Expand Up @@ -555,6 +569,10 @@ def forward(
time_cond = self.to_time_cond(times)
action_tokens = self.to_action_tokens(actions)

action_register_tokens = repeat(self.action_register_tokens, '... -> b ...', b = batch)

action_tokens, inverse_pack_action_registers = pack_with_inverse([action_register_tokens, action_tokens], 'b * d')

if not inferencing:
# language

Expand All @@ -570,14 +588,14 @@ def forward(

if is_multiple_images:
images = rearrange(images, 'b c f h w -> b f c h w')
images, images_frames_packed_shape = pack([images], '* c h w')
images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w')

with torch.no_grad():
self.vit.eval()
visual_tokens = self.vit(images)

if is_multiple_images:
visual_tokens, = unpack(visual_tokens, images_frames_packed_shape, '* n d')
visual_tokens, = inverse_pack_image_frames(visual_tokens, '* n d')
visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d')

else:
Expand All @@ -597,7 +615,7 @@ def forward(

# concat visual rep with language

state_tokens, packed_shape = pack([visual_tokens, language_tokens, joint_state_tokens, reward_tokens], 'b * d')
state_tokens, inverse_packed_states = pack_with_inverse([visual_tokens, language_tokens, joint_state_tokens, reward_tokens], 'b * d')

# prepare maybe flex attention

Expand Down Expand Up @@ -628,14 +646,16 @@ def forward(

# prepare rotary embeddings

action_length = actions.shape[-2]
action_with_registers_length = action_tokens.shape[-2]

if inferencing:
state_length = cached_state_keys_values[0][0].shape[-2]
else:
state_length = state_tokens.shape[-2]

total_seq_length = action_length + state_length
total_seq_length = action_with_registers_length + state_length

# rotary embeddings

seq = torch.arange(total_seq_length, device = self.device)
rotary_emb = self.rotary_emb(seq)
Expand Down Expand Up @@ -707,12 +727,14 @@ def forward(
if not inferencing:
# unpack and unembed to predictions

visual_tokens, tokens, *_ = unpack(state_tokens, packed_shape, 'b * d')
visual_tokens, tokens, *_ = inverse_packed_states(state_tokens, 'b * d')

# gemma uses a final softclamp before norm

tokens = self.final_norm_softclamp(tokens)

action_register_tokens, action_tokens = inverse_pack_action_registers(action_tokens)

action_tokens = self.final_norm_softclamp(action_tokens)

# projection
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.11"
version = "0.0.12"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 1395440

Please sign in to comment.