Skip to content

Commit

Permalink
address #10
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 5, 2022
1 parent 8d8a23e commit 0ea204d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
25 changes: 24 additions & 1 deletion coca_pytorch/coca_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def __init__(self, fn):
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

# to latents


class EmbedToLatents(nn.Module):
def __init__(self, dim, dim_latents):
super().__init__()
self.to_latents = nn.Linear(dim, dim_latents, bias=False)

def forward(self, x):
latents = self.to_latents(x)
return F.normalize(latents, dim=-1)

# rotary positional embedding
# https://arxiv.org/abs/2104.09864
Expand Down Expand Up @@ -282,6 +293,7 @@ def __init__(
num_tokens,
unimodal_depth,
multimodal_depth,
dim_latents = None,
image_dim = None,
num_img_queries=256,
dim_head=64,
Expand Down Expand Up @@ -316,6 +328,12 @@ def __init__(
self.img_attn_pool_norm = LayerNorm(dim)
self.text_cls_norm = LayerNorm(dim)

# to latents

dim_latents = default(dim_latents, dim)
self.img_to_latents = EmbedToLatents(dim, dim_latents)
self.text_to_latents = EmbedToLatents(dim, dim_latents)

# contrastive learning temperature

self.temperature = nn.Parameter(torch.Tensor([1.]))
Expand Down Expand Up @@ -440,9 +458,14 @@ def forward(
caption_loss = ce(logits, labels, ignore_index=self.pad_id)
caption_loss = caption_loss * self.caption_loss_weight

# embedding to latents

text_latents = self.text_to_latents(text_embeds)
image_latents = self.img_to_latents(image_embeds)

# calculate contrastive loss

sim = einsum('i d, j d -> i j', text_embeds, image_embeds)
sim = einsum('i d, j d -> i j', text_latents, image_latents)
sim = sim * self.temperature.exp()
contrastive_labels = torch.arange(batch, device=device)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoCa-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'CoCa, Contrastive Captioners are Image-Text Foundation Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0ea204d

Please sign in to comment.