Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 23, 2024
1 parent 40dfa8f commit 17f9e9d
Show file tree
Hide file tree
Showing 23 changed files with 16,454 additions and 16,380 deletions.
31 changes: 28 additions & 3 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ def _apply_model_extraction(
def remap_fn(node: torch.fx.Node):
return value_remap.get(node) # noqa F821

visited_outputs_names = []
for node in model.graph.nodes:
if node.name not in visited or node.op == "output":
if node.name not in visited:
continue
if node.op == "output":
visited_outputs_names.append(node.name)
continue
value_remap[node] = extracted_graph.node_copy(node, remap_fn)
del value_remap

for input_name in transformation.input_node_names:
node_with_input = get_graph_node_by_name(extracted_graph, input_name)
Expand All @@ -146,7 +149,29 @@ def remap_fn(node: torch.fx.Node):
args[0] = graph_input
node_with_input.args = tuple(args)

nodes_with_output = [get_graph_node_by_name(extracted_graph, name) for name in transformation.output_node_names]
# Merge new output with the original output in case
# the original output is requested in the extracted graph.
nodes_with_output = []
for name in transformation.output_node_names:
nodes_with_output.append(
name if name in visited_outputs_names else get_graph_node_by_name(extracted_graph, name)
)

for idx, node in enumerate(nodes_with_output):
if isinstance(node, torch.fx.Node):
continue
output_node = get_graph_node_by_name(model.graph, node)
args = output_node.args[0]
if isinstance(args, torch.fx.Node):
args = value_remap[args]
else:
args = [value_remap[n] for n in args]
# Unpack target output args in case
# only one arg is presented.
if len(args) == 1:
args = args[0]
nodes_with_output[idx] = args

last_node = list(extracted_graph.nodes)[-1]
with extracted_graph.inserting_after(last_node):
graph_output_name = "output"
Expand Down
15 changes: 7 additions & 8 deletions nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
:return: True if the node has a bias, False otherwise.
"""
# Assumes that all biases were unfused
if node.metatype in FX_OPERATORS_WITH_BIAS_METATYPES:
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)
if node.metatype not in FX_OPERATORS_WITH_BIAS_METATYPES or len(nncf_graph.get_input_edges(node)) != 3:
return False
const_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
return const_node.metatype is om.PTConstNoopMetatype


def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
Expand All @@ -82,7 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphM
:param model: Target GraphModule.
:return: Bias value of the given node.
"""
bias_node = nncf_graph.get_next_nodes(node)[0]
bias_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
# TODO(dlyakhov): make a node_name_vs_node map to speed up the process
graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
graph_bias_const = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_const, model))
182 changes: 1 addition & 181 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,6 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int) -> TransformationFNType:
"""
Return transformation which updates constant of the given node with bias to the given value.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
graph = model.graph
target_node_name = node.node_name
graph_node = get_graph_node_by_name(graph, target_node_name)
add_nodes = []
for user in graph_node.users:
if _is_add(user):
add_nodes.append(user)
if len(add_nodes) != 1:
raise nncf.InternalError(f"Node {graph_node.name} has {len(add_nodes)} outputs with adds, 1 expected")

bias_node = add_nodes[0]
constant_update_fn(model, bias_node, value, input_port_id=input_port_id)

return bias_update_transformation


def shared_constants_unification_transformation(model: torch.fx.GraphModule):
"""
checks FX graph for shared constants and eliminates redundant
Expand Down Expand Up @@ -789,8 +762,6 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# with the target graph BatchNorm operations
# are being fused
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)
shared_constants_unification_transformation(model)


