Skip to content

Commit

Permalink
replace Upsample layer with interpolate function (pytorch#424)
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored and soumith committed Oct 17, 2018
1 parent 323079f commit 502e45d
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions fast_neural_style/neural_style/transformer_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,14 @@ class UpsampleConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

def forward(self, x):
x_in = x
if self.upsample:
x_in = self.upsample_layer(x_in)
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out

0 comments on commit 502e45d

Please sign in to comment.