-
Notifications
You must be signed in to change notification settings - Fork 0
/
variable.cpp
688 lines (607 loc) · 27.5 KB
/
variable.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/InferenceMode.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/generated/Functions.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/FuncTorchTLS.h>
#include <c10/util/Exception.h>
#include <list>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <vector>
#include <typeinfo>
#include <iostream>
namespace torch {
namespace autograd {
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl,
c10::optional<ViewInfo> backward_info,
c10::optional<ViewInfo> forward_info,
bool shared_view_info,
CreationMeta creation_meta)
: AutogradMeta(self_impl),
backward_info_(std::move(backward_info)),
forward_info_(std::move(forward_info)),
shared_view_info_(shared_view_info),
creation_meta_(creation_meta) {
is_view_ = true;
if (backward_info_.has_value()) {
self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_));
attr_version_ = self_impl->version_counter().current_version();
}
if (shared_view_info_) {
TORCH_INTERNAL_ASSERT(backward_info_.has_value(), "Shared view info require a backward view info.");
TORCH_INTERNAL_ASSERT(!forward_info_.has_value(), "Shared view info require forward view info to be empty")
}
}
// Chain this view info with the new view op between base and tensor
ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor,
std::function<Variable(const Variable&)> view_func) const {
// Set `view_func` using the root base as input.
// `view_func` is used to recover views in backward when either as_strided is not supported
// or the view function changes the metadata which is not recorded by as_strided
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
// for more details how we use this function in backward.
if (view_func) {
// both current_view and it's parent have a view_func
if (view_fn_) {
// Copy parent view function to gain ownership
auto prev_fn = view_fn_;
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_fn(root_base);
return view_func(temp);
};
} else {
// current_view has a view_func and but it's parent doesn't have one
if (base.unsafeGetTensorImpl()->support_as_strided()) {
auto size = base.sizes().vec();
auto stride = base.strides().vec();
auto storage_offset = base.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = root_base.as_strided(size, stride, storage_offset);
return view_func(temp);
};
} else {
// When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
// a view that doesn't support inplace update, e.g. unbind.
// In this case we should throw an error when inplace update happens in **forward**.
// One would naturally think the following function will be first called in backward pass.
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
// triggered by inplace update.
// Search Note [View + Inplace update for view tensor] to for the call site.
view_func = [=](const at::Tensor& root_base) {
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
"Such functions do not allow the output views to be modified inplace."
"You should replace the inplace operation by an out-of-place one");
return root_base;
};
}
}
} else if(view_fn_) {
// if current_view doesn't have a view_func but it's parent has one
// Copy parent view function to gain ownership
auto prev_view_fn = view_fn_;
auto size = tensor.sizes().vec();
auto stride = tensor.strides().vec();
auto storage_offset = tensor.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_view_fn(root_base);
return temp.as_strided(size, stride, storage_offset);
};
}
return ViewInfo(base_, view_func);
}
namespace {
at::Tensor singleton_undefined_tensor;
struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
std::unique_ptr<c10::AutogradMetaInterface> make() const override {
return std::make_unique<AutogradMeta>();
}
const at::Tensor& undefined_tensor() const override {
return singleton_undefined_tensor;
}
};
ConcreteAutogradMetaFactory meta_factory;
static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_factory);
}
namespace impl {
AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor");
auto p = self.unsafeGetTensorImpl();
if (!p->autograd_meta()) {
p->set_autograd_meta(std::make_unique<AutogradMeta>());
}
return get_autograd_meta(self);
}
void rebase_history(const Variable& self, Edge gradient_edge) {
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
auto diff_view_meta = get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
auto creation_meta = diff_view_meta->get_creation_meta();
// Do not use handle_view_on_rebase here as check_inplace should have been called before this
// and either throw an error
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
TORCH_INTERNAL_ASSERT(gradient_edge.function);
TORCH_CHECK(
gradient_edge.function->num_inputs() == 1,
"Functions which modify views in-place must return a single Variable");
auto view_info = diff_view_meta->get_backward_view();
diff_view_meta->output_nr_ = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function));
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;
}
set_gradient_edge(self, std::move(gradient_edge));
}
void create_cpp_hook(const at::TensorBase& self) {
auto &list = materialize_autograd_meta(self)->cpp_hooks_list_;
// NOLINTNEXTLINE(modernize-make-shared)
list.reset(new hooks_list());
std::unique_ptr<FunctionPreHook> hook_ptr(new CppFunctionPreHook(list, self.output_nr()));
clear_hooks(self);
add_hook(self, std::make_shared<CppFunctionPreHook>(list, 0));
const auto& fn = self.grad_fn();
if (fn) {
fn->add_pre_hook(std::move(hook_ptr));
}
}
void set_grad_accumulator(const Variable& self,
std::weak_ptr<Node> grad_accumulator) {
materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator);
}
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->grad_accumulator_.lock();
} else {
return nullptr;
}
}
std::shared_ptr<Node> grad_accumulator(const Variable& self) {
auto autograd_meta = get_autograd_meta(self);
if (!autograd_meta) {
return nullptr;
}
if (autograd_meta->grad_fn_) {
throw std::logic_error(
"grad_accumulator() should be only called on leaf Variables");
}
if (!autograd_meta->requires_grad_) {
return nullptr;
}
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
auto result = autograd_meta->grad_accumulator_.lock();
if (result)
return result;
c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
auto intrusive_from_this = c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
autograd_meta->grad_accumulator_ = result;
return result;
}
Edge gradient_edge(const Variable& self) {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = self.grad_fn()) {
return Edge(gradient, self.output_nr());
} else {
return Edge(grad_accumulator(self), 0);
}
}
void set_gradient_edge(const Variable& self, Edge edge) {
auto* meta = materialize_autograd_meta(self);
meta->grad_fn_ = std::move(edge.function);
meta->output_nr_ = edge.input_nr;
// For views, make sure this new grad_fn_ is not overwritten unless it is necessary
// in the VariableHooks::grad_fn below.
// This logic is only relevant for custom autograd Functions for which multiple
// operations can happen on a given Tensor before its gradient edge is set when
// exiting the custom Function.
auto diff_view_meta = get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
diff_view_meta->set_attr_version(self._version());
}
}
Node* grad_fn_unsafe(const Variable& self) {
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->grad_fn_.get();
} else {
return nullptr;
}
}
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void set_version_counter(
const Variable& self,
const c10::VariableVersion& version_counter) {
TORCH_CHECK(self.defined(), "cannot call set_version_counter() on undefined tensor");
self.unsafeGetTensorImpl()->set_version_counter(version_counter);
}
void bump_version(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
self.unsafeGetTensorImpl()->bump_version();
}
const c10::VariableVersion& version_counter(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call version_counter() on undefined tensor");
return self.unsafeGetTensorImpl()->version_counter();
}
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void add_hook(const at::TensorBase& self, std::shared_ptr<FunctionPreHook> hook) {
materialize_autograd_meta(self)->hooks_.push_back(std::move(hook));
}
namespace {
std::vector<std::shared_ptr<FunctionPreHook>> empty_singleton;
}
// TODO: Return an ArrayRef instead (and delete the singleton while you're at
// it
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks(const Variable& self)
{
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->hooks_;
} else {
return empty_singleton;
}
}
void clear_hooks(const at::TensorBase& self) {
// This is a little goofy, but usually this should be a no oop
materialize_autograd_meta(self)->hooks_.clear();
}
void set_name(const Variable& self, const std::string& name) {
materialize_autograd_meta(self)->name_ = name;
}
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
AutogradMeta* get_autograd_meta(const at::TensorBase& self) {
// NB: could return nullptr
TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
}
DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
// NB: return nullptr if self is not a view
AutogradMeta* meta = get_autograd_meta(self);
if (meta && meta->is_view_) {
return static_cast<DifferentiableViewMeta*>(meta);
} else {
return nullptr;
}
}
} // namespace impl
using at::Tensor;
struct VariableHooks final : at::impl::VariableHooksInterface {
at::TensorBase tensor_data(const at::TensorBase&) const override;
at::TensorBase variable_data(const at::TensorBase&) const override;
const std::shared_ptr<torch::autograd::Node>& grad_fn(const at::TensorBase&) const override;
unsigned _register_hook(
const at::TensorBase&, std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
void remove_hook(const at::TensorBase&, unsigned pos) const override;
bool is_view(const at::TensorBase&) const override;
const at::TensorBase& base(const at::TensorBase&) const override;
const std::string& name(const at::TensorBase&) const override;
bool is_leaf(const at::TensorBase&) const override;
int64_t output_nr(const at::TensorBase&) const override;
void set_data(const at::TensorBase & self, const at::TensorBase & new_data) const override;
at::TensorBase data(const at::TensorBase & self) const override;
int64_t _version(const at::TensorBase & self) const override;
void retain_grad(const at::TensorBase& self) const override;
bool retains_grad(const at::TensorBase& self) const override;
void _backward(const Tensor& self, at::TensorList inputs,
const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph,
bool create_graph) const override;
void requires_grad_(const at::TensorBase& self, bool _requires_grad) const override;
};
VariableHooks variableHooks;
at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const {
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/false);
self_impl_copy->set_autograd_meta(nullptr);
return at::Tensor(self_impl_copy);
}
at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
return at::Tensor(self_impl_copy);
}
bool VariableHooks::is_leaf(const at::TensorBase & self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
} else {
return true;
}
}
int64_t VariableHooks::output_nr(const at::TensorBase & self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->output_nr_;
} else {
return 0;
}
}
void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorBase & new_data_base) const {
at::OptionalTensorRef self_ref(self_base);
const Tensor &self = *self_ref;
at::OptionalTensorRef new_data_ref(new_data_base);
const Tensor &new_data = *new_data_ref;
// `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
// from `new_data` to `var`. It requires that `new_data` and `var` have compatible
// tensor type.
TORCH_CHECK(
_has_compatible_shallow_copy_type(self, new_data),
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
// Resets gradient accumulator if metadata is out of date
AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
if (autograd_meta) {
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
if (prior_accumulator) {
const auto prior_device = prior_accumulator->input_metadata(0).device();
const auto new_device = new_data.device();
if (!new_data.options().type_equal(self.options()) || prior_device != new_device) {
autograd_meta->grad_accumulator_.reset();
}
}
}
// Version counter is not shared when we replace a `Variable`'s tensor data
// by calling `set_data(...)`. The original version of the `Variable` is always preserved.
// See NOTE [ Version Counter Sharing ] for details.
//
// `var.set_data(new_data)` always ignores `var`'s `allow_tensor_metadata_change_`, because
// users need this API as an escape hatch for changing a tensor's metadata regardless of its
// `allow_tensor_metadata_change_` value, and the users are responsible for ensuring this is
// the behavior they want.
self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
}
at::TensorBase VariableHooks::data(const at::TensorBase & self) const {
return self.variable_data();
}
int64_t VariableHooks::_version(const at::TensorBase & self) const {
return self.unsafeGetTensorImpl()->version_counter().current_version();
}
void VariableHooks::retain_grad(const at::TensorBase& self) const {
TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False");
// temporary hack to improve functorch UX.
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) {
functorch_tls->checkSupportsRetainGrad();
}
if (self.is_leaf()) { // no-op for leaves
return;
}
if (impl::get_autograd_meta(self)->retains_grad_) {
return;
}
c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
auto retain_grad_hook = [weak_self](const at::Tensor& grad) {
if (weak_self.expired()) {
return;
} else {
auto var = weak_self.lock();
if (!var->grad().defined()) {
if (grad.is_sparse()) {
var->mutable_grad() = grad.clone();
} else {
var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous);
}
} else {
var->mutable_grad() = var->grad() + grad;
}
}
};
at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
impl::get_autograd_meta(self)->retains_grad_ = true;
}
bool VariableHooks::retains_grad(const at::TensorBase& self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->retains_grad_;
} else {
return false;
}
}
void VariableHooks::_backward(
const Tensor& self,
at::TensorList inputs,
const c10::optional<Tensor>& gradient,
c10::optional<bool> keep_graph,
bool create_graph) const {
// TODO torch::autograd::backward should take the c10::optional<Tensor> gradient directly
// instead of us having to unwrap it to Tensor _gradient here.
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
}
void VariableHooks::requires_grad_(const at::TensorBase& self, bool _requires_grad) const {
if (!self.is_leaf() && !_requires_grad) {
throw std::runtime_error(
autograd::utils::requires_grad_leaf_error(_requires_grad)
);
}
self.set_requires_grad(_requires_grad);
}
// Backward View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bool VariableHooks::is_view(const at::TensorBase& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta) {
return diff_view_meta->has_bw_view();
} else {
return false;
}
}
const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta) {
TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor");
return diff_view_meta->get_backward_view().base_;
} else {
throw std::runtime_error("Can't get base of non-view Tensor");
}
}
namespace {
std::string singleton_string;
}
const std::string& VariableHooks::name(const at::TensorBase& self) const {
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->name_;
} else {
return singleton_string;
}
}
namespace {
std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
}
const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const at::TensorBase& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
auto view_info = diff_view_meta->get_backward_view();
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
return diff_view_meta->grad_fn_;
}
auto current_version = self._version();
if (diff_view_meta->get_attr_version() != current_version) {
// This is an indirect rebase_history due to another view or the base being modified inplace
handle_view_on_rebase(diff_view_meta, /* indirect */ true);
TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
// Note [View + Inplace update for view tensor]
// An inplace update happened on Tensor `self` (which is a view).
// For example:
// view_1 = view_op_1(diff_view_meta->base_)
// view_2 = view_op_2(view_1)
// ...
// self = view_op_n(view_n-1)
// self = inplace_op(self)
//
// For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of
// view backward ops for effienciency.
//
// However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full
// forward pass with a tensor that requires gradient to get proper grad_fn setup,
// then save it to DifferentiableViewMeta for future use.
// This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA).
// XLA Tensor only run thorugh VariableType dispatch and lower the forward pass to a XLA HLO graph,
// then we take grad_fn and never materialize the tensor content.
// So we only construct the graph but not execute it, which is a fairly cheap operation to do.
//
// See Note [View + Inplace update for base tensor] for what we do to base tensor when
// an in-place operation happens.
//
// TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp
// that would provide a way to recreate the grad_fn chain.
if (view_info.has_view_fn()) {
auto view_fn = view_info.view_fn();
auto diff_view = view_fn(view_info.base_);
diff_view_meta->grad_fn_ = diff_view.grad_fn();
} else {
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward0>();
fn->self_geometry = at::TensorGeometry(view_info.base_);
fn->size = self.sizes().vec();
fn->stride = self.strides().vec();
fn->storage_offset = self.storage_offset();
fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_));
fn->add_input_metadata(
view_info.base_.options(),
self.sizes(), // Note: sizes(), not base_.sizes(), is intentional
self.unsafeGetTensorImpl()->is_python_dispatch());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->set_attr_version(current_version);
}
return diff_view_meta->grad_fn_;
}
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
} else {
return singleton_shared_ptr;
}
}
void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) const {
auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos);
// Hook will be ignored
(*list)[pos] = nullptr;
}
unsigned VariableHooks::_register_hook(
const at::TensorBase& self, std::function<at::TensorBase(const at::TensorBase&)> hook) const {
TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
// NB: materialize_autograd_meta unnecessary due to requires grad check
auto &list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
if(!list) {
torch::autograd::impl::create_cpp_hook(self);
}
unsigned idx = list->size();
list->push_back(hook);
return idx;
}
void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) {
/// See NOTE [ View + Inplace detection ] for justification of the logic below
auto creation_meta = diff_view_meta->get_creation_meta();
if (creation_meta != CreationMeta::DEFAULT) {
auto grad_fn = diff_view_meta->grad_fn_.get();
std::string msg;
std::string modified_obj;
// Create the header for the error message.
if (indirect) {
modified_obj = "its base or another view of its base has been";
} else {
modified_obj = "is being";
}
if (grad_fn) {
msg = c10::str("Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view and ",
modified_obj, " modified inplace.");
} else if (creation_meta == CreationMeta::INFERENCE_MODE) {
msg = c10::str("A view was created in inference mode and ", modified_obj, " modified inplace in normal mode.");
} else {
msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled.");
}
if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
msg = c10::str(msg, " This view is the output of a function that returns multiple views. Such functions do not"
" allow the output views to be modified inplace. You should replace the inplace operation by an"
" out-of-place one.");
} else if (creation_meta == CreationMeta::NO_GRAD_MODE) {
TORCH_INTERNAL_ASSERT(!grad_fn);
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden."
" You can clarify your code by moving both the view and the inplace either both"
" inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
" the inplace to be tracked).");
} else if (creation_meta == CreationMeta::INFERENCE_MODE) {
TORCH_INTERNAL_ASSERT(!grad_fn);
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden."
" You can clarify your code by moving both the view and the inplace either both"
" inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want"
" the inplace to be tracked).");
TORCH_CHECK(false, msg);
} else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the"
" autograd logic to handle view+inplace would override the custom backward associated with the custom"
" Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by"
" cloning the output of the custom Function.");
} else {
TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
}
TORCH_CHECK(false, msg);
}
}
}} // namespace torch::autograd