From 6f30c84ff7772663b83b9e4c1806390098b63599 Mon Sep 17 00:00:00 2001 From: Jakob Wagner Date: Wed, 24 Jul 2024 15:10:58 +0200 Subject: [PATCH 1/4] change apply uniform attention masks to an explicit implementation --- src/continuiti/networks/attention.py | 24 +++++++------ .../networks/multi_head_attention.py | 34 ++++++++++--------- .../networks/scaled_dot_product_attention.py | 26 +++++++++++--- tests/networks/test_multi_head.py | 12 ++++--- tests/networks/test_scaled_dot.py | 32 ++++++++++++----- 5 files changed, 84 insertions(+), 44 deletions(-) diff --git a/src/continuiti/networks/attention.py b/src/continuiti/networks/attention.py index 0fa6aea2..b09337d5 100644 --- a/src/continuiti/networks/attention.py +++ b/src/continuiti/networks/attention.py @@ -7,15 +7,17 @@ from abc import abstractmethod import torch.nn as nn import torch +from typing import Optional -class Attention(nn.Module): - """Base class for various attention implementations. +class UniformMaskAttention(nn.Module): + """Base class for various attention implementations with uniform masking. - Attention assigns different parts of an input varying importance without set - kernels. The importance of different components is designated using "soft" - weights. These weights are assigned according to specific algorithms (e.g. + Attention assigns different parts of an input varying importance without set kernels. The importance of different + components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g. scaled-dot-product attention). + Uniform masking refers to the characteristic that all queries use the same mask. This restriction allows to remove + the query dimension from the mask. All queries have access to the same key/value pairs. """ def __init__(self): @@ -27,18 +29,18 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: torch.Tensor = None, + attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Calculates the attention scores. Args: - query: query tensor; shape (batch_size, target_seq_length, hidden_dim) - key: key tensor; shape (batch_size, source_seq_length, hidden_dim) - value: value tensor; shape (batch_size, source_seq_length, hidden_dim) + query: query tensor; shape (batch_size, target_seq_length, hidden_dim). + key: key tensor; shape (batch_size, source_seq_length, hidden_dim). + value: value tensor; shape (batch_size, source_seq_length, hidden_dim). attn_mask: tensor indicating which values are used to calculate the output; - shape (batch_size, target_seq_length, source_seq_length) + shape (batch_size, source_seq_length). Returns: tensor containing the outputs of the attention implementation; - shape (batch_size, target_seq_length, hidden_dim) + shape (batch_size, target_seq_length, hidden_dim). """ diff --git a/src/continuiti/networks/multi_head_attention.py b/src/continuiti/networks/multi_head_attention.py index e6eacb6f..8200cee2 100644 --- a/src/continuiti/networks/multi_head_attention.py +++ b/src/continuiti/networks/multi_head_attention.py @@ -6,19 +6,21 @@ import torch import torch.nn as nn +from typing import Optional -from .attention import Attention +from .attention import UniformMaskAttention from .scaled_dot_product_attention import ScaledDotProductAttention -class MultiHeadAttention(Attention): - r"""Multi-Head Attention module. +class MultiHeadAttention(UniformMaskAttention): + r"""Multi-Head Attention module with uniform mask. Module as described in the paper [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) - with optional bias for the projections. This implementation allows to use - attention implementations other than the standard scaled dot product - attention implemented by the MultiheadAttention PyTorch module. + with optional bias for the projections. This implementation allows to use attention implementations other than the + standard scaled dot product attention implemented by the MultiheadAttention PyTorch module. Additionally assumes + this implementation that the attention mask is applied uniformly for every batch (the mask for every key-value pair + matches for all queries of one sample). $$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$ @@ -39,7 +41,7 @@ def __init__( self, hidden_dim: int, n_heads: int, - attention: Attention = None, + attention: Optional[UniformMaskAttention] = None, dropout_p: float = 0, bias: bool = True, ): @@ -70,7 +72,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: torch.Tensor = None, + attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Compute the attention scores. @@ -78,7 +80,7 @@ def forward( query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim). key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim). value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim). - attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length). + attn_mask: Attention mask of shape (batch_size, source_sequence_length). Returns: Attention scores of shape (batch_size, target_sequence_length, hidden_dim). @@ -98,7 +100,6 @@ def forward( batch_size = query.size(0) src_len = key.size(1) - tgt_len = query.size(1) # project values query = self.query_project(query) @@ -112,18 +113,19 @@ def forward( # reshape attention mask to match heads if attn_mask is not None: + assert ( + attn_mask.ndim == 2 + ), "Expects exatly 2 dimensions in the mask, but found {attn_mask.ndim}." assert ( attn_mask.size(0) == batch_size ), "Attention mask batch size does not match input tensors." assert ( - attn_mask.size(1) == tgt_len - ), "First dimension of the attention mask needs to match target length." - assert ( - attn_mask.size(2) == src_len + attn_mask.size(1) == src_len ), "Second dimension of the attention mask needs to match source length." - attn_mask = attn_mask.unsqueeze(1) # mask for a single head - attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head + attn_mask = attn_mask.unsqueeze( + 1 + ) # apply attention mask uniformly to all heads # perform attention attn_out = self.attention( diff --git a/src/continuiti/networks/scaled_dot_product_attention.py b/src/continuiti/networks/scaled_dot_product_attention.py index 752fb765..813b77e3 100644 --- a/src/continuiti/networks/scaled_dot_product_attention.py +++ b/src/continuiti/networks/scaled_dot_product_attention.py @@ -4,13 +4,14 @@ Scaled dot product attention module. """ import torch +from typing import Optional -from .attention import Attention +from .attention import UniformMaskAttention from torch.nn.functional import scaled_dot_product_attention -class ScaledDotProductAttention(Attention): - """Scaled dot product attention module. +class ScaledDotProductAttention(UniformMaskAttention): + """Scaled dot product attention module with uniform mask. This module is a wrapper for the torch implementation of the scaled dot product attention mechanism as described in the paper "Attention Is All You @@ -29,9 +30,26 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: torch.Tensor = None, + attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Calculate attention scores. + + Args: + query: query tensor; shape (batch_size, target_seq_length, hidden_dim). + key: key tensor; shape (batch_size, source_seq_length, hidden_dim). + value: value tensor; shape (batch_size, source_seq_length, hidden_dim). + attn_mask: tensor indicating which values are used to calculate the output; shape + (batch_size, source_seq_length). Defaults to None. + + Returns: + tensor containing the outputs of the attention implementation; shape + (batch_size, target_seq_length, hidden_dim). + """ dropout_p = self.dropout_p if self.training else 0.0 + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(1) + return scaled_dot_product_attention( query=query, key=key, diff --git a/tests/networks/test_multi_head.py b/tests/networks/test_multi_head.py index 7a87e58d..7c602f9a 100644 --- a/tests/networks/test_multi_head.py +++ b/tests/networks/test_multi_head.py @@ -81,7 +81,7 @@ def test_gradient_flow(self, some_multi_head_attn): def test_equal_to_torch(self, random_qkv): q, k, v = random_qkv - mask = torch.rand(q.size(0), q.size(1), k.size(1)) < 0.2 + mask = torch.rand(q.size(0), k.size(1)) < 0.2 heads = 2 embedding_dim = q.size(-1) @@ -122,7 +122,9 @@ def test_equal_to_torch(self, random_qkv): out = attn(q, k, v, attn_mask=mask) # torch applies masks differently to scaled-dot-product and multi-head attention (inversed). - gt_mask = torch.repeat_interleave(mask, heads, 0).logical_not() + gt_mask = torch.repeat_interleave( + mask.unsqueeze(1).repeat(1, q.size(1), 1), heads, 0 + ).logical_not() ground_truth, _ = gt_attn(q, k, v, need_weights=False, attn_mask=gt_mask) assert torch.allclose( @@ -133,7 +135,7 @@ def test_full_mask_identical_to_none(self, random_qkv): heads = 2 q, k, v = random_qkv - mask = torch.ones(q.size(0), q.size(1), k.size(1)) + mask = torch.ones(q.size(0), k.size(1), dtype=torch.bool) attn = MultiHeadAttention( hidden_dim=q.size(-1), @@ -155,8 +157,8 @@ def test_mask_all_but_one(self, random_qkv): v.requires_grad = True # Masks out the last kvs - mask = torch.ones(q.size(0), q.size(1), k.size(1), dtype=torch.bool) - mask[:, :, -1] = 0 + mask = torch.ones(q.size(0), k.size(1), dtype=torch.bool) + mask[:, -1] = 0 attn = MultiHeadAttention( hidden_dim=q.size(-1), diff --git a/tests/networks/test_scaled_dot.py b/tests/networks/test_scaled_dot.py index bb8d1286..0fc8857e 100644 --- a/tests/networks/test_scaled_dot.py +++ b/tests/networks/test_scaled_dot.py @@ -4,19 +4,35 @@ from continuiti.networks import ScaledDotProductAttention -def test_forward_correct(): +class TestScaledDotProductAttention: batch_size = 3 query_size = 5 key_val_size = 7 hidden_dim = 11 - query = torch.rand(batch_size, query_size, hidden_dim) - key = torch.rand(batch_size, key_val_size, hidden_dim) - value = torch.rand(batch_size, key_val_size, hidden_dim) + def test_forward_correct(self): + query = torch.rand(self.batch_size, self.query_size, self.hidden_dim) + key = torch.rand(self.batch_size, self.key_val_size, self.hidden_dim) + value = torch.rand(self.batch_size, self.key_val_size, self.hidden_dim) - attn = ScaledDotProductAttention() + attn = ScaledDotProductAttention() - out = attn(query, key, value) - gt_out = scaled_dot_product_attention(query, key, value) + out = attn(query, key, value) + gt_out = scaled_dot_product_attention(query, key, value) - assert torch.allclose(out, gt_out) + assert torch.allclose(out, gt_out) + + def test_masked_correct(self): + query = torch.rand(self.batch_size, self.query_size, self.hidden_dim) + key = torch.rand(self.batch_size, self.key_val_size, self.hidden_dim) + value = torch.rand(self.batch_size, self.key_val_size, self.hidden_dim) + mask = torch.rand(self.batch_size, self.key_val_size) >= 0.2 + + attn = ScaledDotProductAttention() + + out = attn(query, key, value, mask) + + gt_mask = mask.unsqueeze(1).repeat(1, query.size(1), 1) + out_gt = scaled_dot_product_attention(query, key, value, gt_mask) + + assert torch.allclose(out, out_gt) From 087f82e60425e9ad75f9249a8102583df510d6b2 Mon Sep 17 00:00:00 2001 From: Jakob Wagner Date: Wed, 24 Jul 2024 15:18:26 +0200 Subject: [PATCH 2/4] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d3e8c5c..e48ac5b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.2.0 - Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes. +- Remove `Attention` base class, add `UniformMaskAttention` base class. ## 0.1.0 From c6ce2d3cb4ae700310f64d143dee2d28f45d7dee Mon Sep 17 00:00:00 2001 From: Jakob Wagner Date: Wed, 24 Jul 2024 15:40:16 +0200 Subject: [PATCH 3/4] add heterogeneous normalized attention --- src/continuiti/networks/__init__.py | 2 + .../heterogeneous_normalized_attention.py | 86 ++++++++++++++++++ ...test_heterogeneous_normalized_attention.py | 89 +++++++++++++++++++ 3 files changed, 177 insertions(+) create mode 100644 src/continuiti/networks/heterogeneous_normalized_attention.py create mode 100644 tests/networks/test_heterogeneous_normalized_attention.py diff --git a/src/continuiti/networks/__init__.py b/src/continuiti/networks/__init__.py index 49fce95c..ee0756ef 100644 --- a/src/continuiti/networks/__init__.py +++ b/src/continuiti/networks/__init__.py @@ -8,10 +8,12 @@ from .deep_residual_network import DeepResidualNetwork from .multi_head_attention import MultiHeadAttention from .scaled_dot_product_attention import ScaledDotProductAttention +from .heterogeneous_normalized_attention import HeterogeneousNormalizedAttention __all__ = [ "FullyConnected", "DeepResidualNetwork", "MultiHeadAttention", "ScaledDotProductAttention", + "HeterogeneousNormalizedAttention", ] diff --git a/src/continuiti/networks/heterogeneous_normalized_attention.py b/src/continuiti/networks/heterogeneous_normalized_attention.py new file mode 100644 index 00000000..9b3ce37c --- /dev/null +++ b/src/continuiti/networks/heterogeneous_normalized_attention.py @@ -0,0 +1,86 @@ +""" +`continuiti.networks.heterogeneous_normalized_attention` + +Heterogeneous normalized attention block introduced by Hao et al. (https://proceedings.mlr.press/v202/hao23c). +""" + +import torch +import torch.nn as nn +from torch.nn.functional import softmax +from typing import Optional + +from .attention import UniformMaskAttention + + +class HeterogeneousNormalizedAttention(UniformMaskAttention): + r"""Heterogeneous normalized attention with uniform masks. + + Computes the normalization coefficient alpha for attention mechanisms, as proposed by Hao et al. in "GNOT: A + General Neural Operator Transformer for Operator Learning" (https://proceedings.mlr.press/v202/hao23c). The + attention score is calculated by normalizing the keys and queries + $$\tilde{q}_i = Softmax(\frac{\exp(q_{i,j})}{\sum_j\exp(q_{i,j})}$$, + $$\tilde{k}_i = Softmax(\frac{\exp(k_{i,j})}{\sum_j\exp(k_{i,j})}$$, and then calculating the attention without + softmax using $$z_t=\sum_i \frac{\tilde{q}_t \cdot \tilde{k}_i}{\sum_j \tilde{q}_t \cdot \tilde{k}_j}\cdot v_i$$. + The computational cost for this is O((M+N)n_e^2) (M=number of keys/values, N=number of queries, n_e=embedding_dim), + now is linear with respect to the sequence length. + + Args: + tau: Temperature parameter controlling the sharpness of the softmax operation. + """ + + def __init__(self, tau: float = 1.0, dropout_p: float = 0.0): + super().__init__() + self.tau = tau + self.dropout = nn.Dropout(p=dropout_p) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Forward pass. + Args: + query: Tensor of shape (batch_size, ..., d_q, embedding_dim). + key: Tensor of shape (batch_size, ..., d_kv, embedding_dim). + value: Tensor of shape (batch_size, ..., d_kv, embedding_dim). + attn_mask: Attention mask of shape (batch_size, ..., d_kv). A boolean mask where a True indicates that + a value should be taken into consideration in the calculations. + Returns: + Attention output of shape (batch_size, ..., d_q, e_dim). + """ + assert ( + query.ndim == key.ndim + ), "Number of dimensions in queries and keys should match." + assert ( + query.ndim == value.ndim + ), "Number of dimensions in queries and values should match." + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(-1) + key = torch.masked_fill(key, attn_mask.logical_not(), float("-inf")) + value = torch.masked_fill(value, attn_mask.logical_not(), 0.0) + + q_tilde = softmax(query, dim=-1) + if attn_mask is not None: + q_tilde = torch.nan_to_num( + q_tilde, nan=0.0 + ) # masking might ignore queries entirely resulting in nan in softmax + + k_tilde = softmax(key / self.tau, dim=-1) + if attn_mask is not None: + k_tilde = torch.nan_to_num( + k_tilde, nan=0.0 + ) # masking might ignore keys entirely resulting in nan in softmax + + alpha = torch.matmul(q_tilde, k_tilde.transpose(-1, -2)) + alpha = torch.sum(alpha, dim=-1, keepdim=True) + if attn_mask is not None: + alpha[alpha == 0.0] = 1.0 # numerical stability + + mat = k_tilde * value + mat = self.dropout(mat) + mat = torch.sum(mat, dim=-2, keepdim=True) + + return q_tilde * mat / alpha diff --git a/tests/networks/test_heterogeneous_normalized_attention.py b/tests/networks/test_heterogeneous_normalized_attention.py new file mode 100644 index 00000000..a8f394c8 --- /dev/null +++ b/tests/networks/test_heterogeneous_normalized_attention.py @@ -0,0 +1,89 @@ +import pytest +import torch +from torch.nn.functional import scaled_dot_product_attention + +from continuiti.networks import HeterogeneousNormalizedAttention + + +@pytest.fixture(scope="module") +def random_query_key_value_pair(): + batch_size = 3 + query_size = 5 + key_val_size = 7 + hidden_dim = 11 + + query = torch.rand(batch_size, query_size, hidden_dim) + key = torch.rand(batch_size, key_val_size, hidden_dim) + value = torch.rand(batch_size, key_val_size, hidden_dim) + + return query, key, value + + +class TestHeterogeneousNormalized: + def test_can_initialize(self): + _ = HeterogeneousNormalizedAttention() + assert True + + def test_shape_correct(self, random_query_key_value_pair): + query, key, value = random_query_key_value_pair + + attn = HeterogeneousNormalizedAttention() + + out = attn(query, key, value) + gt_out = scaled_dot_product_attention(query, key, value) + + assert out.shape == gt_out.shape + + def test_gradient_flow(self, random_query_key_value_pair): + query, key, value = random_query_key_value_pair + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + + attn = HeterogeneousNormalizedAttention() + + out = attn(query, key, value) + + out.sum().backward() + + assert query.grad is not None, "Gradients not flowing to query" + assert key.grad is not None, "Gradients not flowing to key" + assert value.grad is not None, "Gradients not flowing to value" + + def test_zero_input(self, random_query_key_value_pair): + query, key, value = random_query_key_value_pair + attn = HeterogeneousNormalizedAttention() + out = attn(query, key, torch.zeros(value.shape)) + assert torch.allclose(torch.zeros(out.shape), out) + + def test_mask_forward(self, random_query_key_value_pair): + query, key, value = random_query_key_value_pair + attn = HeterogeneousNormalizedAttention() + + # masks in the operator setting should be always block tensors with the upper left block of the last two + # dimensions being True. The dimensions of the True block corresponds to the numbers of sensors and evaluations. + mask = [] + mask = torch.rand(query.size(0), key.size(1)) >= 0.2 + + out = attn(query, key, value, mask) + + assert isinstance(out, torch.Tensor) + + def test_mask_correct(self, random_query_key_value_pair): + query, key, value = random_query_key_value_pair + attn = HeterogeneousNormalizedAttention() + + out_gt = attn(query, key, value) + + key_rand = torch.rand(key.shape) + key_masked = torch.cat([key, key_rand], dim=1) + + value_rand = torch.rand(value.shape) + value_masked = torch.cat([value, value_rand], dim=1) + + true_mask = torch.ones(value.size(0), value.size(1), dtype=torch.bool) + attn_mask = torch.cat([true_mask, ~true_mask], dim=1) + + out_masked = attn(query, key_masked, value_masked, attn_mask) + + assert torch.allclose(out_gt, out_masked) From 929ab979d1c00bb01a8462ca89fa2e435bbe3e53 Mon Sep 17 00:00:00 2001 From: Jakob Wagner Date: Wed, 24 Jul 2024 15:41:32 +0200 Subject: [PATCH 4/4] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e48ac5b7..8bd52b29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes. - Remove `Attention` base class, add `UniformMaskAttention` base class. +- Add `HeterogeneousNormalizedAttention` class. ## 0.1.0