Skip to content

Commit

Permalink
add stochastic depth to resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
bonlime committed Apr 8, 2020
1 parent 0d923f0 commit d8d7b4b
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 8 deletions.
14 changes: 14 additions & 0 deletions pytorch_tools/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class ResNet(nn.Module):
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
drop_rate (float):
Dropout probability before classifier, for training. Defaults to 0.0.
drop_connect_rate (float):
Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training.
keep prob will be linearly decreased from 1 to 1 - drop_connect_rate each block. Ref: https://arxiv.org/abs/1603.09382
global_pool (str):
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'. Defaults to 'avg'.
init_bn0 (bool):
Expand All @@ -95,6 +98,7 @@ def __init__(
antialias=False,
encoder=False,
drop_rate=0.0,
drop_connect_rate=0.0,
global_pool="avg",
init_bn0=True,
):
Expand All @@ -108,6 +112,9 @@ def __init__(
self.block = block
self.expansion = block.expansion
self.norm_act = norm_act
self.block_idx = 0
self.num_blocks = sum(layers)
self.drop_connect_rate = drop_connect_rate
super(ResNet, self).__init__()

if deep_stem:
Expand Down Expand Up @@ -185,6 +192,7 @@ def _make_layer(
norm_layer=norm_layer,
norm_act=norm_act,
antialias=antialias,
keep_prob=self.keep_prob,
)
]

Expand All @@ -201,6 +209,7 @@ def _make_layer(
norm_layer=norm_layer,
norm_act=norm_act,
antialias=antialias,
keep_prob=self.keep_prob,
)
)
return nn.Sequential(*layers)
Expand Down Expand Up @@ -266,6 +275,11 @@ def load_state_dict(self, state_dict, **kwargs):
state_dict[k.replace("layer0.", "")] = state_dict.pop(k)
super().load_state_dict(state_dict, **kwargs)

@property
def keep_prob(self):
keep_prob = 1 - self.drop_connect_rate * self.block_idx / self.num_blocks
self.block_idx += 1
return keep_prob

