diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 2375b2a..a5a828c 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,33 +1,32 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python application +name: Python Tests -on: [push, pull_request] +on: [push] jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install -e . - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest -s --cov=perceiver_pytorch + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest pytest-cov + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest -s --cov=perceiver_pytorch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d7fbade --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +default_language_version: + python: python3.9 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: debug-statements + - id: detect-private-key + + # python code formatting + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + args: [--line-length, "100"] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.3.0 + hooks: + - id: prettier + types: [yaml] diff --git a/perceiver_pytorch/queries.py b/perceiver_pytorch/queries.py new file mode 100644 index 0000000..429a917 --- /dev/null +++ b/perceiver_pytorch/queries.py @@ -0,0 +1,112 @@ +import torch +from torch.distributions import uniform +from typing import List, Union, Tuple +from perceiver_pytorch.utils import encode_position +from math import prod +import einops +import logging + +_LOG = logging.getLogger("perceiver.queries") +_LOG.setLevel(logging.WARN) + + +class LearnableQuery(torch.nn.Module): + """ + Module that constructs a learnable query of query_shape for the Perceiver + """ + + def __init__( + self, + channel_dim: int, + query_shape: Union[Tuple[int], List[int]], + conv_layer: str = "3d", + max_frequency: float = 16.0, + num_frequency_bands: int = 64, + frequency_base: float = 2.0, + sine_only: bool = False, + ): + """ + Learnable Query with some inbuilt randomness to help with ensembling + + Args: + channel_dim: Channel dimension for the output of the network + query_shape: The final shape of the query, generally, the (T, H, W) of the output + """ + super(LearnableQuery, self).__init__() + self.query_shape = query_shape + # Need to get Fourier Features once and then just append to the output + self.fourier_features = encode_position( + 1, # Batch size, 1 for this as it will be adapted in forward + axis=query_shape, + max_frequency=max_frequency, + frequency_base=frequency_base, + num_frequency_bands=num_frequency_bands, + sine_only=sine_only, + ) + self.channel_dim = channel_dim + if ( + conv_layer == "3d" and len(self.query_shape) == 3 + ): # If Query shape is for an image, then 3D conv won't work + conv = torch.nn.Conv3d + elif conv_layer == "2d": + conv = torch.nn.Conv2d + else: + raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'") + self.conv_layer = conv_layer + self.layer = conv( + in_channels=channel_dim, out_channels=channel_dim, kernel_size=3, padding=1 + ) + # Linear layer to compress channels down to query_dim size? + self.fc = torch.nn.Linear(self.channel_dim, self.channel_dim) + self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) + + def output_shape(self) -> Tuple[int, int]: + """ + Gives the output shape from the query, useful for setting the correct + query_dim in the Perceiver + + Returns: + The shape of the resulting query, excluding the batch size + """ + + # The shape is the query_dim + Fourier Feature channels + channels = self.fourier_features.shape[-1] + self.channel_dim + return prod(self.query_shape), channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Samples the uniform distribution and creates the query by passing the + sample through the model and appending Fourier features + + Args: + x: The input tensor to the model, used to batch the batch size + + Returns: + Torch tensor used to query the output of the PerceiverIO model + """ + _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.channel_dim}") + z = self.distribution.sample((x.shape[0], self.channel_dim, *self.query_shape)).type_as( + x + ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1] + z = torch.squeeze(z, dim=-1) # Extra 1 for some reason + _LOG.debug(f"Z: {z.shape}") + # Do 3D or 2D CNN to keep same spatial size, concat, then linearize + if self.conv_layer == "2d" and len(self.query_shape) == 3: + # Iterate through time dimension + outs = [] + for i in range(x.shape[1]): + outs.append(self.layer(z[:, :, i, :, :])) + query = torch.stack(outs, dim=2) + else: + query = self.layer(z) + # Add Fourier Features + ff = einops.repeat( + self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0] + ) # Match batches + # Move channels to correct location + query = einops.rearrange(query, "b c ... -> b ... c") + query = torch.cat([query, ff], dim=-1) + # concat to channels of data and flatten axis + query = einops.rearrange(query, "b ... d -> b (...) d") + _LOG.debug(f"Final Query Shape: {query.shape}") + return query diff --git a/setup.py b/setup.py index 43f9fe8..3f77d32 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,15 @@ from setuptools import setup, find_packages from pathlib import Path + this_directory = Path(__file__).parent -install_requires = (this_directory / 'requirements.txt').read_text().splitlines() +install_requires = (this_directory / "requirements.txt").read_text().splitlines() long_description = (this_directory / "README.md").read_text() setup( name="perceiver-model", packages=find_packages(), - version="0.7.1", + version="0.7.2", license="MIT", description="Multimodal Perceiver - Pytorch", author="Jacob Bieker, Jack Kelly, Peter Dudfield", @@ -22,7 +23,7 @@ "attention mechanism", ], long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", install_requires=install_requires, classifiers=[ "Development Status :: 4 - Beta", diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 0000000..5926536 --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,56 @@ +import pytest +import torch +from perceiver_pytorch.queries import LearnableQuery +from perceiver_pytorch.perceiver_io import PerceiverIO +import einops + + +@pytest.mark.parametrize("layer_shape", ["2d", "3d"]) +def test_learnable_query(layer_shape): + query_creator = LearnableQuery( + channel_dim=32, + query_shape=(6, 16, 16), + conv_layer=layer_shape, + max_frequency=64.0, + frequency_base=2.0, + num_frequency_bands=128, + sine_only=False, + ) + x = torch.randn((4, 6, 12, 16, 16)) + out = query_creator(x) + # Output is flattened, so should be [B, T*H*W, C] + # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) + # 32 + 3*(257) = 771 + 32 = 803 + assert out.shape == (4, 16 * 16 * 6, 803) + + +@pytest.mark.parametrize("layer_shape", ["2d", "3d"]) +def test_learnable_query_qpplication(layer_shape): + output_shape = (6, 16, 16) + query_creator = LearnableQuery( + channel_dim=32, + query_shape=output_shape, + conv_layer=layer_shape, + max_frequency=64.0, + frequency_base=2.0, + num_frequency_bands=32, + sine_only=False, + ) + with torch.no_grad(): + query_creator.eval() + x = torch.randn((2, 6, 12, 16, 16)) + out = query_creator(x) + + model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1]) + model.eval() + model_input = torch.randn((2, 256, 100)) + model_out = model(model_input, queries=out) + # Reshape back to correct shape + model_out = einops.rearrange( + model_out, + "b (t h w) c -> b t c h w", + t=output_shape[0], + h=output_shape[1], + w=output_shape[2], + ) + assert model_out.shape == (2, 6, 227, 16, 16)