forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Convert.lua
245 lines (224 loc) · 7.73 KB
/
Convert.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
------------------------------------------------------------------------
--[ nn.Convert ]--
-- Module to convert between different data formats
-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f')
-- Automatically converts input to same type as self.output
-- Simplest use is for automatic input type converions : nn.Convert()
------------------------------------------------------------------------
local _ = require 'moses'
local Convert, parent = torch.class("nn.Convert", "nn.Container")
function Convert:__init(inputShape, outputShape)
if outputShape and not inputShape then
error"Expecting non-nil arg 1 when arg 2 is provided"
end
inputShape = inputShape or 'b*'
outputShape = outputShape or inputShape
self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape)
self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape)
self.inputBatchDim = self.inputShape:find('b')
self.outputBatchDim = self.outputShape:find('b')
if self.inputShape == 'b*' or self.outputShape == 'b*' then
assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*')
self.nInputDim = -1
self.nOutputDim = -1
self.transposition = true
else
-- number of dims in batch mode
self.nInputDim = #self.inputShape
self.nOutputDim = #self.outputShape
-- is the outputShape just a transposition of the inputShape?
if self.nInputDim == self.nOutputDim then
self.transposition = true
for i=1,self.nInputDim do
if not self.outputShape:find(self.inputShape:sub(i,i)) then
self.transposition = false
break
end
end
end
end
parent.__init(self)
end
-- post-initialization
function Convert:buildConverter(input)
if self.transposition then
self.converter = self:transpose(self.outputShape)
else
if (torch.type(self[self.outputShape]) ~= 'function') then
error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape))
end
self.converter = self[self.outputShape](self, input)
end
assert(torch.isTensor(self.output), "Expecting Tensor output")
self.converter:type(torch.type(self.output))
self.modules[1] = self.converter
end
function Convert:updateOutput(input)
assert(torch.isTensor(input), "expecting Tensor")
if not torch.isTypeOf(input, torch.type(self.output)) then
-- handle different input type
self._input = self._input or self.output.new()
self._input:resize(input:size()):copy(input)
input = self._input
end
self.batchMode = true
if input:dim() < self.nInputDim then
-- handle non-batch mode
local inputSize = input:size():totable()
table.insert(inputSize, self.inputBatchDim, 1)
self.__input = self.__input or input.new()
self.__input:set(input):resize(table.unpack(inputSize))
input = self.__input
self.batchMode = false
end
if not self.converter then
self:buildConverter(input)
end
self.output = self.converter:updateOutput(input)
if not self.batchMode then
local outputSize = self.output:size():totable()
table.remove(outputSize, self.outputBatchDim)
self.__output = self.__output or self.output.new()
self.__output:set(self.output):resize(table.unpack(outputSize))
self.output = self.__output
end
return self.output
end
function Convert:updateGradInput(input, gradOutput)
local input_ = input
input = self._input or input
if not self.batchMode then
input = self.__input
self.__gradOutput = self.__gradOutput or gradOutput.new()
self.__gradOutput:set(gradOutput):resize(self.converter.output:size())
gradOutput = self.__gradOutput
end
local gradInput = self.converter:updateGradInput(input, gradOutput)
if not self.batchMode then
self.__gradInput = self.__gradInput or gradInput.new()
self.__gradInput:set(gradInput):resize(input_:size())
gradInput = self.__gradInput
end
if self._input then
self._gradInput = self._gradInput or input.new()
self._gradInput:resize(input:size()):copy(gradInput)
self.gradInput = self._gradInput
else
self.gradInput = gradInput
end
return self.gradInput
end
function Convert:accGradParameters(input, gradOutput, scale)
input = self.batchMode and self.__input or self._input or input
gradOutput = self.batchMode and self.__gradOutput or gradOutput
self.converter:accGradParameters(input, gradOutput, scale)
end
function Convert:accUpdateGradParameters(input, gradOutput, lr)
input = self.batchMode and self.__input or self._input or input
gradOutput = self.batchMode and self.__gradOutput or gradOutput
self.converter:accUpdateGradParameters(input, gradOutput, lr)
end
-- batch feature
function Convert:bf(input)
local b_pos = self:findAxis('b', self.inputShape)
local dim = #self.inputShape
if self.inputShape == 'bt' then
error"Conversion of shape bt to bf not supported: open an issue on github"
end
-- was b
if dim == 1 then
return nn.Reshape(1)
end
-- was b...
local modula
if b_pos ~= 1 then
modula = nn.Transpose({1, b_pos})
end
if dim > 2 then
local transpose = modula
local sampleSize = input:select(self:findAxis('b'),1):nElement()
local reshape = nn.Reshape(sampleSize)
if transpose then
modula = nn.Sequential()
modula:add(transpose)
modula:add(reshape)
else
modula = reshape
end
end
return modula or nn.Identity()
end
-- each example is a scalar; batch is a vector
function Convert:b(input)
local b_pos = self:findAxis('b')
if self.inputShape == 'bt' or self.inputShape == 'tb' then
local t_pos = self:findAxis('t')
-- select first set of classes
return nn.Select(t_pos, 1)
elseif self.inputShape == 'bf' or self.inputShape == 'fb' then
-- this wont work as expected with size(f) > 1
local f_pos = self:findAxis('f')
if input:size(f_pos) > 1 then
error("Cannot convert shape "..self.inputShape.." to b when feature > 1")
end
return nn.Select(f_pos, 1)
else
error("Cannot convert shape "..self.inputShape.." to shape b")
end
end
-- returns the current shape of the data
function Convert:default()
return nn.Identity()
end
-- multi-class (batch target)
function Convert:bt()
local b_pos = self:findAxis('b')
local modula
if self.inputShape == 'b' then
modula = nn.Reshape(1)
else
error("cannot convert shape '"..self.inputShape.."' to bt")
end
return modula
end
-- a generic function for transposing shape axes
function Convert:transpose(newShape)
if newShape == self.inputShape then
return nn.Identity()
end
local inputShape = {}
for i=1,#self.inputShape do
table.insert(inputShape, self.inputShape:sub(i,i))
end
local transpositions = {}
for i=1,#newShape do
local j = _.indexOf(inputShape, newShape:sub(i,i))
if i ~= j then
local char = inputShape[i]
inputShape[i] = inputShape[j]
inputShape[j] = char
table.insert(transpositions, {j, i})
end
end
return nn.Transpose(table.unpack(transpositions))
end
function Convert:findAxis(axis_char, shape, silent)
shape = shape or self.inputShape
local axis_pos = shape:find(axis_char)
if (not silent) and (not axis_pos) then
error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2)
end
return axis_pos
end
function Convert:clearState()
self._input = nil
self._gradInput = nil
self.__input = nil
self.__output = nil
self.__gradInput = nil
self.__gradOutput = nil
end
function Convert:type(type)
self:clearState()
return parent.type(self, type)
end