Expand All @@ -800,8 +771,7 @@ def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
:param model: Model to revert transformations from.
"""
merge_conv_and_bias(model)
merge_linear_and_bias(model)
pass


def _is_linear(n: torch.fx.Node) -> bool:
Expand All @@ -823,153 +793,3 @@ def _is_conv(n: torch.fx.Node):
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
)


def _is_add(n: torch.fx.Node):
"""
Return whether the node refers to an aten add op.
"""
return n.op == "call_function" and n.target in (
torch.ops.aten.add_.Tensor,
torch.ops.aten.add.Tensor,
)


def separate_linear_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined linear+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add.Tensor
for n in model.graph.nodes:
if not _is_linear(n):
continue
# This check also makes sure to ignore linear nodes which might already
# have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
linear_node = n
linear_bias_node = linear_node.args[2]
while linear_bias_node.op != "get_attr":
# Assume zero argument is on a path to the constant
linear_bias_node = linear_bias_node.args[0]
linear_bias_value = get_tensor_constant_from_node(linear_bias_node, model)
args = list(n.args)
args[2] = None
linear_node.args = tuple(args)
with model.graph.inserting_after(linear_node):
new_linear_bias_node = create_getattr_from_value(
model,
model.graph,
linear_bias_node.name + "_",
linear_bias_value,
)
with model.graph.inserting_after(new_linear_bias_node):
add_node = model.graph.create_node(
"call_function", add_node_target, (linear_node, new_linear_bias_node), {}
)
for user in list(linear_node.users):
if user is add_node:
continue
user.replace_input_with(linear_node, add_node)
if "val" in linear_node.meta:
add_node.meta["val"] = linear_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def separate_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add_.Tensor
for n in model.graph.nodes:
if not _is_conv(n):
continue
# This check also makes sure to ignore convolution nodes which might
# already have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
conv_node = n
dims = len(get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
with model.graph.inserting_after(conv_node):
new_conv_bias_node = create_getattr_from_value(
model, model.graph, conv_bias_node.name + "_", conv_bias_value.reshape((1, -1) + (1,) * (dims - 2))
)
with model.graph.inserting_after(new_conv_bias_node):
add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {})
for user in list(conv_node.users):
if user is add_node:
continue
user.replace_input_with(conv_node, add_node)

if "val" in conv_node.meta:
add_node.meta["val"] = conv_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def merge_conv_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate conv and bias nodes to a one node: conv+bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
_merge_node_and_bias(model, _is_conv)


def merge_linear_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate linear and bias nodes to a one node: linear+bias.
:param model: Target model.
"""
_merge_node_and_bias(model, _is_linear)


def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[torch.fx.Node], bool]):
"""
Merges two separate node and bias node to a one node: node+bias.
Check which node should be merged by the given `is_target_node` predicate.
:param model: Target model.
:param is_target_node: Predicate to specify nodes which should be merged with the bias
"""
add_node_targets = (torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor)
for n in model.graph.nodes:
if not is_target_node(n):
continue
if len(n.args) > 2 and n.args[2] is not None:
continue
bias_node = next(iter(n.users))
if len(n.users) > 1 or bias_node.target not in add_node_targets:
continue
conv_node = n
const_node = None
for node in bias_node.all_input_nodes:
if node is not conv_node:
const_node = node
break
assert const_node is not None
bias_value = get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
args[2] = new_bias_node
conv_node.args = tuple(args)
for user in list(bias_node.users):
user.replace_input_with(bias_node, conv_node)

model.graph.eliminate_dead_code()
model.recompile()
10 changes: 8 additions & 2 deletions nncf/quantization/algorithms/bias_correction/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nncf.experimental.torch.fx.node_utils import get_bias_value
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import is_node_with_bias
from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
from nncf.experimental.torch.fx.transformations import output_insertion_transformation_builder
from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend
from nncf.tensor import Tensor
Expand All @@ -45,7 +45,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))
return FXApplyTransformationCommand(
constant_update_transformation_builder(node, bias_value.data, input_port_id=2)
)

@staticmethod
def model_extraction_command(
Expand Down Expand Up @@ -90,6 +92,10 @@ def get_input_name(model: torch.fx.GraphModule, node_name: str, input_port_id: i
@staticmethod
def get_output_name(model: torch.fx.GraphModule, node_name: str, output_port_id: int) -> int:
graph_node = get_graph_node_by_name(model.graph, node_name)
if graph_node.op == "output":
# Original node output is kept as the first
# output tensor, thus returns 0.
return 0
nodes = list(graph_node.users)
while nodes:
node = nodes.pop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nncf.experimental.torch.fx.model_utils import get_target_point
from nncf.experimental.torch.fx.node_utils import get_bias_value
from nncf.experimental.torch.fx.node_utils import is_node_with_bias
from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.tensor import Tensor
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
Expand All @@ -41,7 +41,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))
return FXApplyTransformationCommand(
constant_update_transformation_builder(node, bias_value.data, input_port_id=2)
)

@staticmethod
def model_extraction_command(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 add" [id=4, type=add];
"5 output" [id=5, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(1,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 add" [label="(1, 1, 3, 3)", style=solid];
"3 conv2d" -> "5 output" [label="(1, 1, 3, 3)", style=solid];
"4 add" -> "5 output" [label="(1, 1, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 add" [id=4, type=add];
"5 output" [id=5, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(1,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 add" [label="(1, 1, 3, 3)", style=solid];
"3 conv2d" -> "5 output" [label="(1, 1, 3, 3)", style=solid];
"4 add" -> "5 output" [label="(1, 1, 3, 3)", style=solid];
}
Loading

0 comments on commit 17f9e9d

Please sign in to comment.