Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 31, 2024
1 parent 59f003d commit 5a0d546
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
38 changes: 36 additions & 2 deletions nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
return node_type, node_metatype

@staticmethod
def _separate_conv_and_bias(model: torch.fx.GraphModule):
def separate_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
Expand Down Expand Up @@ -122,6 +122,40 @@ def _separate_conv_and_bias(model: torch.fx.GraphModule):
model.graph.eliminate_dead_code()
model.recompile()

@staticmethod
def merge_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
"""
add_node_targets = (torch.ops.aten.add_.Tensor,)
for n in model.graph.nodes:
if not _is_conv(n):
continue
if len(n.args) > 2 and n.args[2] is not None:
continue
bias_node = next(iter(n.users))
if len(n.users) > 1 or bias_node.target not in add_node_targets:
continue
conv_node = n
const_node = None
for node in bias_node.all_input_nodes:
if node is not conv_node:
const_node = node
break
assert const_node is not None
bias_value = _get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
args[2] = new_bias_node
conv_node.args = tuple(args)
for user in list(bias_node.users):
user.replace_input_with(bias_node, conv_node)

model.graph.eliminate_dead_code()
model.recompile()

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
"""
Expand All @@ -136,7 +170,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
_fuse_conv_bn_(model)
# BN fuses to conv bias, conv+bias joined op
# needs to be splited for nncf
GraphConverter._separate_conv_and_bias(model)
GraphConverter.separate_conv_and_bias(model)

nncf_graph = PTNNCFGraph()

Expand Down
16 changes: 10 additions & 6 deletions nncf/experimental/torch_fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,23 @@ def quantize_impl(
nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)

from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter

GraphConverter.merge_conv_and_bias(quantized_model)

# Magic. Without this call compiled model
# is not preformant
model = GraphModule(model, model.graph)
quantized_model = GraphModule(quantized_model, quantized_model.graph)

model = _fold_conv_bn_qat(model)
quantized_model = _fold_conv_bn_qat(quantized_model)
pm = PassManager([DuplicateDQPass()])

model = pm(model).graph_module
quantized_model = pm(quantized_model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module
quantized_model = pm(quantized_model).graph_module

model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
quantized_model.meta.update(original_graph_meta)
quantized_model = _disallow_eval_train(quantized_model)

return quantized_model

Expand Down
25 changes: 11 additions & 14 deletions torch_compile_ex_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_exported_model_from_nn_module(module, example_inputs):
return capture_pre_autograd_graph(module, example_inputs)


NNCF_IMPL = True
NNCF_IMPL = False


def get_qsetup(exported_model, example_inputs):
Expand Down Expand Up @@ -79,8 +79,6 @@ def get_qsetup(exported_model, example_inputs):


def quantize(model, example_inputs):
exported_model = get_exported_model_from_nn_module(model, example_inputs)

if NNCF_IMPL:
# Use NNCF here on exported model
# to create a quantized model which is compatible with
Expand All @@ -97,19 +95,18 @@ def quantize(model, example_inputs):
import nncf

calibration_dataset = nncf.Dataset(example_inputs)
exported_model = get_exported_model_from_nn_module(model, example_inputs)
quantized_model = nncf.quantize(exported_model, calibration_dataset)
g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf")
g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg")
return quantized_model

else:

g = FxGraphDrawer(exported_model, "resnet18")
g.get_dot_graph().write_svg("resnet18_compiled.svg")
nncf_graph = GraphConverter.create_nncf_graph(exported_model)
del nncf_graph
# g = FxGraphDrawer(exported_model, "resnet18")
# g.get_dot_graph().write_svg("resnet18_compiled.svg")

# MOCK NNCF QUANTIZATION
exported_model = get_exported_model_from_nn_module(model, example_inputs)
qsetup = get_qsetup(exported_model, example_inputs)
exported_model = get_exported_model_from_nn_module(model, example_inputs)
exported_model = insert_qdq_to_model(exported_model, qsetup)
Expand Down Expand Up @@ -166,13 +163,13 @@ def main(model_name, num_iters):

converted_model = quantize(copy.deepcopy(model), example_inputs)

print("original model execution time: ", measure_time(model, example_inputs, num_iters))
# print("original model execution time: ", measure_time(model, example_inputs, num_iters))

native_optimized_model_fp32 = torch.compile(model)
print(
"Torch Inductor FP32 model execution time: ",
measure_time(native_optimized_model_fp32, example_inputs, num_iters),
)
# native_optimized_model_fp32 = torch.compile(model)
# print(
# "Torch Inductor FP32 model execution time: ",
# measure_time(native_optimized_model_fp32, example_inputs, num_iters),
# )

native_optimized_model_int8 = torch.compile(converted_model)
print(
Expand Down

0 comments on commit 5a0d546

Please sign in to comment.