Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Heterogeneous Normalized Attention #153

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## 0.2.0

- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
- Remove `Attention` base class, add `UniformMaskAttention` base class.
- Add `HeterogeneousNormalizedAttention` class.

## 0.1.0

Expand Down
2 changes: 2 additions & 0 deletions src/continuiti/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from .deep_residual_network import DeepResidualNetwork
from .multi_head_attention import MultiHeadAttention
from .scaled_dot_product_attention import ScaledDotProductAttention
from .heterogeneous_normalized_attention import HeterogeneousNormalizedAttention

__all__ = [
"FullyConnected",
"DeepResidualNetwork",
"MultiHeadAttention",
"ScaledDotProductAttention",
"HeterogeneousNormalizedAttention",
]
24 changes: 13 additions & 11 deletions src/continuiti/networks/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from abc import abstractmethod
import torch.nn as nn
import torch
from typing import Optional


class Attention(nn.Module):
"""Base class for various attention implementations.
class UniformMaskAttention(nn.Module):
"""Base class for various attention implementations with uniform masking.

Attention assigns different parts of an input varying importance without set
kernels. The importance of different components is designated using "soft"
weights. These weights are assigned according to specific algorithms (e.g.
Attention assigns different parts of an input varying importance without set kernels. The importance of different
components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g.
scaled-dot-product attention).
Uniform masking refers to the characteristic that all queries use the same mask. This restriction allows to remove
the query dimension from the mask. All queries have access to the same key/value pairs.
"""

def __init__(self):
Expand All @@ -27,18 +29,18 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculates the attention scores.

Args:
query: query tensor; shape (batch_size, target_seq_length, hidden_dim)
key: key tensor; shape (batch_size, source_seq_length, hidden_dim)
value: value tensor; shape (batch_size, source_seq_length, hidden_dim)
query: query tensor; shape (batch_size, target_seq_length, hidden_dim).
key: key tensor; shape (batch_size, source_seq_length, hidden_dim).
value: value tensor; shape (batch_size, source_seq_length, hidden_dim).
attn_mask: tensor indicating which values are used to calculate the output;
shape (batch_size, target_seq_length, source_seq_length)
shape (batch_size, source_seq_length).

Returns:
tensor containing the outputs of the attention implementation;
shape (batch_size, target_seq_length, hidden_dim)
shape (batch_size, target_seq_length, hidden_dim).
"""
86 changes: 86 additions & 0 deletions src/continuiti/networks/heterogeneous_normalized_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
`continuiti.networks.heterogeneous_normalized_attention`

