Skip to content

Commit

Permalink
add ability to length extrapolate to greater local attention window s…
Browse files Browse the repository at this point in the history
…izes at inference with xpos paper, needs to be turned on with use_xpos = True
  • Loading branch information
lucidrains committed Feb 28, 2023
1 parent 1311aa3 commit 35ad623
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 14 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,11 @@ $ python train.py
primaryClass = {cs.CL}
}
```

```bibtex
@inproceedings{Sun2022ALT,
title = {A Length-Extrapolatable Transformer},
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
year = {2022}
}
```
17 changes: 11 additions & 6 deletions local_attention/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(
dim = None,
autopad = False,
exact_windowsize = False,
scale = None
scale = None,
use_xpos = False
):
super().__init__()
look_forward = default(look_forward, 0 if causal else 1)
Expand All @@ -86,24 +87,28 @@ def __init__(
# relative positions

self.rel_pos = None
self.use_xpos = use_xpos

if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument
if exists(rel_pos_emb_config):
dim = rel_pos_emb_config[0]
self.rel_pos = SinusoidalEmbeddings(dim)
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = window_size // 2)

def forward(self, q, k, v, mask = None, input_mask = None):
def forward(self, q, k, v, mask = None, input_mask = None, window_size = None):
mask = default(mask, input_mask)

shape, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk
assert not (exists(window_size) and not self.use_xpos), 'cannot perform window size extrapolation if xpos is not turned on'

shape, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk

# https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))

# rotary embeddings

if exists(self.rel_pos):
pos_emb = self.rel_pos(q)
q, k = apply_rotary_pos_emb(q, k, pos_emb)
pos_emb, scale = self.rel_pos(q)
q, k = apply_rotary_pos_emb(q, k, pos_emb, scale = scale)

# auto padding

Expand Down
42 changes: 36 additions & 6 deletions local_attention/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,53 @@

from einops import rearrange

def exists(val):
return val is not None

class SinusoidalEmbeddings(nn.Module):
def __init__(self, dim):
def __init__(
self,
dim,
scale_base = None,
use_xpos = False
):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

# xpos related

self.use_xpos = use_xpos
self.scale_base = scale_base

assert not (use_xpos and not exists(scale_base)), 'scale base must be defined if using xpos'

scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale, persistent = False)

def forward(self, x):
n = x.shape[-2]
t = torch.arange(n, device = x.device).type_as(self.inv_freq)
seq_len, device = x.shape[-2], x.device

t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
freqs = torch.cat((freqs, freqs), dim = -1)

if not self.use_xpos:
return freqs, torch.ones(1, device = device)

power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)

return freqs, scale

def rotate_half(x):
x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)

def apply_rotary_pos_emb(q, k, freqs):
q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
inv_scale = scale ** -1
q = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
return q, k
5 changes: 4 additions & 1 deletion local_attention/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
prenorm = False,
qk_rmsnorm = False,
qk_scale = 8,
use_xpos = False,
**kwargs
):
super().__init__()
Expand All @@ -70,6 +71,7 @@ def __init__(
autopad = True,
scale = (qk_scale if qk_rmsnorm else None),
exact_windowsize = True,
use_xpos = use_xpos,
**kwargs
)

Expand Down Expand Up @@ -128,6 +130,7 @@ def __init__(
attn_dropout = 0.,
ff_dropout = 0.,
ignore_index = -1,
use_xpos = False,
**kwargs
):
super().__init__()
Expand All @@ -139,7 +142,7 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, prenorm = True, **kwargs),
LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, prenorm = True, **kwargs),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'local-attention',
packages = find_packages(),
version = '1.6.0',
version = '1.7.0',
license='MIT',
description = 'Local attention, window with lookback, for language modeling',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 35ad623

Please sign in to comment.