diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index c5dc8904488..f9b79ed6d52 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -473,18 +473,16 @@ void CheckpointTensorCell::fill(const Tensor& t) { defined = true; is_undefined_tensor = !t.defined(); key_set_ = t.key_set(); - if (t.requires_grad()) { - key_set_ = key_set_.add(DispatchKey::Autograd); - } dtype_ = t.dtype(); optional_device_ = t.optional_device(); + original_requires_grad = t.requires_grad(); } } } intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const { - auto ret = intrusive_ptr::make(ref); + auto ret = intrusive_ptr::make(DETACH {}, *this); if (use_log_) { DTRLogCopy(ret->counter_name(), counter_name()); } @@ -493,7 +491,7 @@ intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const Va void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { STATS.track("CheckpointTensorCell::shallow_copy_from"); - TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId)); + TORCH_CHECK(impl->key_set().has(DispatchKey::Checkpoint)); auto* cpti = dynamic_cast(impl.get()); TORCH_CHECK(cpti != nullptr); ref->value = cpti->ref->value; @@ -645,7 +643,11 @@ Tensors CheckpointTensorImpl::make(const std::string& name, } } - return tensors; + Tensors return_tensors; + for (const auto& tensor : tensors) { + return_tensors.push_back(get_cpti(tensor)->ref->value->value->is_undefined_tensor ? Tensor() : tensor); + } + return return_tensors; } // TODO: check that mutated value does not have alias. diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 7eec97e8e75..3d2b9d31687 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -24,7 +24,7 @@ #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) -#define TORCH_CHECK(a, ...) // profile mode +//#define TORCH_CHECK(a, ...) // profile mode // System Description: // Every Tensor is managed by a CheckpointTensor, @@ -63,7 +63,9 @@ // and output all the new value of the aliases, then update the Ref. // Of course, the cleaner way is to not support this. // Shame on those who use this feature. - +// An operator might return undefined tensor. +// Since undefined tensor has special semantic, we dont want to wrap it. +// When that is the case we unwrap the tensor and return normal undefined tensor to the user. // Memory Safety: // The objects here will have lots of backedges. // In order to collect memory when computation is completed, @@ -299,6 +301,7 @@ struct CheckpointTensorCell : intrusive_ptr_target { TORCH_CHECK(defined); return optional_device_; } + bool original_requires_grad; // A Tensor is evictable iff it's AliasPool is evictable. // A evictable tensor must have Rematerializer. intrusive_ptr pool; @@ -327,7 +330,7 @@ struct CheckpointTensorCell : intrusive_ptr_target { remat->remat(); } TORCH_CHECK(t); - TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)); + TORCH_CHECK(! t->key_set().has(DispatchKey::Checkpoint)); pool->last_used_time = std::chrono::system_clock::now(); return *t; } @@ -368,11 +371,14 @@ struct External : intrusive_ptr_target { }; inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { - CHECK(!t.has(DispatchKey::Checkpoint)); + TORCH_CHECK(!t.has(DispatchKey::Checkpoint)); + TORCH_CHECK(!t.has(DispatchKey::Autograd)); auto ret = t.add(DispatchKey::Checkpoint); return ret; } +struct DETACH { }; + struct CheckpointTensorImpl : TensorImpl { int id = gen_counter(); static int counter; @@ -392,10 +398,15 @@ struct CheckpointTensorImpl : TensorImpl { ref->value->value->dtype(), ref->value->value->optional_device()), ref(ref) { - if (key_set().has(DispatchKey::Autograd)) { - set_requires_grad(true); + if (ref->value->value->original_requires_grad) { + set_requires_grad(true); + } } - } + + explicit CheckpointTensorImpl(DETACH, const CheckpointTensorImpl& cpti) : + TensorImpl(convert_key_set(cpti.ref->value->value->key_set()), + cpti.ref->value->value->dtype(), + cpti.ref->value->value->optional_device()), ref(cpti.ref) { } explicit CheckpointTensorImpl(const intrusive_ptr& e) : CheckpointTensorImpl(Ref>::make(e)) { }