Heterogeneous normalized attention block introduced by Hao et al. (https://proceedings.mlr.press/v202/hao23c).
"""

import torch
import torch.nn as nn
from torch.nn.functional import softmax
from typing import Optional

from .attention import UniformMaskAttention


class HeterogeneousNormalizedAttention(UniformMaskAttention):
r"""Heterogeneous normalized attention with uniform masks.

Computes the normalization coefficient alpha for attention mechanisms, as proposed by Hao et al. in "GNOT: A
General Neural Operator Transformer for Operator Learning" (https://proceedings.mlr.press/v202/hao23c). The
attention score is calculated by normalizing the keys and queries
$$\tilde{q}_i = Softmax(\frac{\exp(q_{i,j})}{\sum_j\exp(q_{i,j})}$$,
$$\tilde{k}_i = Softmax(\frac{\exp(k_{i,j})}{\sum_j\exp(k_{i,j})}$$, and then calculating the attention without
softmax using $$z_t=\sum_i \frac{\tilde{q}_t \cdot \tilde{k}_i}{\sum_j \tilde{q}_t \cdot \tilde{k}_j}\cdot v_i$$.
Comment on lines +21 to +23
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math does not render well in docs

The computational cost for this is O((M+N)n_e^2) (M=number of keys/values, N=number of queries, n_e=embedding_dim),
now is linear with respect to the sequence length.

Args:
tau: Temperature parameter controlling the sharpness of the softmax operation.
"""

def __init__(self, tau: float = 1.0, dropout_p: float = 0.0):
super().__init__()
self.tau = tau
self.dropout = nn.Dropout(p=dropout_p)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Forward pass.
Args:
query: Tensor of shape (batch_size, ..., d_q, embedding_dim).
key: Tensor of shape (batch_size, ..., d_kv, embedding_dim).
value: Tensor of shape (batch_size, ..., d_kv, embedding_dim).
attn_mask: Attention mask of shape (batch_size, ..., d_kv). A boolean mask where a True indicates that
a value should be taken into consideration in the calculations.
Returns:
Attention output of shape (batch_size, ..., d_q, e_dim).
"""
assert (
query.ndim == key.ndim
), "Number of dimensions in queries and keys should match."
assert (
query.ndim == value.ndim
), "Number of dimensions in queries and values should match."

if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(-1)
key = torch.masked_fill(key, attn_mask.logical_not(), float("-inf"))
value = torch.masked_fill(value, attn_mask.logical_not(), 0.0)

q_tilde = softmax(query, dim=-1)
if attn_mask is not None:
q_tilde = torch.nan_to_num(
q_tilde, nan=0.0
) # masking might ignore queries entirely resulting in nan in softmax

k_tilde = softmax(key / self.tau, dim=-1)
if attn_mask is not None:
k_tilde = torch.nan_to_num(
k_tilde, nan=0.0
) # masking might ignore keys entirely resulting in nan in softmax

alpha = torch.matmul(q_tilde, k_tilde.transpose(-1, -2))
alpha = torch.sum(alpha, dim=-1, keepdim=True)
if attn_mask is not None:
alpha[alpha == 0.0] = 1.0 # numerical stability

mat = k_tilde * value
mat = self.dropout(mat)
mat = torch.sum(mat, dim=-2, keepdim=True)

return q_tilde * mat / alpha
34 changes: 18 additions & 16 deletions src/continuiti/networks/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@

import torch
import torch.nn as nn
from typing import Optional

from .attention import Attention
from .attention import UniformMaskAttention
from .scaled_dot_product_attention import ScaledDotProductAttention


class MultiHeadAttention(Attention):
r"""Multi-Head Attention module.
class MultiHeadAttention(UniformMaskAttention):
r"""Multi-Head Attention module with uniform mask.

Module as described in the paper [Attention is All you
Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
with optional bias for the projections. This implementation allows to use
attention implementations other than the standard scaled dot product
attention implemented by the MultiheadAttention PyTorch module.
with optional bias for the projections. This implementation allows to use attention implementations other than the
standard scaled dot product attention implemented by the MultiheadAttention PyTorch module. Additionally assumes
this implementation that the attention mask is applied uniformly for every batch (the mask for every key-value pair
matches for all queries of one sample).

$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$

Expand All @@ -39,7 +41,7 @@ def __init__(
self,
hidden_dim: int,
n_heads: int,
attention: Attention = None,
attention: Optional[UniformMaskAttention] = None,
dropout_p: float = 0,
bias: bool = True,
):
Expand Down Expand Up @@ -70,15 +72,15 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Compute the attention scores.

Args:
query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim).
key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim).
value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim).
attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length).
attn_mask: Attention mask of shape (batch_size, source_sequence_length).

