forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inline_autodiff_subgraphs.cpp
88 lines (72 loc) · 2.58 KB
/
inline_autodiff_subgraphs.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch::jit {
// aten and prim nodes (except FusionGroup) are guaranteed to work
// with Autograd, other nodes (e.g. user-defined nodes) are not necessarily
// Autograd-aware
bool canRunWithAutograd(Node* node) {
auto kind = node->kind();
for (Block* block : node->blocks()) {
if (!std::all_of(
block->nodes().begin(), block->nodes().end(), canRunWithAutograd)) {
return false;
}
}
return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
}
namespace {
void InlineAutodiffSubgraphs(Block* block, size_t threshold);
size_t blockSize(Block* block) {
size_t num = 0;
for (Node* n : block->nodes()) {
for (Block* b : n->blocks()) {
num += blockSize(b);
}
num++;
}
return num;
}
graph_node_list::iterator scanNode(Node* node, size_t threshold) {
auto next_node = ++node->iterator();
for (Block* block : node->blocks()) {
InlineAutodiffSubgraphs(block, threshold);
}
if (node->kind() != prim::DifferentiableGraph) {
return next_node;
}
auto subgraph = node->g(attr::Subgraph);
size_t subgraph_size = blockSize(subgraph->block());
if (subgraph_size >= threshold) {
return next_node;
}
if (!std::all_of(
subgraph->nodes().begin(),
subgraph->nodes().end(),
canRunWithAutograd)) {
return next_node;
}
// now that we inline the graph, we are no longer detaching input tensors,
// so the profiles will have outdated requires_grad=False.
// conservatively update them to maybe requiring grad, bc we might create
// autodiff graphs when the tensors maybe require grad
UpdateDifferentiableGraphRequiresGrad(subgraph, std::nullopt);
SubgraphUtils::unmergeSubgraph(node);
return next_node;
}
void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
it = scanNode(*it, threshold);
}
}
} // anonymous namespace
void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
InlineAutodiffSubgraphs(graph->block(), threshold);
EliminateDeadCode(graph);
}
} // namespace torch::jit