Skip to content

Commit

Permalink
Merge pull request lucidrains#15 from openclimatefix/rnn-comments-only
Browse files Browse the repository at this point in the history
Rnn comments only
  • Loading branch information
peterdudfield authored Sep 13, 2021
2 parents f9bf15f + 3378e01 commit 4540966
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 33 deletions.
39 changes: 38 additions & 1 deletion perceiver_pytorch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,25 @@ def forward(self, x, **kwargs):


class GEGLU(nn.Module):
"""
Gaussian Error Gated Linear Unit.
See Shazer 2020: https://arxiv.org/abs/2002.05202
"""
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
"""Feed forward neural net with GEGLU activation."""

def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
"""
Args:
dim: Input & Output size.
mult: The inner dimension of the FF net will be dim * mult.
dropout: Proportion to dropout after the GEGLU.
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
Expand All @@ -76,6 +88,16 @@ class Attention(nn.Module):
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
"""
Args:
query_dim: Size of the queries.
context_dim: Size of the 'context' (the 'byte array' in the paper).
If None, will default to the query_dim.
heads: Number of attention heads.
dim_head: Number of dimensions per head.
dropout: Proportion to dropout (in the final linear layer).
"""

super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand All @@ -91,12 +113,27 @@ def __init__(
)

def forward(self, x, context=None, mask=None, pos_emb=None):
"""
Args:
x: The 'latent array' in the Perceiver paper.
context: The 'byte array' in the Perceiver paper (the input data).
mask:
pos_emb:
Returns:
"""

h = self.heads

q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)

# Rearrange the query, key and value tensors.
# b = batch size; n = TODO (PD-2021-09-13)
# h = number of heads; d = number of dims per head.
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
)
Expand Down
74 changes: 42 additions & 32 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
from torch import nn

from einops import rearrange, repeat
from torch import nn

from perceiver_pytorch.layers import exists, cache_fn, PreNorm, FeedForward, Attention
from perceiver_pytorch.rotary import SinusoidalEmbeddings
from perceiver_pytorch.utils import encode_position

# main class


class Perceiver(nn.Module):
def __init__(
self,
Expand All @@ -32,7 +32,7 @@ def __init__(
fourier_encode_data=True,
sine_only: bool = False,
self_per_cross_attn=1,
self_attn_rel_pos=True
self_attn_rel_pos=True,
):
"""
Perceiver: https://arxiv.org/abs/2103.03206
Expand All @@ -44,7 +44,8 @@ def __init__(
depth: Depth of net.
max_freq: Maximum frequency, hyperparameter depending on how
fine the data is.
freq_base: Base for the frequency
freq_base: Base of the logarithm function for Fourier position
encoding.
input_channels: Number of channels for each token of the input.
input_axis: Number of axes for input data (2 for images, 3 for video)
num_latents: Number of latents, or induced set points, or centroids.
Expand All @@ -63,6 +64,7 @@ def __init__(
if you are fourier encoding the data yourself.
sine_only: Use only sine encoding in fourier encoding, compared to using sine and cos
self_per_cross_attn: Number of self attention blocks per cross attn.
self_attn_rel_pos:
"""
super().__init__()
self.input_axis = input_axis
Expand All @@ -71,14 +73,11 @@ def __init__(
self.freq_base = freq_base

self.fourier_encode_data = fourier_encode_data
fourier_channels = (
(input_axis * ((num_freq_bands * 2) + 1))
if fourier_encode_data
else 0
)
fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0
self.sine_only = sine_only
input_dim = fourier_channels + input_channels

# Randomly initialise the 'latent array'.
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

def get_cross_attn():
Expand All @@ -91,12 +90,11 @@ def get_cross_attn():
dim_head=cross_dim_head,
dropout=attn_dropout,
),
context_dim=input_dim)
context_dim=input_dim,
)

def get_cross_ff():
return PreNorm(
latent_dim,
FeedForward(latent_dim, dropout=ff_dropout))
return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))

def get_latent_attn():
return PreNorm(
Expand All @@ -106,17 +104,16 @@ def get_latent_attn():
heads=latent_heads,
dim_head=latent_dim_head,
dropout=attn_dropout,
))
),
)

def get_latent_ff():
return PreNorm(
latent_dim,
FeedForward(latent_dim, dropout=ff_dropout))
return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))

# Cache all the above functions.
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
cache_fn,
(get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)
)

self.layers = nn.ModuleList([])
for i in range(depth):
Expand All @@ -132,7 +129,8 @@ def get_latent_ff():
get_latent_attn(**cache_args),
get_latent_ff(**cache_args),
]
))
)
)

self.layers.append(
nn.ModuleList(
Expand All @@ -141,38 +139,50 @@ def get_latent_ff():
get_cross_ff(**cache_args),
self_attns,
]
))
)
)

self.to_logits = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes))
self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes))

self.sinu_emb = None
if self_attn_rel_pos:
self.sinu_emb = SinusoidalEmbeddings(latent_dim_head)

def forward(self, data, mask=None):
"""
Args:
data: If sequential is True, then data must be of shape:
(batch size, sequence length, *axes) where axes would be width
and height for images.
"""

b, *axis, _ = data.shape
device = data.device

assert (
len(axis) == self.input_axis
), f"Input data must have {self.input_axis} axes, not {len(axis)}!"

if self.fourier_encode_data:
# Calculate Fourier encoded positions in the range of [-1, 1],
# for all axes.
enc_pos = encode_position(b,
axis,
self.max_freq,
self.num_freq_bands,
self.freq_base,
sine_only=self.sine_only).type_as(data)
enc_pos = encode_position(
b,
axis,
self.max_freq,
self.num_freq_bands,
self.freq_base,
sine_only=self.sine_only,
).type_as(data)

data = torch.cat((data, enc_pos), dim=-1)

# Concat to channels of data and flatten axis.
data = rearrange(data, "b ... d -> b (...) d")
# Concat to channels of data and flatten axes.
# b = batch size; d = last dimension of data
data = rearrange(data, "b ... d -> b (...) d", b=b)

# x is the 'latent array' in the paper.
# b = batch size; n = number of latents; d = latent dimensions.
x = repeat(self.latents, "n d -> b n d", b=b)

# Rotary embeddings for latents, if specified.
Expand Down

0 comments on commit 4540966

Please sign in to comment.