Skip to content

Commit

Permalink
allow for turning off fourier positions for experimental perceiver (i…
Browse files Browse the repository at this point in the history
…sab-like) for @nleroy917
  • Loading branch information
lucidrains committed Aug 22, 2023
1 parent d6e3cda commit c3d505a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
29 changes: 18 additions & 11 deletions perceiver_pytorch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, x, mask = None):
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

q *= self.scale
q = q * self.scale
q, k = q.softmax(dim = -1), k.softmax(dim = -2)

if exists(mask):
Expand Down Expand Up @@ -64,14 +64,19 @@ def __init__(
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
weight_tie_layers = False
weight_tie_layers = False,
fourier_encode_data = True
):
super().__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.fourier_encode_data = fourier_encode_data

input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
input_dim = input_channels

if fourier_encode_data:
input_dim += input_axis * ((num_freq_bands * 2) + 1) + input_channels

self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

Expand Down Expand Up @@ -113,17 +118,19 @@ def forward(self, data, mask = None):
b, *axis, _, device = *data.shape, data.device
assert len(axis) == self.input_axis, 'input data must have the right number of axis'

# calculate fourier encoded positions in the range of [-1, 1], for all axis
if self.fourier_encode_data:
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
# concat to channels of data and flatten axis

# concat to channels of data and flatten axis
data = torch.cat((data, enc_pos), dim = -1)

data = torch.cat((data, enc_pos), dim = -1)
data = rearrange(data, 'b ... d -> b (...) d')

data = self.data_proj(data)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.8.7',
version = '0.8.8',
license='MIT',
description = 'Perceiver - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit c3d505a

Please sign in to comment.