diff --git a/local_attention/local_attention.py b/local_attention/local_attention.py index afc8964..d9c39d8 100644 --- a/local_attention/local_attention.py +++ b/local_attention/local_attention.py @@ -147,7 +147,7 @@ def forward(self, q, k, v, input_mask = None): if self.exact_windowsize: max_causal_window_size = (self.window_size * self.look_backward) - mask = mask & (bq_t[:, :, :, None] > (bq_k[:, :, None, :] + max_causal_window_size)) + mask = mask | (bq_t[:, :, :, None] > (bq_k[:, :, None, :] + max_causal_window_size)) dots.masked_fill_(mask, mask_value) del mask diff --git a/setup.py b/setup.py index d2d3903..21c244f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'local-attention', packages = find_packages(), - version = '1.2.1', + version = '1.2.2', license='MIT', description = 'Local windowed attention, for language modeling', author = 'Phil Wang',