Skip to content

Commit

Permalink
add ability to configure xpos scale base
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 2, 2023
1 parent 35ad623 commit c74a5bf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
10 changes: 8 additions & 2 deletions local_attention/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(
autopad = False,
exact_windowsize = False,
scale = None,
use_xpos = False
use_xpos = False,
xpos_scale_base = None
):
super().__init__()
look_forward = default(look_forward, 0 if causal else 1)
Expand Down Expand Up @@ -92,7 +93,12 @@ def __init__(
if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument
if exists(rel_pos_emb_config):
dim = rel_pos_emb_config[0]
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = window_size // 2)

self.rel_pos = SinusoidalEmbeddings(
dim,
use_xpos = use_xpos,
scale_base = default(xpos_scale_base, window_size // 2)
)

def forward(self, q, k, v, mask = None, input_mask = None, window_size = None):
mask = default(mask, input_mask)
Expand Down
5 changes: 4 additions & 1 deletion local_attention/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
qk_rmsnorm = False,
qk_scale = 8,
use_xpos = False,
xpos_scale_base = None,
**kwargs
):
super().__init__()
Expand All @@ -72,6 +73,7 @@ def __init__(
scale = (qk_scale if qk_rmsnorm else None),
exact_windowsize = True,
use_xpos = use_xpos,
xpos_scale_base = xpos_scale_base,
**kwargs
)

Expand Down Expand Up @@ -131,6 +133,7 @@ def __init__(
ff_dropout = 0.,
ignore_index = -1,
use_xpos = False,
xpos_scale_base = None,
**kwargs
):
super().__init__()
Expand All @@ -142,7 +145,7 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, prenorm = True, **kwargs),
LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, xpos_scale_base = xpos_scale_base, prenorm = True, **kwargs),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'local-attention',
packages = find_packages(),
version = '1.7.0',
version = '1.7.1',
license='MIT',
description = 'Local attention, window with lookback, for language modeling',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit c74a5bf

Please sign in to comment.