Skip to content

Commit

Permalink
fix everything and optimize for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 28, 2022
1 parent acafc19 commit b4bd3f9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 22 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
dim = 1024,
depth = 2,
dim_head = 64,
heads = 8,
num_latents = 64,
)

medias = torch.randn(1, 2, 256, 1024) # (batch, num medias, sequence length, dimension)
resampled = perceive(medias) # (1, 2, 64, 1024) - (batch, num medias, num latents, dimension)
medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
resampled = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)
```

## Citations
Expand Down
79 changes: 60 additions & 19 deletions flamingo_pytorch/flamingo_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,49 @@
from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many

class PerceiverResampler(nn.Module):
def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias = False)
)

class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
num_latents = 64
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads

self.latents = nn.Parameter(torch.randn(num_latents, dim))

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x):
def forward(self, x, latents):
"""
einstein notation
b - batch
m - number of medias
t - time
n - sequence
d - dimension
"""
if x.ndim == 3:
x = rearrange(x, 'b n d -> b 1 n d')

b, m, h = *x.shape[:2], self.heads

q = self.to_q(self.latents)
q = self.to_q(latents)

# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
lk, lv = self.to_kv(self.latents).chunk(2, dim = -1)

lk, lv = self.to_kv(latents).chunk(2, dim = -1)
k, v = self.to_kv(x).chunk(2, dim = -1)

q = rearrange(q, 'n (h d) -> h n d', h = h)
k, v = rearrange_many((k, v), 'b m n (h d) -> b h m n d', h = h)
lk, lv = repeat_many((lk, lv), 'n (h d) -> b h m n d', b = b, m = m, h = h)
k, v, lk, lv, q = rearrange_many((k, v, lk, lv, q), 'b t n (h d) -> b h t n d', h = h)

k = torch.cat((k, lk), dim = -2)
v = torch.cat((v, lv), dim = -2)
Expand All @@ -56,11 +56,52 @@ def forward(self, x):

# attention

sim = einsum('h i d, b h m j d -> b h m i j', q, k)
sim = einsum('... i d, ... j d -> ... i j', q, k)

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)

out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h m n d -> b m n (h d)', h = h)
out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
return self.to_out(out)

class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
ff_mult = 4
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads

self.latents = nn.Parameter(torch.randn(num_latents, dim))

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x):
if x.ndim == 3:
x = rearrange(x, 'b n d -> b 1 n d')

latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])

for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents

return latents
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flamingo-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Flamingo - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b4bd3f9

Please sign in to comment.