Returns:
Attention scores of shape (batch_size, target_sequence_length, hidden_dim).
Expand All @@ -98,7 +100,6 @@ def forward(

batch_size = query.size(0)
src_len = key.size(1)
tgt_len = query.size(1)

# project values
query = self.query_project(query)
Expand All @@ -112,18 +113,19 @@ def forward(

# reshape attention mask to match heads
if attn_mask is not None:
assert (
attn_mask.ndim == 2
), "Expects exatly 2 dimensions in the mask, but found {attn_mask.ndim}."
assert (
attn_mask.size(0) == batch_size
), "Attention mask batch size does not match input tensors."
assert (
attn_mask.size(1) == tgt_len
), "First dimension of the attention mask needs to match target length."
assert (
attn_mask.size(2) == src_len
attn_mask.size(1) == src_len
), "Second dimension of the attention mask needs to match source length."

attn_mask = attn_mask.unsqueeze(1) # mask for a single head
attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head
attn_mask = attn_mask.unsqueeze(
1
) # apply attention mask uniformly to all heads

# perform attention
attn_out = self.attention(
Expand Down
26 changes: 22 additions & 4 deletions src/continuiti/networks/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
Scaled dot product attention module.
"""
import torch
from typing import Optional

from .attention import Attention
from .attention import UniformMaskAttention
from torch.nn.functional import scaled_dot_product_attention


class ScaledDotProductAttention(Attention):
"""Scaled dot product attention module.
class ScaledDotProductAttention(UniformMaskAttention):
"""Scaled dot product attention module with uniform mask.

This module is a wrapper for the torch implementation of the scaled dot
product attention mechanism as described in the paper "Attention Is All You
Expand All @@ -29,9 +30,26 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate attention scores.

Args:
query: query tensor; shape (batch_size, target_seq_length, hidden_dim).
key: key tensor; shape (batch_size, source_seq_length, hidden_dim).
value: value tensor; shape (batch_size, source_seq_length, hidden_dim).
attn_mask: tensor indicating which values are used to calculate the output; shape
(batch_size, source_seq_length). Defaults to None.

Returns:
tensor containing the outputs of the attention implementation; shape
(batch_size, target_seq_length, hidden_dim).
"""
dropout_p = self.dropout_p if self.training else 0.0

if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1)

return scaled_dot_product_attention(
query=query,
key=key,
Expand Down
89 changes: 89 additions & 0 deletions tests/networks/test_heterogeneous_normalized_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch
from torch.nn.functional import scaled_dot_product_attention

from continuiti.networks import HeterogeneousNormalizedAttention


@pytest.fixture(scope="module")
def random_query_key_value_pair():
batch_size = 3
query_size = 5
key_val_size = 7
hidden_dim = 11

query = torch.rand(batch_size, query_size, hidden_dim)
key = torch.rand(batch_size, key_val_size, hidden_dim)
value = torch.rand(batch_size, key_val_size, hidden_dim)

return query, key, value


class TestHeterogeneousNormalized:
def test_can_initialize(self):
_ = HeterogeneousNormalizedAttention()
assert True

def test_shape_correct(self, random_query_key_value_pair):
query, key, value = random_query_key_value_pair

attn = HeterogeneousNormalizedAttention()

out = attn(query, key, value)
gt_out = scaled_dot_product_attention(query, key, value)

assert out.shape == gt_out.shape

def test_gradient_flow(self, random_query_key_value_pair):
query, key, value = random_query_key_value_pair
query.requires_grad = True
key.requires_grad = True
value.requires_grad = True

attn = HeterogeneousNormalizedAttention()

out = attn(query, key, value)

out.sum().backward()

assert query.grad is not None, "Gradients not flowing to query"
assert key.grad is not None, "Gradients not flowing to key"
assert value.grad is not None, "Gradients not flowing to value"

def test_zero_input(self, random_query_key_value_pair):
query, key, value = random_query_key_value_pair
attn = HeterogeneousNormalizedAttention()
out = attn(query, key, torch.zeros(value.shape))
assert torch.allclose(torch.zeros(out.shape), out)

def test_mask_forward(self, random_query_key_value_pair):
query, key, value = random_query_key_value_pair
attn = HeterogeneousNormalizedAttention()

# masks in the operator setting should be always block tensors with the upper left block of the last two
# dimensions being True. The dimensions of the True block corresponds to the numbers of sensors and evaluations.
mask = []
mask = torch.rand(query.size(0), key.size(1)) >= 0.2

out = attn(query, key, value, mask)

assert isinstance(out, torch.Tensor)

def test_mask_correct(self, random_query_key_value_pair):
query, key, value = random_query_key_value_pair
attn = HeterogeneousNormalizedAttention()

out_gt = attn(query, key, value)

key_rand = torch.rand(key.shape)
key_masked = torch.cat([key, key_rand], dim=1)

value_rand = torch.rand(value.shape)
value_masked = torch.cat([value, value_rand], dim=1)

true_mask = torch.ones(value.size(0), value.size(1), dtype=torch.bool)
attn_mask = torch.cat([true_mask, ~true_mask], dim=1)

out_masked = attn(query, key_masked, value_masked, attn_mask)

assert torch.allclose(out_gt, out_masked)
Loading
Loading