forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
function.cpp
80 lines (70 loc) · 2.27 KB
/
function.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <torch/csrc/jit/function.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/script/error_report.h>
namespace torch {
namespace jit {
namespace {
FunctionSchema defaultSchemaFor(const Function& function) {
std::vector<Argument> args;
std::vector<Argument> returns;
Graph& g = *function.graph();
size_t num_inputs = function.num_inputs();
for (size_t i = 0; i < num_inputs; ++i) {
const Value* v = g.inputs().at(i);
std::string name = v->hasDebugName() ? v->debugNameBase()
: ("argument_" + c10::to_string(i));
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
}
for (size_t i = 0; i < g.outputs().size(); ++i) {
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
}
return {function.name(), "", std::move(args), std::move(returns)};
}
} // namespace
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Function&) {
throw RecursiveMethodCallError();
}
void Function::run(Stack& stack) {
get_executor().run(stack);
}
void Function::run(Stack&& stack) {
run(stack);
}
IValue Function::operator()(
std::vector<IValue> stack,
const Kwargs& kwargs) {
getSchema().checkAndNormalizeInputs(stack, kwargs);
run(stack);
return stack.front();
}
void Function::ensure_defined() {
try {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
creator(*this);
function_creator_ = nullptr;
}
} catch (RecursiveMethodCallError&) {
throw script::ErrorReport() // TODO: once lower_first_class methods is
// removed re-establish callsite info for
// debugging
<< " method '" << name() << "' is called recursively. "
<< "Recursive calls are not supported";
}
check_single_output();
}
const FunctionSchema& Function::getSchema() const {
if (schema_ == nullptr) {
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
}
return *schema_;
}
void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
// TODO: Invoke cleanup passes before and after inlining to reduce amount of
// code we're copying.
Inline(*graph);
}
} // namespace jit
} // namespace torch