Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unwrap undefined tensor #56

Open
wants to merge 1 commit into
base: 9-9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,16 @@ void CheckpointTensorCell::fill(const Tensor& t) {
defined = true;
is_undefined_tensor = !t.defined();
key_set_ = t.key_set();
if (t.requires_grad()) {
key_set_ = key_set_.add(DispatchKey::Autograd);
}
dtype_ = t.dtype();
optional_device_ = t.optional_device();
original_requires_grad = t.requires_grad();
}
}
}

intrusive_ptr<TensorImpl> CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto ret = intrusive_ptr<CheckpointTensorImpl>::make(ref);
auto ret = intrusive_ptr<CheckpointTensorImpl>::make(DETACH {}, *this);
if (use_log_) {
DTRLogCopy(ret->counter_name(), counter_name());
}
Expand All @@ -493,7 +491,7 @@ intrusive_ptr<TensorImpl> CheckpointTensorImpl::shallow_copy_and_detach(const Va

void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
STATS.track("CheckpointTensorCell::shallow_copy_from");
TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId));
TORCH_CHECK(impl->key_set().has(DispatchKey::Checkpoint));
auto* cpti = dynamic_cast<CheckpointTensorImpl*>(impl.get());
TORCH_CHECK(cpti != nullptr);
ref->value = cpti->ref->value;
Expand Down Expand Up @@ -645,7 +643,11 @@ Tensors CheckpointTensorImpl::make(const std::string& name,
}
}

return tensors;
Tensors return_tensors;
for (const auto& tensor : tensors) {
return_tensors.push_back(get_cpti(tensor)->ref->value->value->is_undefined_tensor ? Tensor() : tensor);
}
return return_tensors;
}

// TODO: check that mutated value does not have alias.
Expand Down
25 changes: 18 additions & 7 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)
#define TORCH_CHECK(a, ...) // profile mode
//#define TORCH_CHECK(a, ...) // profile mode

// System Description:
// Every Tensor is managed by a CheckpointTensor,
Expand Down Expand Up @@ -63,7 +63,9 @@
// and output all the new value of the aliases, then update the Ref.
// Of course, the cleaner way is to not support this.
// Shame on those who use this feature.

// An operator might return undefined tensor.
// Since undefined tensor has special semantic, we dont want to wrap it.
// When that is the case we unwrap the tensor and return normal undefined tensor to the user.
// Memory Safety:
// The objects here will have lots of backedges.
// In order to collect memory when computation is completed,
Expand Down Expand Up @@ -299,6 +301,7 @@ struct CheckpointTensorCell : intrusive_ptr_target {
TORCH_CHECK(defined);
return optional_device_;
}
bool original_requires_grad;
// A Tensor is evictable iff it's AliasPool is evictable.
// A evictable tensor must have Rematerializer.
intrusive_ptr<AliasPool> pool;
Expand Down Expand Up @@ -327,7 +330,7 @@ struct CheckpointTensorCell : intrusive_ptr_target {
remat->remat();
}
TORCH_CHECK(t);
TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId));
TORCH_CHECK(! t->key_set().has(DispatchKey::Checkpoint));
pool->last_used_time = std::chrono::system_clock::now();
return *t;
}
Expand Down Expand Up @@ -368,11 +371,14 @@ struct External : intrusive_ptr_target {
};

inline DispatchKeySet convert_key_set(const DispatchKeySet& t) {
CHECK(!t.has(DispatchKey::Checkpoint));
TORCH_CHECK(!t.has(DispatchKey::Checkpoint));
TORCH_CHECK(!t.has(DispatchKey::Autograd));
auto ret = t.add(DispatchKey::Checkpoint);
return ret;
}

struct DETACH { };

struct CheckpointTensorImpl : TensorImpl {
int id = gen_counter();
static int counter;
Expand All @@ -392,10 +398,15 @@ struct CheckpointTensorImpl : TensorImpl {
ref->value->value->dtype(),
ref->value->value->optional_device()),
ref(ref) {
if (key_set().has(DispatchKey::Autograd)) {
set_requires_grad(true);
if (ref->value->value->original_requires_grad) {
set_requires_grad(true);
}
}
}

explicit CheckpointTensorImpl(DETACH, const CheckpointTensorImpl& cpti) :
TensorImpl(convert_key_set(cpti.ref->value->value->key_set()),
cpti.ref->value->value->dtype(),
cpti.ref->value->value->optional_device()), ref(cpti.ref) { }

explicit CheckpointTensorImpl(const intrusive_ptr<External>& e) :
CheckpointTensorImpl(Ref<intrusive_ptr<External>>::make(e)) { }
Expand Down