forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Container.lua
149 lines (133 loc) · 4.54 KB
/
Container.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
-- This is code common to container modules, which are collections of
-- smaller constituent modules like Parallel, Sequential, etc.
local Container, parent = torch.class('nn.Container', 'nn.Module')
function Container:__init(...)
parent.__init(self, ...)
self.modules = {}
end
function Container:add(module)
table.insert(self.modules, module)
return self
end
function Container:get(index)
return self.modules[index]
end
function Container:size()
return #self.modules
end
-- Check if passing arguments through xpcall is supported in this Lua interpreter.
local _, XPCALL_ARGS = xpcall(function(x) return x ~= nil end, function() end, 1)
local TRACEBACK_WARNING = "WARNING: If you see a stack trace below, it doesn't point to the place where this error occurred. Please use only the one above."
-- module argument can be retrieved with moduleIndex, but code is cleaner when
-- it has to be specified anyway.
function Container:rethrowErrors(module, moduleIndex, funcName, ...)
assert(module == self.modules[moduleIndex],
"mismatch between moduleIndex and self.modules in rethrowErrors")
local function handleError(err)
-- This will be executed only in the first container that handles the error.
if not err:find(TRACEBACK_WARNING) then
local traceback = debug.traceback()
-- Remove this handler from the stack
local _, first_line_end = traceback:find('^.-\n')
local _, second_line_end = traceback:find('^.-\n.-\n')
traceback = traceback:sub(1, first_line_end) .. traceback:sub(second_line_end+1)
err = err .. '\n' .. traceback .. '\n\n' .. TRACEBACK_WARNING
else
-- Remove file path
err = err:sub(err:find('\n')+1)
end
local msg = string.format('In %d module of %s:',
moduleIndex, torch.type(self))
-- Preceding newline has to be here, because Lua will prepend a file path.
err = '\n' .. msg .. '\n' .. err
return err
end
-- Lua 5.1 doesn't support passing arguments through xpcall, so they have to
-- be passed via a closure. This incurs some overhead, so it's better not to
-- make it the default.
local ok, ret, noret
if not XPCALL_ARGS then
local args = {...}
local unpack = unpack or table.unpack
ok, ret, noret = xpcall(function()
return module[funcName](module, unpack(args))
end,
handleError)
else
ok, ret, noret = xpcall(module[funcName], handleError, module, ...)
end
assert(noret == nil, "rethrowErrors supports only one return argument")
if not ok then error(ret) end
return ret
end
function Container:applyToModules(func)
for _, module in ipairs(self.modules) do
func(module)
end
end
function Container:zeroGradParameters()
self:applyToModules(function(module) module:zeroGradParameters() end)
end
function Container:updateParameters(learningRate)
self:applyToModules(function(module) module:updateParameters(learningRate) end)
end
function Container:training()
self:applyToModules(function(module) module:training() end)
parent.training(self)
end
function Container:evaluate()
self:applyToModules(function(module) module:evaluate() end)
parent.evaluate(self)
end
function Container:share(mlp, ...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i], ...);
end
return self
end
function Container:reset(stdv)
self:applyToModules(function(module) module:reset(stdv) end)
end
function Container:parameters()
local function tinsert(to, from)
if torch.type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
else
table.insert(to,from)
end
end
local w = {}
local gw = {}
for i=1,#self.modules do
local mw,mgw = self.modules[i]:parameters()
if mw then
tinsert(w,mw)
tinsert(gw,mgw)
end
end
return w,gw
end
function Container:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif torch.type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
if self.modules then
for i,module in pairs(self.modules) do
module:clearState()
end
end
return self
end