Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dl/fx/experimental quantization conformance #33

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Next Next commit
[TorchFX] Constant folding check
daniil-lyakhov committed Oct 21, 2024
commit 20b640cfb012129ac68c12052ab9cd25e151e6f2
266 changes: 266 additions & 0 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.utils._pytree as pytree

aten = torch.ops.aten


def replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
name: Optional[str] = None,
) -> None:
g = gm.graph

if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0 # type: ignore[assignment]
i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder" and lifted_constants is not None and node.name in lifted_constants
)


class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors

# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False

def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)

def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
)

if (
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True

quant_registered = getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) is not None
if quant_registered and node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False

def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))

for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue

def add_use(inp: torch.fx.Node) -> None:
if inp in seen_uses:
return

seen_uses.add(inp)
last_non_output_use[node].append(inp)

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))

# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)

return last_non_output_use

def run_node(self, node: torch.fx.Node) -> Any:
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg: torch.fx.Node) -> None:
self.env[arg] = self.unknown_value

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)

args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)

# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) is type(input_) and self.unknown_value == input_ for input_ in flattened_inputs
):
return self.unknown_value

# TODO - fix errors with this
if node.op == "call_function" and node.target == aten._efficientzerotensor.default:
return self.unknown_value

# TODO - constant folding triton kernel returns the inputs -- fix this
if node.op == "call_function" and node.name == "triton_kernel_wrapper_functional_proxy":
return self.unknown_value

# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value

# All mutations should either be removed or on inputs which we did not make constant
if isinstance(node.target, torch._ops.OpOverload) and torch.Tag.nondeterministic_seeded in node.target.tags:
return self.unknown_value

out = self._deduce_value(node)
if out == self.unknown_value:
return self.unknown_value

if not is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out

if not self.insertable_tensor_check(out):
return out

if self.is_impure(node):
return self.unknown_value

self.add_node_replacement(node, out)

flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)

for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue

self.replaced_uses[n] += 1

for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)

return out

def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True

def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor

def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]


def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()

for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
from torch._export import capture_pre_autograd_graph
from torchvision import models

from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.torch import disable_patching
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
@@ -74,6 +75,7 @@ def prepare_model(self) -> None:
with torch.no_grad():
with disable_patching():
self.model = self.model_params.export_fn(model, (self.dummy_tensor,))
constant_fold(self.model)

elif self.backend in PT_BACKENDS:
self.model = model