From 1023c124493def9a122c1de792c76d14efe73691 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Tue, 27 Jul 2021 16:51:36 +0100 Subject: [PATCH 01/11] Added lots more comments! --- perceiver_pytorch/perceiver_pytorch.py | 78 +++++++++++++++++++++----- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index e876b37..46d9f93 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -1,5 +1,6 @@ from math import pi, log from functools import wraps +from typing import Optional import torch from torch import nn, einsum @@ -37,14 +38,22 @@ def cached_fn(*args, _cache=True, **kwargs): def fourier_encode(x, max_freq, num_bands=4, base=2): + """Concatenate Fourier position features onto x. + + Args: + x: Input data. + max_freq: Maximum frequency. + num_bands: Number of frequency bands to concatenate. + base: Base of the logarithm function. + """ x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x scales = torch.logspace( - 0.0, - log(max_freq / 2) / log(base), - num_bands, - base=base, + start=0.0, + end=log(max_freq / 2) / log(base), + steps=num_bands, # Size of the 'scales' tensor. + base=base, # Base of the log function. device=device, dtype=dtype, ) @@ -80,13 +89,25 @@ def forward(self, x, **kwargs): class GEGLU(nn.Module): + """Gaussian Error Gated Linear Unit. + + See Shazer 2020: https://arxiv.org/abs/2002.05202 + """ def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): - def __init__(self, dim, mult=4, dropout=0.0): + """Feed forward neural net with GEGLU activation.""" + + def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0): + """ + Args: + dim: Input & Output size. + mult: The inner dimension of the FF net will be dim * mult. + dropout: Proportion to dropout after the GEGLU. + """ super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), @@ -101,8 +122,21 @@ def forward(self, x): class Attention(nn.Module): def __init__( - self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0 - ): + self, + query_dim: int, + context_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0): + """ + Args: + query_dim: Size of the queries. + context_dim: Size of the 'context' (the 'byte array' in the paper). + If None, will default to the query_dim. + heads: Number of attention heads. + dim_head: Number of dimensions per head. + dropout: Proportion to dropout (in the final linear layer). + """ super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -110,22 +144,36 @@ def __init__( self.scale = dim_head ** -0.5 self.heads = heads + # Network to generate queries ('q'). self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + + # Network to generate keys and values ('k' and 'v'). + # Uses inner_dim * 2 out_features because the output is + # split in two in forward() function. self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout)) def forward(self, x, context=None, mask=None, pos_emb=None): + """ + Args: + x: The 'latent array' in the Perceiver paper. + context: The 'byte array' in the Perceiver paper (the input data). + """ h = self.heads - q = self.to_q(x) + q = self.to_q(x) # Generate query. context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) + # Rearrange the query, key and value tensors. + # b = batch size; n = + # h = number of heads; d = number of dims per head. q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), + (q, k, v) ) if exists(pos_emb): @@ -182,7 +230,8 @@ def __init__( depth: Depth of net. max_freq: Maximum frequency, hyperparameter depending on how fine the data is. - freq_base: + freq_base: Base of the logarithm function for Fourier position + encoding. input_channels: Number of channels for each token of the input. input_axis: Number of axes for input data (2 for images, 3 for video) num_latents: Number of latents, or induced set points, or centroids. @@ -216,6 +265,7 @@ def __init__( ) input_dim = fourier_channels + input_channels + # Randomly initialise the 'latent array'. self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) def get_cross_attn(): @@ -314,9 +364,11 @@ def forward(self, data, mask=None): data = torch.cat((data, enc_pos), dim=-1) # Concat to channels of data and flatten axis. - data = rearrange(data, "b ... d -> b (...) d") + # b = batch size; d = last dimension of data. + data = rearrange(data, "b ... d -> b (...) d", b=b) # x is the 'latent array' in the paper. + # b = batch size; n = number of latents; d = latent dimensions. x = repeat(self.latents, "n d -> b n d", b=b) # Rotary embeddings for latents, if specified. From 3b633d668cec0e09419af7774d920cc67106862b Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Tue, 27 Jul 2021 17:34:48 +0100 Subject: [PATCH 02/11] First draft of sequential 'Perceiver RNN' #1 --- perceiver_pytorch/perceiver_pytorch.py | 63 +++++++++++++++++++++----- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 46d9f93..9cecb21 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -220,7 +220,8 @@ def __init__( weight_tie_layers=False, fourier_encode_data=True, self_per_cross_attn=1, - self_attn_rel_pos=True + self_attn_rel_pos=True, + sequential=False ): """The shape of the final attention mechanism will be: depth * (cross attention -> self_per_cross_attn * self attention) @@ -250,6 +251,11 @@ def __init__( if you are fourier encoding the data yourself. self_per_cross_attn: Number of self attention blocks per cross attn. self_attn_rel_pos: + sequential: If True, use the Perceiver like a recurrent neural + network: Each cross attention gets a different timestep of + the input 'byte array' or 'context', and the output of forward() + has one timestep for the latent array for each timestep. + depth must be set to the sequence length. """ super().__init__() self.input_axis = input_axis @@ -257,6 +263,11 @@ def __init__( self.num_freq_bands = num_freq_bands self.freq_base = freq_base + if sequential and not weight_tie_layers: + raise Warning( + '`sequential` is True but `weight_tie_layers` is False.' + 'Are you sure you want different weights for each timestep?') + self.fourier_encode_data = fourier_encode_data fourier_channels = ( (input_axis * ((num_freq_bands * 2) + 1)) @@ -330,16 +341,29 @@ def get_latent_ff(): ] )) - self.to_logits = nn.Sequential( - nn.LayerNorm(latent_dim), - nn.Linear(latent_dim, num_classes)) + if not sequential: + self.to_logits = nn.Sequential( + nn.LayerNorm(latent_dim), + nn.Linear(latent_dim, num_classes)) self.sinu_emb = None if self_attn_rel_pos: self.sinu_emb = SinusoidalEmbeddings(latent_dim_head) def forward(self, data, mask=None): - b, *axis, _ = *data.shape + """ + Args: + data: If sequential is True, then data must be of shape: + (batch size, sequence length, *axes) where axes would be width + and height for images. + """ + if self.sequential: + b, seq_length, *axis, _ = *data.shape + assert ( + seq_length == self.depth + ), 'The 2nd dim of `data` must be equal to the network `depth`.' + else: + b, *axis, _ = *data.shape device = data.device assert ( len(axis) == self.input_axis @@ -363,9 +387,15 @@ def forward(self, data, mask=None): data = torch.cat((data, enc_pos), dim=-1) - # Concat to channels of data and flatten axis. - # b = batch size; d = last dimension of data. - data = rearrange(data, "b ... d -> b (...) d", b=b) + # Concat to channels of data and flatten axes. + # b = batch size; d = last dimension of data + if self.sequential: + data = rearrange( + data, + "b, s ... d -> b, s (...) d", + b=b, s=seq_length) + else: + data = rearrange(data, "b ... d -> b (...) d", b=b) # x is the 'latent array' in the paper. # b = batch size; n = number of latents; d = latent dimensions. @@ -375,13 +405,22 @@ def forward(self, data, mask=None): pos_emb = self.sinu_emb(x) if exists(self.sinu_emb) else None # Layers. - for cross_attn, cross_ff, self_attns in self.layers: - x = cross_attn(x, context=data, mask=mask) + x + output_per_timestep = [] + for i, (cross_attn, cross_ff, self_attns) in enumerate(self.layers): + if self.sequential: + data_for_step = data[:, i] + else: + data_for_step = data + x = cross_attn(x, context=data_for_step, mask=mask) + x x = cross_ff(x) + x for self_attn, self_ff in self_attns: x = self_attn(x, pos_emb=pos_emb) + x x = self_ff(x) + x + output_per_timestep.append(x) - x = x.mean(dim=-2) - return self.to_logits(x) + if self.sequential: + return torch.stack(output_per_timestep, dim=1) + else: + x = x.mean(dim=-2) + return self.to_logits(x) From 0ba6bc2461f29464ed2b48c6e39528edf58db243 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 15:20:33 +0100 Subject: [PATCH 03/11] fix --- perceiver_pytorch/perceiver_pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 7230f01..9878de9 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -94,6 +94,7 @@ def __init__( # Randomly initialise the 'latent array'. self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + self.sequential = sequential def get_cross_attn(): return PreNorm( From bb7d820487d1c34514377804b7ba3247cf87eccd Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 15:25:54 +0100 Subject: [PATCH 04/11] copy comments --- perceiver_pytorch/layers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/perceiver_pytorch/layers.py b/perceiver_pytorch/layers.py index 3cb4b43..904fa69 100644 --- a/perceiver_pytorch/layers.py +++ b/perceiver_pytorch/layers.py @@ -53,6 +53,10 @@ def forward(self, x, **kwargs): class GEGLU(nn.Module): + """ + Gaussian Error Gated Linear Unit. + See Shazer 2020: https://arxiv.org/abs/2002.05202 + """ def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) @@ -126,6 +130,9 @@ def forward(self, x, context=None, mask=None, pos_emb=None): context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) + # Rearrange the query, key and value tensors. + # b = batch size; n = + # h = number of heads; d = number of dims per head. q, k, v = map( lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) ) From bb8e3726f4492d6a18612893434031fa89671637 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 15:35:56 +0100 Subject: [PATCH 05/11] remove first draft of rnn model --- perceiver_pytorch/perceiver_pytorch.py | 37 ++++++-------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 9878de9..35d0abd 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -32,7 +32,6 @@ def __init__( sine_only: bool = False, self_per_cross_attn=1, self_attn_rel_pos=True, - sequential=False ): """ Perceiver: https://arxiv.org/abs/2103.03206 @@ -66,11 +65,6 @@ def __init__( sine_only: Use only sine encoding in fourier encoding, compared to using sine and cos self_per_cross_attn: Number of self attention blocks per cross attn. self_attn_rel_pos: - sequential: If True, use the Perceiver like a recurrent neural - network: Each cross attention gets a different timestep of - the input 'byte array' or 'context', and the output of forward() - has one timestep for the latent array for each timestep. - depth must be set to the sequence length. """ super().__init__() self.input_axis = input_axis @@ -78,11 +72,6 @@ def __init__( self.num_freq_bands = num_freq_bands self.freq_base = freq_base - if sequential and not weight_tie_layers: - raise Warning( - '`sequential` is True but `weight_tie_layers` is False.' - 'Are you sure you want different weights for each timestep?') - self.fourier_encode_data = fourier_encode_data fourier_channels = ( (input_axis * ((num_freq_bands * 2) + 1)) @@ -94,7 +83,6 @@ def __init__( # Randomly initialise the 'latent array'. self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) - self.sequential = sequential def get_cross_attn(): return PreNorm( @@ -158,10 +146,9 @@ def get_latent_ff(): ] )) - if not sequential: - self.to_logits = nn.Sequential( - nn.LayerNorm(latent_dim), - nn.Linear(latent_dim, num_classes)) + self.to_logits = nn.Sequential( + nn.LayerNorm(latent_dim), + nn.Linear(latent_dim, num_classes)) self.sinu_emb = None if self_attn_rel_pos: @@ -174,13 +161,8 @@ def forward(self, data, mask=None): (batch size, sequence length, *axes) where axes would be width and height for images. """ - if self.sequential: - b, seq_length, *axis, _ = data.shape - assert ( - seq_length == self.depth - ), 'The 2nd dim of `data` must be equal to the network `depth`.' - else: - b, *axis, _ = data.shape + + b, *axis, _ = data.shape device = data.device assert ( @@ -205,7 +187,7 @@ def forward(self, data, mask=None): data = rearrange( data, "b, s ... d -> b, s (...) d", - b=b, s=seq_length) + b=b) else: data = rearrange(data, "b ... d -> b (...) d", b=b) @@ -231,8 +213,5 @@ def forward(self, data, mask=None): x = self_ff(x) + x output_per_timestep.append(x) - if self.sequential: - return torch.stack(output_per_timestep, dim=1) - else: - x = x.mean(dim=-2) - return self.to_logits(x) + x = x.mean(dim=-2) + return self.to_logits(x) From 7910bb88593bc56d95bd1f19e49806fc617bda47 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 15:37:47 +0100 Subject: [PATCH 06/11] remove extra code --- perceiver_pytorch/perceiver_pytorch.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 35d0abd..0f0ced2 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -199,19 +199,8 @@ def forward(self, data, mask=None): pos_emb = self.sinu_emb(x) if exists(self.sinu_emb) else None # Layers. - output_per_timestep = [] - for i, (cross_attn, cross_ff, self_attns) in enumerate(self.layers): - if self.sequential: - data_for_step = data[:, i] - else: - data_for_step = data - x = cross_attn(x, context=data_for_step, mask=mask) + x - x = cross_ff(x) + x - - for self_attn, self_ff in self_attns: - x = self_attn(x, pos_emb=pos_emb) + x - x = self_ff(x) + x - output_per_timestep.append(x) + for cross_attn, cross_ff, self_attns in self.layers: + x = cross_attn(x, context=data, mask=mask) + x x = x.mean(dim=-2) return self.to_logits(x) From 8c572c1bdcb922fcde56321f416dd345414e0d3d Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 15:40:11 +0100 Subject: [PATCH 07/11] add code bk in --- perceiver_pytorch/perceiver_pytorch.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 0f0ced2..f6713a4 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -183,13 +183,7 @@ def forward(self, data, mask=None): # Concat to channels of data and flatten axes. # b = batch size; d = last dimension of data - if self.sequential: - data = rearrange( - data, - "b, s ... d -> b, s (...) d", - b=b) - else: - data = rearrange(data, "b ... d -> b (...) d", b=b) + data = rearrange(data, "b ... d -> b (...) d", b=b) # x is the 'latent array' in the paper. # b = batch size; n = number of latents; d = latent dimensions. @@ -201,6 +195,11 @@ def forward(self, data, mask=None): # Layers. for cross_attn, cross_ff, self_attns in self.layers: x = cross_attn(x, context=data, mask=mask) + x + x = cross_ff(x) + x + + for self_attn, self_ff in self_attns: + x = self_attn(x, pos_emb=pos_emb) + x + x = self_ff(x) + x x = x.mean(dim=-2) return self.to_logits(x) From 548d6b0745343973cca89a45e16392e6090023af Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 16:50:43 +0100 Subject: [PATCH 08/11] PR comment --- perceiver_pytorch/layers.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/perceiver_pytorch/layers.py b/perceiver_pytorch/layers.py index 904fa69..4969539 100644 --- a/perceiver_pytorch/layers.py +++ b/perceiver_pytorch/layers.py @@ -68,9 +68,9 @@ class FeedForward(nn.Module): def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0): """ Args: - dim: Input & Output size. - mult: The inner dimension of the FF net will be dim * mult. - dropout: Proportion to dropout after the GEGLU. + dim: Input & Output size. + mult: The inner dimension of the FF net will be dim * mult. + dropout: Proportion to dropout after the GEGLU. """ super().__init__() self.net = nn.Sequential( @@ -89,14 +89,15 @@ def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0 ): """ - - :param query_dim: Size of the queries. - :param context_dim: Size of the 'context' (the 'byte array' in the paper). - If None, will default to the query_dim. - :param heads: Number of attention heads. - :param dim_head: Number of dimensions per head. - :param dropout: Proportion to dropout (in the final linear layer). + Args: + query_dim: Size of the queries. + context_dim: Size of the 'context' (the 'byte array' in the paper). + If None, will default to the query_dim. + heads: Number of attention heads. + dim_head: Number of dimensions per head. + dropout: Proportion to dropout (in the final linear layer). """ + super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) From d8cbf8ac80f3d796862637d5d9148916898bd5e2 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 13 Sep 2021 20:59:05 +0100 Subject: [PATCH 09/11] Update perceiver_pytorch/layers.py Co-authored-by: Jack Kelly --- perceiver_pytorch/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perceiver_pytorch/layers.py b/perceiver_pytorch/layers.py index 4969539..40e7aa5 100644 --- a/perceiver_pytorch/layers.py +++ b/perceiver_pytorch/layers.py @@ -92,7 +92,7 @@ def __init__( Args: query_dim: Size of the queries. context_dim: Size of the 'context' (the 'byte array' in the paper). - If None, will default to the query_dim. + If None, will default to the query_dim. heads: Number of attention heads. dim_head: Number of dimensions per head. dropout: Proportion to dropout (in the final linear layer). From dc031e38c94060e5ad602234bb15e50f3010d024 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 21:01:59 +0100 Subject: [PATCH 10/11] blacks, and PR comment --- perceiver_pytorch/perceiver_pytorch.py | 50 ++++++++++++-------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index f6713a4..a35a8d1 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -8,6 +8,7 @@ # main class + class Perceiver(nn.Module): def __init__( self, @@ -45,7 +46,6 @@ def __init__( fine the data is. freq_base: Base of the logarithm function for Fourier position encoding. - input_channels: Number of channels for each token of the input. input_axis: Number of axes for input data (2 for images, 3 for video) num_latents: Number of latents, or induced set points, or centroids. @@ -73,11 +73,7 @@ def __init__( self.freq_base = freq_base self.fourier_encode_data = fourier_encode_data - fourier_channels = ( - (input_axis * ((num_freq_bands * 2) + 1)) - if fourier_encode_data - else 0 - ) + fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 self.sine_only = sine_only input_dim = fourier_channels + input_channels @@ -94,12 +90,11 @@ def get_cross_attn(): dim_head=cross_dim_head, dropout=attn_dropout, ), - context_dim=input_dim) + context_dim=input_dim, + ) def get_cross_ff(): - return PreNorm( - latent_dim, - FeedForward(latent_dim, dropout=ff_dropout)) + return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) def get_latent_attn(): return PreNorm( @@ -109,17 +104,16 @@ def get_latent_attn(): heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout, - )) + ), + ) def get_latent_ff(): - return PreNorm( - latent_dim, - FeedForward(latent_dim, dropout=ff_dropout)) + return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) # Cache all the above functions. get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( - cache_fn, - (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) + cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff) + ) self.layers = nn.ModuleList([]) for i in range(depth): @@ -135,7 +129,8 @@ def get_latent_ff(): get_latent_attn(**cache_args), get_latent_ff(**cache_args), ] - )) + ) + ) self.layers.append( nn.ModuleList( @@ -144,11 +139,10 @@ def get_latent_ff(): get_cross_ff(**cache_args), self_attns, ] - )) + ) + ) - self.to_logits = nn.Sequential( - nn.LayerNorm(latent_dim), - nn.Linear(latent_dim, num_classes)) + self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes)) self.sinu_emb = None if self_attn_rel_pos: @@ -172,12 +166,14 @@ def forward(self, data, mask=None): if self.fourier_encode_data: # Calculate Fourier encoded positions in the range of [-1, 1], # for all axes. - enc_pos = encode_position(b, - axis, - self.max_freq, - self.num_freq_bands, - self.freq_base, - sine_only=self.sine_only).type_as(data) + enc_pos = encode_position( + b, + axis, + self.max_freq, + self.num_freq_bands, + self.freq_base, + sine_only=self.sine_only, + ).type_as(data) data = torch.cat((data, enc_pos), dim=-1) From 3378e016151f16fad1d8759576aed6dc9b081eb7 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 13 Sep 2021 21:03:04 +0100 Subject: [PATCH 11/11] add TODO --- perceiver_pytorch/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perceiver_pytorch/layers.py b/perceiver_pytorch/layers.py index 40e7aa5..7229d20 100644 --- a/perceiver_pytorch/layers.py +++ b/perceiver_pytorch/layers.py @@ -132,7 +132,7 @@ def forward(self, x, context=None, mask=None, pos_emb=None): k, v = self.to_kv(context).chunk(2, dim=-1) # Rearrange the query, key and value tensors. - # b = batch size; n = + # b = batch size; n = TODO (PD-2021-09-13) # h = number of heads; d = number of dims per head. q, k, v = map( lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)