Skip to content

Commit

Permalink
fix flex attention masking
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 19, 2024
1 parent d5256ff commit 346bb05
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ def create_pizero_attn_mask(prefix_causal_length, mask: Bool['b n']):
# the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side

def mask_fn(batch_index, head_index, query_index, key_index):
return (
mask[batch_index, key_index] and # variable length states
query_index >= key_index and # causal
key_index >= prefix_causal_length # bidirectional
key_mask = mask[batch_index, key_index] # variable length states
causal_mask = query_index >= key_index # causal

bidirectional_action_mask = ( # bidirectional action mask
key_index >= prefix_causal_length and
query_index >= prefix_causal_length
)

return (key_mask and causal_mask) or bidirectional_action_mask

return mask_fn

def softclamp_score_mod(value):
Expand Down

0 comments on commit 346bb05

Please sign in to comment.