Skip to content

Commit

Permalink
add weight tying across layers feature
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent 0705deb commit c3e3add
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ model = Perceiver(
latent_dim_head = 64,
num_classes = 1000, # output number of classes
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram)
)

img = torch.randn(1, 224 * 224) # 1 imagenet image, pixelized
Expand Down
32 changes: 27 additions & 5 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import wraps

import torch
from torch import nn, einsum
import torch.nn.functional as F
Expand All @@ -12,6 +14,17 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, **kwargs):
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn

def fourier_encode(x, num_encodings = 4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
Expand Down Expand Up @@ -102,21 +115,30 @@ def __init__(
latent_dim_head = 64,
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
weight_tie_layers = False
):
super().__init__()

self.num_fourier_features = num_fourier_features
input_dim = (num_fourier_features * 2) + 1
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, dropout = attn_dropout), context_dim = input_dim)
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, dropout = attn_dropout))
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

if weight_tie_layers:
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(latent_dim, Attention(latent_dim, input_dim, dropout = attn_dropout), context_dim = input_dim),
PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)),
PreNorm(latent_dim, Attention(latent_dim, dropout = attn_dropout)),
PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_cross_attn(),
get_cross_ff(),
get_latent_attn(),
get_latent_ff()
]))

self.to_logits = nn.Linear(latent_dim, num_classes)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Perceiver - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c3e3add

Please sign in to comment.