diff --git a/convert.lua b/convert.lua index 638928b..9371e27 100644 --- a/convert.lua +++ b/convert.lua @@ -18,23 +18,10 @@ local layer_list = { 'VolumetricAveragePooling', } --- similar to nn.Module.apply --- goes over a net and recursively replaces modules --- using callback function -local function replace(self, callback) - local out = callback(self) - if self.modules then - for i, module in ipairs(self.modules) do - self.modules[i] = replace(module, callback) - end - end - return out -end - -- goes over a given net and converts all layers to dst backend -- for example: net = cudnn.convert(net, cudnn) function cudnn.convert(net, dst) - return replace(net, function(x) + return net:replace(function(x) local y = 0 local src = dst == nn and cudnn or nn local src_prefix = src == nn and 'nn.' or 'cudnn.'