diff --git a/convert.lua b/convert.lua index a64dfe2..2c5ef1a 100644 --- a/convert.lua +++ b/convert.lua @@ -15,6 +15,21 @@ local layer_list = { 'VolumetricAveragePooling', } +local layer_isize = { + {'SpatialConvolution', 4}, + {'SpatialCrossMapLRN', 4}, + {'SpatialFullConvolution', 4}, + {'SpatialMaxPooling', 4}, + {'SpatialAveragePooling', 4}, + {'SpatialDivisiveNormalization', 4}, + {'SoftMax', 4}, + {'LogSoftMax', 4}, + {'TemporalConvolution', 4}, + {'VolumetricConvolution', 5}, + {'VolumetricMaxPooling', 5}, + {'VolumetricAveragePooling', 5}, +} + -- similar to nn.Module.apply -- goes over a net and recursively replaces modules -- using callback function @@ -56,12 +71,27 @@ function cudnn.convert(net, dst) y = convert('SpatialCrossMapLRN') else for i,v in ipairs(layer_list) do - if torch.typename(x) == src_prefix..v then + if t == src_prefix..v then y = convert(v) end end end - return y == 0 and x or y + if y == 0 then y = x end + + -- hacky code to initialize iSize + if src == nn then + t = torch.typename(y) + if t == 'cudnn.SpatialBatchNormalization' or t == 'cudnn.VolumetricBatchNormalization' then + y.iSize = torch.LongStorage(y.nDim):fill(0) + else + for i,v in ipairs(layer_isize) do + if t == dst_prefix..v[1] then + y.iSize = torch.LongStorage(v[2]):fill(0) + end + end + end + end + return y end) end