forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialAveragePooling.cu
329 lines (273 loc) · 10.7 KB
/
SpatialAveragePooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#include "utils.h"
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
/*
* Description:
* this function avg-pools an input 3D tensor along dimensions 1 and 2
* 3D input, 3D output
*/
__global__ void subsample(float *input, float *output,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
output = output + o*output_w*output_h;
input = input + i*input_w*input_h;
// For all output pixels...
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
// Compute the mean of the input image...
float *ptr_input = input + yy*dH*input_w + xx*dW;
float *ptr_output = output + yy*output_w + xx;
float sum = 0;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++)
sum += ptr_input[kx];
ptr_input += input_w; // next input line
}
// Update output
*ptr_output = sum/float(kW*kH);
}
}
}
static int cunn_SpatialAveragePooling_updateOutput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
THCudaTensor *output = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 2, input, output));
float *output_data;
float *input_data;
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
if (input->nDimension == 3) {
long nInputCols = input->size[2];
long nInputRows = input->size[1];
long nOutputCols = (nInputCols - kW) / dW + 1;
long nOutputRows = (nInputRows - kH) / dH + 1;
long nInputPlane = input->size[0];
luaL_argcheck(L, nInputCols >= kW && nInputRows >= kH, 2, "input image smaller than kernel size");
input = THCudaTensor_newContiguous(state, input);
input_data = THCudaTensor_data(state, input);
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
output_data = THCudaTensor_data(state, output);
// cuda blocks & threads:
int yblocks = (int)(16L / nInputPlane);
yblocks = yblocks < 1 ? 1 : yblocks;
dim3 blocks(nInputPlane,yblocks);
dim3 threads(32,8);
// run subsample kernel
subsample <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
nInputPlane, nInputRows, nInputCols, kH, kW, dH, dW);
} else {
long nInputCols = input->size[3];
long nInputRows = input->size[2];
long nbatch = input->size[0];
long nOutputCols = (nInputCols - kW) / dW + 1;
long nOutputRows = (nInputRows - kH) / dH + 1;
long nInputPlane = input->size[1];
luaL_argcheck(L, nInputCols >= kW && nInputRows >= kH, 2, "input image smaller than kernel size");
input = THCudaTensor_newContiguous(state, input);
input_data = THCudaTensor_data(state, input);
THCudaTensor_resize4d(state, output, nbatch, nInputPlane, nOutputRows, nOutputCols);
output_data = THCudaTensor_data(state, output);
// cuda blocks & threads:
int yblocks = (int)(16L / nInputPlane);
yblocks = yblocks < 1 ? 1 : yblocks;
dim3 blocks(nInputPlane*nbatch,yblocks);
dim3 threads(32,8);
// run subsample kernel
subsample <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
nInputPlane, nInputRows, nInputCols, kH, kW, dH, dW);
}
// clean
THCudaTensor_free(state, input);
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in SpatialAveragePooling.updateOutput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
return 1;
}
/*
* Description:
* this function computes the gradInput from gradOutput
*/
__global__ void subgradinput(float *gradInput, float *gradOutput,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o*output_w*output_h;
gradInput = gradInput + i*input_w*input_h;
// compute gradInput
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
float *ptr_gradInput = gradInput + yy*dH*input_w + xx*dW;
float *ptr_gradOutput = gradOutput + yy*output_w + xx;
float z = *ptr_gradOutput;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++)
ptr_gradInput[kx] += z / float(kW*kH);
ptr_gradInput += input_w;
}
}
}
}
/*
* Description:
* this function computes the gradInput from gradOutput
* but with an atomic accumulation. It is needed to be done so
* for cases of kH != dH and kW != dW
*/
__global__ void subgradinputAtomic(float *gradInput, float *gradOutput,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o*output_w*output_h;
gradInput = gradInput + i*input_w*input_h;
// compute gradInput
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
float *ptr_gradInput = gradInput + yy*dH*input_w + xx*dW;
float *ptr_gradOutput = gradOutput + yy*output_w + xx;
float z = *ptr_gradOutput;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++) {
atomicAdd(&(ptr_gradInput[kx]), z / float(kW*kH));
}
ptr_gradInput += input_w;
}
}
}
}
static int cunn_SpatialAveragePooling_updateGradInput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 3, input, gradInput, gradOutput));
if (input->nDimension == 3) {
long nInputCols = input->size[2];
long nInputRows = input->size[1];
long nInputPlane = input->size[0];
float *gradOutput_data = THCudaTensor_data(state, gradOutput);
float *gradInput_data;
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
gradInput_data = THCudaTensor_data(state, gradInput);
// cuda blocks & threads:
int yblocks = (int)(16L / nInputPlane);
yblocks = yblocks < 1 ? 1 : yblocks;
dim3 blocks(nInputPlane,yblocks);
dim3 threads(32,8);
// run updateGradInput kernel
if (kH == dH && kW == dW) {
subgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
nInputPlane, nInputRows, nInputCols,
kH, kW, dH, dW);
} else {
subgradinputAtomic <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
nInputPlane, nInputRows, nInputCols,
kH, kW, dH, dW);
}
} else {
long nInputCols = input->size[3];
long nInputRows = input->size[2];
long nInputPlane = input->size[1];
long nbatch = input->size[0];
float *gradOutput_data = THCudaTensor_data(state, gradOutput);
float *gradInput_data;
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
gradInput_data = THCudaTensor_data(state, gradInput);
// cuda blocks & threads:
int yblocks = (int)(16L / nInputPlane);
yblocks = yblocks < 1 ? 1 : yblocks;
dim3 blocks(nInputPlane*nbatch,yblocks);
dim3 threads(32,8);
// run updateGradInput kernel
if (kH == dH && kW == dW) {
subgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
nInputPlane, nInputRows, nInputCols,
kH, kW, dH, dW);
} else {
subgradinputAtomic <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
nInputPlane, nInputRows, nInputCols,
kH, kW, dH, dW);
}
}
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in SpatialAveragePooling.updateGradInput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
return 1;
}
static const struct luaL_Reg cunn_SpatialAveragePooling__ [] = {
{"SpatialAveragePooling_updateOutput", cunn_SpatialAveragePooling_updateOutput},
{"SpatialAveragePooling_updateGradInput", cunn_SpatialAveragePooling_updateGradInput},
{NULL, NULL}
};
void cunn_SpatialAveragePooling_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_SpatialAveragePooling__, "nn");
lua_pop(L,1);
}
#undef CUDA_MAX_THREADS