diff --git a/README.md b/README.md index b4c08f1..6dbd412 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,6 @@ model = Perceiver( input_axis = 2, # number of axis for input data (2 for images, 3 for video) num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is - freq_base = 2, depth = 6, # depth of net. The shape of the final attention mechanism will be: # depth * (cross attention -> self_per_cross_attn * self attention) num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index d6809b6..799b81f 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -29,11 +29,11 @@ def cached_fn(*args, _cache = True, **kwargs): return cache return cached_fn -def fourier_encode(x, max_freq, num_bands = 4, base = 2): +def fourier_encode(x, max_freq, num_bands = 4): x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x - scales = torch.logspace(0., log(max_freq / 2) / log(base), num_bands, base = base, device = device, dtype = dtype) + scales = torch.logspace(0., log(max_freq / 2), num_bands, device = device, dtype = dtype) scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] x = x * scales * pi @@ -128,7 +128,6 @@ def __init__( num_freq_bands, depth, max_freq, - freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, @@ -177,7 +176,6 @@ def __init__( self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands - self.freq_base = freq_base self.fourier_encode_data = fourier_encode_data fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 @@ -231,7 +229,7 @@ def forward( axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) - enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') enc_pos = repeat(enc_pos, '... -> b ...', b = b)