From 91735c2c378e6364ad7317f6ebaa937eb2076c34 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 5 Jul 2020 13:00:57 -0700 Subject: [PATCH] add relative positional encoding for each window, but do not advertise --- local_attention/local_attention.py | 36 +++++++++++++++++++++++++++++- setup.py | 2 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/local_attention/local_attention.py b/local_attention/local_attention.py index 2a20dff..5a0e4fa 100644 --- a/local_attention/local_attention.py +++ b/local_attention/local_attention.py @@ -13,6 +13,9 @@ def default(value, d): return d if value is None else value +def to(t): + return {'device': t.device, 'dtype': t.dtype} + def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max @@ -36,10 +39,32 @@ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2): tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] return torch.cat(tensors, dim=dim) +# Shaw's relative positional encoding per window + +def shift(x): + *_, i, j = x.shape + zero_pad = torch.zeros((*_, i, i), **to(x)) + x = torch.cat([x, zero_pad], -1) + l = i + j - 1 + x = x.view(*_, -1) + zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x)) + shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l) + return shifted[..., :i, i - 1:] + +class RelativePositionalEmbedding(nn.Module): + def __init__(self, dim, heads, length): + super().__init__() + self.scale = dim ** -0.5 + self.weights = nn.Parameter(torch.zeros(length, heads, dim)) + + def forward(self, q): + emb = torch.einsum('bhnid,jhd->bhnij', q, self.weights.type(q.dtype)) * self.scale + return shift(emb) + # main class class LocalAttention(nn.Module): - def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False): + def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None): super().__init__() self.look_forward = default(look_forward, 0 if causal else 1) assert not (causal and self.look_forward > 0), 'you cannot look forward if causal' @@ -52,6 +77,11 @@ def __init__(self, window_size, causal = False, look_backward = 1, look_forward self.shared_qk = shared_qk + if rel_pos_emb_config is not None: + dim_head, heads = rel_pos_emb_config + self.heads = heads + self.rel_pos = RelativePositionalEmbedding(dim_head, heads, window_size * 2) + def forward(self, q, k, v, input_mask = None): shape = q.shape @@ -82,6 +112,10 @@ def forward(self, q, k, v, input_mask = None): dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5) + if self.rel_pos is not None: + rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(dots) + dots = dots + rel_attn + mask_value = max_neg_value(dots) if shared_qk: diff --git a/setup.py b/setup.py index 8b9ec2e..a34fec9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'local-attention', packages = find_packages(), - version = '1.0.0', + version = '1.0.1', license='MIT', description = 'Local windowed attention, for language modeling', author = 'Phil Wang',