Skip to content

Commit

Permalink
Merge branch 'dl/fx/constant_folding' into dl/fx/tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 23, 2024
2 parents 17f9e9d + 300232b commit 260440b
Show file tree
Hide file tree
Showing 26 changed files with 16,855 additions and 16,322 deletions.
273 changes: 273 additions & 0 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.utils._pytree as pytree

aten = torch.ops.aten


def _replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
name: Optional[str] = None,
) -> None:
g = gm.graph

if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0 # type: ignore[assignment]
i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def _is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder" and lifted_constants is not None and node.name in lifted_constants
)


class _ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors

# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False

def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)

def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
)

if (
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and _is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True

quant_registered = getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) is not None
if quant_registered and node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False

def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))

for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue

def add_use(inp: torch.fx.Node) -> None:
if inp in seen_uses:
return

seen_uses.add(inp)
last_non_output_use[node].append(inp)

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))

# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)

return last_non_output_use

def run_node(self, node: torch.fx.Node) -> Any:
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg: torch.fx.Node) -> None:
self.env[arg] = self.unknown_value

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)

args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)

# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) is type(input_) and self.unknown_value == input_ for input_ in flattened_inputs
):
return self.unknown_value

# TODO - fix errors with this
if node.op == "call_function" and node.target == aten._efficientzerotensor.default:
return self.unknown_value

# TODO - constant folding triton kernel returns the inputs -- fix this
if node.op == "call_function" and node.name == "triton_kernel_wrapper_functional_proxy":
return self.unknown_value

# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not _is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value

# All mutations should either be removed or on inputs which we did not make constant
if isinstance(node.target, torch._ops.OpOverload) and torch.Tag.nondeterministic_seeded in node.target.tags:
return self.unknown_value

out = self._deduce_value(node)
if out == self.unknown_value:
return self.unknown_value

if not _is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out

if not self.insertable_tensor_check(out):
return out

if self.is_impure(node):
return self.unknown_value

self.add_node_replacement(node, out)

flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)

for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue

self.replaced_uses[n] += 1

for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)

return out

def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True

def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor

def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]


def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
"""
Calcualtes constant subgraphs values and replaces them with a constant node inplace.
:param gm: Given graph model.
:param constraint_fn: Constraint function which takes a node and returs the constraint:
should the node be constant folded or not.
"""
with torch.utils._python_dispatch._disable_current_modes():
cf = _ConstantFolder(gm, skip_constructors=True)
cf.run()

for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
_replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
26 changes: 22 additions & 4 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -23,6 +22,7 @@
import nncf.torch
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand Down Expand Up @@ -642,13 +642,17 @@ def _compress_qdq_constant_transformation(model: torch.fx.GraphModule, matches)
for match in matches:
mul_node = match.replacements[0]
sub_node = match.replacements[1]
weight_node, scale_node, zp_node, axis = None, None, None, None
nodes_map = {node.name: match.nodes_map[node] for node in match.nodes_map}
get_const = partial(get_tensor_constant_from_node, model=model)

def get_const(arg: Optional[Union[torch.fx.Node, float, int]]):
if isinstance(arg, torch.fx.Node):
return get_tensor_constant_from_node(arg, model)
return arg

weight_node = get_const(nodes_map["weight"])
scale_node = get_const(nodes_map["scale"])
zp_node = get_const(nodes_map["zero_point"])
axis = nodes_map["axis"]
axis = get_const(nodes_map.get("axis"))
port_id = 0
if axis is not None:
result = torch.ops.quantized_decomposed.quantize_per_channel.default(
Expand Down Expand Up @@ -761,10 +765,24 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# to make it easier for algorithms to work
# with the target graph BatchNorm operations
# are being fused
fold_constant_except_qdq(model)
fuse_conv_bn(model)
shared_constants_unification_transformation(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
"""
Performs constant folding avoiding quantize-dequantize pattern.
:param model: Model to perform constant folding on.
"""

def constraint_fn(node: torch.fx.Node):
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS

constant_fold(model, constraint_fn=constraint_fn)


def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
"""
Reverts quantization transformations from the model.
Expand Down
3 changes: 3 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def save_compressed_model(self) -> None:
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
self.path_compressed_ir = self.output_model_dir / "model.xml"
from openvino._offline_transformations import apply_moc_transformations

apply_moc_transformations(self.compressed_model, cf=True)
ov.serialize(self.compressed_model, str(self.path_compressed_ir))

def get_num_compressed(self) -> None:
Expand Down
16 changes: 11 additions & 5 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]
export_torch_before_ov_convert: bool = False


class ImageClassificationTorchvision(ImageClassificationBase):
Expand All @@ -47,8 +48,12 @@ class ImageClassificationTorchvision(ImageClassificationBase):
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -92,9 +97,10 @@ def prepare_model(self) -> None:

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
with disable_patching():
m = torch.export.export(model, args=(self.dummy_tensor,))
self.model = ov.convert_model(m, example_input=self.dummy_tensor, input=self.input_size)
if self.model_params.export_torch_before_ov_convert:
with disable_patching():
model = torch.export.export(model, (self.dummy_tensor,))
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()
Expand Down
Loading

0 comments on commit 260440b

Please sign in to comment.