diff --git a/flamingo_pytorch/flamingo_palm.py b/flamingo_pytorch/flamingo_palm.py index 7388ab1..e9fc853 100644 --- a/flamingo_pytorch/flamingo_palm.py +++ b/flamingo_pytorch/flamingo_palm.py @@ -209,7 +209,8 @@ def __init__( cross_attn_every=3, img_encoder=None, perceiver_num_latents=64, - perceiver_depth=2 + perceiver_depth=2, + only_attend_immediate_media=False ): super().__init__() @@ -231,7 +232,7 @@ def __init__( for ind in range(depth): self.layers.append(nn.ModuleList([ Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), - GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads) if not (ind % cross_attn_every) else None + GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media) if not (ind % cross_attn_every) else None ])) self.to_logits = nn.Sequential( diff --git a/flamingo_pytorch/flamingo_pytorch.py b/flamingo_pytorch/flamingo_pytorch.py index d29d09a..d4571c5 100644 --- a/flamingo_pytorch/flamingo_pytorch.py +++ b/flamingo_pytorch/flamingo_pytorch.py @@ -119,7 +119,8 @@ def __init__( *, dim, dim_head = 64, - heads = 8 + heads = 8, + only_attend_immediate_media = False ): super().__init__() self.scale = dim_head ** -0.5 @@ -132,6 +133,10 @@ def __init__( self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) + # whether for text to only attend to immediate preceding image, or all images + + self.only_attend_immediate_media = only_attend_immediate_media + def forward( self, x, @@ -156,7 +161,12 @@ def forward( if exists(media_locations): text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time) media_time = torch.arange(t, device = x.device) + 1 - text_to_media_mask = rearrange(text_time, 'b i -> b 1 i 1') >= repeat(media_time, 'j -> 1 1 1 (j m)', m = m) + + # text time must equal media time if only attending to most immediate image + # otherwise, as long as text time is greater than media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge + + text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m)) sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) sim = sim - sim.amax(dim = -1, keepdim = True).detach() @@ -173,10 +183,11 @@ def __init__( dim, dim_head = 64, heads = 8, - ff_mult = 4 + ff_mult = 4, + only_attend_immediate_media = False ): super().__init__() - self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads) + self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media) self.attn_gate = nn.Parameter(torch.tensor([0.])) self.ff = FeedForward(dim, mult = ff_mult) diff --git a/setup.py b/setup.py index f3325eb..eb16f46 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/flamingo-pytorch', + long_description_content_type = 'text/markdown', keywords = [ 'artificial intelligence', 'deep learning',