Skip to content

Commit

Permalink
remove base parameter from fourier_encode function
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 26, 2021
1 parent b33aced commit de5666b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ model = Perceiver(
input_axis = 2, # number of axis for input data (2 for images, 3 for video)
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
freq_base = 2,
depth = 6, # depth of net. The shape of the final attention mechanism will be:
# depth * (cross attention -> self_per_cross_attn * self attention)
num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
Expand Down
8 changes: 3 additions & 5 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def cached_fn(*args, _cache = True, **kwargs):
return cache
return cached_fn

def fourier_encode(x, max_freq, num_bands = 4, base = 2):
def fourier_encode(x, max_freq, num_bands = 4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x

scales = torch.logspace(0., log(max_freq / 2) / log(base), num_bands, base = base, device = device, dtype = dtype)
scales = torch.logspace(0., log(max_freq / 2), num_bands, device = device, dtype = dtype)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]

x = x * scales * pi
Expand Down Expand Up @@ -128,7 +128,6 @@ def __init__(
num_freq_bands,
depth,
max_freq,
freq_base = 2,
input_channels = 3,
input_axis = 2,
num_latents = 512,
Expand Down Expand Up @@ -177,7 +176,6 @@ def __init__(
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.freq_base = freq_base

self.fourier_encode_data = fourier_encode_data
fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0
Expand Down Expand Up @@ -231,7 +229,7 @@ def forward(

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)

Expand Down

0 comments on commit de5666b

Please sign in to comment.