Skip to content

Commit

Permalink
Initialize iSize in convert.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanasdf committed Apr 8, 2016
1 parent 9803eb5 commit 6c51cae
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions convert.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ local layer_list = {
'VolumetricAveragePooling',
}

local layer_isize = {
{'SpatialConvolution', 4},
{'SpatialCrossMapLRN', 4},
{'SpatialFullConvolution', 4},
{'SpatialMaxPooling', 4},
{'SpatialAveragePooling', 4},
{'SpatialDivisiveNormalization', 4},
{'SoftMax', 4},
{'LogSoftMax', 4},
{'TemporalConvolution', 4},
{'VolumetricConvolution', 5},
{'VolumetricMaxPooling', 5},
{'VolumetricAveragePooling', 5},
}

-- similar to nn.Module.apply
-- goes over a net and recursively replaces modules
-- using callback function
Expand Down Expand Up @@ -56,12 +71,27 @@ function cudnn.convert(net, dst)
y = convert('SpatialCrossMapLRN')
else
for i,v in ipairs(layer_list) do
if torch.typename(x) == src_prefix..v then
if t == src_prefix..v then
y = convert(v)
end
end
end
return y == 0 and x or y
if y == 0 then y = x end

-- hacky code to initialize iSize
if src == nn then
t = torch.typename(y)
if t == 'cudnn.SpatialBatchNormalization' or t == 'cudnn.VolumetricBatchNormalization' then
y.iSize = torch.LongStorage(y.nDim):fill(0)
else
for i,v in ipairs(layer_isize) do
if t == dst_prefix..v[1] then
y.iSize = torch.LongStorage(v[2]):fill(0)
end
end
end
end
return y
end)
end

0 comments on commit 6c51cae

Please sign in to comment.