diff --git a/local_attention/local_attention.py b/local_attention/local_attention.py index b437c20..53c820f 100644 --- a/local_attention/local_attention.py +++ b/local_attention/local_attention.py @@ -63,7 +63,8 @@ def __init__( autopad = False, exact_windowsize = False, scale = None, - use_xpos = False + use_xpos = False, + xpos_scale_base = None ): super().__init__() look_forward = default(look_forward, 0 if causal else 1) @@ -92,7 +93,12 @@ def __init__( if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0] - self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = window_size // 2) + + self.rel_pos = SinusoidalEmbeddings( + dim, + use_xpos = use_xpos, + scale_base = default(xpos_scale_base, window_size // 2) + ) def forward(self, q, k, v, mask = None, input_mask = None, window_size = None): mask = default(mask, input_mask) diff --git a/local_attention/transformer.py b/local_attention/transformer.py index a80c26d..677aca1 100644 --- a/local_attention/transformer.py +++ b/local_attention/transformer.py @@ -48,6 +48,7 @@ def __init__( qk_rmsnorm = False, qk_scale = 8, use_xpos = False, + xpos_scale_base = None, **kwargs ): super().__init__() @@ -72,6 +73,7 @@ def __init__( scale = (qk_scale if qk_rmsnorm else None), exact_windowsize = True, use_xpos = use_xpos, + xpos_scale_base = xpos_scale_base, **kwargs ) @@ -131,6 +133,7 @@ def __init__( ff_dropout = 0., ignore_index = -1, use_xpos = False, + xpos_scale_base = None, **kwargs ): super().__init__() @@ -142,7 +145,7 @@ def __init__( for _ in range(depth): self.layers.append(nn.ModuleList([ - LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, prenorm = True, **kwargs), + LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, xpos_scale_base = xpos_scale_base, prenorm = True, **kwargs), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) diff --git a/setup.py b/setup.py index 253f460..a8b5057 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'local-attention', packages = find_packages(), - version = '1.7.0', + version = '1.7.1', license='MIT', description = 'Local attention, window with lookback, for language modeling', long_description_content_type = 'text/markdown',