forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorFactories.cu
440 lines (386 loc) · 16 KB
/
TensorFactories.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/cuda/Resize.cuh>
#include <c10/util/Exception.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <algorithm>
#include <cstddef>
#include <cmath>
namespace at {
namespace native {
Tensor& eye_out_cuda(Tensor& result, int64_t n) {
return at::native::eye_out_cuda(result, n, /*m=*/-1);
}
Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
if(m < 0) {
m = n;
}
result.resize_({n, m});
result.zero_();
int64_t sz = std::min<int64_t>(n, m);
int64_t stride = result.stride(0) + result.stride(1);
Tensor diag = result.as_strided({sz}, {stride});
diag.fill_(1);
return result;
}
Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional<MemoryFormat> optional_memory_format) {
AT_ASSERT(options.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
check_size_nonnegative(size);
auto* allocator = at::cuda::getCUDADeviceAllocator();
int64_t nelements = prod_intlist(size);
auto dtype = options.dtype();
auto storage_impl = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizeable=*/true);
auto tensor = detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::CUDATensorId);
// Default TensorImpl has size [0]
if (size.size() != 1 || size[0] != 0) {
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
}
auto memory_format = optional_memory_format.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
return tensor;
}
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) {
auto t = at::native::empty_cuda({0}, options);
at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride);
return t;
}
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
check_supported_max_int_with_precision(n, result);
result.resize_({n});
if (n < 30000) { // For small inputs, we offload it to CPU instead.
auto result_cpu = at::empty({n}, result.options().device(kCPU));
randperm_out(result_cpu, n, generator);
return result.copy_(result_cpu);
}
#if 0
// This if condition should never be true because if n >= 30000 and the tensor has a Half type,
// check_supported_max_int_with_precision should have reported an error. This snippet is commented out but left here
// for the sake of clarity, because Half in thrust is spotty, and we do not want future change unaware of this.
if (result.scalar_type() == at::ScalarType::Half) { // Half in thrust is spotty. Avoid.
auto result_float = at::empty({n}, initialTensorOptions().device(Device(DeviceType::CUDA)));
return result.copy_(randperm_out_cuda(result_float, n, generator));
}
#endif
// Generate random values for the keys array
AT_DISPATCH_ALL_TYPES(
result.scalar_type(), "randperm_out_cuda", [&] {
auto keys = at::empty(result.sizes(), result.options()).random_(generator);
auto keys_data = thrust::device_ptr<scalar_t>(keys.data_ptr<scalar_t>());
// shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it
// points to a new tensor.
Tensor shuffled;
thrust::device_ptr<scalar_t> shuffled_data;
if (result.is_contiguous()) {
shuffled_data = thrust::device_ptr<scalar_t>(result.data_ptr<scalar_t>());
} else {
shuffled = at::empty(n, result.options());
shuffled_data = thrust::device_ptr<scalar_t>(shuffled.data_ptr<scalar_t>());
}
auto state = globalContext().getTHCState();
THCThrustAllocator thrustAlloc(state);
auto policy = thrust::cuda::par(thrustAlloc).on(at::cuda::getCurrentCUDAStream());
thrust::sequence(policy, shuffled_data, shuffled_data + n);
// Use the sorted order of keys to rearrange the result array
thrust::sort_by_key(policy, keys_data, keys_data + n, shuffled_data);
if (!result.is_contiguous()) {
result.copy_(shuffled);
}
}
);
return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace {
// To find the max integer that does not exceed the root of an int64_t variable,
// we could use a loop to test one bit at a time, which takes up to 31
// iterations. This would give the accurate result, but is relatively slow and
// is an overkill for most cases where double's precision suffice.
//
// If we directly use sqrt to calculate the root, the conversion from int64_t
// to double would lose 11 bits precision.
//
// The following solution uses sqrt directly for most cases, and would only
// special handle it if there is indeed precision loss.
__device__
inline int64_t resolve_root_int(
int64_t b, int64_t cX4, int64_t x, int32_t sign) {
int64_t bXb_cX4 = b*b - cX4;
// potential precision loss could occur here when casting int64_t (63 bits
// precision) to double (52 bits precision)
double sr = ::sqrt((double)bXb_cX4);
int64_t res = ::__double2ll_rd((-b + sign * sr)/2);
// have to cast double to int64_t, otherwise it would only compare up to the
// precision of a double variable, ignoring the precision loss
if (bXb_cX4 != (int64_t) (sr * sr)) {
// handle precision loss by using binary search
int64_t llsr = ::__double2ll_rd(sr);
// Use the following math to reduce search space.
// Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss
// let d = abs(bXb_cX4 - llsr * llsr), then we have:
// z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d)
// z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d)
// Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)).
// And the true value of row would also be with in range,
// [res - sqrt(d), res + sqrt(d) + 1)
// as the denominator would only reduce the precision penalty.
int64_t diff =
::__double2ll_ru(::sqrt(::fabs((double)(bXb_cX4 - llsr * llsr))));
// l never exceeds (could equal to) the target row index
auto l = res > diff ? res - diff : 0;
// r is always larger than the target row index
auto r = res + diff + 1;
// binary search for the correct answer
x <<= 1; // the loop always compares with 2x, so do it once here
while (l + 1 < r) {
auto m = (l + r) >> 1;
// for tril:
// b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2
// for triu:
// b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2
if (sign * (b + m) * m > x) {
r = m;
} else {
l = m;
}
}
res = l;
}
return res;
}
// f: the number of elements in the first row of the trapezoid.
// x: the index of the target coordinates ordered by row and then column.
//
// View the tril as a top trapezoid stacked on a bottom rectangle. Assume x
// corresponds to the coordinate (row, col) in the trapezoid, where the row and
// the col both start from 0, then we have:
//
// (f + f + row - 1) * row / 2 <= x [1]
// (f + f + row) * (row + 1) / 2 > x [2]
//
// Therefore, row is the maximum integer satisfying the following inequality:
//
// (row + 2f - 1)row <= 2x
// row^2 + (2f-1)row - 2x <= 0. [3]
//
// Based on inequality [3], we have the following coefficients for formula of
// root:
// a = 1
// b = 2f - 1
// c = -2x
// There are two roots, and we should use the largest integer that does not
// exceed the root on the right. Intuitively, it is because:
// i) the valid solution range of row is between two roots, as it is <= 0;
// ii) as we count in more rows, the total # of elements should always
// increase, hence so does the left-hand side row^2 + (2f-1)row - 2x.
// Therefore, the valid range of row lies in between the nadir point and
// the larger root on the right.
// Full proof can be derived from inequality [2]. So, we calculate the result
// coordinate as:
//
// row = floor((-b + sqrt(b^2 - 4c)) / 2)
// col = x - (f + f + row - 1) * row / 2
__device__
inline void get_coordinate_in_tril_trapezoid(
int64_t f, int64_t x, int64_t & row, int64_t & col) {
f <<= 1; // all statements use 2f, so only calculate it once here.
auto b = f - 1;
auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x;
row = resolve_root_int(b, cX4, x, 1);
col = x - ((f + row - 1) * row >> 1);
}
// f: the number of elements in the first row of the bottom trapezoid.
// x: the index of the target coordinates ordered by row and then column.
//
// View the triu as a top rectangle stacked on a bottom trapezoid, where the
// trapezoid is upside down. Assume x corresponds to the coordinate (row, col)
// in the bottom trapezoid, where the row and the col start from 0, then we
// have:
//
// (f + f - row + 1) * row / 2 <= x [1]
// (f + f - row) * (row + 1) / 2 > x [2]
//
// Therefore, row is the maximum integer satisfying the following inequality:
//
// (-row + 2f + 1)row <= 2x
// row^2 - (2f+1)row + 2x >= 0. [3]
//
// Based on inequality [3], we have the following coefficients for formula of
// root:
// a = 1
// b = -1 - 2f
// c = 2x
// There are two roots, and we should use the largest integer that does not
// exceed the root on the left. Intuitively, it is because:
// i) the valid solution range of row is outside of the two roots, as it is <
// > 0;
// ii) as we count in more rows, the total # of elements should always
// increase, hence so does the left-hand side row^2 - (2f+1)row + 2x.
// Therefore, the valid range of row lies to the left of the smaller root
// on the left.
// Full proof can be derived from inequality [2]. So, we calculate the result
// coordinate as:
//
// row = floor((-b - sqrt(b^2 - 4c)) / 2)
// col = x - (f + f - row + 1) * row / 2
__device__
inline void get_coordinate_in_triu_trapezoid(
int64_t f, int64_t x, int64_t & row, int64_t & col) {
f <<= 1; // all statements use 2f, so only calculate it once here.
auto b = -1 - f;
auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x;
row = resolve_root_int(b, cX4, x, -1);
col = x - ((f - row + 1) * row >> 1) + row;
}
} // namespace
template <typename scalar_t>
__global__
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
void tril_indices_kernel(scalar_t * tensor,
int64_t row_offset,
int64_t m_first_row,
int64_t col,
int64_t trapezoid_size,
int64_t tril_size) {
int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_index < tril_size) {
int64_t r, c;
if (linear_index < trapezoid_size) {
// the coordinate is within the top trapezoid
get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c);
} else {
// the coordinate falls in the bottom rectangle
auto surplus = linear_index - trapezoid_size;
// add the height of trapezoid: m_last_row (col) - m_first_row + 1
r = surplus / col + col - m_first_row + 1;
c = surplus % col;
}
r += row_offset;
tensor[linear_index] = r;
tensor[linear_index + tril_size] = c;
}
}
// Some Large test cases for the fallback binary search path is disabled by
// default to speed up CI tests and to avoid OOM error. When modifying the
// implementation, please enable them in test/test_cuda.py and make sure they
// pass on your local server.
Tensor tril_indices_cuda(
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
check_args(row, col, options);
auto tril_size = get_tril_size(row, col, offset);
auto tensor = empty_cuda({2, tril_size}, options);
if (tril_size > 0) {
auto m_first_row = offset > 0 ?
std::min<int64_t>(col, 1 + offset) : // upper bounded by col
row + offset > 0; // either 0 or 1
auto trapezoid_row_offset = std::max<int64_t>(0, -offset);
auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1;
int64_t rectangle_size = 0;
if (rectangle_row_offset < row) {
rectangle_size = (row - rectangle_row_offset) * col;
}
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid;
// using tril_size instead of tensor.numel(), as each thread takes care of
// two elements in the tensor.
TORCH_CHECK(
cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
"unable to get dim grid");
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "tril_indices_cuda", [&] {
tril_indices_kernel<<<
dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor.data_ptr<scalar_t>(),
trapezoid_row_offset,
m_first_row,
col,
tril_size - rectangle_size,
tril_size);
});
}
return tensor;
}
template <typename scalar_t>
__global__
void triu_indices_kernel(scalar_t * tensor,
int64_t col_offset,
int64_t m_first_row,
int64_t col,
int64_t rectangle_size,
int64_t triu_size) {
int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_index < triu_size) {
int64_t r, c;
if (linear_index < rectangle_size) {
// the coordinate is within the top rectangle
r = linear_index / col;
c = linear_index % col;
} else {
// the coordinate falls in the bottom trapezoid
get_coordinate_in_triu_trapezoid(
m_first_row, linear_index - rectangle_size, r, c);
r += rectangle_size / col;
}
c += col_offset;
tensor[linear_index] = r;
tensor[linear_index + triu_size] = c;
}
}
// Some Large test cases for the fallback binary search path is disabled by
// default to speed up CI tests and to avoid OOM error. When modifying the
// implementation, please enable them in test/test_cuda.py and make sure they
// pass on your local server.
Tensor triu_indices_cuda(
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
check_args(row, col, options);
auto triu_size = row * col - get_tril_size(row, col, offset - 1);
auto tensor = empty_cuda({2, triu_size}, options);
if (triu_size > 0) {
// # of triu elements in the first row
auto m_first_row = offset > 0 ?
std::max<int64_t>(col - offset, 0) : // upper bounded by col
col;
// size of the top rectangle
int64_t rectangle_size = 0;
if (offset < 0) {
rectangle_size = std::min<int64_t>(row, -offset) * col;
}
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid;
// using triu_size instead of tensor.numel(), as each thread takes care of
// two elements in the tensor.
TORCH_CHECK(
cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
"unable to get dim grid");
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "triu_indices_cuda", [&] {
triu_indices_kernel<<<
dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor.data_ptr<scalar_t>(),
std::max<int64_t>(0, offset),
m_first_row,
col,
rectangle_size,
triu_size);
});
}
return tensor;
}
}} // namespace at::native