forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model_dump.py
240 lines (197 loc) · 7.75 KB
/
test_model_dump.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]
import sys
import os
import io
import functools
import tempfile
import urllib
import unittest
import torch
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from torch.testing._internal.common_quantized import supported_qengines
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(16, 64)
self.relu1 = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(64, 8)
self.relu2 = torch.nn.ReLU()
def forward(self, features):
act = features
act = self.layer1(act)
act = self.relu1(act)
act = self.layer2(act)
act = self.relu2(act)
return act
class QuantModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
self.core = SimpleModel()
def forward(self, x):
x = self.quant(x)
x = self.core(x)
x = self.dequant(x)
return x
class ModelWithLists(torch.nn.Module):
def __init__(self):
super().__init__()
self.rt = [torch.zeros(1)]
self.ot = [torch.zeros(1), None]
def forward(self, arg):
arg = arg + self.rt[0]
o = self.ot[0]
if o is not None:
arg = arg + o
return arg
def webdriver_test(testfunc):
@functools.wraps(testfunc)
def wrapper(self, *args, **kwds):
self.needs_resources()
if os.environ.get("RUN_WEBDRIVER") != "1":
self.skipTest("Webdriver not requested")
from selenium import webdriver
for driver in [
"Firefox",
"Chrome",
]:
with self.subTest(driver=driver):
wd = getattr(webdriver, driver)()
testfunc(self, wd, *args, **kwds)
wd.close()
return wrapper
class TestModelDump(TestCase):
def needs_resources(self):
if sys.version_info < (3, 7):
self.skipTest("importlib.resources was new in 3.7")
def test_inline_skeleton(self):
self.needs_resources()
skel = torch.utils.model_dump.get_inline_skeleton()
assert "unpkg.org" not in skel
assert "src=" not in skel
def do_dump_model(self, model, extra_files=None):
# Just check that we're able to run successfully.
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
info = torch.utils.model_dump.get_model_info(buf)
assert info is not None
def open_html_model(self, wd, model, extra_files=None):
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
def open_section_and_get_body(self, wd, name):
container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
caret = container.find_element_by_class_name("caret")
if container.get_attribute("data-shown") != "true":
caret.click()
content = container.find_element_by_tag_name("div")
return content
def test_scripted_model(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model)
def test_traced_model(self):
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
self.do_dump_model(model)
def test_main(self):
self.needs_resources()
if IS_WINDOWS:
# I was getting tempfile errors in CI. Just skip it.
self.skipTest("Disabled on Windows.")
with tempfile.NamedTemporaryFile() as tf:
torch.jit.save(torch.jit.script(SimpleModel()), tf)
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=json",
tf.name,
],
stdout=stdout)
self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=html",
tf.name,
],
stdout=stdout)
self.assertRegex(
stdout.getvalue().replace("\n", " "),
r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
def get_quant_model(self):
fmodel = QuantModel().eval()
fmodel = torch.ao.quantization.fuse_modules(fmodel, [
["core.layer1", "core.relu1"],
["core.layer2", "core.relu2"],
])
fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
prepped = torch.ao.quantization.prepare(fmodel)
prepped(torch.randn(2, 16))
qmodel = torch.ao.quantization.convert(prepped)
return qmodel
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_quantized_model(self):
qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel))
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self):
qmodel = self.get_quant_model()
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
self.do_dump_model(omodel)
def test_model_with_lists(self):
model = torch.jit.script(ModelWithLists())
self.do_dump_model(model)
def test_invalid_json(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model, extra_files={"foo.json": "{"})
@webdriver_test
def test_memory_computation(self, wd):
def check_memory(model, expected):
self.open_html_model(wd, model)
memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
self.assertEqual("cpu", device)
memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
self.assertEqual(expected, int(memory_usage_str))
simple_model_memory = (
# First layer, including bias.
64 * (16 + 1) +
# Second layer, including bias.
8 * (64 + 1)
# 32-bit float
) * 4
check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
# The same SimpleModel instance appears twice in this model.
# The tensors will be shared, so ensure no double-counting.
a_simple_model = SimpleModel()
check_memory(
torch.jit.script(
torch.nn.Sequential(a_simple_model, a_simple_model)),
simple_model_memory)
# The freezing process will move the weight and bias
# from data to constants. Ensure they are still counted.
check_memory(
torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
simple_model_memory)
# Make sure we can handle a model with both constants and data tensors.
class ComposedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.w1 = torch.zeros(1, 2)
self.w2 = torch.ones(2, 2)
def forward(self, arg):
return arg * self.w2 + self.w1
check_memory(
torch.jit.freeze(
torch.jit.script(ComposedModule()).eval(),
preserved_attrs=["w1"]),
4 * (2 + 4))
if __name__ == '__main__':
run_tests()