diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 71eaa69..e00b648 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -23,6 +23,7 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, 'nOutputPlane should be divisible by nGroups') self.weight = torch.Tensor(nOutputPlane, nInputPlane/self.groups, kH, kW) self.gradWeight = torch.Tensor(nOutputPlane, nInputPlane/self.groups, kH, kW) + self.iSize = torch.LongStorage(4):fill(0) self:reset() -- should nil for serialization, the reset will still work self.reset = nil @@ -99,11 +100,10 @@ function SpatialConvolution:createIODescriptors(input) batch = false end assert(input:dim() == 4 and input:isContiguous()); - self.iSize = self.iSize or torch.LongStorage(4):fill(0) if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then - self.iSize = input:size() + self.iSize:copy(input:size()) assert(self.nInputPlane == input:size(2), 'input has to contain: ' .. self.nInputPlane