From 0e8855c0943cbd5975da41c8f74e107ef24ccf93 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 17 Sep 2021 12:10:14 +0100 Subject: [PATCH 01/18] Add basic LearnableQuery module --- perceiver_pytorch/queries.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 perceiver_pytorch/queries.py diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py new file mode 100644 index 0000000..1848104 --- /dev/null +++ b/perceiver_pytorch/queries.py @@ -0,0 +1,30 @@ +import torch +from torch.distributions import uniform +from typing import List, Union, Tuple +from math import prod + + +class LearnableQuery(torch.nn.Module): + """ + Module that constructs a learnable query of query_shape for the Perceiver + """ + def __init__(self, query_dim: int, query_shape: Union[Tuple[int], List[int]]): + """ + Learnable Query with some inbuilt randomness to help with ensembling + + Args: + query_dim: Query dimension + query_shape: The final shape of the query, generally, the (T, H, W) of the output + """ + super().__init__() + self.query_shape = prod(query_shape) # Flatten the shape + self.query_dim = query_dim + self.learnable_query = torch.nn.Linear(self.query_dim, self.query_dim) + self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) + + def forward(self, x: torch.Tensor): + z = self.distribution.sample( + (x.shape[0], self.query_future_size, self.query_dim) + ).type_as(x) + queries = self.learnable_query(z) + return queries From 4f6d6483c18abfefff69d817805ac81a95fe7ec7 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 17 Sep 2021 12:18:27 +0100 Subject: [PATCH 02/18] Add TODO --- perceiver_pytorch/queries.py | 1 + 1 file changed, 1 insertion(+) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index 1848104..d5733a4 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -27,4 +27,5 @@ def forward(self, x: torch.Tensor): (x.shape[0], self.query_future_size, self.query_dim) ).type_as(x) queries = self.learnable_query(z) + # TODO: Add Fourier Features return queries From e2781cf826fcb7e75246b7f95f9bcaaf538fcd8e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 17 Sep 2021 12:19:00 +0100 Subject: [PATCH 03/18] Add pre-commit config --- .pre-commit-config.yaml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d7fbade --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +default_language_version: + python: python3.9 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: debug-statements + - id: detect-private-key + + # python code formatting + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + args: [--line-length, "100"] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.3.0 + hooks: + - id: prettier + types: [yaml] From 323115d3ffc1fd73dff38838d3711898458e3244 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 17 Sep 2021 13:15:08 +0100 Subject: [PATCH 04/18] Add creating Fourier Features --- perceiver_pytorch/queries.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index d5733a4..5cde79c 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -2,13 +2,23 @@ from torch.distributions import uniform from typing import List, Union, Tuple from math import prod +from perceiver_pytorch.utils import encode_position class LearnableQuery(torch.nn.Module): """ Module that constructs a learnable query of query_shape for the Perceiver """ - def __init__(self, query_dim: int, query_shape: Union[Tuple[int], List[int]]): + + def __init__( + self, + query_dim: int, + query_shape: Union[Tuple[int], List[int]], + max_frequency: float = 16.0, + num_frequency_bands: int = 64, + frequency_base: float = 2.0, + sine_only: bool = False, + ): """ Learnable Query with some inbuilt randomness to help with ensembling @@ -17,15 +27,25 @@ def __init__(self, query_dim: int, query_shape: Union[Tuple[int], List[int]]): query_shape: The final shape of the query, generally, the (T, H, W) of the output """ super().__init__() - self.query_shape = prod(query_shape) # Flatten the shape + self.query_shape = prod(query_shape) # Flatten the shape + # Need to get Fourier Features once and then just append to the output + self.fourier_features = encode_position( + 1, + axis=query_shape, + max_frequency=max_frequency, + frequency_base=frequency_base, + num_frequency_bands=num_frequency_bands, + sine_only=sine_only, + ) self.query_dim = query_dim self.learnable_query = torch.nn.Linear(self.query_dim, self.query_dim) self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) def forward(self, x: torch.Tensor): - z = self.distribution.sample( - (x.shape[0], self.query_future_size, self.query_dim) - ).type_as(x) + z = self.distribution.sample((x.shape[0], self.query_future_size, self.query_dim)).type_as( + x + ) + # Do 3D or 2D CNN to keep same spatial size, concat, then linearize queries = self.learnable_query(z) - # TODO: Add Fourier Features + # TODO Add Fourier Features return queries From f6ce7b4c5b8c4fe03dfa73c918dcccda27df3ce5 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 17 Sep 2021 16:54:20 +0100 Subject: [PATCH 05/18] Add test, conv layer encoding Keep the latent space as the 4D tensor and then linearize at the end, the same as what the Perceiver models do --- perceiver_pytorch/queries.py | 46 ++++++++++++++++++++++++++++++------ tests/test_queries.py | 20 ++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 tests/test_queries.py diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index 5cde79c..244f178 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -3,6 +3,7 @@ from typing import List, Union, Tuple from math import prod from perceiver_pytorch.utils import encode_position +import einops class LearnableQuery(torch.nn.Module): @@ -14,6 +15,7 @@ def __init__( self, query_dim: int, query_shape: Union[Tuple[int], List[int]], + conv_layer: str = "3d", max_frequency: float = 16.0, num_frequency_bands: int = 64, frequency_base: float = 2.0, @@ -27,7 +29,7 @@ def __init__( query_shape: The final shape of the query, generally, the (T, H, W) of the output """ super().__init__() - self.query_shape = prod(query_shape) # Flatten the shape + self.query_shape = query_shape # Need to get Fourier Features once and then just append to the output self.fourier_features = encode_position( 1, @@ -37,15 +39,45 @@ def __init__( num_frequency_bands=num_frequency_bands, sine_only=sine_only, ) + print(self.fourier_features.shape) self.query_dim = query_dim - self.learnable_query = torch.nn.Linear(self.query_dim, self.query_dim) + if conv_layer == "3d": + conv = torch.nn.Conv3d + elif conv_layer == "2d": + conv = torch.nn.Conv2d + else: + raise ValueError(f"Value for 'layer' is {layer} which is not one of '3d', '2d'") + self.conv_layer = conv_layer + self.layer = conv(in_channels=query_dim, out_channels=query_dim, kernel_size=3, padding=1) + # Linear layer to compress channels down to query_dim size? + self.fc = torch.nn.Linear(self.query_dim, self.query_dim) self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) def forward(self, x: torch.Tensor): - z = self.distribution.sample((x.shape[0], self.query_future_size, self.query_dim)).type_as( + print(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.query_dim}") + z = self.distribution.sample((x.shape[0], self.query_dim, *self.query_shape)).type_as( x - ) + ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1] + z = torch.squeeze(z, dim=-1) # Extra 1 for some reason + print(f"Z: {z.shape}") # Do 3D or 2D CNN to keep same spatial size, concat, then linearize - queries = self.learnable_query(z) - # TODO Add Fourier Features - return queries + if self.conv_layer == "2d": + # Iterate through time dimension + outs = [] + for i in range(x.shape[1]): + outs.append(self.layer(z[:, i, :, :, :])) + query = torch.stack(outs, dim=1) + else: + query = self.layer(z) + # Add Fourier Features + ff = einops.repeat( + self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0] + ) # Match batches + print(ff.shape) + query = torch.cat([query, ff], dim=-1) + print(query.shape) + # concat to channels of data and flatten axis + query = einops.rearrange(query, "b ... d -> b (...) d") + # Need query to end with query_dim channels + + return query diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 0000000..1cf40ad --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,20 @@ +import pytest +import torch +from perceiver_pytorch.queries import LearnableQuery + + +@pytest.mark.parametrize("layer_shape", ["2d", "3d"]) +def test_learnable_query(layer_shape): + query_creator = LearnableQuery( + query_dim=32, + query_shape=(24, 128, 128), + conv_layer=layer_shape, + max_frequency=64.0, + frequency_base=2.0, + num_frequency_bands=128, + sine_only=False, + ) + x = torch.randn((16, 24, 12, 128, 128)) + out = query_creator(x) + + pass From 213f87ce1cf99d2836b4daf91e4de6d249fa6db6 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 09:56:30 +0100 Subject: [PATCH 06/18] Fix ordering of channels Change prints to Debug statements --- perceiver_pytorch/queries.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index 244f178..661b503 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -1,9 +1,12 @@ import torch from torch.distributions import uniform from typing import List, Union, Tuple -from math import prod from perceiver_pytorch.utils import encode_position import einops +import logging + +_LOG = logging.getLogger("perceiver.queries") +_LOG.setLevel(logging.WARN) class LearnableQuery(torch.nn.Module): @@ -39,14 +42,14 @@ def __init__( num_frequency_bands=num_frequency_bands, sine_only=sine_only, ) - print(self.fourier_features.shape) + _LOG.debug(self.fourier_features.shape) self.query_dim = query_dim if conv_layer == "3d": conv = torch.nn.Conv3d elif conv_layer == "2d": conv = torch.nn.Conv2d else: - raise ValueError(f"Value for 'layer' is {layer} which is not one of '3d', '2d'") + raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'") self.conv_layer = conv_layer self.layer = conv(in_channels=query_dim, out_channels=query_dim, kernel_size=3, padding=1) # Linear layer to compress channels down to query_dim size? @@ -54,28 +57,31 @@ def __init__( self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) def forward(self, x: torch.Tensor): - print(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.query_dim}") + _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.query_dim}") z = self.distribution.sample((x.shape[0], self.query_dim, *self.query_shape)).type_as( x ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1] z = torch.squeeze(z, dim=-1) # Extra 1 for some reason - print(f"Z: {z.shape}") + _LOG.debug(f"Z: {z.shape}") # Do 3D or 2D CNN to keep same spatial size, concat, then linearize if self.conv_layer == "2d": # Iterate through time dimension outs = [] for i in range(x.shape[1]): - outs.append(self.layer(z[:, i, :, :, :])) - query = torch.stack(outs, dim=1) + outs.append(self.layer(z[:, :, i, :, :])) + query = torch.stack(outs, dim=2) else: query = self.layer(z) # Add Fourier Features ff = einops.repeat( self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0] ) # Match batches - print(ff.shape) + # Move channels to correct location + query = einops.rearrange(query, "b c ... -> b ... c") + _LOG.debug(f"Fourier: {ff.shape}") + _LOG.debug(f"Query: {query.shape}") query = torch.cat([query, ff], dim=-1) - print(query.shape) + _LOG.debug(query.shape) # concat to channels of data and flatten axis query = einops.rearrange(query, "b ... d -> b (...) d") # Need query to end with query_dim channels From 012ac31863ad94f10986d7d9bf9c842266483b54 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:14:00 +0100 Subject: [PATCH 07/18] Add getting output shape --- perceiver_pytorch/queries.py | 45 ++++++++++++++++++++++++++++-------- tests/test_queries.py | 2 +- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index 661b503..985c825 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -2,6 +2,7 @@ from torch.distributions import uniform from typing import List, Union, Tuple from perceiver_pytorch.utils import encode_position +from math import prod import einops import logging @@ -16,7 +17,7 @@ class LearnableQuery(torch.nn.Module): def __init__( self, - query_dim: int, + channel_dim: int, query_shape: Union[Tuple[int], List[int]], conv_layer: str = "3d", max_frequency: float = 16.0, @@ -28,10 +29,10 @@ def __init__( Learnable Query with some inbuilt randomness to help with ensembling Args: - query_dim: Query dimension + channel_dim: Channel dimension for the output of the network query_shape: The final shape of the query, generally, the (T, H, W) of the output """ - super().__init__() + super(LearnableQuery, self).__init__() self.query_shape = query_shape # Need to get Fourier Features once and then just append to the output self.fourier_features = encode_position( @@ -43,7 +44,7 @@ def __init__( sine_only=sine_only, ) _LOG.debug(self.fourier_features.shape) - self.query_dim = query_dim + self.channel_dim = channel_dim if conv_layer == "3d": conv = torch.nn.Conv3d elif conv_layer == "2d": @@ -51,14 +52,39 @@ def __init__( else: raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'") self.conv_layer = conv_layer - self.layer = conv(in_channels=query_dim, out_channels=query_dim, kernel_size=3, padding=1) + self.layer = conv( + in_channels=channel_dim, out_channels=channel_dim, kernel_size=3, padding=1 + ) # Linear layer to compress channels down to query_dim size? - self.fc = torch.nn.Linear(self.query_dim, self.query_dim) + self.fc = torch.nn.Linear(self.channel_dim, self.channel_dim) self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) - def forward(self, x: torch.Tensor): - _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.query_dim}") - z = self.distribution.sample((x.shape[0], self.query_dim, *self.query_shape)).type_as( + def output_shape(self) -> Tuple[int, int]: + """ + Gives the output shape from the query, useful for setting the correct + query_dim in the Perceiver + + Returns: + The shape of the resulting query, excluding the batch size + """ + + # The shape is the query_dim + Fourier Feature channels + channels = self.fourier_features.shape[-1] + self.channel_dim + return prod(self.query_shape), channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Samples the uniform distribution and creates the query by passing the + sample through the model and appending Fourier features + + Args: + x: The input tensor to the model, used to batch the batch size + + Returns: + Torch tensor used to query the output of the PerceiverIO model + """ + _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.channel_dim}") + z = self.distribution.sample((x.shape[0], self.channel_dim, *self.query_shape)).type_as( x ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1] z = torch.squeeze(z, dim=-1) # Extra 1 for some reason @@ -84,6 +110,5 @@ def forward(self, x: torch.Tensor): _LOG.debug(query.shape) # concat to channels of data and flatten axis query = einops.rearrange(query, "b ... d -> b (...) d") - # Need query to end with query_dim channels return query diff --git a/tests/test_queries.py b/tests/test_queries.py index 1cf40ad..0f7a6c9 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) def test_learnable_query(layer_shape): query_creator = LearnableQuery( - query_dim=32, + channel_dim=32, query_shape=(24, 128, 128), conv_layer=layer_shape, max_frequency=64.0, From a4fb689296f22db750a061a17a13050ba7453c88 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:28:59 +0100 Subject: [PATCH 08/18] Add test to ensure query works with PerceiverIO --- tests/test_queries.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index 0f7a6c9..2b84505 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,6 +1,8 @@ import pytest import torch from perceiver_pytorch.queries import LearnableQuery +from perceiver_pytorch.perceiver_io import PerceiverIO +import einops @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) @@ -16,5 +18,36 @@ def test_learnable_query(layer_shape): ) x = torch.randn((16, 24, 12, 128, 128)) out = query_creator(x) + assert out.shape == (16, 393216, 803) - pass + +@pytest.mark.parametrize("layer_shape", ["2d", "3d"]) +def test_learnable_query_qpplication(layer_shape): + output_shape = (24, 128, 128) + query_creator = LearnableQuery( + channel_dim=32, + query_shape=output_shape, + conv_layer=layer_shape, + max_frequency=64.0, + frequency_base=2.0, + num_frequency_bands=128, + sine_only=False, + ) + with torch.no_grad(): + query_creator.eval() + x = torch.randn((4, 24, 12, 128, 128)) + out = query_creator(x) + + model = PerceiverIO(depth=6, dim=100, queries_dim=query_creator.output_shape()[-1]) + model.eval() + model_input = torch.randn((4, 256, 100)) + model_out = model(model_input, queries=out) + # Reshape back to correct shape + model_out = einops.rearrange( + model_out, + "b (t h w) c -> b t c h w", + t=output_shape[0], + h=output_shape[1], + w=output_shape[2], + ) + assert model_out.shape == (4, 24, 803, 128, 128) From d4329bc50bbd0a125c6d07a0b59c06e6c8feab2d Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:32:22 +0100 Subject: [PATCH 09/18] Reduce image size for tests --- tests/test_queries.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index 2b84505..2efc6bb 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -23,7 +23,7 @@ def test_learnable_query(layer_shape): @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) def test_learnable_query_qpplication(layer_shape): - output_shape = (24, 128, 128) + output_shape = (24, 32, 32) query_creator = LearnableQuery( channel_dim=32, query_shape=output_shape, @@ -35,7 +35,7 @@ def test_learnable_query_qpplication(layer_shape): ) with torch.no_grad(): query_creator.eval() - x = torch.randn((4, 24, 12, 128, 128)) + x = torch.randn((4, 24, 12, 32, 32)) out = query_creator(x) model = PerceiverIO(depth=6, dim=100, queries_dim=query_creator.output_shape()[-1]) @@ -50,4 +50,4 @@ def test_learnable_query_qpplication(layer_shape): h=output_shape[1], w=output_shape[2], ) - assert model_out.shape == (4, 24, 803, 128, 128) + assert model_out.shape == (4, 24, 803, 32, 32) From 72815bd948bb500cb3e9814083b2fa576759d72e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:35:19 +0100 Subject: [PATCH 10/18] Reduce depth and size more --- tests/test_queries.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index 2efc6bb..306a481 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -23,7 +23,7 @@ def test_learnable_query(layer_shape): @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) def test_learnable_query_qpplication(layer_shape): - output_shape = (24, 32, 32) + output_shape = (6, 32, 32) query_creator = LearnableQuery( channel_dim=32, query_shape=output_shape, @@ -35,10 +35,10 @@ def test_learnable_query_qpplication(layer_shape): ) with torch.no_grad(): query_creator.eval() - x = torch.randn((4, 24, 12, 32, 32)) + x = torch.randn((4, 6, 12, 32, 32)) out = query_creator(x) - model = PerceiverIO(depth=6, dim=100, queries_dim=query_creator.output_shape()[-1]) + model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1]) model.eval() model_input = torch.randn((4, 256, 100)) model_out = model(model_input, queries=out) @@ -50,4 +50,4 @@ def test_learnable_query_qpplication(layer_shape): h=output_shape[1], w=output_shape[2], ) - assert model_out.shape == (4, 24, 803, 32, 32) + assert model_out.shape == (4, 6, 803, 32, 32) From 1188bae501c28eb203598100d83d300a9419fc90 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:38:21 +0100 Subject: [PATCH 11/18] Reduce number of Fourier Features --- tests/test_queries.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index 306a481..50b4367 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -23,24 +23,24 @@ def test_learnable_query(layer_shape): @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) def test_learnable_query_qpplication(layer_shape): - output_shape = (6, 32, 32) + output_shape = (6, 16, 16) query_creator = LearnableQuery( channel_dim=32, query_shape=output_shape, conv_layer=layer_shape, max_frequency=64.0, frequency_base=2.0, - num_frequency_bands=128, + num_frequency_bands=32, sine_only=False, ) with torch.no_grad(): query_creator.eval() - x = torch.randn((4, 6, 12, 32, 32)) + x = torch.randn((2, 6, 12, 16, 16)) out = query_creator(x) model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1]) model.eval() - model_input = torch.randn((4, 256, 100)) + model_input = torch.randn((2, 256, 100)) model_out = model(model_input, queries=out) # Reshape back to correct shape model_out = einops.rearrange( @@ -50,4 +50,4 @@ def test_learnable_query_qpplication(layer_shape): h=output_shape[1], w=output_shape[2], ) - assert model_out.shape == (4, 6, 803, 32, 32) + assert model_out.shape == (2, 6, 227, 16, 16) From ff42bf31cd9f2a2afcd98091eed74001cb0e758b Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:39:10 +0100 Subject: [PATCH 12/18] Rename check --- .github/workflows/python-app.yml | 45 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 2375b2a..a5a828c 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,33 +1,32 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python application +name: Python Tests -on: [push, pull_request] +on: [push] jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install -e . - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest -s --cov=perceiver_pytorch + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest pytest-cov + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest -s --cov=perceiver_pytorch From 071914f885fb911f607a8a326bc5863143bd8466 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:42:23 +0100 Subject: [PATCH 13/18] Reduce size --- tests/test_queries.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index 50b4367..dd09bc2 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -9,14 +9,14 @@ def test_learnable_query(layer_shape): query_creator = LearnableQuery( channel_dim=32, - query_shape=(24, 128, 128), + query_shape=(6, 16, 16), conv_layer=layer_shape, max_frequency=64.0, frequency_base=2.0, num_frequency_bands=128, sine_only=False, ) - x = torch.randn((16, 24, 12, 128, 128)) + x = torch.randn((16, 6, 12, 16, 16)) out = query_creator(x) assert out.shape == (16, 393216, 803) From 320ff1482c6479a9a10d3d513207fcab3403fb84 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 11:43:14 +0100 Subject: [PATCH 14/18] Reduce size --- tests/test_queries.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_queries.py b/tests/test_queries.py index dd09bc2..b2aec48 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -16,9 +16,9 @@ def test_learnable_query(layer_shape): num_frequency_bands=128, sine_only=False, ) - x = torch.randn((16, 6, 12, 16, 16)) + x = torch.randn((4, 6, 12, 16, 16)) out = query_creator(x) - assert out.shape == (16, 393216, 803) + assert out.shape == (4, 1536, 803) @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) From 497e54bdb8ad0445688e7b9f29535a936969ed74 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 15:14:15 +0100 Subject: [PATCH 15/18] Address PR comments --- perceiver_pytorch/queries.py | 8 ++------ tests/test_queries.py | 3 ++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index 985c825..fd0d153 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -36,14 +36,13 @@ def __init__( self.query_shape = query_shape # Need to get Fourier Features once and then just append to the output self.fourier_features = encode_position( - 1, + 1, # Batch size, 1 for this as it will be adapted in forward axis=query_shape, max_frequency=max_frequency, frequency_base=frequency_base, num_frequency_bands=num_frequency_bands, sine_only=sine_only, ) - _LOG.debug(self.fourier_features.shape) self.channel_dim = channel_dim if conv_layer == "3d": conv = torch.nn.Conv3d @@ -104,11 +103,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Match batches # Move channels to correct location query = einops.rearrange(query, "b c ... -> b ... c") - _LOG.debug(f"Fourier: {ff.shape}") - _LOG.debug(f"Query: {query.shape}") query = torch.cat([query, ff], dim=-1) - _LOG.debug(query.shape) # concat to channels of data and flatten axis query = einops.rearrange(query, "b ... d -> b (...) d") - + _LOG.debug(f"Final Query Shape: {query.shape}") return query diff --git a/tests/test_queries.py b/tests/test_queries.py index b2aec48..87a2957 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -18,7 +18,8 @@ def test_learnable_query(layer_shape): ) x = torch.randn((4, 6, 12, 16, 16)) out = query_creator(x) - assert out.shape == (4, 1536, 803) + # Output is flattened, so should be [B, T*H*W, C] + assert out.shape == (4, 16 * 16 * 6, 803) @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) From 818c8c5b4162af3c68d53cc0b385e09c977610c8 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 22 Sep 2021 15:33:51 +0100 Subject: [PATCH 16/18] Add another comment --- tests/test_queries.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_queries.py b/tests/test_queries.py index 87a2957..5926536 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -19,6 +19,8 @@ def test_learnable_query(layer_shape): x = torch.randn((4, 6, 12, 16, 16)) out = query_creator(x) # Output is flattened, so should be [B, T*H*W, C] + # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) + # 32 + 3*(257) = 771 + 32 = 803 assert out.shape == (4, 16 * 16 * 6, 803) From f0909b269f658251f493b346c30fc012ed34aa33 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 28 Sep 2021 15:43:28 +0100 Subject: [PATCH 17/18] Add support for single timestep query shape --- perceiver_pytorch/queries.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py index fd0d153..429a917 100644 --- a/perceiver_pytorch/queries.py +++ b/perceiver_pytorch/queries.py @@ -44,7 +44,9 @@ def __init__( sine_only=sine_only, ) self.channel_dim = channel_dim - if conv_layer == "3d": + if ( + conv_layer == "3d" and len(self.query_shape) == 3 + ): # If Query shape is for an image, then 3D conv won't work conv = torch.nn.Conv3d elif conv_layer == "2d": conv = torch.nn.Conv2d @@ -89,7 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: z = torch.squeeze(z, dim=-1) # Extra 1 for some reason _LOG.debug(f"Z: {z.shape}") # Do 3D or 2D CNN to keep same spatial size, concat, then linearize - if self.conv_layer == "2d": + if self.conv_layer == "2d" and len(self.query_shape) == 3: # Iterate through time dimension outs = [] for i in range(x.shape[1]): From 6b11e0c323cbb550003c88433408d0f3e1ac7ab6 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 28 Sep 2021 15:46:36 +0100 Subject: [PATCH 18/18] Update version --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 43f9fe8..3f77d32 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,15 @@ from setuptools import setup, find_packages from pathlib import Path + this_directory = Path(__file__).parent -install_requires = (this_directory / 'requirements.txt').read_text().splitlines() +install_requires = (this_directory / "requirements.txt").read_text().splitlines() long_description = (this_directory / "README.md").read_text() setup( name="perceiver-model", packages=find_packages(), - version="0.7.1", + version="0.7.2", license="MIT", description="Multimodal Perceiver - Pytorch", author="Jacob Bieker, Jack Kelly, Peter Dudfield", @@ -22,7 +23,7 @@ "attention mechanism", ], long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", install_requires=install_requires, classifiers=[ "Development Status :: 4 - Beta",