forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TemporalMaxPooling.cu
240 lines (195 loc) · 8.4 KB
/
TemporalMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#include "utils.h"
#define TEMPORAL_MAX_POOLING_THREADS 1024
__global__ void cunn_TemporalMaxPooling_updateOutputKernel(float *input, float *output, float *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
float *input_data = input + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
float *output_data = output + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
float *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
int time = 0;
int max_time = input_n * kW;
float max_value;
float max_index = 0.0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
max_value = -FLT_MAX;
// For all values in the kernel space
for (time = 0; time < max_time; time += input_n) {
if (max_value < input_data[time + feat]) {
max_value = input_data[time + feat];
max_index = time / input_n;
}
}
output_data[feat] = max_value;
indices_data[feat] = (float)max_index;
}
}
}
__global__ void cunn_TemporalMaxPooling_updateGradInputKernel(float *gradInput, float *gradOutput, float *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
float *gradInput_data = gradInput + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
float *gradOutput_data = gradOutput + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
float *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
gradInput_data[(int)indices_data[feat] * input_n + feat] += gradOutput_data[feat];
}
}
}
__global__ void cunn_TemporalMaxPooling_updateGradInputKernelAtomic(float *gradInput, float *gradOutput, float *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
float *gradInput_data = gradInput + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
float *gradOutput_data = gradOutput + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
float *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
atomicAdd(&gradInput_data[(int)indices_data[feat] * input_n + feat], gradOutput_data[feat]);
}
}
}
static int cunn_TemporalMaxPooling_updateOutput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int dW = luaT_getfieldcheckint(L, 1, "dW");
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THCudaTensor *indices = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
int dimT = 0; // Temporal dimension
int dimF = 1; // Feature dimension
int batch = 1;
int input_w;
int input_n;
int output_w;
int nthreads;
float *input_data;
float *output_data;
float *indices_data;
THAssert(THCudaTensor_checkGPU(state, 3, input, output, indices));
luaL_argcheck(L, input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected");
if (input->nDimension == 3)
{
dimT = 1;
dimF = 2;
batch = input->size[0];
}
luaL_argcheck(L, input->size[dimT] >= kW, 2, "input sequence smaller than kernel size");
input = THCudaTensor_newContiguous(state, input);
input_w = input->size[dimT];
input_n = input->size[dimF];
output_w = (input_w - kW) / dW + 1;
if (input->nDimension == 2)
{
THCudaTensor_resize2d(state, output, output_w, input->size[dimF]);
THCudaTensor_resize2d(state, indices, output_w, input->size[dimF]);
}
else
{
THCudaTensor_resize3d(state, output, batch, output_w, input->size[dimF]);
THCudaTensor_resize3d(state, indices, batch, output_w, input->size[dimF]);
}
input_data = THCudaTensor_data(state, input);
output_data = THCudaTensor_data(state, output);
indices_data = THCudaTensor_data(state, indices);
dim3 blocks(batch);
nthreads = (output_w / 32) * 32;
if (output_w % 32 > 0) {
nthreads += 32;
}
if (nthreads > TEMPORAL_MAX_POOLING_THREADS) {
blocks.y = nthreads / TEMPORAL_MAX_POOLING_THREADS;
if (nthreads % TEMPORAL_MAX_POOLING_THREADS > 0) {
blocks.y += 1;
}
nthreads = TEMPORAL_MAX_POOLING_THREADS;
}
dim3 threads(nthreads);
cunn_TemporalMaxPooling_updateOutputKernel <<< blocks, threads, 0, THCState_getCurrentStream(state) >>>(
input_data, output_data, indices_data, input_w, input_n, output_w, kW, dW);
THCudaTensor_free(state, input);
return 1;
}
static int cunn_TemporalMaxPooling_updateGradInput(lua_State *L) {
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int dW = luaT_getfieldcheckint(L, 1, "dW");
THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THCudaTensor *indices = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
int dimT = 0; // Temporal dimension
int dimF = 1; // Feature dimension
int batch = 1;
int input_w;
int input_n;
int output_w;
int nthreads;
float *gradInput_data;
float *gradOutput_data;
float *indices_data;
THAssert(THCudaTensor_checkGPU(state, 4, input, gradOutput, gradInput, indices));
luaL_argcheck(L, input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected");
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
if (input->nDimension == 3)
{
dimT = 1;
dimF = 2;
batch = input->size[0];
}
luaL_argcheck(L, input->size[dimT] >= kW, 2, "input sequence smaller than kernel size");
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
input_w = input->size[dimT];
input_n = input->size[dimF];
output_w = (input_w - kW) / dW + 1;
gradInput_data = THCudaTensor_data(state, gradInput);
gradOutput_data = THCudaTensor_data(state, gradOutput);
indices_data = THCudaTensor_data(state, indices);
dim3 blocks(batch);
nthreads = (output_w / 32) * 32;
if (output_w % 32 > 0) {
nthreads += 32;
}
if (nthreads > TEMPORAL_MAX_POOLING_THREADS) {
blocks.y = nthreads / TEMPORAL_MAX_POOLING_THREADS;
if (nthreads % TEMPORAL_MAX_POOLING_THREADS > 0) {
blocks.y += 1;
}
nthreads = TEMPORAL_MAX_POOLING_THREADS;
}
dim3 threads(nthreads);
if (kW <= dW) {
cunn_TemporalMaxPooling_updateGradInputKernel <<< blocks, threads, 0, THCState_getCurrentStream(state) >>>(
gradInput_data, gradOutput_data, indices_data, input_w, input_n, output_w, kW, dW);
} else {
cunn_TemporalMaxPooling_updateGradInputKernelAtomic <<< blocks, threads, 0, THCState_getCurrentStream(state) >>>(
gradInput_data, gradOutput_data, indices_data, input_w, input_n, output_w, kW, dW);
}
THCudaTensor_free(state, gradOutput);
return 1;
}
static const struct luaL_Reg cunn_TemporalMaxPooling__ [] = {
{"TemporalMaxPooling_updateOutput", cunn_TemporalMaxPooling_updateOutput},
{"TemporalMaxPooling_updateGradInput", cunn_TemporalMaxPooling_updateGradInput},
{NULL, NULL}
};
void cunn_TemporalMaxPooling_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_TemporalMaxPooling__, "nn");
lua_pop(L,1);
}