forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CPUApplyUtils.h
343 lines (309 loc) · 10 KB
/
CPUApplyUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
#pragma once
#include <ATen/CollapseDims.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
#include <c10/util/irange.h>
#include <cstring>
#include <limits>
#include <utility>
namespace at {
/*
* The basic strategy for apply is as follows:
*
* 1. Starting with the outermost index, loop until we reach a dimension where
* the data is no longer contiguous, i.e. the stride at that dimension is not
* equal to the size of the tensor defined by the outer dimensions. Let's call
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
* A is equal to the entire Tensor. Let's call the inner tensor B.
*
* 2. We loop through the indices in B, starting at its outermost dimension. For
* example, if B is a 2x2 matrix, then we do:
*
* B[0][0]
* B[0][1]
* B[1][0]
* B[1][1]
*
* We set the offset into the underlying storage as (storageOffset + stride_B *
* index_B), i.e. basically we compute the offset into the storage as we would
* normally for a Tensor. But because we are guaranteed the subsequent data is
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
* the operation, without having to follow the order described by the strides of
* A.
*
* 3. As an optimization, we merge dimensions of A that are contiguous in
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
* then the first two dimensions can be merged for the purposes of APPLY,
* reducing the number of nested loops.
*/
inline Tensor sort_strides(Tensor& tensor_) {
IntArrayRef strides = tensor_.strides();
std::vector<int64_t> indices;
indices.reserve(tensor_.ndimension());
for (const auto i : c10::irange(tensor_.ndimension())) {
indices.push_back(i);
}
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
return strides[i1] > strides[i2];
});
Tensor tensor = tensor_.permute(indices);
return tensor;
}
template <typename T, int N>
struct strided_tensor_iter_fixed {
public:
T* data_ = NULL;
int64_t dim_ = 0;
int64_t counter_[N] = {0};
int64_t sizes_[N] = {0};
int64_t strides_[N] = {0};
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
void operator=(strided_tensor_iter_fixed const& x) = delete;
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
: data_(tensor.data_ptr<T>()) {
(void)sort_strides; // Suppress unused variable warning
std::memset(counter_, 0, sizeof(int64_t) * N);
if (tensor.dim() > 0) {
std::memcpy(
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
std::memcpy(
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
}
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
}
};
template <typename T>
struct strided_tensor_iter {
private:
public:
T* data_ = NULL;
int64_t dim_;
std::vector<int64_t> counter_;
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
strided_tensor_iter(strided_tensor_iter const&) = delete;
void operator=(strided_tensor_iter const& x) = delete;
strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iter(Tensor& tensor)
: data_(tensor.data_ptr<T>()),
dim_(tensor.ndimension()),
counter_(dim_, 0),
sizes_(tensor.sizes().vec()),
strides_(tensor.strides().vec()) {
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
}
};
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
if (tensors.size() == 0)
return true;
int64_t all_numel = tensors[0].numel();
for (const auto i : c10::irange(1, tensors.size())) {
if (tensors[i].numel() != all_numel)
return false;
}
return true;
}
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].sizes() << ", ";
}
oss << "and " << tensors[tensors.size() - 1].sizes()
<< " to have the same number of elements, but got ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].numel() << ", ";
}
oss << "and " << tensors[tensors.size() - 1].numel()
<< " elements respectively";
return oss.str();
}
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
checkLayout("CPU_tensor_apply", tensors, kStrided);
if (!_all_equal_numel(tensors))
AT_ERROR(_all_equal_numel_error(tensors));
// An empty tensor has no elements
for (auto& t : tensors)
if (t.numel() == 0)
return false;
return true;
}
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
int64_t dim = 0;
for (auto& t : tensors)
dim = std::max(dim, t.ndimension());
return dim;
}
inline void iterate(int64_t /*size*/){};
template <typename Arg, typename... Args>
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
iter.counter_[iter.dim_ - 1] += size;
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
iterate(size, iter_tail...);
}
inline bool iterate_continue() {
return true;
};
template <typename Arg, typename... Args>
inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
iterate_continue(iter_tail...);
}
inline int64_t max_iterate_size() {
return std::numeric_limits<int64_t>::max();
};
template <typename Arg, typename... Args>
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
return std::min(
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
max_iterate_size(iter_tail...));
}
inline void iterate_overflow(){};
template <typename Arg, typename... Args>
inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
if (iter.counter_[i] == iter.sizes_[i]) {
iter.counter_[i] = 0;
iter.counter_[i - 1]++;
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
iter.strides_[i - 1];
}
}
}
iterate_overflow(iter_tail...);
}
inline void forward(int64_t /*offset*/){};
template <typename Arg, typename... Args>
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
int64_t multi = offset;
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
int64_t inc = multi % iter.sizes_[i];
multi = multi / iter.sizes_[i];
iter.data_ = iter.data_ + inc * iter.strides_[i];
iter.counter_[i] += inc;
}
forward(offset, iter_tail...);
}
inline int64_t max_dim() {
return 0;
}
template <typename Arg, typename... Args>
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
return std::max(iter.dim_, max_dim(iter_tail...));
}
inline void apply_op(){};
template <typename Op, typename... Args>
inline void apply_op(
int64_t numel,
int64_t offset,
const Op& op,
Args... iters) {
// For 0-dim tensors
if (numel == 1 && max_dim(iters...) == 0) {
op(*iters.data_...);
return;
}
if (offset > 0)
forward(offset, iters...);
// Splitting this into chunks helps the compiler create faster assembly
for (int64_t i = 0; i < numel;) {
for (; iterate_continue(iters...) && i < numel;) {
op(*iters.data_...);
iterate(1, iters...);
i++;
}
iterate_overflow(iters...);
}
}
/*
Apply a pointwise operator to sequence of tensors
The calling convention for op is a function/functor that takes the same
number of pointers of type scalar as the number of given tensors. For example,
to compute a = b * c, op would be of the form:
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
b_val[0] * c_val[0]; };
*/
template <typename scalar1, typename scalar2, typename Op>
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2));
}
}
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
inline void CPU_tensor_apply3(
Tensor tensor1,
Tensor tensor2,
Tensor tensor3,
const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3));
}
}
template <
typename scalar1,
typename scalar2,
typename scalar3,
typename scalar4,
typename Op>
inline void CPU_tensor_apply4(
Tensor tensor1,
Tensor tensor2,
Tensor tensor3,
Tensor tensor4,
const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3),
strided_tensor_iter<scalar4>(tensor4));
}
}
} // namespace at