forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
UpSampleLinear1d.cu
278 lines (238 loc) · 8.58 KB
/
UpSampleLinear1d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// Adapted from interp.cpp from Caffe util by Pauline Luc
// Originally developed by George Papandreou
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/cuda/UpSample.cuh>
#include <THC/THCAtomics.cuh>
namespace at {
namespace native {
namespace {
template <typename scalar_t, typename accscalar_t>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
#endif
__global__ void upsample_linear1d_out_frame(
const int n,
const accscalar_t rwidth,
const bool align_corners,
const PackedTensorAccessor64<scalar_t, 3> idata,
PackedTensorAccessor64<scalar_t, 3> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int width1 = idata.size(2);
const int width2 = odata.size(2);
if (index < n) {
const int w2 = index % width2;
// special case: just copy
if (width1 == width2) {
const int w1 = w2;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = idata[n][c][w1];
odata[n][c][w2] = val;
}
}
return;
}
//
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
//
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const accscalar_t val =
w0lambda * idata[n][c][w1] + w1lambda * idata[n][c][w1 + w1p];
odata[n][c][w2] = static_cast<scalar_t>(val);
}
}
}
}
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
#endif
__global__ void upsample_linear1d_out_frame_backward(
const int n,
const accscalar_t rwidth,
const bool align_corners,
PackedTensorAccessor64<scalar_t, 3> idata,
const PackedTensorAccessor64<scalar_t, 3> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int width1 = idata.size(2);
const int width2 = odata.size(2);
if (index < n) {
const int w2 = index % width2;
// special case: just copy
if (width1 == width2) {
const int w1 = w2;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][w1];
idata[n][c][w2] = val;
}
}
return;
}
//
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
//
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t d2val = odata[n][c][w2];
gpuAtomicAdd(&idata[n][c][w1], static_cast<scalar_t>(w0lambda * d2val));
gpuAtomicAdd(
&idata[n][c][w1 + w1p], static_cast<scalar_t>(w1lambda * d2val));
}
}
}
}
static void upsample_linear1d_out_cuda_template(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales) {
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_linear1d_out_cuda", {input_arg, output_arg});
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
int output_width = output_size[0];
int nbatch = input.size(0);
int channels = input.size(1);
int input_width = input.size(2);
upsample_1d_shape_check(
input, Tensor(), nbatch, channels, input_width, output_width);
output.resize_({input.size(0), input.size(1), output_width});
output.zero_();
AT_ASSERT(input_width > 0 && output_width > 0);
const int num_kernels = output_width;
const int num_threads =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_linear1d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = input.packed_accessor64<scalar_t, 3>();
auto odata = output.packed_accessor64<scalar_t, 3>();
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales);
upsample_linear1d_out_frame<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(num_kernels, rwidth, align_corners, idata, odata);
});
AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_linear1d_backward_out_cuda_template(
Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales) {
TensorArg grad_output_arg{grad_output_, "grad_output_", 1},
grad_input_arg{grad_input, "grad_input", 2};
checkAllSameGPU(
"upsample_linear1d_backward_out_cuda", {grad_output_arg, grad_input_arg});
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 3,
"It is expected input_size equals to 3, but got size ",
input_size.size());
int output_width = output_size[0];
int nbatch = input_size[0];
int channels = input_size[1];
int input_width = input_size[2];
upsample_1d_shape_check(
Tensor(), grad_output_, nbatch, channels, input_width, output_width);
Tensor grad_output = grad_output_.contiguous();
grad_input.resize_({nbatch, channels, input_width});
grad_input.zero_();
const int num_kernels = output_width;
const int num_threads =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_linear1d_out_frame_backward", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = grad_input.packed_accessor64<scalar_t, 3>();
auto odata = grad_output.packed_accessor64<scalar_t, 3>();
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales);
upsample_linear1d_out_frame_backward<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(num_kernels, rwidth, align_corners, idata, odata);
});
AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
Tensor& upsample_linear1d_out_cuda(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales) {
upsample_linear1d_out_cuda_template(
output, input, output_size, align_corners, scales);
return output;
}
Tensor upsample_linear1d_cuda(
const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales) {
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_linear1d_out_cuda_template(
output, input, output_size, align_corners, scales);
return output;
}
Tensor& upsample_linear1d_backward_out_cuda(
Tensor& grad_input,
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales) {
upsample_linear1d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, align_corners, scales);
return grad_input;
}
Tensor upsample_linear1d_backward_cuda(
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales) {
Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_linear1d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, align_corners, scales);
return grad_input;
}
} // namespace native
} // namespace at