# fmt: off
CFGS = {
Expand Down
6 changes: 6 additions & 0 deletions pytorch_tools/models/tresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class TResNet(ResNet):
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
drop_rate (float):
Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'.
drop_connect_rate (float):
Drop rate for StochasticDepth. Randomly removes samples each block. Used as regularization during training. Ref: https://arxiv.org/abs/1603.09382
"""

def __init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
norm_act="leaky_relu",
encoder=False,
drop_rate=0.0,
drop_connect_rate=0.0,
):
nn.Module.__init__(self)
stem_width = int(64 * width_factor)
Expand All @@ -74,6 +77,9 @@ def __init__(
self.groups = 1 # not really used but needed inside _make_layer
self.base_width = 64 # used inside _make_layer
self.norm_act = norm_act
self.block_idx = 0
self.num_blocks = sum(layers)
self.drop_connect_rate = drop_connect_rate

# in the paper they use conv1x1 but in code conv3x3 (which seems better)
self.conv1 = nn.Sequential(SpaceToDepth(), conv3x3(in_channels * 16, stem_width))
Expand Down
16 changes: 10 additions & 6 deletions pytorch_tools/modules/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
norm_layer=ABN,
norm_act="relu",
antialias=False,
keep_prob=1,
):
super(BasicBlock, self).__init__()
antialias = antialias and stride == 2
Expand All @@ -167,6 +168,7 @@ def __init__(
self.downsample = downsample
self.blurpool = BlurPool(channels=planes) if antialias else nn.Identity()
self.antialias = antialias
self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity()

def forward(self, x):
residual = x
Expand All @@ -180,11 +182,11 @@ def forward(self, x):
if self.antialias:
out = self.blurpool(out)
out = self.conv2(out)
# avoid 2 inplace ops by chaining into one long op. Neede for inplaceabn
# avoid 2 inplace ops by chaining into one long op. Needed for inplaceabn
if self.se_module is not None:
out = self.se_module(self.bn2(out)) + residual
out = self.drop_connect(self.se_module(self.bn2(out))) + residual
else:
out = self.bn2(out) + residual
out = self.drop_connect(self.bn2(out)) + residual
return self.final_act(out)


Expand All @@ -204,6 +206,7 @@ def __init__(
norm_layer=ABN,
norm_act="relu",
antialias=False,
keep_prob=1, # for drop connect
):
super(Bottleneck, self).__init__()
antialias = antialias and stride == 2
Expand All @@ -222,6 +225,7 @@ def __init__(
self.downsample = downsample
self.blurpool = BlurPool(channels=width) if antialias else nn.Identity()
self.antialias = antialias
self.drop_connect = DropConnect(keep_prob) if keep_prob < 1 else nn.Identity()

def forward(self, x):
residual = x
Expand All @@ -241,9 +245,9 @@ def forward(self, x):
out = self.conv3(out)
# avoid 2 inplace ops by chaining into one long op
if self.se_module is not None:
out = self.se_module(self.bn3(out)) + residual
out = self.drop_connect(self.se_module(self.bn3(out))) + residual
else:
out = self.bn3(out) + residual
out = self.drop_connect(self.bn3(out)) + residual
return self.final_act(out)

# TResnet models use slightly modified versions of BasicBlock and Bottleneck
Expand Down Expand Up @@ -292,5 +296,5 @@ def forward(self, x):

out = self.conv3(out)
# avoid 2 inplace ops by chaining into one long op
out = self.bn3(out) + residual
out = self.drop_connect(self.bn3(out)) + residual
return self.final_act(out)
2 changes: 1 addition & 1 deletion pytorch_tools/segmentation_models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.layer3 = UnetDecoderBlock(in_channels[2], out_channels[2], **bn_params)
self.layer4 = UnetDecoderBlock(in_channels[3], out_channels[3], **bn_params)
self.layer5 = UnetDecoderBlock(in_channels[4], out_channels[4], **bn_params)
self.dropout = nn.Dropout2d(drop_rate, inplace=True)
self.dropout = nn.Dropout2d(drop_rate, inplace=False) # inplace=True raises a backprop error
self.final_conv = conv1x1(out_channels[4], final_channels)

initialize(self)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def test_dilation(arch, output_stride):
W, H = INP.shape[-2:]
assert res.shape[-2:] == (W // output_stride, H // output_stride)

@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
def test_drop_connect(arch):
m = models.__dict__[arch](drop_connect_rate=0.2)
_test_forward(m)

NUM_PARAMS = {
"tresnetm": 31389032,
"tresnetl": 55989256,
Expand Down
25 changes: 24 additions & 1 deletion tests/models/test_weights.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## test that imagenet pretrained weights are valid and able to classify correctly the cat and dog

import torch
import pytest
import numpy as np
from PIL import Image
import pytest

from pytorch_tools.utils.preprocessing import get_preprocessing_fn
from pytorch_tools.utils.visualization import tensor_from_rgb_image
Expand Down Expand Up @@ -53,3 +54,25 @@ def test_imagenet_pretrain(arch):
im = im.view(1, *im.shape).float()
pred_cls = m(im).argmax()
assert pred_cls == im_cls

# test that output mean for fixed input is the same
MODEL_NAMES2 = [
"resnet34",
"se_resnet50",
"efficientnet_b0",
]

MODEL_MEAN = {
"resnet34": 7.6799e-06,
"se_resnet50": -2.6095e-06,
"efficientnet_b0": 0.0070,
}

@pytest.mark.parametrize("arch", MODEL_NAMES2)
def test_output_mean(arch):
m = models.__dict__[arch](pretrained="imagenet")
m.eval()
inp = torch.ones(1, 3, 256, 256)
with torch.no_grad():
out = m(inp).mean().numpy()
assert np.allclose(out, MODEL_MEAN[arch], rtol=1e-4, atol=1e-4)

0 comments on commit d8d7b4b

Please sign in to comment.