Skip to content

Commit

Permalink
Merge pull request lucidrains#19 from openclimatefix/jacob/learnable-…
Browse files Browse the repository at this point in the history
…queries

Add Learnable Query
  • Loading branch information
jacobbieker authored Sep 28, 2021
2 parents 4540966 + 6b11e0c commit 31631d2
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 26 deletions.
45 changes: 22 additions & 23 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python application
name: Python Tests

on: [push, pull_request]
on: [push]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest -s --cov=perceiver_pytorch
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest -s --cov=perceiver_pytorch
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
default_language_version:
python: python3.9

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: debug-statements
- id: detect-private-key

# python code formatting
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
args: [--line-length, "100"]

# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.3.0
hooks:
- id: prettier
types: [yaml]
112 changes: 112 additions & 0 deletions perceiver_pytorch/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
from torch.distributions import uniform
from typing import List, Union, Tuple
from perceiver_pytorch.utils import encode_position
from math import prod
import einops
import logging

_LOG = logging.getLogger("perceiver.queries")
_LOG.setLevel(logging.WARN)


class LearnableQuery(torch.nn.Module):
"""
Module that constructs a learnable query of query_shape for the Perceiver
"""

def __init__(
self,
channel_dim: int,
query_shape: Union[Tuple[int], List[int]],
conv_layer: str = "3d",
max_frequency: float = 16.0,
num_frequency_bands: int = 64,
frequency_base: float = 2.0,
sine_only: bool = False,
):
"""
Learnable Query with some inbuilt randomness to help with ensembling
Args:
channel_dim: Channel dimension for the output of the network
query_shape: The final shape of the query, generally, the (T, H, W) of the output
"""
super(LearnableQuery, self).__init__()
self.query_shape = query_shape
# Need to get Fourier Features once and then just append to the output
self.fourier_features = encode_position(
1, # Batch size, 1 for this as it will be adapted in forward
axis=query_shape,
max_frequency=max_frequency,
frequency_base=frequency_base,
num_frequency_bands=num_frequency_bands,
sine_only=sine_only,
)
self.channel_dim = channel_dim
if (
conv_layer == "3d" and len(self.query_shape) == 3
): # If Query shape is for an image, then 3D conv won't work
conv = torch.nn.Conv3d
elif conv_layer == "2d":
conv = torch.nn.Conv2d
else:
raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'")
self.conv_layer = conv_layer
self.layer = conv(
in_channels=channel_dim, out_channels=channel_dim, kernel_size=3, padding=1
)
# Linear layer to compress channels down to query_dim size?
self.fc = torch.nn.Linear(self.channel_dim, self.channel_dim)
self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0]))

def output_shape(self) -> Tuple[int, int]:
"""
Gives the output shape from the query, useful for setting the correct
query_dim in the Perceiver
Returns:
The shape of the resulting query, excluding the batch size
"""

# The shape is the query_dim + Fourier Feature channels
channels = self.fourier_features.shape[-1] + self.channel_dim
return prod(self.query_shape), channels

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Samples the uniform distribution and creates the query by passing the
sample through the model and appending Fourier features
Args:
x: The input tensor to the model, used to batch the batch size
Returns:
Torch tensor used to query the output of the PerceiverIO model
"""
_LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.channel_dim}")
z = self.distribution.sample((x.shape[0], self.channel_dim, *self.query_shape)).type_as(
x
) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1]
z = torch.squeeze(z, dim=-1) # Extra 1 for some reason
_LOG.debug(f"Z: {z.shape}")
# Do 3D or 2D CNN to keep same spatial size, concat, then linearize
if self.conv_layer == "2d" and len(self.query_shape) == 3:
# Iterate through time dimension
outs = []
for i in range(x.shape[1]):
outs.append(self.layer(z[:, :, i, :, :]))
query = torch.stack(outs, dim=2)
else:
query = self.layer(z)
# Add Fourier Features
ff = einops.repeat(
self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0]
) # Match batches
# Move channels to correct location
query = einops.rearrange(query, "b c ... -> b ... c")
query = torch.cat([query, ff], dim=-1)
# concat to channels of data and flatten axis
query = einops.rearrange(query, "b ... d -> b (...) d")
_LOG.debug(f"Final Query Shape: {query.shape}")
return query
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from setuptools import setup, find_packages
from pathlib import Path

this_directory = Path(__file__).parent
install_requires = (this_directory / 'requirements.txt').read_text().splitlines()
install_requires = (this_directory / "requirements.txt").read_text().splitlines()
long_description = (this_directory / "README.md").read_text()


setup(
name="perceiver-model",
packages=find_packages(),
version="0.7.1",
version="0.7.2",
license="MIT",
description="Multimodal Perceiver - Pytorch",
author="Jacob Bieker, Jack Kelly, Peter Dudfield",
Expand All @@ -22,7 +23,7 @@
"attention mechanism",
],
long_description=long_description,
long_description_content_type='text/markdown',
long_description_content_type="text/markdown",
install_requires=install_requires,
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
56 changes: 56 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import torch
from perceiver_pytorch.queries import LearnableQuery
from perceiver_pytorch.perceiver_io import PerceiverIO
import einops


@pytest.mark.parametrize("layer_shape", ["2d", "3d"])
def test_learnable_query(layer_shape):
query_creator = LearnableQuery(
channel_dim=32,
query_shape=(6, 16, 16),
conv_layer=layer_shape,
max_frequency=64.0,
frequency_base=2.0,
num_frequency_bands=128,
sine_only=False,
)
x = torch.randn((4, 6, 12, 16, 16))
out = query_creator(x)
# Output is flattened, so should be [B, T*H*W, C]
# Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
# 32 + 3*(257) = 771 + 32 = 803
assert out.shape == (4, 16 * 16 * 6, 803)


@pytest.mark.parametrize("layer_shape", ["2d", "3d"])
def test_learnable_query_qpplication(layer_shape):
output_shape = (6, 16, 16)
query_creator = LearnableQuery(
channel_dim=32,
query_shape=output_shape,
conv_layer=layer_shape,
max_frequency=64.0,
frequency_base=2.0,
num_frequency_bands=32,
sine_only=False,
)
with torch.no_grad():
query_creator.eval()
x = torch.randn((2, 6, 12, 16, 16))
out = query_creator(x)

model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1])
model.eval()
model_input = torch.randn((2, 256, 100))
model_out = model(model_input, queries=out)
# Reshape back to correct shape
model_out = einops.rearrange(
model_out,
"b (t h w) c -> b t c h w",
t=output_shape[0],
h=output_shape[1],
w=output_shape[2],
)
assert model_out.shape == (2, 6, 227, 16, 16)

0 comments on commit 31631d2

Please sign in to comment.