Skip to content

Commit

Permalink
Merge pull request #7 from lucidrains/rotary
Browse files Browse the repository at this point in the history
replace shaws with rotary embeddings
  • Loading branch information
lucidrains authored Apr 14, 2021
2 parents 4eea9b9 + 98343cf commit a643dc1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ k = torch.randn(8, 2048, 64)
v = torch.randn(8, 2048, 64)

attn = LocalAttention(
dim = 64, # dimension of each head (you need to pass this in for relative positional encoding)
window_size = 512, # window size. 512 is optimal, but 256 or 128 yields good enough results
causal = True, # auto-regressive or not
look_backward = 1, # each window looks at the window before
Expand All @@ -45,6 +46,7 @@ qk = torch.randn(8, 2048, 64)
v = torch.randn(8, 2048, 64)

attn = LocalAttention(
dim = 64,
window_size = 512,
shared_qk = True,
causal = True
Expand Down
60 changes: 27 additions & 33 deletions local_attention/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
from operator import mul
from functools import reduce

from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# constant

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work

# helper functions

def exists(val):
return val is not None

def default(value, d):
return d if value is None else value
return d if not exists(value) else value

def to(t):
return {'device': t.device, 'dtype': t.dtype}
Expand Down Expand Up @@ -49,32 +54,22 @@ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
return torch.cat(tensors, dim=dim)

# Shaw's relative positional encoding per window

def shift(x):
*_, i, j = x.shape
zero_pad = torch.zeros((*_, i, i), **to(x))
x = torch.cat([x, zero_pad], -1)
l = i + j - 1
x = x.view(*_, -1)
zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
return shifted[..., :i, i - 1:]

class RelativePositionalEmbedding(nn.Module):
def __init__(self, dim, heads, length):
super().__init__()
self.scale = dim ** -0.5
self.weights = nn.Parameter(torch.zeros(length, heads, dim))

def forward(self, q):
emb = torch.einsum('bhnid,jhd->bhnij', q, self.weights.type(q.dtype)) * self.scale
return shift(emb)

# main class

class LocalAttention(nn.Module):
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, autopad = False, exact_windowsize = False):
def __init__(
self,
window_size,
causal = False,
look_backward = 1,
look_forward = None,
dropout = 0.,
shared_qk = False,
rel_pos_emb_config = None,
dim = None,
autopad = False,
exact_windowsize = False
):
super().__init__()
look_forward = default(look_forward, 0 if causal else 1)
assert not (causal and look_forward > 0), 'you cannot look forward if causal'
Expand All @@ -91,18 +86,21 @@ def __init__(self, window_size, causal = False, look_backward = 1, look_forward
self.shared_qk = shared_qk

self.rel_pos = None
if rel_pos_emb_config is not None:
dim_head, heads = rel_pos_emb_config
rel_pos_length = window_size * (1 + look_forward + look_backward)
self.heads = heads
self.rel_pos = RelativePositionalEmbedding(dim_head, heads, rel_pos_length)
if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument
if exists(rel_pos_emb_config):
dim = rel_pos_emb_config[0]
self.rel_pos = SinusoidalEmbeddings(dim)

def forward(self, q, k, v, input_mask = None):
shape = q.shape

merge_into_batch = lambda t: t.reshape(-1, *t.shape[-2:])
q, k, v = map(merge_into_batch, (q, k, v))

if exists(self.rel_pos):
pos_emb = self.rel_pos(q)
q, k = apply_rotary_pos_emb(q, k, pos_emb)

if self.autopad:
orig_t = q.shape[1]
q, k, v = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
Expand Down Expand Up @@ -131,10 +129,6 @@ def forward(self, q, k, v, input_mask = None):

dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5)

if self.rel_pos is not None:
rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(dots)
dots = dots + rel_attn

mask_value = max_neg_value(dots)

if shared_qk:
Expand Down
29 changes: 29 additions & 0 deletions local_attention/rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch import nn, einsum
from einops import rearrange, repeat

class SinusoidalEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, x):
n = x.shape[-2]
t = torch.arange(n, device = x.device).type_as(self.inv_freq)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]

def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(q, k, sinu_pos):
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
sin, cos = sinu_pos.unbind(dim = -2)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
return q, k
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'local-attention',
packages = find_packages(),
version = '1.2.2',
version = '1.4.0',
license='MIT',
description = 'Local windowed attention, for language modeling',
author = 'Phil Wang',
Expand Down

0 comments on commit a643dc1

Please sign in to comment.