Skip to content

Commit

Permalink
Make k=stride=2 ('avg2') pooling default for coatnet/maxvit. Add weig…
Browse files Browse the repository at this point in the history
…ht links. Rename 'combined' partition to 'parallel'.
  • Loading branch information
rwightman committed Aug 24, 2022
1 parent 837c682 commit b2e8426
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions timm/models/maxxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,26 @@ def _cfg(url='', **kwargs):
# Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9),
'coatnet_0_rw_224': _cfg(
url=''),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg(
url=''
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
),
'coatnet_2_rw_224': _cfg(url=''),

# Highly experimental configs
'coatnet_bn_0_rw_224': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg(
url=''),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(url=''),

Expand All @@ -107,10 +107,12 @@ def _cfg(url='', **kwargs):

# Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),

# Trying to be like the MaxViT paper configs
Expand All @@ -131,7 +133,7 @@ class MaxxVitTransformerCfg:
attn_bias: bool = True
attn_drop: float = 0.
proj_drop: float = 0.
pool_type: str = 'avg'
pool_type: str = 'avg2'
rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7)
Expand All @@ -153,7 +155,7 @@ class MaxxVitConvCfg:
pre_norm_act: bool = False # activation after pre-norm
output_bias: bool = True # bias for shortcut + final 1x1 projection conv
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
pool_type: str = 'avg'
pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2'
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
attn_layer: str = 'se'
Expand Down Expand Up @@ -241,7 +243,7 @@ def _rw_coat_cfg(

def _rw_max_cfg(
stride_mode='dw',
pool_type='avg',
pool_type='avg2',
conv_output_bias=False,
conv_attn_ratio=1 / 16,
conv_norm_layer='',
Expand Down Expand Up @@ -325,7 +327,6 @@ def _next_cfg(
depths=(2, 3, 5, 2),
stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here
pool_type='avg2',
conv_output_bias=True,
conv_attn_ratio=0.25,
),
Expand All @@ -336,7 +337,6 @@ def _next_cfg(
stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here
stride_mode='pool',
pool_type='avg2',
conv_output_bias=True,
conv_attn_ratio=0.25,
),
Expand Down Expand Up @@ -384,7 +384,6 @@ def _next_cfg(
depths=(3, 4, 6, 3),
stem_width=(32, 64),
**_rw_max_cfg(
pool_type='avg2',
conv_output_bias=True,
conv_attn_ratio=0.25,
rel_pos_type='mlp',
Expand Down Expand Up @@ -487,10 +486,10 @@ def _next_cfg(
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
),
maxvit_tiny_cm_256=MaxxVitCfg(
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('CM',) * 4,
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
),
Expand Down Expand Up @@ -663,13 +662,15 @@ def __init__(
bias: bool = True,
):
super().__init__()
assert pool_type in ('max', 'avg', 'avg2')
assert pool_type in ('max', 'max2', 'avg', 'avg2')
if pool_type == 'max':
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
elif pool_type == 'max2':
self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2
elif pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
else:
self.pool = nn.AvgPool2d(2)
self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2

if dim != dim_out:
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
Expand Down Expand Up @@ -1073,7 +1074,7 @@ def forward(self, x):
return x


class CombinedPartitionAttention(nn.Module):
class ParallelPartitionAttention(nn.Module):
""" Experimental. Grid and Block partition + single FFN
NxC tensor layout.
"""
Expand Down Expand Up @@ -1286,7 +1287,7 @@ def forward(self, x):
return x


class CombinedMaxxVitBlock(nn.Module):
class ParallelMaxxVitBlock(nn.Module):
"""
"""

Expand All @@ -1309,7 +1310,7 @@ def __init__(
self.conv = nn.Sequential(*convs)
else:
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)

def init_weights(self, scheme=''):
named_apply(partial(_init_transformer, scheme=scheme), self.attn)
Expand Down Expand Up @@ -1343,7 +1344,7 @@ def __init__(
blocks = []
for i, t in enumerate(block_types):
block_stride = stride if i == 0 else 1
assert t in ('C', 'T', 'M', 'CM')
assert t in ('C', 'T', 'M', 'PM')
if t == 'C':
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
blocks += [conv_cls(
Expand Down Expand Up @@ -1372,8 +1373,8 @@ def __init__(
transformer_cfg=transformer_cfg,
drop_path=drop_path[i],
)]
elif t == 'CM':
blocks += [CombinedMaxxVitBlock(
elif t == 'PM':
blocks += [ParallelMaxxVitBlock(
in_chs,
out_chs,
stride=block_stride,
Expand Down Expand Up @@ -1415,7 +1416,6 @@ def __init__(
self.norm1 = norm_act_layer(out_chs[0])
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)

@torch.jit.ignore
def init_weights(self, scheme=''):
named_apply(partial(_init_conv, scheme=scheme), self)

Expand Down Expand Up @@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):


@register_model
def maxvit_tiny_cm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs)
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)


@register_model
Expand Down

0 comments on commit b2e8426

Please sign in to comment.