forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Activation.cu
468 lines (412 loc) · 17.6 KB
/
Activation.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#define _USE_MATH_DEFINES
#include <ATen/native/Activation.h>
#include <math.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAMathCompat.h>
namespace at { namespace native {
// -----------------------------------
// prelu forward
// -----------------------------------
template <typename scalar_t>
void prelu_cuda_kernel_share_weights(
const Tensor& input,
Tensor& result,
const scalar_t* weight_data) {
at::TensorIterator iter;
iter.add_output(result);
iter.add_input(input);
iter.build();
at::native::gpu_kernel(iter,
[weight_data] GPU_LAMBDA (scalar_t input_val) {
return (input_val > 0) ? input_val : *weight_data * input_val;
});
}
template <typename scalar_t>
__global__ void prelu_cuda_kernel_multi_weights(
scalar_t* result_data,
const scalar_t* input_data,
const scalar_t* weight_data,
int64_t input_stride0,
int64_t input_stride1,
int64_t input_numel) {
int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x;
if (linearId >= input_numel) return;
// multiply values at each channel with weight[channel_index]
int64_t channel = (linearId % input_stride0) / input_stride1;
scalar_t input_data_val = input_data[linearId];
result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val;
}
Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) {
TORCH_CHECK(self.is_cuda());
TORCH_CHECK(weight_.is_cuda());
auto input = self.contiguous();
auto weight = weight_.contiguous();
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int64_t weight_num = weight.numel();
Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto strides = input.strides();
// case1: shared weight for all channels
if (weight_num == 1) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] {
prelu_cuda_kernel_share_weights<scalar_t>(
input,
result,
weight.data_ptr<scalar_t>());
});
}
else { // case2: multiple weights, one for each channel
int64_t input_ndim = input.dim();
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
int64_t channel_size = 1; // channel_size default to 1
int64_t input_stride0 = 1, input_stride1 = 1;
if (input_ndim > 1) {
channel_size = input.size(1); // channel is the 2nd dim of input
input_stride0 = strides[0];
input_stride1 = strides[1];
}
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");
// config to run cuda kernel
int64_t input_numel = input.numel();
const dim3 block = dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), input_numel));
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions");
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] {
prelu_cuda_kernel_multi_weights<scalar_t>
<<<grid, block, 0, stream>>>(
result.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
input_stride0,
input_stride1,
input_numel);
});
}
return result;
}
// -----------------------------------
// prelu backward
// -----------------------------------
template <typename scalar_t>
void prelu_cuda_backward_kernel_share_weights(
const Tensor& input,
const Tensor& grad_out,
Tensor& input_grad,
Tensor& weight_grad_collector,
const scalar_t* weight_data) {
at::cuda::CUDA_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
input,
grad_out,
input_grad,
weight_grad_collector,
[=] __device__ (
const scalar_t& input_val,
const scalar_t& grad_out_val,
scalar_t& input_grad_val,
scalar_t& weight_grad_collector_val) {
input_grad_val = (input_val > 0) ? grad_out_val : *weight_data * grad_out_val;
weight_grad_collector_val = (input_val > 0) ? scalar_t(0) : input_val * grad_out_val;
});
}
template <typename scalar_t>
__global__ void prelu_cuda_backward_kernel_multi_weights(
const scalar_t* input_data,
const scalar_t* weight_data,
const scalar_t* grad_out_data,
scalar_t* input_grad_data,
scalar_t* weight_grad_collector,
int64_t input_stride0,
int64_t input_stride1,
int64_t input_numel) {
int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x;
if (linearId >= input_numel) return;
int64_t channel = (linearId % input_stride0) / input_stride1;
scalar_t input_data_val = input_data[linearId];
scalar_t grad_out_data_val = grad_out_data[linearId];
input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val;
weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val;
}
std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) {
TORCH_CHECK(grad_out_.is_cuda());
TORCH_CHECK(self.is_cuda());
TORCH_CHECK(weight_.is_cuda());
auto input = self.contiguous();
auto grad_out = grad_out_.contiguous();
auto weight = weight_.contiguous();
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(grad_out.is_contiguous());
int64_t weight_num = weight.numel();
auto strides = input.strides();
auto dims = input.dim();
Tensor input_grad = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor weight_grad = at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor weight_grad_collector = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// case1: shared parameter for all channels
if (weight_num == 1) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] {
prelu_cuda_backward_kernel_share_weights<scalar_t>(
input,
grad_out,
input_grad,
weight_grad_collector,
weight.data_ptr<scalar_t>());
});
weight_grad.fill_(weight_grad_collector.sum());
}
else { // case2: multiple parameters, one for each channel
int64_t input_ndim = input.dim();
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
int64_t channel_size = 1; // channel_size default to 1
int64_t input_stride0 = 1, input_stride1 = 1;
if (input_ndim > 1) {
channel_size = input.size(1); // channel is the 2nd dim of input
input_stride0 = strides[0];
input_stride1 = strides[1];
}
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");
// config to run cuda kernel
int64_t input_numel = input.numel();
const dim3 block = dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), input_numel));
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions");
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] {
prelu_cuda_backward_kernel_multi_weights<scalar_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(),
input_grad.data_ptr<scalar_t>(),
weight_grad_collector.data_ptr<scalar_t>(),
input_stride0,
input_stride1,
input_numel);
});
// update weight_grad
std::vector<int64_t> reduce_dims;
reduce_dims.push_back(0);
if (dims > 2) {
for(int64_t i = 2; i < dims; i++) reduce_dims.push_back(i);
}
weight_grad = weight_grad_collector.sum(reduce_dims);
}
return std::tuple<Tensor, Tensor>{input_grad, weight_grad};
}
// -----------------------------------
// hardshrink
// -----------------------------------
void hardshrink_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardshrink_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardshrink_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a >= -lambd && a <= lambd) ? scalar_t(0) : a;
});
});
});
}
void softshrink_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softshrink_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
});
});
});
}
void shrink_backward_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "shrink_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "shrink_backward_cuda", [&] {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t {
return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val;
});
});
});
}
void hardtanh_backward_kernel(TensorIterator& iter, Scalar min, Scalar max) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.dtype(), "hardtanh_backward_cuda", [&]() {
auto min_val = min.to<scalar_t>();
auto max_val = max.to<scalar_t>();
gpu_kernel(iter, [min_val, max_val]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return (b <= min_val) || (b >= max_val) ? scalar_t(0) : a;
});
});
}
void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
});
});
});
}
void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_backward_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
scalar_t z = std::exp(b * beta);
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
});
});
});
}
template <typename scalar_t>
void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t value) {
gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
});
}
static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "threshold_cuda", [&] {
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
});
}
void elu_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_cuda", [&] {
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a * poscoef : (static_cast<scalar_t>(std::exp(a * negiptcoef)) - scalar_t(1.)) * negcoef;
});
});
});
}
void elu_backward_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_backward_cuda", [&] {
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
auto poscoef = scale.to<scalar_t>();
auto negiptcoef = input_scale.to<scalar_t>();
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
});
});
});
}
namespace {
void GeluCUDAKernelImpl(TensorIterator& it) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluCUDAKernelImpl", [&] {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
return static_cast<T_ACC>(x) *
c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
});
});
});
}
void GeluBackwardCUDAKernelImpl(TensorIterator& it) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluBackwardCUDAKernelImpl", [&] {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5);
const T_ACC cdf = c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
const T_ACC pdf =
c10::cuda::compat::exp(
T_ACC(-0.5) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x)) *
kBeta;
return static_cast<T_ACC>(dy) * (cdf + static_cast<T_ACC>(x) * pdf);
});
});
});
}
void leaky_relu_kernel(TensorIterator& iter, Scalar negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_cuda", [&] {
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a : a * negval;
});
});
});
}
void leaky_relu_backward_kernel(TensorIterator& iter, Scalar negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_backward_cuda", [&] {
auto negval = negval_.to<scalar_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a > scalar_t(0) ? b : b * negval;
});
});
});
}
} // namespace
Tensor gelu_cuda(const Tensor& self) {
Tensor Y = at::native::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto it = TensorIterator::unary_op(Y, self);
GeluCUDAKernelImpl(it);
return Y;
}
Tensor gelu_backward_cuda(const Tensor& grad, const Tensor& self) {
Tensor dX = at::native::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto it = TensorIterator::binary_op(dX, grad, self);
GeluBackwardCUDAKernelImpl(it);
return dX;
}
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
static Tensor threshold_out_cuda(
optional<Tensor> opt_result,
const Tensor& self,
Scalar threshold,
Scalar value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
auto iter = TensorIterator::binary_op(result, self, other);
threshold_kernel(iter, threshold, value);
return iter.output();
}
Tensor threshold_cuda(const Tensor& self, Scalar threshold, Scalar value) {
return threshold_out_cuda(nullopt, self, threshold, value, self);
}
Tensor& threshold__cuda(Tensor& self, Scalar threshold, Scalar value) {
threshold_out_cuda(make_optional(self), self, threshold, value, self);
return self;
}
Tensor& threshold_out_cuda(Tensor& result, const Tensor& self, Scalar threshold, Scalar value) {
threshold_out_cuda(make_optional(result), self, threshold, value, self);
return result;
}
Tensor threshold_backward_cuda(const Tensor& grad, const Tensor& self, Scalar threshold) {
return threshold_out_cuda(nullopt, self, threshold, 0, grad);
}
REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
REGISTER_DISPATCH(elu_stub, &elu_kernel);
REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
}} // namespace at::native