Skip to content

Commit

Permalink
Merge pull request #2113 from gau-nernst/tinyclip
Browse files Browse the repository at this point in the history
Add TinyCLIP
  • Loading branch information
rwightman authored Mar 20, 2024
2 parents 111fad1 + 256cf19 commit 492947d
Showing 1 changed file with 65 additions and 4 deletions.
69 changes: 65 additions & 4 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,13 @@ def _convert_openai_clip(
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
verbose=True,
)
out_dict[k] = v
return out_dict
Expand Down Expand Up @@ -1735,6 +1737,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 384, 384),
num_classes=0),

'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),

'vit_medium_patch16_reg4_256': _cfg(
input_size=(3, 256, 256)),
'vit_medium_patch16_reg4_gap_256': _cfg(
Expand Down Expand Up @@ -2073,6 +2096,44 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans
return model


@register_model
def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 8M
model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 40M
model_args = dict(
patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 39M
model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 61M
model_args = dict(
patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
Expand Down

0 comments on commit 492947d

Please sign in to comment.