diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index b6910d1d49f..c40183dd94e 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -35,7 +35,6 @@ class TransformationPriority(IntEnum): FP32_TENSOR_STATISTICS_OBSERVATION = 1 PRUNING_PRIORITY = 2 SPARSIFICATION_PRIORITY = 3 - OP_INSERTION_PRIORITY = 4 QUANTIZATION_PRIORITY = 11 diff --git a/nncf/experimental/tensor/functions/__init__.py b/nncf/experimental/tensor/functions/__init__.py index 02c57a393e3..ba6bb83bc88 100644 --- a/nncf/experimental/tensor/functions/__init__.py +++ b/nncf/experimental/tensor/functions/__init__.py @@ -33,6 +33,8 @@ from nncf.experimental.tensor.functions.numeric import moveaxis as moveaxis from nncf.experimental.tensor.functions.numeric import multiply as multiply from nncf.experimental.tensor.functions.numeric import ones_like as ones_like +from nncf.experimental.tensor.functions.numeric import power as power +from nncf.experimental.tensor.functions.numeric import quantile as quantile from nncf.experimental.tensor.functions.numeric import reshape as reshape from nncf.experimental.tensor.functions.numeric import round as round from nncf.experimental.tensor.functions.numeric import squeeze as squeeze diff --git a/nncf/experimental/tensor/functions/numeric.py b/nncf/experimental/tensor/functions/numeric.py index 80425352134..cdd00f49b91 100644 --- a/nncf/experimental/tensor/functions/numeric.py +++ b/nncf/experimental/tensor/functions/numeric.py @@ -383,8 +383,16 @@ def round(a: Tensor, decimals=0) -> Tensor: @functools.singledispatch @tensor_guard -def power(a: Tensor, pwr: float) -> Tensor: - return Tensor(power(a.data, pwr)) +def power(a: Tensor, exponent: float) -> Tensor: + """ + Takes the power of each element in input with given power and + returns a tensor with the result. + + :param a: Input data. + :param exponent: Exponent value. + :return: The result of the power of each element in input with given exponent. + """ + return Tensor(power(a.data, exponent)) @functools.singledispatch @@ -392,13 +400,22 @@ def power(a: Tensor, pwr: float) -> Tensor: def quantile( a: Tensor, q: Union[float, List[float]], - axis: Union[int, List[int]] = None, + axis: Union[int, Tuple[int]] = None, keepdims: Optional[bool] = None, ) -> Union[float, Tensor]: - retval = quantile(a.data, q, axis, keepdims) + """ + Compute the quantile(s) of the data along the specified axis. - if isinstance(retval, float): - return retval + :param a: Given tensor. + :params q: Quantile or sequence of quantiles to compute, which must be between + 0 and 1 inclusive. + :param axis: Axis or axes along which the quantiles are computed. + :param keepdims: If True, the axes which are reduced are left in the result + as dimensions with size one. + :return: An tensor with quantiles, the first axis of the result corresponds + to the quantiles, the second axis of the result corresponds to the quantiles values. + """ + retval = quantile(a.data, q, axis, keepdims) return Tensor(retval) diff --git a/nncf/experimental/tensor/functions/numpy_numeric.py b/nncf/experimental/tensor/functions/numpy_numeric.py index 82039085f47..1c8b2728293 100644 --- a/nncf/experimental/tensor/functions/numpy_numeric.py +++ b/nncf/experimental/tensor/functions/numpy_numeric.py @@ -186,26 +186,24 @@ def _( return np.clip(a, a_min=min_val, a_max=max_val) -@register_numpy_types(numeric.eps) -def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> float: - return np.finfo(DTYPE_MAP[dtype]).eps - - @register_numpy_types(numeric.power) -def _(a: Union[np.ndarray, np.generic], pwr: float) -> Union[np.ndarray, np.generic]: - return np.power(a, pwr) +def _(a: Union[np.ndarray, np.generic], exponent: float) -> Union[np.ndarray, np.generic]: + return np.power(a, exponent) @register_numpy_types(numeric.quantile) def _( a: Union[np.ndarray, np.generic], q: Union[float, List[float]], - axis: Union[int, List[int]] = None, + axis: Union[int, Tuple[int]] = None, keepdims: Optional[bool] = None, ) -> Union[float, Union[np.ndarray, np.generic]]: if keepdims is None: keepdims = np._NoValue - return np.quantile(a, q=q, axis=axis, keepdims=keepdims) + ret_val = np.quantile(a, q=q, axis=axis, keepdims=keepdims) + if isinstance(ret_val, np.ndarray): + return ret_val + return np.array(ret_val) @register_numpy_types(numeric.size) diff --git a/nncf/experimental/tensor/functions/torch_numeric.py b/nncf/experimental/tensor/functions/torch_numeric.py index 2732cb03e03..ab612b18031 100644 --- a/nncf/experimental/tensor/functions/torch_numeric.py +++ b/nncf/experimental/tensor/functions/torch_numeric.py @@ -197,32 +197,31 @@ def _(a: torch.Tensor, min_val: float, max_val: Optional[float] = None) -> torch return torch.clip(a, min=min_val, max=max_val) -@numeric.eps.register(torch.Tensor) -def _(a: torch.Tensor, dtype: TensorDataType) -> float: - return torch.finfo(DTYPE_MAP[dtype]).eps - - @numeric.power.register(torch.Tensor) -def _(a: torch.Tensor, pwr: float) -> torch.Tensor: - return torch.pow(a, exponent=pwr) +def _(a: torch.Tensor, exponent: float) -> torch.Tensor: + return torch.pow(a, exponent=exponent) @numeric.quantile.register(torch.Tensor) def _( a: torch.Tensor, q: Union[float, List[float]], - axis: Union[int, List[int]] = None, + axis: Union[int, Tuple[int]] = None, keepdims: Optional[bool] = None, ) -> Union[float, torch.Tensor]: + device = a.device # See https://github.com/pytorch/pytorch/issues/61582 # https://github.com/pytorch/pytorch/issues/64947 - device = a.device - if keepdims is None: - keepdims = np._NoValue - np_result = np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims) - if isinstance(np_result, np.ndarray): - return torch.tensor(np_result).type(a.dtype).to(device) - return np_result + if len(a) <= 16_000_000 and isinstance(axis, int): + result = torch.quantile( + a, + torch.tensor(q, dtype=a.dtype, device=a.device), + axis, + keepdims, + ) + else: + result = torch.tensor(np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims)) + return result.type(a.dtype).to(device) @numeric.size.register(torch.Tensor) diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 73467910c24..77e16e2ba03 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -128,7 +128,14 @@ def apply( if any(val.data is None for val in activations_value): empty_statistic = True break - assert len(activations_value) == 1 + if len(activations_value) != 1: + raise RuntimeError( + ( + "More than one statistic is collected for one node during" + f"Smooth Quanti algorithm: {node_to_smooth.node_name}" + ) + ) + activations_value = self._clip_statistics(activations_value) weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) @@ -194,7 +201,7 @@ def _calculate_scale_and_ratio( a_min = fns.quantile(scales, quantile, keepdims=False) a_max = 1e2 - scales = fns.clip(scales, min_val=a_min, max_val=a_max) + scales = fns.clip(scales, a_min=a_min, a_max=a_max) ratio = scales.min() / (scales.max() + eps) return scales, ratio @@ -253,6 +260,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin for node_data in nodes_to_smooth_data: node_to_smooth = node_data["node_to_smooth"] target_point = self._backend_entity.target_point( + target_type=self._backend_entity.pre_layer_target_type(), target_node_name=node_to_smooth.node_name, port_id=node_data["input_act_port"], ) @@ -305,7 +313,7 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[ nodes_to_smooth_data.append( { "node_to_smooth": node_with_weight, - "input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph), + "input_act_port": activation_port_id, } ) return nodes_to_smooth_data @@ -435,4 +443,4 @@ def _clip_statistics(statistics: List[Tensor]) -> Tensor: statistics = fns.stack(statistics) squeezed = fns.squeeze(statistics) - return fns.clip(squeezed, min_val=a_min, max_val=None) + return fns.clip(squeezed, a_min=a_min, a_max=None) diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index 05aae4b0ac3..57440e1f371 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -11,13 +11,15 @@ from abc import ABC from abc import abstractmethod -from typing import List, Tuple, TypeVar +from typing import Callable, List, Tuple, TypeVar from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor import Tensor @@ -55,10 +57,20 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: @staticmethod @abstractmethod - def target_point(target_node_name: str, port_id: int) -> TargetPoint: + def pre_layer_target_type() -> TargetType: + """ + Returns backend-specific pre layer target type. + + :returns: Backend-specific pre layer target type. + """ + + @staticmethod + @abstractmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: """ Returns backend-specific target point. + :param target_type: Type of the location that should be modified. :param target_node_name: Name of the located node. :param port_id: Port ID of the tensor for the statistics distribution. :return: Backend-specific TargetPoint. @@ -184,10 +196,20 @@ def get_weight_channel_axis(node: NNCFNode) -> int: @staticmethod @abstractmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): - pass + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + """ + Returns true if given node shares constant with a different node. + + :param node: NNCFNode instance. + :param nncf_graph: NNCFGraph instance. + :return: Whether the given node is shares weights with a different node or not. + """ @staticmethod @abstractmethod - def get_filter_fn_for_statistics(activation_port_id: int): - pass + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: + """ + Returns backend-specific callable to filter statistic containers according to its statistic point. + + :param activation_port_id: Activation port id for the statistic collection target node. + """ diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index efee08ed253..9d1d7504073 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple import numpy as np import openvino.runtime as ov @@ -52,8 +52,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: return QUANTIZE_AGNOSTIC_OPERATIONS @staticmethod - def target_point(target_node_name: str, port_id: int) -> OVTargetPoint: - return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id) + def pre_layer_target_type() -> TargetType: + return TargetType.PRE_LAYER_OPERATION + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: + return OVTargetPoint(target_type, target_node_name, port_id) @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: @@ -92,15 +96,11 @@ def get_abs_max_channel_collector( @staticmethod def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> Tensor: - port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) + port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight) return Tensor(get_weight_value(node_with_weight, model, port_id)) @staticmethod def get_weight_tensor_port_id(node: NNCFNode) -> int: - return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node) - - @staticmethod - def _get_weight_tensor_port_id(node: NNCFNode) -> int: const_ids = node.layer_attributes.get_const_port_ids() if len(const_ids) != 1: raise RuntimeError(f"Found more than 1 port for {node.node_name} node") @@ -108,7 +108,7 @@ def _get_weight_tensor_port_id(node: NNCFNode) -> int: @staticmethod def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand: - weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) + weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight) return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id) @staticmethod @@ -154,13 +154,13 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: return -2 + port_id if transpose else -1 - port_id @staticmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node) weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node return len(nncf_graph.get_next_nodes(weight_node)) > 1 @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int): + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: return point.target_point.port_id == activation_port_id diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 7927461378a..a8315890771 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple import numpy as np @@ -63,8 +63,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT[QuantizationTrait.QUANTIZATION_AGNOSTIC] @staticmethod - def target_point(target_node_name: str, port_id: int) -> PTTargetPoint: - return PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node_name, input_port_id=port_id) + def pre_layer_target_type() -> TargetType: + return TargetType.OPERATOR_PRE_HOOK + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: @@ -92,7 +96,7 @@ def get_abs_max_channel_collector( def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor: node_module = model.nncf.get_containing_module(node_with_weight.node_name) if node_module.weight is None: - return None + raise RuntimeError(f"{node_module} module has no .weight attribute.") return Tensor(node_module.weight.data) @staticmethod @@ -130,11 +134,11 @@ def get_weight_channel_axis(node: NNCFNode) -> int: return 1 @staticmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return node.is_shared() @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int): + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: return True diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index 2b20e5ec3f4..53ed1c74667 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -99,7 +99,7 @@ def __init__(self): self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict) self._num_nested_hooks = 0 self.reused_parameters = [] - self._hooks_counter = 0 + self._hooks_counter = -1 self._threading = CopySafeThreadingVars() diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index d58e289379b..4d6ce1c5a0d 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from enum import IntEnum -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar +from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar import torch from torch import nn @@ -368,25 +368,6 @@ def reset_nncf_modules(self): module = self.get_module_by_scope(some_scope) module.reset() - def get_shallow_copy(self) -> "NNCFNetwork": - from nncf.torch.utils import load_module_state - from nncf.torch.utils import save_module_state - - saved_state = save_module_state(self._model_ref) - new_interface = NNCFNetworkInterface( - self._model_ref, - self._input_infos, - self._user_dummy_forward_fn, - self._wrap_inputs_fn, - self._scopes_without_shape_matching, - self._ignored_scopes, - self._target_scopes, - wrap_outputs_fn=self._wrap_outputs_fn, - ) - self._model_ref._nncf = new_interface - load_module_state(self._model_ref, saved_state) - return self._model_ref - def get_clean_shallow_copy(self) -> "NNCFNetwork": # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions # and load_nncf_module_additions to preserve these, or temporary_clean_view(). @@ -418,9 +399,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc retval[nncf_module_scope + relative_scope] = target_module return retval - def update_model_ref(self, model: torch.nn.Module) -> None: - object.__setattr__(self, "__model_ref", model) - def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): hook_addresses = self.insert_at_point(point, fn_list) self._temprorary_hooks_adresses.append(hook_addresses) @@ -849,14 +827,6 @@ def get_reused_parameters(self): return ret -class TemporaryOp: - def __init__(self, op: Callable) -> None: - self._op = op - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self._op(*args, **kwargs) - - class NNCFNetworkMeta(type): """ Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index ad065392470..68097eada08 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -14,7 +14,6 @@ from nncf.torch.quantization.debug_interface import QuantizationDebugInterface EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers" -EXTERNAL_OP_STORAGE_NAME = "external_op" EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index fa49cf3cca3..65beac840be 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from typing import Dict import numpy as np @@ -27,41 +26,9 @@ from nncf.torch.tensor_statistics.algo import create_register_input_hook -class ModelView: - def __init__(self, model: NNCFNetwork): - self.model = model - self.nncf_module_additions = self.model.nncf.save_nncf_module_additions() - - def __enter__(self): - # Model ref removed to prevent copying - self.model.nncf.update_model_ref(None) - - # nncf_replaced_models removed to prevent copying - replaced_modules = self.model.nncf._nncf_replaced_modules - self.model.nncf._nncf_replaced_modules = None - - self.nncf_interface = deepcopy(self.model.nncf) - - # Model ref is recovering - self.model.nncf.update_model_ref(self.model) - self.nncf_interface.update_model_ref(self.model) - - # nncf_replaced_models is recovering - self.model.nncf._nncf_replaced_modules = replaced_modules - self.nncf_interface._nncf_replaced_modules = replaced_modules - return self.model - - def __exit__(self, exc_type, exc_val, exc_tb): - self.model._nncf = self.nncf_interface - self.model.nncf.reset_nncf_modules() - self.model.nncf.load_nncf_module_additions(self.nncf_module_additions) - - class PTStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: with torch.no_grad(): - # with ModelView(model) as intermediate_model: - # super().collect_statistics(intermediate_model, graph) super().collect_statistics(model, graph) model.nncf.remove_temporary_ops() diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py index 908e482f889..edf234734dc 100644 --- a/nncf/torch/tensor.py +++ b/nncf/torch/tensor.py @@ -9,12 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union - import torch from nncf.common.tensor import NNCFTensor -from nncf.torch.return_types import maybe_unwrap_from_torch_return_type class PTNNCFTensor(NNCFTensor): @@ -22,13 +19,11 @@ class PTNNCFTensor(NNCFTensor): A realisation of torch tensors wrapper for common NNCF algorithms. """ - def __init__(self, tensor: Union[torch.tensor, "PTNNCFTensor", tuple]): + def __init__(self, tensor: torch.tensor): # In case somebody attempts to wrap # tensor twice if isinstance(tensor, self.__class__): tensor = tensor.tensor - else: - tensor = maybe_unwrap_from_torch_return_type(tensor) super().__init__(tensor) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 49fead9fb8b..84a16b91585 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -15,7 +15,7 @@ from collections import defaultdict from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, TypeVar, Union import numpy as np import onnx @@ -35,6 +35,7 @@ from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args @@ -503,6 +504,9 @@ def load_exported_onnx_version( return model_proto +HookType = TypeVar("HookType") + + class HookChecker: """ Class to check pre/post hooks and pre ops are placed correctly. @@ -535,7 +539,11 @@ def _convert_to_op_address(self, target_type: TargetType, target_node_name: str, address = address_map[target_node_name] if target_type == TargetType.OPERATOR_PRE_HOOK: address = PreHookId(address, input_port_id) - elif target_type == TargetType.OPERATION_WITH_WEIGHTS: + elif target_type in [ + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: address = getattr(self._target_model, self._nncf_module_attr_name) return address @@ -544,6 +552,15 @@ def check_with_reference(self): Check hooks in the target model and reference hooks are matching. """ self._check_weight_update_hooks(self._ref_hooks[TargetType.OPERATION_WITH_WEIGHTS]) + + target_module = getattr(self._target_model, self._nncf_module_attr_name) + if target_module in self._ref_hooks[TargetType.PRE_LAYER_OPERATION]: + hooks = target_module.pre_ops + self._check_pre_post_op_hooks(hooks, self._ref_hooks[TargetType.PRE_LAYER_OPERATION][target_module]) + if target_module in self._ref_hooks[TargetType.POST_LAYER_OPERATION]: + hooks = target_module.post_ops + self._check_pre_post_op_hooks(hooks, self._ref_hooks[TargetType.POST_LAYER_OPERATION][target_module]) + hooks = self._target_model.nncf._compressed_context._pre_hooks self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_PRE_HOOK]) hooks = self._target_model.nncf._compressed_context._post_hooks @@ -556,7 +573,7 @@ def clear(self): self._ref_hooks.clear() @staticmethod - def _check_weight_update_hooks(ref_hooks): + def _check_weight_update_hooks(ref_hooks: Dict[torch.nn.Module, List[HookType]]): for target_module, ref_hooks_per_module in ref_hooks.items(): assert len(target_module.pre_ops) == len(ref_hooks_per_module) for actual_op, ref_op in zip(target_module.pre_ops.values(), ref_hooks_per_module): @@ -564,7 +581,15 @@ def _check_weight_update_hooks(ref_hooks): assert actual_op.op is ref_op @staticmethod - def _check_pre_post_hooks(hooks, ref_hooks): + def _check_pre_post_op_hooks(hooks: List[torch.ModuleDict], ref_hooks: List[HookType]): + assert len(hooks) == len(ref_hooks) + for actual_hook, ref_hook in zip(hooks.values(), ref_hooks): + assert actual_hook is ref_hook + + @staticmethod + def _check_pre_post_hooks( + hooks: Dict[OperationAddress, Dict[Any, HookType]], ref_hooks: Dict[OperationAddress, List[HookType]] + ): assert len(hooks) == len(ref_hooks) for op_address, ref_hooks in ref_hooks.items(): actual_hooks = hooks[op_address].values() diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py index 2c9da2c031e..07e7a041242 100644 --- a/tests/torch/ptq/test_smooth_quant.py +++ b/tests/torch/ptq/test_smooth_quant.py @@ -27,12 +27,9 @@ from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm PT_LINEAR_MODEL_SQ_MAP = { - ("Linear1",): "LinearMultiShapeModel/split_0_1_0/nncf_smooth_quant" - "[LinearMultiShapeModel/NNCFLinear[linear_2]/linear_0]", - ("Linear2",): "LinearMultiShapeModel/split_0_0_0/nncf_smooth_quant" - "[LinearMultiShapeModel/NNCFLinear[linear_1]/linear_0]", - ("Linear3", "Linear4"): "LinearMultiShapeModel/add_0_0_0/nncf_smooth_quant" - "[LinearMultiShapeModel/NNCFLinear[linear_3]/linear_0;LinearMultiShapeModel/NNCFLinear[linear_4]/linear_0]", + ("Linear1",): "LinearMultiShapeModel/split_0_1_0/nncf_smooth_quant", + ("Linear2",): "LinearMultiShapeModel/split_0_0_0/nncf_smooth_quant", + ("Linear3", "Linear4"): "LinearMultiShapeModel/add_0_0_0/nncf_smooth_quant", } PT_LINEAR_MODEL_MM_MAP = { diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index 96185d64f76..3cf13ebc841 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -50,6 +50,7 @@ from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -149,8 +150,11 @@ def setup(self): point_for_relu_inputs, ] + @pytest.mark.parametrize( + "insert_method_name,check_tmp_ops", [("insert_at_point", False), ("temporary_insert_at_point", True)] + ) @pytest.mark.parametrize("target_point", available_points) - def test_single_insertions(self, setup, target_point: PTTargetPoint): + def test_single_insertions(self, setup, target_point: PTTargetPoint, insert_method_name: str, check_tmp_ops: bool): insertion_point = PTInsertionPoint( target_point.target_type, OperationAddress.from_str(target_point.target_node_name), @@ -161,7 +165,8 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): else: hook = BaseOp(lambda x: x) - self.compressed_model.nncf.insert_at_point(insertion_point, [hook]) + insert_at_point_method = getattr(self.compressed_model.nncf, insert_method_name) + insert_at_point_method(insertion_point, [hook]) if insertion_point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() @@ -178,6 +183,9 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): module = self.compressed_model.nncf.get_module_by_scope(insertion_point.module_scope) assert module.post_ops["0"] is hook + if check_tmp_ops: + assert len(self.compressed_model.nncf._temprorary_hooks_adresses) == 1 + priority_types = ["same", "different"] insertion_types = TargetType priority_test_cases = list(itertools.product(priority_types, insertion_types)) @@ -187,8 +195,9 @@ def check_order(iterable1: List, iterable2: List, ordering: List): for idx, order in enumerate(ordering): assert iterable1[idx] is iterable2[order] + @pytest.mark.parametrize("command_cls", [PTInsertionCommand, PTInsertionTemporaryCommand]) @pytest.mark.parametrize("case", priority_test_cases, ids=[x[1].name + "-" + x[0] for x in priority_test_cases]) - def test_priority(self, case, setup): + def test_priority(self, case, command_cls, setup): priority_type = case[0] insertion_type = case[1] @@ -218,14 +227,14 @@ def test_priority(self, case, setup): if priority_type == "same": # Same-priority commands will be executed in registration order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.DEFAULT_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.DEFAULT_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.DEFAULT_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.DEFAULT_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) else: # Prioritized commands will be executed in ascending priority order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) layout = PTTransformationLayout() layout.register(command1) @@ -245,10 +254,12 @@ def test_priority(self, case, setup): pre_hook_id = PreHookId( OperationAddress.from_str(point.target_node_name), input_port_id=point.input_port_id ) - self.check_order(ctx._pre_hooks[pre_hook_id], hook_list, order) + actual_pre_hooks = list(ctx._pre_hooks[pre_hook_id].values()) + self.check_order(actual_pre_hooks, hook_list, order) if insertion_type == TargetType.OPERATOR_POST_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() - self.check_order(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)], hook_list, order) + actual_post_hooks = list(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)].values()) + self.check_order(actual_post_hooks, hook_list, order) if insertion_type == TargetType.OPERATION_WITH_WEIGHTS: module = self.compressed_model.nncf.get_containing_module(point.target_node_name) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 10e78307b59..4776b3fdf99 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -48,6 +48,7 @@ from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel +from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test @@ -925,3 +926,55 @@ def test_insert_hook_after_parameter(): assert hook.forward_calls_counter == 1 assert torch.sum(result.nonzero()) > 0 assert torch.sum(result_with_hook.nonzero()) == 0 + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +def test_temporary_insert_at_point(target_type, target_node_name, input_port_id): + class Hook(torch.nn.Module): + def forward(self, x): + return x + + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + permanent_hook = Hook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = Hook() + nncf_model.nncf.temporary_insert_at_point(ip, [temporary_hook]) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_temporary_ops() + del ref_hooks[-2] + _check(ref_hooks) diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index addb82d1ff7..582dfb0fdae 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -21,13 +21,14 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.quantization.structs import QuantizationMode +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.range_estimator import RangeEstimatorParametersSet +from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.model_transformer import PTInsertionCommand from nncf.torch.statistics.aggregator import PTStatisticsAggregator @@ -133,7 +134,7 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_ pass @pytest.mark.parametrize( - "test_parameters, ", + "test_parameters", ( MinMaxTestParameters( RangeEstimatorParametersSet.MINMAX, @@ -185,21 +186,17 @@ def test_successive_statistics_aggregation( def fn(x): return x * 2 - layout = TransformationLayout() target_point = self.get_target_point(test_parameters.target_type) - command = PTInsertionCommand(target_point, fn) - layout.register(command) - model_transformer = factory.ModelTransformerFactory.create(model) - model = model_transformer.transform(layout) - model.nncf.rebuild_graph() + model = self.__add_fn_to_model(model, target_point, fn) ### Check hook inserted correctly - self.__check_hooks(test_parameters, model, target_point, fn) + self.__check_successive_hooks(test_parameters, model, target_point, fn) ### Register and collect statistics after inserted operations - tensor_collector = self.__collect_statistics_get_collector( + statistic_points = self.__get_statistic_points( test_parameters, model, quantizer_config, dataset_samples, inplace_statistics ) + tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) ### Check values are changed because of the inserted operation self.__check_collector( test_parameters, @@ -208,42 +205,145 @@ def fn(x): ) ### Check the inserted operation is inside the model - self.__check_hooks(test_parameters, model, target_point, fn) + self.__check_successive_hooks(test_parameters, model, target_point, fn) - def __collect_statistics_get_collector( - self, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics + @pytest.mark.parametrize( + "test_parameters, nested_target_node_name", + ( + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_PRE_HOOK, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/fn_0", + ), + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATION_WITH_WEIGHTS, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/fn_0", + ), + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_POST_HOOK, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/fn_0", + ), + ), + ) + @pytest.mark.parametrize("nested_target_type", [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK]) + def test_nested_statistics_aggregation( + self, + test_parameters: MinMaxTestParameters, + nested_target_type: TargetType, + nested_target_node_name, + dataset_samples, + is_stat_in_shape_of_scale, + inplace_statistics, + is_backend_support_custom_estimators, ): + model = self.get_backend_model(dataset_samples) + quantizer_config = QuantizerConfig( + mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel + ) + + is_standard_estimator = test_parameters.range_estimator_params in [ + RangeEstimatorParametersSet.MINMAX, + RangeEstimatorParametersSet.MEAN_MINMAX, + ] + if not is_standard_estimator and not is_backend_support_custom_estimators: + pytest.skip("Custom estimators are not supported for this backend yet") + + ### Register operations before statistic collection + @register_operator() + def fn(x): + return x * 2 + + target_point = self.get_target_point(test_parameters.target_type) + model = self.__add_fn_to_model(model, target_point, fn) + nested_target_point = PTMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0) + model = self.__add_fn_to_model(model, nested_target_point, fn) + + ### Check hook inserted correctly + self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) + + ### Register and collect statistics after inserted operations + statistic_points = self.__get_statistic_points( + test_parameters, + model, + quantizer_config, + dataset_samples, + inplace_statistics, + ) + tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) + ### Check values are changed because of the inserted operation + self.__check_collector( + test_parameters, + tensor_collector, + is_stat_in_shape_of_scale, + ) + + ### Check the inserted operation is inside the model + self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) + + @staticmethod + def __add_fn_to_model(model, target_point, fn): + layout = TransformationLayout() + command = PTInsertionCommand(target_point, fn) + layout.register(command) + model_transformer = factory.ModelTransformerFactory.create(model) + model = model_transformer.transform(layout) + model.nncf.rebuild_graph() + return model + + @classmethod + def __get_statistic_points( + cls, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics + ) -> StatisticPointsContainer: statistics_points = StatisticPointsContainer() - for target_point in [test_parameters.target_type]: - target_point = self.get_target_point(target_point) - algorithm_name = "TestAlgo" - statistic_point = self.create_statistics_point( + for target_type in [test_parameters.target_type]: + target_point = cls.get_target_point(target_type) + statistic_point = cls.create_statistics_point( model, quantizer_config, target_point, len(dataset_samples), - algorithm_name, + "TEST_ALGO", inplace_statistics, test_parameters.range_estimator_params, ) statistics_points.add_statistic_point(statistic_point) + return statistics_points + def __collect_statistics_get_collector( + self, + statistics_points: StatisticPointsContainer, + model, + dataset_samples, + ): dataset = self.get_dataset(dataset_samples) statistics_aggregator = self.get_statistics_aggregator(dataset) statistics_aggregator.register_statistic_points(statistics_points) graph = NNCFGraphFactory.create(model) statistics_aggregator.collect_statistics(model, graph) - def filter_func(point): - return ( - algorithm_name in point.algorithm_to_tensor_collectors and point.target_point.type == target_point.type - ) - - tensor_collectors = list( - statistics_points.get_algo_statistics_for_node(target_point.target_node_name, filter_func, algorithm_name) - ) + tensor_collectors = list(statistics_points.get_tensor_collectors()) assert len(tensor_collectors) == 1 - return tensor_collectors[0] + return tensor_collectors[0][2] @staticmethod def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale): @@ -268,7 +368,20 @@ def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale) assert stat.max_values.shape == ref_shape @staticmethod - def __check_hooks(test_parameters, model, target_point, fn): + def __check_successive_hooks(test_parameters, model, target_point, fn): + checker = HookChecker(model, "conv") + checker.add_ref( + ref_hooks=[fn], + target_type=test_parameters.target_type, + target_node_name=target_point.target_node_name, + input_port_id=0, + ) + checker.check_with_reference() + + @staticmethod + def __check_nested_hooks( + test_parameters, model, target_point, nested_target_type: TargetType, nested_target_node_name: str, fn + ): checker = HookChecker(model, "conv") checker.add_ref( ref_hooks=[fn], @@ -276,4 +389,10 @@ def __check_hooks(test_parameters, model, target_point, fn): target_node_name=target_point.target_node_name, input_port_id=0, ) + checker.add_ref( + ref_hooks=[fn], + target_type=nested_target_type, + target_node_name=nested_target_node_name, + input_port_id=0, + ) checker.check_with_reference()