forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
ForeachUtils.h
217 lines (179 loc) · 7.15 KB
/
ForeachUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#pragma once
#include <ATen/ATen.h>
namespace at {
namespace native {
namespace {
void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_dtype = tensors[0].dtype();
for (const auto& t : tensors) {
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
}
}
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
check_foreach_api_restrictions(tensors);
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
auto expected_dtype = tensors1[0].dtype();
for (int i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors3.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size());
auto expected_dtype = tensors1[0].dtype();
for (int i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes());
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
}
// To go via 'fast' path, several conditions must be satisfied
// - All tensors must be on the same device
// - All tensors must have strided layout
// - All tensors must be non-overlapping and dense
// - Resulting tensor must have the same dtype as the input one
// Check if all tensors have the same device, layout, strides and are not overlapping and dense
bool has_same_attributes(Device expected_device, TensorList tensors) {
auto expected_strides = tensors[0].strides();
for (const auto& t : tensors) {
if (t.device() != expected_device) {
return false;
}
if (t.layout() != at::kStrided) {
return false;
}
if (!t.is_non_overlapping_and_dense()) {
return false;
}
if (t.strides() != expected_strides) {
return false;
}
}
return true;
}
bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
// complex scalar + integral or boolean tensor will result in complex tensor
if (scalar.isComplex() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
return false;
}
// float scalar + integral or boolean tensor will result in float tensor
if (scalar.isFloatingPoint() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
return false;
}
// integral scalar + boolean tensor will result in integral tensor
if (scalar.isIntegral(/*includeBool*/ false) && tensor.dtype() == at::kBool) {
return false;
}
return true;
}
bool can_use_fast_route(TensorList tensors) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors[0].device();
for (auto t : tensors) {
if (!has_same_attributes(expected_device, {t})) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors, Scalar scalar) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors[0].device();
for (auto t : tensors) {
if (!has_same_attributes(expected_device, {t})) {
return false;
}
if (!will_promote_tensor(t, scalar)) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
return can_use_fast_route(tensors);
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();
for (int64_t i = 0; i < tensors1.size(); i++) {
if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i]})) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();
for (int64_t i = 0; i < tensors1.size(); i++) {
if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i]})) {
return false;
}
if (!will_promote_tensor(tensors1[i], scalar)) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();
for (int64_t i = 0; i < tensors1.size(); i++) {
if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i], tensors3[i]})) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, Scalar scalar) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();
for (int64_t i = 0; i < tensors1.size(); i++) {
if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i], tensors3[i]})) {
return false;
}
if (!will_promote_tensor(tensors1[i], scalar)) {
return false;
}
}
return true;
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
return can_use_fast_route(tensors1, tensors2, tensors3);
}
}
}} // at::native