forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
debug_info.cpp
228 lines (211 loc) · 8.94 KB
/
debug_info.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <c10/util/string_view.h>
namespace torch {
namespace jit {
namespace {
C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
const std::string& debug_handles_string) {
return "Debug info for handle(s): " + debug_handles_string +
", was not found.";
}
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
const DebugInfoTuple& source_callstack,
const std::string& caller_name) {
std::vector<StackEntry> entries;
const SourceRange& range =
std::get<kDebugInfoTupleSourceRangeIndex>(source_callstack);
InlinedCallStackPtr callstack_ptr =
std::get<kDebugInfoTupleInlinedCSIndex>(source_callstack);
std::string prev_function_name = caller_name;
std::string module_info;
if (!callstack_ptr) {
// If not cs then top level node
entries.emplace_back(StackEntry{prev_function_name, range});
return {std::move(entries), std::move(module_info)};
} else {
while (callstack_ptr) {
const auto& opt_module_instance_info = callstack_ptr->module_instance();
if (opt_module_instance_info.has_value()) {
const auto& module_instance_info = opt_module_instance_info.value();
// Sometimes (e.g., in lowered backends) we augment instance name with
// type name instead of losing type name. In those cases instance_name
// includes both instance name and type name. See
// callstack_debug_info_serialization.cpp
if (module_instance_info.class_type()) {
module_info.append(".").append(
utils::get_module_info(module_instance_info));
} else {
module_info.append(".").append(module_instance_info.instance_name());
}
} else {
module_info.append(".UNKNOWN_INSTANCE(UNKNOWN_TYPE)");
}
// Now add source range info to stack
entries.emplace_back(
StackEntry{prev_function_name, callstack_ptr->source_range()});
prev_function_name = callstack_ptr->function_name();
// Function name appended here
// It is renamed to prev_function_name because for StackEntry
// it will be appended in the next iteration. This is the format
// in which format_stack_trace expects function names.
module_info.append("::").append(prev_function_name);
if (callstack_ptr->callee()) {
callstack_ptr = callstack_ptr->callee().value();
} else {
callstack_ptr = c10::intrusive_ptr<InlinedCallStack>();
}
}
entries.emplace_back(StackEntry{prev_function_name, range});
return {std::move(entries), std::move(module_info)};
}
}
// This function construct stacktrace with module hierarchy
// Module hierarchy will contain information about where in the
// module hierarchy this source is. For example if conv2d op
// exist in hierarcy A->B->C->Conv2d with type annotations of
// A -> TopM, B->MyModule, C->SomeModule, then module hierarchy
// will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv)
// Source level stack information will be from model source code.
std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
const std::vector<DebugInfoTuple>& source_callstacks,
const std::string& root_scope_string,
const std::string& top_module_type_name) {
std::vector<StackEntry> stack_entries;
std::string module_info =
root_scope_string + "(" + top_module_type_name + ")";
std::string caller_fn_name = "<unknown>";
module_info.append("::").append(caller_fn_name);
for (const auto& debug_info : source_callstacks) {
auto debug_info_pair =
getStackTraceWithModuleHierarchy(debug_info, caller_fn_name);
auto entries = std::move(debug_info_pair.first);
stack_entries.insert(stack_entries.end(), entries.begin(), entries.end());
module_info.append(debug_info_pair.second);
}
// Only last entry in the callstack will have a node name of interest.
// Rest are likely CallMethod/CallFunction nodes
auto last_entry = source_callstacks.back();
const std::string& node_name =
std::get<kDebugInfoTupleNodeNameIndex>(last_entry);
module_info.append(".").append(node_name);
std::ostringstream ss;
ss << "Module hierarchy:" << module_info << "\n";
format_stack_trace(ss, stack_entries);
return {ss.str(), std::move(module_info)};
}
} // namespace
MobileDebugTable::MobileDebugTable(
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
const std::shared_ptr<CompilationUnit>& cu) {
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
const std::vector<std::string>& record_names = reader->getAllRecords();
const c10::string_view suffix(".debug_pkl");
for (const auto& record_name : record_names) {
if (c10::string_view(record_name).ends_with(suffix)) {
at::DataPtr debug_data;
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalues =
std::move(*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType)
.toTuple())
.elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {
auto tup_elems = std::move(*std::move(val).toTuple()).elements();
// For BC we decode only tuples with 3 elements
// assuming it contains
// byte_offset, debug_handle (=source range tag), source range
if (tup_elems.size() == 3) {
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
auto source_range =
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
source_range_map.emplace(debug_handle, std::move(source_range));
}
}
}
}
const std::string callstack_debug_file("callstack_debug_map.pkl");
if (reader->hasRecord("callstack_debug_map.pkl")) {
at::DataPtr callstack_data;
size_t callstack_data_size{0};
std::tie(callstack_data, callstack_data_size) =
reader->getRecord(callstack_debug_file);
CallStackDebugInfoUnpickler unpickler;
callstack_ptr_map_ = unpickler.unpickle(
std::move(callstack_data), callstack_data_size, source_range_map, cu);
}
}
std::string MobileDebugTable::getModuleHierarchyInfo(
const int64_t debug_handle,
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
.second;
}
std::string MobileDebugTable::getModuleHierarchyInfo(
const std::vector<int64_t>& debug_handles,
const std::string& top_module_type_name) const {
return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
.second;
}
std::string MobileDebugTable::getSourceDebugString(
const int64_t debug_handle,
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
.first;
}
std::string MobileDebugTable::getSourceDebugString(
const std::vector<int64_t>& debug_handles,
const std::string& top_module_type_name) const {
return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
.first;
}
std::pair<std::string, std::string> MobileDebugTable::
getSourceDebugModuleHierarchyInfo(
const std::vector<int64_t>& debug_handles,
const std::string& top_module_type_name) const {
std::vector<DebugInfoTuple> debug_infos;
bool debug_handle_not_found{false};
for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) {
auto debug_handle = *it;
const auto cs_it = callstack_ptr_map_.find(debug_handle);
if (cs_it == callstack_ptr_map_.end()) {
debug_handle_not_found = true;
break;
}
debug_infos.emplace_back(cs_it->second);
}
if (debug_handle_not_found) {
std::string debug_handles_string = "debug_handles:{";
for (const auto debug_handle : debug_handles) {
debug_handles_string += std::to_string(debug_handle);
}
debug_handles_string += "}";
debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
return {debug_handles_string, debug_handles_string};
}
return (getStackTraceWithModuleHierarchy(
debug_infos, "top", top_module_type_name));
}
} // namespace jit
} // namespace torch