forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NamedTensorUtils.h
165 lines (131 loc) · 5.61 KB
/
NamedTensorUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#pragma once
#include <ATen/NamedTensor.h>
#include <ATen/TensorNames.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/DimVector.h>
#include <functional>
namespace at {
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
inline bool has_names(TensorList tensors) {
return std::any_of(
tensors.begin(), tensors.end(), [](const Tensor& t) { return t.has_names(); });
}
// Converts dim to an positional index. Errors if `dim` cannot be used to
// refer to any dimension of tensor.
TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
TORCH_API std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims);
// Unifies two DimnameList to produce a third. This is useful for implementing
// the named inference rule for binary broadcasting operations like add.
//
// There are three main constraints:
// 1) Check matching: Names must match positionally from the right.
// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
// the same index from the right in other.
// 3) The output names are obtained by unifying the names individually from the right.
TORCH_API std::vector<Dimname>
unify_from_right(DimnameList names, DimnameList other, const char* action = "broadcast");
[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
TORCH_CHECK(
false,
op_name, ": You passed a dimname (string) to this op in place of a dimension "
"index but it does not yet support this behavior. Please pass a dimension "
"index to work around this.");
}
// [NOTE] Writing name inference rules
//
// Operators that support named tensors are either composed of operations that
// support named tensors or implement some name inference rule. An op that
// implements its own name inference rule generally looks like the following:
//
// Tensor op(...) {
// perform_shape_checks(...);
// # (1)
// auto maybe_outnames = compute_outnames(...);
// auto result = [&]() {
// NoNamesGuard guard;
// return op_impl(...);
// }();
// # (2)
// propagate_names_if_nonempty(result, maybe_outnames);
//
// Each op has (1) a compute outnames step and (2) a propagate names step.
//
// compute_outnames is responsible for checking that input names match and
// determining what the output names should be. It returns either:
// - {} (if the inputs tensors are all unnamed)
// - non-empty outnames.
//
// propagate_names_if_nonempty propagates the outnames if they exist to the result
// tensors.
//
// The {} case is an optimization; if the user does not use named tensors they
// pay no perf cost for it.
namespace namedinference {
// Propagates `names` to `result` if `names` is not empty.
// `names` can be empty; see [NOTE] Writing name inference rules
// If `names` is not empty, `names.size()` should equal `result.dim()`.
// When in doubt, use this overload instead of the others.
TORCH_API const Tensor& propagate_names_if_nonempty(
const Tensor& result,
DimnameList maybe_names,
bool validate_names = false);
// Propagates `names` to `result`. Only use this if we are certain that there are
// names to propagate (that names is not empty).
TORCH_API const Tensor& propagate_names(
const Tensor& result,
DimnameList names,
bool validate_names = false);
// Propagates all names from src to result.
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
// Propagates all names except for those at the excluded_idxs.
TORCH_API void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
// Used for reduction ops that have a `keepdim` arg.
TORCH_API void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
TORCH_API void propagate_names_for_expand(const Tensor& result, const Tensor& self);
TORCH_API std::vector<Dimname> compute_cat_outnames(TensorList tensors);
TORCH_API std::vector<Dimname> compute_broadcast_outnames(
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> broadcast_to_outnames(
const Tensor& tensor,
const Tensor& reference_tensor,
const char* op_name);
TORCH_API std::vector<Dimname> compute_matmul_outnames(const Tensor& self, const Tensor& other);
TORCH_API std::vector<Dimname> compute_cdist_outnames(const Tensor& self, const Tensor& other);
TORCH_API std::vector<Dimname> compute_bmm_outnames(
Tensor& result,
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
std::vector<Dimname> compute_diagonal_outnames(
const Tensor& tensor,
int64_t dim1,
int64_t dim2);
// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
TORCH_API TensorImpl* propagate_names_if_nonempty(
TensorImpl* result,
DimnameList maybe_names,
bool validate_names = false);
TORCH_API TensorImpl* propagate_names(
TensorImpl* result,
DimnameList names,
bool validate_names = false);
TORCH_API void propagate_names(TensorImpl* result, /*const */TensorImpl* src);
// result = m1 @ m2 + bias
TORCH_API std::vector<Dimname> propagate_names_for_addmm(
const Tensor& m1,
const Tensor& m2,
const Tensor& bias);
TORCH_API std::vector<Dimname> propagate_names_for_addmv(
const Tensor& mat,
const Tensor& vec,
const Tensor& bias);
TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);
TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
} // namespace namedinference
} // namespace at