Skip to content

Commit

Permalink
any text without media should not be updated at all during masked cro…
Browse files Browse the repository at this point in the history
…ss attention
  • Loading branch information
lucidrains committed Jun 8, 2022
1 parent 0e1a915 commit 26ca8db
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions flamingo_pytorch/flamingo_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def forward(
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)

if exists(media_locations) and self.only_attend_immediate_media:
# any text without a preceding media needs to have attention zeroed out
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
attn.masked_fill(text_without_media_mask, 0.)

out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flamingo-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.16',
version = '0.0.17',
license='MIT',
description = 'Flamingo - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 26ca8db

Please sign in to comment.