forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BatchedTensorImpl.h
160 lines (133 loc) · 5.42 KB
/
BatchedTensorImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#pragma once
#include <bitset>
#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
#include <ATen/Tensor.h>
namespace at {
// We assume this in a few other places in the codebase,
// but there isn't a centralized definition.
constexpr int64_t kVmapMaxTensorDims = 64;
// The valid vmap levels range from [0, 64). This effectively means that we
// support a maximum of 64 nested vmaps.
constexpr int64_t kVmapNumLevels = 64;
// Store this number of elements of BatchDims on the stack. Most people will
// probably use <= 5 nested vmaps, but adjust this number as necessary.
constexpr int64_t kBatchDimsStackSize = 5;
// a BatchDim represents a "private" dimension on a Tensor created inside of
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
// is being vmap'ed over and the `level` being an identifier for which vmap
// said dimension was created inside. The `dim` corresponds to a "physical
// dim" - it is a dimension index on the underlying physical tensor that is
// being vmapped over.
struct BatchDim {
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
int64_t dim() const {
return dim_;
}
int64_t level() const {
return level_;
}
private:
int64_t dim_;
int64_t level_;
};
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
using BatchDimsRef = ArrayRef<BatchDim>;
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
//
// The batch dimensions are treated as being "private"; they are not
// user-visible. For example, in the following Tensor,
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
// dimensions 0 and 1 are batch dimensions.
//
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
// tensor.
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
// Returns a reference to BatchDims that represent which dimensions of this
// tensor are private.
BatchDimsRef bdims() const {
return bdims_;
}
// BatchedTensorImpl wraps a Tensor
const Tensor& value() const {
return value_;
};
// Given a public dimension index, return the dimension index in the
// underlying value() tensor. For example, if we have
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
// dim=2)])
// bt.actualDim(0) -> 1
// bt.actualDim(1) -> 3
// bt.actualDim(2) -> Error
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
// We have to override this because we opted into CustomStrides
IntArrayRef strides_custom() const override;
// Override a bunch of methods inherited from TensorImpl to return error
// messages.
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
#ifdef DEBUG
bool has_storage() const override;
#endif
private:
// see NOTE: [BatchedTensorImpl levels invariant]
void checkInvariants() const;
const char* tensorimpl_type_name() const override;
Tensor value_;
// Note: [BatchedTensorImpl levels invariant]
// There is an invariant that the BatchDims must be stored in increasing
// `level` order. That is, for i < j, bdims_[i].level must be less than
// bdims_[j].level.
BatchDims bdims_;
};
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
inline bool isBatchedTensor(const Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
}
// It is unsafe to call this on a Tensor that is not backed by a
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
if (!isBatchedTensor(tensor)) {
return nullptr;
}
return unsafeGetBatchedImpl(tensor);
}
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
BatchDimsRef bdims) {
std::bitset<kVmapMaxTensorDims> is_bdim;
for (const auto& bdim : bdims) {
is_bdim.set(bdim.dim());
}
return is_bdim;
}
// Creates a bitset for all of the levels present in `bdims`
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
std::bitset<kVmapNumLevels> result;
for (const auto& bdim : bdims) {
result.set(bdim.level());
}
return result;
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
return out;
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
// Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
} // namespace at