Skip to content

Commit

Permalink
makeover with einops
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 8, 2022
1 parent 81cd932 commit 0beeff2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 56 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ $ pip install local-attention
import torch
from local_attention import LocalAttention

q = torch.randn(8, 2048, 64)
k = torch.randn(8, 2048, 64)
v = torch.randn(8, 2048, 64)
q = torch.randn(2, 8, 2048, 64)
k = torch.randn(2, 8, 2048, 64)
v = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
dim = 64, # dimension of each head (you need to pass this in for relative positional encoding)
Expand All @@ -32,18 +32,18 @@ attn = LocalAttention(
exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(1, 2048).bool()
out = attn(q, k, v, input_mask = mask) # (1, 8, 2048, 64)
mask = torch.ones(2, 2048).bool()
out = attn(q, k, v, input_mask = mask) # (2, 8, 2048, 64)
```

This library also allows for local attention in the setting of shared query/key space. The normalization of the keys, as well as the masking of tokens to itself, will be taken care of.
This library also allows for local attention in the setting of shared query/key space (Reformer architecture). The normalization of the keys, as well as the masking of tokens to itself, will be taken care of.

```python
import torch
from local_attention import LocalAttention

qk = torch.randn(8, 2048, 64)
v = torch.randn(8, 2048, 64)
qk = torch.randn(2, 8, 2048, 64)
v = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
dim = 64,
Expand All @@ -52,7 +52,7 @@ attn = LocalAttention(
causal = True
)

mask = torch.ones(1, 2048).bool()
mask = torch.ones(2, 2048).bool()
out = attn(qk, qk, v, input_mask = mask) # (1, 8, 2048, 64)
```

Expand Down
96 changes: 50 additions & 46 deletions local_attention/local_attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import math
from operator import mul
from functools import reduce

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# constant

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work
TOKEN_SELF_ATTN_VALUE = -5e4

# helper functions

Expand All @@ -26,18 +26,10 @@ def to(t):
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

def merge_dims(ind_from, ind_to, tensor):
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)

def expand_dim(t, dim, k, unsqueeze=True):
if unsqueeze:
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
def l2norm(tensor):
dtype = tensor.dtype
normed = F.normalize(tensor, dim = -1)
return normed.type(dtype)

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
seqlen = tensor.shape[dim]
Expand Down Expand Up @@ -97,36 +89,37 @@ def __init__(
self.rel_pos = SinusoidalEmbeddings(dim)

def forward(self, q, k, v, input_mask = None):
shape, autopad, pad_value = q.shape, self.autopad, -1
shape, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk

# https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))

merge_into_batch = lambda t: t.reshape(-1, *t.shape[-2:])
q, k, v = map(merge_into_batch, (q, k, v))
# rotary embeddings

if exists(self.rel_pos):
pos_emb = self.rel_pos(q)
q, k = apply_rotary_pos_emb(q, k, pos_emb)

# auto padding

if autopad:
orig_t = q.shape[1]
orig_seq_len = q.shape[1]
q, k, v = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))

window_size, causal, look_backward, look_forward, shared_qk = self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk

b, t, e, device, dtype = *q.shape, q.device, q.dtype
scale = e ** -0.5
b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
scale = dim_head ** -0.5

assert (t % window_size) == 0, f'sequence length {t} must be divisible by window size {window_size} for local attention'
assert (n % window_size) == 0, f'sequence length {t} must be divisible by window size {window_size} for local attention'

windows = t // window_size
windows = n // window_size

if shared_qk:
k = F.normalize(k, 2, dim=-1).type_as(q)
k = l2norm(k)

ticker = torch.arange(t, device=device, dtype=torch.long)[None, :]
b_t = ticker.reshape(1, windows, window_size)
seq = torch.arange(n, device = device)
b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)

bucket_fn = lambda t: t.reshape(b, windows, window_size, -1)
bq, bk, bv = map(bucket_fn, (q, k, v))
bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))

look_around_kwargs = dict(
backward = look_backward,
Expand All @@ -140,48 +133,59 @@ def forward(self, q, k, v, input_mask = None):
bq_t = b_t
bq_k = look_around(b_t, **look_around_kwargs)

dots = einsum('b h i e, b h j e -> b h i j', bq, bk) * scale
bq_t = rearrange(bq_t, '... i -> ... i 1')
bq_k = rearrange(bq_k, '... j -> ... 1 j')

mask_value = max_neg_value(dots)
sim = einsum('b h i e, b h j e -> b h i j', bq, bk) * scale

mask_value = max_neg_value(sim)

if shared_qk:
mask = bq_t[..., :, None] == bq_k[..., None, :]
dots = dots.masked_fill(mask, TOKEN_SELF_ATTN_VALUE)
mask = bq_t == bq_k
sim = sim.masked_fill(mask, TOKEN_SELF_ATTN_VALUE)
del mask

if causal:
mask = bq_t[..., :, None] < bq_k[..., None, :]
mask = bq_t < bq_k

if self.exact_windowsize:
max_causal_window_size = (self.window_size * self.look_backward)
mask = mask | (bq_t[..., :, None] > (bq_k[..., None, :] + max_causal_window_size))
mask = mask | (bq_t > (bq_k + max_causal_window_size))

dots = dots.masked_fill(mask, mask_value)
sim = sim.masked_fill(mask, mask_value)
del mask

mask = bq_k[:, :, None, :] == pad_value
dots = dots.masked_fill(mask, mask_value)
# mask out padding value

mask = bq_k == pad_value
sim = sim.masked_fill(mask, mask_value)
del mask

if exists(input_mask):
h = b // input_mask.shape[0]
if autopad:
input_mask = pad_to_multiple(input_mask, window_size, dim = -1, value = False)
input_mask = input_mask.reshape(-1, windows, window_size)

input_mask = rearrange(input_mask, '... (w n) -> (...) w n', w = windows, n = window_size)
mq = mk = input_mask
mk = look_around(mk, **{**look_around_kwargs, 'pad_value': False})
mask = mq[..., :, None] & mk[..., None, :]
mask = merge_dims(0, 1, expand_dim(mask, 1, h))
dots = dots.masked_fill(~mask, mask_value)
mask = repeat(mask, 'b ... -> (b h) ...', h = h)
sim = sim.masked_fill(~mask, mask_value)
del mask

attn = dots.softmax(dim = -1)
# attention

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

# aggregation

out = einsum('b h i j, b h j e -> b h i e', attn, bv)
out = out.reshape(-1, t, e)
out = rearrange(out, 'b w n d -> b (w n) d')

if autopad:
out = out[:, :orig_t, :]
out = out[:, :orig_seq_len, :]

return out.reshape(*shape)
out, *_ = unpack(out, packed_shape, '* n d')
return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'local-attention',
packages = find_packages(),
version = '1.5.1',
version = '1.5.2',
license='MIT',
description = 'Local attention, window with lookback, for language modeling',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 0beeff2

Please sign in to comment.