Skip to content

Commit

Permalink
add relative positional encoding for each window, but do not advertise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 5, 2020
1 parent 4f59e4a commit 91735c2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
36 changes: 35 additions & 1 deletion local_attention/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
def default(value, d):
return d if value is None else value

def to(t):
return {'device': t.device, 'dtype': t.dtype}

def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

Expand All @@ -36,10 +39,32 @@ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
return torch.cat(tensors, dim=dim)

# Shaw's relative positional encoding per window

def shift(x):
*_, i, j = x.shape
zero_pad = torch.zeros((*_, i, i), **to(x))
x = torch.cat([x, zero_pad], -1)
l = i + j - 1
x = x.view(*_, -1)
zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
return shifted[..., :i, i - 1:]

class RelativePositionalEmbedding(nn.Module):
def __init__(self, dim, heads, length):
super().__init__()
self.scale = dim ** -0.5
self.weights = nn.Parameter(torch.zeros(length, heads, dim))

def forward(self, q):
emb = torch.einsum('bhnid,jhd->bhnij', q, self.weights.type(q.dtype)) * self.scale
return shift(emb)

# main class

class LocalAttention(nn.Module):
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False):
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None):
super().__init__()
self.look_forward = default(look_forward, 0 if causal else 1)
assert not (causal and self.look_forward > 0), 'you cannot look forward if causal'
Expand All @@ -52,6 +77,11 @@ def __init__(self, window_size, causal = False, look_backward = 1, look_forward

self.shared_qk = shared_qk

if rel_pos_emb_config is not None:
dim_head, heads = rel_pos_emb_config
self.heads = heads
self.rel_pos = RelativePositionalEmbedding(dim_head, heads, window_size * 2)

def forward(self, q, k, v, input_mask = None):
shape = q.shape

Expand Down Expand Up @@ -82,6 +112,10 @@ def forward(self, q, k, v, input_mask = None):

dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5)

if self.rel_pos is not None:
rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(dots)
dots = dots + rel_attn

mask_value = max_neg_value(dots)

if shared_qk:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'local-attention',
packages = find_packages(),
version = '1.0.0',
version = '1.0.1',
license='MIT',
description = 'Local windowed attention, for language modeling',
author = 'Phil Wang',
Expand Down

0 comments on commit 91735c2

Please sign in to comment.