forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PythonFallbackKernel.cpp
49 lines (44 loc) · 1.83 KB
/
PythonFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/PythonModeTLS.h>
namespace {
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// If Python Mode is active, use its PyInterpreter for dispatch
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
if (maybe_python_mode_state) {
maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
return;
}
// Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
// without checking the interpreters of any of the arguments, because when
// we actually run dispatch(), we will take out PyObjects in the context
// of that interpreter, and this will ensure that everyone is on the same
// interpreter.
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
interpreter->dispatch(op, stack, nullptr);
return;
}
} else if (ivalue.isTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
// is not a thing)
for (const auto& nv : ivalue.toListRef()) {
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
interpreter->dispatch(op, stack, nullptr);
return;
}
}
}
}
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(_, Python, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
}