diff --git a/README.md b/README.md index 3a83b47..b988b33 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ k = torch.randn(8, 2048, 64) v = torch.randn(8, 2048, 64) attn = LocalAttention( + dim = 64, # dimension of each head (you need to pass this in for relative positional encoding) window_size = 512, # window size. 512 is optimal, but 256 or 128 yields good enough results causal = True, # auto-regressive or not look_backward = 1, # each window looks at the window before @@ -45,6 +46,7 @@ qk = torch.randn(8, 2048, 64) v = torch.randn(8, 2048, 64) attn = LocalAttention( + dim = 64, window_size = 512, shared_qk = True, causal = True diff --git a/local_attention/local_attention.py b/local_attention/local_attention.py index d9c39d8..b71ae47 100644 --- a/local_attention/local_attention.py +++ b/local_attention/local_attention.py @@ -5,14 +5,19 @@ from operator import mul from functools import reduce +from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb + # constant TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work # helper functions +def exists(val): + return val is not None + def default(value, d): - return d if value is None else value + return d if not exists(value) else value def to(t): return {'device': t.device, 'dtype': t.dtype} @@ -49,32 +54,22 @@ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2): tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] return torch.cat(tensors, dim=dim) -# Shaw's relative positional encoding per window - -def shift(x): - *_, i, j = x.shape - zero_pad = torch.zeros((*_, i, i), **to(x)) - x = torch.cat([x, zero_pad], -1) - l = i + j - 1 - x = x.view(*_, -1) - zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x)) - shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l) - return shifted[..., :i, i - 1:] - -class RelativePositionalEmbedding(nn.Module): - def __init__(self, dim, heads, length): - super().__init__() - self.scale = dim ** -0.5 - self.weights = nn.Parameter(torch.zeros(length, heads, dim)) - - def forward(self, q): - emb = torch.einsum('bhnid,jhd->bhnij', q, self.weights.type(q.dtype)) * self.scale - return shift(emb) - # main class class LocalAttention(nn.Module): - def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, autopad = False, exact_windowsize = False): + def __init__( + self, + window_size, + causal = False, + look_backward = 1, + look_forward = None, + dropout = 0., + shared_qk = False, + rel_pos_emb_config = None, + dim = None, + autopad = False, + exact_windowsize = False + ): super().__init__() look_forward = default(look_forward, 0 if causal else 1) assert not (causal and look_forward > 0), 'you cannot look forward if causal' @@ -91,11 +86,10 @@ def __init__(self, window_size, causal = False, look_backward = 1, look_forward self.shared_qk = shared_qk self.rel_pos = None - if rel_pos_emb_config is not None: - dim_head, heads = rel_pos_emb_config - rel_pos_length = window_size * (1 + look_forward + look_backward) - self.heads = heads - self.rel_pos = RelativePositionalEmbedding(dim_head, heads, rel_pos_length) + if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument + if exists(rel_pos_emb_config): + dim = rel_pos_emb_config[0] + self.rel_pos = SinusoidalEmbeddings(dim) def forward(self, q, k, v, input_mask = None): shape = q.shape @@ -103,6 +97,10 @@ def forward(self, q, k, v, input_mask = None): merge_into_batch = lambda t: t.reshape(-1, *t.shape[-2:]) q, k, v = map(merge_into_batch, (q, k, v)) + if exists(self.rel_pos): + pos_emb = self.rel_pos(q) + q, k = apply_rotary_pos_emb(q, k, pos_emb) + if self.autopad: orig_t = q.shape[1] q, k, v = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v)) @@ -131,10 +129,6 @@ def forward(self, q, k, v, input_mask = None): dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5) - if self.rel_pos is not None: - rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(dots) - dots = dots + rel_attn - mask_value = max_neg_value(dots) if shared_qk: diff --git a/local_attention/rotary.py b/local_attention/rotary.py new file mode 100644 index 0000000..0e8ea8f --- /dev/null +++ b/local_attention/rotary.py @@ -0,0 +1,29 @@ +import torch +from torch import nn, einsum +from einops import rearrange, repeat + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n = x.shape[-2] + t = torch.arange(n, device = x.device).type_as(self.inv_freq) + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + +def rotate_every_two(x): + x = rearrange(x, '... (d j) -> ... d j', j = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d j -> ... (d j)') + +def apply_rotary_pos_emb(q, k, sinu_pos): + sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) + sin, cos = sinu_pos.unbind(dim = -2) + sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) + q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + return q, k diff --git a/setup.py b/setup.py index 21c244f..1c64f07 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'local-attention', packages = find_packages(), - version = '1.2.2', + version = '1.4.0', license='MIT', description = 'Local windowed attention, for language modeling', author = 'Phil Wang',