Skip to content

Commit

Permalink
allow for internal states + joint state to attend bidirectionally
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 24, 2024
1 parent 944aab5 commit 3fada24
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
63 changes: 54 additions & 9 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from random import random

from beartype.typing import Callable
from beartype import beartype
from beartype.typing import Callable

from functools import partial

Expand Down Expand Up @@ -48,6 +48,12 @@
# w - image width
# f - image frames

# token layout for transformer
# vision and language tokens are autoregressive causal mask, actions, interal states + joint bidirectional amongst own tokens, but still autoregressive with respect to other tokens

# [state token groups] [action token groups]
# [external state] [visual tokens] [language tokens] [joint state + internal state] [maybe reward / condition token] [action registers] [actions]

# constants

LinearNoBias = partial(nn.Linear, bias = False)
Expand All @@ -61,7 +67,15 @@
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention)

def create_pizero_attn_mask(prefix_causal_length, mask: Bool['b n']):
def create_pizero_attn_mask(
prefix_causal_length,
mask: Bool['b n'],
internal_state_offset_and_len: tuple[int, int] | None = None
):

state_offset, state_len = default(internal_state_offset_and_len, (0, 0))
state_left, state_right = state_offset, state_offset + state_len

# the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side

def mask_fn(batch_index, head_index, query_index, key_index):
Expand All @@ -73,7 +87,12 @@ def mask_fn(batch_index, head_index, query_index, key_index):
query_index >= prefix_causal_length
)

return (key_mask and causal_mask) or bidirectional_action_mask
bidirectional_internal_state_mask = (
state_left <= key_index and key_index < state_right and
state_left <= query_index and query_index < state_right
)

return (key_mask and causal_mask) or bidirectional_action_mask or bidirectional_internal_state_mask

return mask_fn

Expand Down Expand Up @@ -280,6 +299,7 @@ def forward(
actions,
rotary_emb = None,
mask: Bool['b n'] | None = None,
internal_state_offset_and_len: tuple[int, int] | None = None,
actions_value_residual: Tensor | None = None,
return_keys_values = False,
flex_attn_fn: Callable | None = None
Expand Down Expand Up @@ -328,6 +348,11 @@ def forward(

causal_mask[..., seq_len:, seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper

if exists(internal_state_offset_and_len):
offset, length = internal_state_offset_and_len
state_slice = slice(offset, offset + length)
causal_mask[..., state_slice, state_slice] = False

sim = sim.masked_fill(causal_mask, max_neg_value(sim))

attn = sim.softmax(dim = -1)
Expand Down Expand Up @@ -754,7 +779,7 @@ def forward(
self,
images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision
token_ids: Int['b nt'], # language
joint_state: Float['b djs'], # joint state
joint_state: Float['b djs'], # joint state
actions: Float['b na da'] | None = None, # action
times: Float['b'] = None,
reward_tokens: Float['b d'] | None = None,
Expand Down Expand Up @@ -804,6 +829,10 @@ def forward(

action_tokens, inverse_pack_action_registers = pack_with_inverse([action_register_tokens, action_tokens], 'b * d')

action_with_registers_length = action_tokens.shape[-2]

internal_state_offset_and_len = None

if not inferencing:
# language

Expand Down Expand Up @@ -865,6 +894,21 @@ def forward(
else:
external_state_tokens = visual_tokens.new_empty((batch, 0, self.dim))

# allow joint and internal states to have bidirectional attention

internal_state_len = joint_state_tokens.shape[-2] + internal_state_tokens.shape[-2]

internal_state_offset = (
external_state_tokens.shape[-2] +
visual_tokens.shape[-2] +
language_tokens.shape[-2]
)

internal_state_offset_and_len = (
internal_state_offset,
internal_state_len
)

# concat visual rep with language

state_tokens, inverse_packed_states = pack_with_inverse([
Expand All @@ -876,7 +920,6 @@ def forward(
reward_tokens
], 'b * d')


# take care of masking for variable lengthed states, starting with the language tokens

# which then leads to proper rotary embeddings
Expand All @@ -885,8 +928,6 @@ def forward(

language_mask = token_ids != self.lm_pad_id

action_with_registers_length = action_tokens.shape[-2]

if inferencing:
state_length = cached_state_keys_values[0][0].shape[-2]
else:
Expand All @@ -911,7 +952,11 @@ def forward(
seq_len = prefix_length + action_tokens.shape[-2]

block_mask = create_block_mask(
create_pizero_attn_mask(prefix_length, mask = mask),
create_pizero_attn_mask(
prefix_length,
mask = mask,
internal_state_offset_and_len = internal_state_offset_and_len
),
Q_LEN = seq_len,
KV_LEN = seq_len,
device = state_tokens.device
Expand Down Expand Up @@ -945,7 +990,7 @@ def forward(

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, return_keys_values = True)
(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, internal_state_offset_and_len = internal_state_offset_and_len, return_keys_values = True)

state_cached_keys_values.append((state_keys, state_values))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.30"
version = "0.0.31"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 3fada24

Please sign in to comment.