From 7701b88b6d565293857e58a06fa9da33262d94ec Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 2 Dec 2024 14:10:00 +0100 Subject: [PATCH] Reuse MinMax algo instead of copy-paste --- .../algorithms/post_training/algorithm.py | 2 +- .../algorithms/post_training/pipeline.py | 2 +- .../{quantizer.py => base_quantizer.py} | 0 .../algorithms/quantizer/fx_quantizer.py | 2 +- .../algorithms/range_estimator/backend.py | 154 ------ .../range_estimator/range_estimator.py | 452 +----------------- .../range_estimator/torch_fx_backend.py | 221 --------- .../algorithms/min_max/algorithm.py | 40 +- tests/common/quantization/test_minmax.py | 21 +- 9 files changed, 72 insertions(+), 822 deletions(-) rename nncf/experimental/common/quantization/algorithms/quantizer/{quantizer.py => base_quantizer.py} (100%) delete mode 100644 nncf/experimental/common/quantization/algorithms/range_estimator/backend.py delete mode 100644 nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py diff --git a/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py b/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py index 17558fbb1db..efbc5bb7449 100644 --- a/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py +++ b/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py @@ -17,7 +17,7 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.experimental.common.quantization.algorithms.post_training.pipeline import experimental_create_ptq_pipeline -from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from nncf.quantization.advanced_parameters import RangeEstimatorParameters diff --git a/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py b/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py index de496eba2b3..dab7b2be856 100644 --- a/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py +++ b/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py @@ -11,7 +11,7 @@ from typing import Optional, TypeVar -from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters diff --git a/nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py b/nncf/experimental/common/quantization/algorithms/quantizer/base_quantizer.py similarity index 100% rename from nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py rename to nncf/experimental/common/quantization/algorithms/quantizer/base_quantizer.py diff --git a/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py b/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py index 7842e7475f0..db0ae167132 100644 --- a/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py +++ b/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py @@ -29,7 +29,7 @@ from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig -from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]] diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py b/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py deleted file mode 100644 index dbd11f3f6b7..00000000000 --- a/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC -from abc import abstractmethod -from typing import List, Optional, Set, Tuple, TypeVar - -from nncf.common.graph.graph import NNCFGraph -from nncf.common.graph.graph import NNCFNode -from nncf.common.graph.transformations.commands import TargetPoint -from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationCommand -from nncf.common.quantization.structs import QuantizerConfig -from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase -from nncf.quantization.fake_quantize import FakeQuantizeParameters -from nncf.quantization.range_estimator import RangeEstimatorParameters - -TModel = TypeVar("TModel") - - -class RangeEstimatorAlgoBackend(ABC): - @staticmethod - @abstractmethod - def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: - """ - Returns backend-specific target point. - - :param target_type: Type of the location that should be modified. - :param target_node_name: Name of the located node. - :param port_id: Port ID of the tensor for the statistics distribution. - :return: Backend-specific TargetPoint. - """ - - @staticmethod - @abstractmethod - def create_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: TargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> TransformationCommand: - """ - Returns backend-specific quantizer insertion command. - - :param nncf_graph: NNCFGraph to get input/output shapes for the target point. - :param target_point: Target location for the quantizer insertion. - :param quantizer_config: QuantizerConfig instance for the current layer. - :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. - :return: Backend-specific TransformationCommand for the quantizer insertion operation. - """ - - @staticmethod - @abstractmethod - def create_unified_scales_quantizers_insertion_commands( - nncf_graph: NNCFGraph, - target_points: List[TargetPoint], - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> List[TransformationCommand]: - """ - Returns backend-specific unified scales quantizers insertion commands. - - :param nncf_graph: NNCFGraph to get input/output shapes for the target point. - :param target_points: List of target locations for the quantizers insertion. - :param quantizer_config: QuantizerConfig instance for the current layer. - :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. - :return: List of backend-specific TransformationCommands - for the quantizers with unified scales insertion operations. - """ - - @staticmethod - @abstractmethod - def get_target_point_shape(nncf_graph: NNCFGraph, node: NNCFNode, target_point: TargetPoint) -> Tuple[int, ...]: - """ - Returns shape of a target point tensor. - - :param nncf_graph: NNCFGraph instance. - :param node: NNCFNode. - :param target_point: Target point of which tensor shape is seeked. - :return: Shape of target point tensor. - """ - - @staticmethod - @abstractmethod - def get_weight_quantization_axes(node: NNCFNode, target_point: TargetPoint, ndims: int) -> Tuple[int, ...]: - """ - Returns axes for per-channel quantization of weights of the node placed on a input port_id. - - :param node: Quantized node with the weight. - :param target_point: Corresponding target point. - :param ndims: Number of dimensions of weight. - :return: Axes for per-channel quantization of weights. - """ - - @staticmethod - @abstractmethod - def get_statistic_collector( - range_estimator_params: RangeEstimatorParameters, - use_abs_max: bool, - reduction_axes: Optional[Tuple[int, ...]], - aggregation_axes: Optional[Tuple[int, ...]], - inplace: bool, - num_samples: Optional[int] = None, - ) -> TensorStatisticCollectorBase: - """ - Returns backend-specific statistic collector. - - :param range_estimator_params: Parameters that specify estimators types. - :param use_abs_max: Wheather reduce absolute values of input tensors or not. - :param reduction_axes: Axes for reducer. - :param aggregation_axes: Axes for aggregator. - :param inplace: Whether to calculate statistic inplace or not. - :param num_samples: Maximum number of samples to collect. - :return: Backend-specific TensorStatisticCollectorBase for the statistics calculation. - """ - - @staticmethod - @abstractmethod - def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: - """ - Returns node's input port indices with weight tensors. - - :param node: NNCFNode to find its weight input port indices. - :param graph: NNCFGraph instance. - :return: Weights input port indices. - """ - - @staticmethod - def get_weight_name(nncf_graph: NNCFGraph, target_point: TargetPoint) -> str: - """ - Returns node's weight name corresponding to port ID. - - :param nncf_graph: NNCFGraph instance. - :param target_point: The TargetPoint instance that contains layer's information. - :return: Weight name. - """ - - @staticmethod - def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: - """ - Return True if weight should be quantized. - - :param weight_name: Weight name. - :param quantized_weight_names: Set containing already quantized weight names. - :return: A boolean value specifying whether a weight should be quantized. - """ diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py b/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py index 5431703cdb1..1b5ad8c5692 100644 --- a/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py +++ b/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py @@ -9,38 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections -import dataclasses -from copy import deepcopy -from typing import List, Optional, OrderedDict, Tuple, TypeVar +from typing import List, Optional, TypeVar -import nncf -import nncf.tensor.functions as fns from nncf import Dataset -from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph -from nncf.common.graph.transformations.commands import TargetPoint -from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.logging import nncf_logger -from nncf.common.quantization.initialization.range import RangeInitCollectorParams -from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint -from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup -from nncf.common.quantization.structs import QuantizerConfig -from nncf.common.quantization.structs import QuantizerGroup -from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase -from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType -from nncf.common.utils.backend import get_backend -from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer -from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.quantization.advanced_parameters import changes_asdict +from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer from nncf.quantization.algorithms.algorithm import Algorithm -from nncf.quantization.fake_quantize import calculate_quantizer_parameters -from nncf.quantization.fake_quantize import get_quantizer_narrow_range +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.range_estimator import RangeEstimatorParameters -from nncf.quantization.range_estimator import RangeEstimatorParametersSet TModel = TypeVar("TModel") @@ -69,315 +47,18 @@ def __init__( for weights. """ self._quantizer = quantizer - self._subset_size = subset_size - self._inplace_statistics = inplace_statistics - self._batchwise_statistics = batchwise_statistics - self._activations_range_estimator_params = activations_range_estimator_params - self._weights_range_estimator_params = weights_range_estimator_params - - self._range_estimator_params = { - QuantizerGroup.WEIGHTS: self._weights_range_estimator_params, - QuantizerGroup.ACTIVATIONS: self._activations_range_estimator_params, - } - # Calculates global quantizer constraints - self._reset_cache() - self._algorithm_key = f"MMQ_{hash(self)}" - - def _reset_cache(self) -> None: - """ - Marks cache by noninitialized values. Needs to be called when the new quantizer setup is needed. - """ - self._quantization_target_points_to_qconfig: OrderedDict[TargetPoint, QuantizerConfig] = None - self._unified_scale_groups = None - - def _init_cache(self) -> None: - """ - Initializes cache. - """ - self._quantization_target_points_to_qconfig: OrderedDict[TargetPoint, QuantizerConfig] = ( - collections.OrderedDict() + self._min_max_algo = MinMaxQuantization( + subset_size=subset_size, + inplace_statistics=inplace_statistics, + batchwise_statistics=batchwise_statistics, + activations_range_estimator_params=activations_range_estimator_params, + weights_range_estimator_params=weights_range_estimator_params, ) - self._unified_scale_groups = [] @property def available_backends(self) -> List[BackendType]: return [BackendType.TORCH_FX] - def _set_backend_entity(self, model: TModel) -> None: - """ - Creates a helper class with a backed-specific logic of the algorithm - - :param model: backend-specific input model - """ - model_backend = get_backend(model) - if model_backend == BackendType.TORCH_FX: - from nncf.experimental.common.quantization.algorithms.range_estimator.torch_fx_backend import ( - FXRangeEstimatorAlgoBackend, - ) - - self._backend_entity = FXRangeEstimatorAlgoBackend() - else: - raise nncf.UnsupportedBackendError( - "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) - ) - - def _get_range_estimator_parameters( - self, target_point: TargetPoint, quantizer_config: QuantizerConfig - ) -> RangeEstimatorParameters: - """ - Returns range estimator parameters. - - :param target_point: Quantizer target point. - :param quantizer_config: Quantizer config. - :return: Range estimator parameters. - """ - quantizer_group = QuantizerGroup.ACTIVATIONS - if target_point.is_weight_target_point(): - quantizer_group = QuantizerGroup.WEIGHTS - - if quantizer_group == QuantizerGroup.WEIGHTS or ( - quantizer_group == QuantizerGroup.ACTIVATIONS and quantizer_config.per_channel - ): - params = RangeEstimatorParametersSet.MINMAX - else: - params = RangeEstimatorParametersSet.MEAN_MINMAX - - user_params = self._range_estimator_params[quantizer_group] - if user_params is None: - return deepcopy(params) - - min_changes = changes_asdict(user_params.min) - min_statistic_collector = dataclasses.replace(params.min, **min_changes) - - max_changes = changes_asdict(user_params.max) - max_statistic_collector = dataclasses.replace(params.max, **max_changes) - - return RangeEstimatorParameters(min_statistic_collector, max_statistic_collector) - - def _get_stat_collector( - self, - graph: NNCFGraph, - target_point: TargetPoint, - qconfig: QuantizerConfig, - batchwise_statistics: bool, - ) -> TensorStatisticCollectorBase: - """ - Creates and returns a statistic collector based on the quantizer's configuration. - - :param graph: NNCFGraph instance. - :param target_point: Target point indicates where statistics should be collected. - :param qconfig: Configuration of a quantizer layer, - defining the configuration of created statistic collector. - :param batchwise_statistics: Determines whether quantizer statistics should be calculated - for each item of the batch or for the entire batch. - :return: Statistic Collector. - """ - is_weight = target_point.is_weight_target_point() - node = graph.get_node_by_name(target_point.target_node_name) - shape = self._backend_entity.get_target_point_shape(graph, node, target_point) - range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) - - channel_axes = () - if qconfig.per_channel: - channel_axes = ( - self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,) - ) - - # Weight statistics is constant, so only one collection is enough. - num_samples = self._subset_size if not is_weight else 1 - - batchwise_statistics = batchwise_statistics and not is_weight - - collector_params = RangeInitCollectorParams( - is_weights=is_weight, scheme=qconfig.mode, per_channel=qconfig.per_channel - ) - reduction_axes, aggregation_axes = None, None - if shape is not None: - reduction_axes, aggregation_axes = collector_params.get_reduction_aggregation_axes( - shape, channel_axes, batchwise_statistics - ) - - return self._backend_entity.get_statistic_collector( - range_estimator_params, - collector_params.use_abs_max, - reduction_axes, - aggregation_axes, - self._inplace_statistics, - num_samples=num_samples, - ) - - def _add_weight_quantization_target_point( - self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph - ) -> None: - """ - Adds weight quantization target point to the set of existing points. - - :param quantization_point: SingleConfigQuantizationPoint for the needed layer. - :param nncf_graph: The built NNCFGraph of the model. - """ - weight_quantization_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph) - for weight_quantization_target_point in weight_quantization_target_points: - self._quantization_target_points_to_qconfig[weight_quantization_target_point] = quantization_point.qconfig - - def _add_activation_quantization_target_point( - self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph - ) -> None: - """ - Adds activation quantization target point to the set of existing points. - - :param quantization_point: SingleConfigQuantizationPoint for the needed layer. - :param nncf_graph: NNCFGraph instance for working with the graph and nodes. - """ - activation_quantization_target_point = self._get_activation_quantization_target_point( - quantization_point, nncf_graph - ) - self._quantization_target_points_to_qconfig[activation_quantization_target_point] = quantization_point.qconfig - - def _get_weight_quantization_target_points( - self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph - ) -> List[SingleConfigQuantizationPoint]: - """ - Returns weight quantization target points to the set of existing points. - - :param quantization_point: SingleConfigQuantizationPoint for the needed layer. - :param nncf_graph: NNCFGraph instance for working with the graph and nodes. - :return: List of SingleConfigQuantizationPoints for the needed layer. - """ - weight_quantization_target_points = [] - node_name = quantization_point.insertion_point.target_node_name - node = nncf_graph.get_node_by_name(node_name) - weights_port_ids = self._backend_entity.get_weight_tensor_port_ids(node, nncf_graph) - for port_id in weights_port_ids: - weight_quantization_target_points.append( - self._backend_entity.target_point(TargetType.OPERATION_WITH_WEIGHTS, node_name, port_id) - ) - return weight_quantization_target_points - - def _get_activation_quantization_target_point( - self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph - ) -> SingleConfigQuantizationPoint: - """ - Returns activation quantization target point to the set of existing points. - - :param quantization_point: SingleConfigQuantizationPoint for the needed layer. - :param nncf_graph: NNCFGraph instance for working with the graph and nodes. - :return: SingleConfigQuantizationPoint for the needed layer. - """ - node_name = quantization_point.insertion_point.target_node_name - # If Quantization of node's input - if quantization_point.insertion_point.input_port_id is not None: - input_port_id = quantization_point.insertion_point.input_port_id - activation_quantization_target_point = self._backend_entity.target_point( - TargetType.PRE_LAYER_OPERATION, node_name, input_port_id - ) - # If quantization of node's output or Model Input node - else: - # NOTE: Assumes that the operation has output edges only from one output port because - # we haven't encountered a model with operations that have multiple output edges with different - # output port IDs. Currently, such models are not supported. Usually, `output_port_id = 0` is used. - # However, there are operations, such as LSTMSequence, where the `output_port_id` changes from case - # to case. Therefore, the code below is required to dynamically determine the `output_port_id` where - # the quantize operation should be inserted." - node = nncf_graph.get_node_by_name(node_name) - unique_output_port_ids = set(e.output_port_id for e in nncf_graph.get_output_edges(node)) - if len(unique_output_port_ids) > 1: - nncf_logger.warning( - f"Cannot determine the output_port_id for the operation: {node_name}, " - "output_port_id = 0 will be used." - ) - output_port_id = 0 - else: - output_port_id = next(iter(unique_output_port_ids)) - - activation_quantization_target_point = self._backend_entity.target_point( - TargetType.POST_LAYER_OPERATION, node_name, output_port_id - ) - return activation_quantization_target_point - - def _find_quantization_target_points( - self, model: TModel, nncf_graph: NNCFGraph - ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: - """ - Initializes a cache, finds quantization target points and them puts in the cache. - - :param model: Backend-specific model, for which Quantization Target Points are being seek. - :param nncf_graph: NNCFGraph instance. - :return: Mapping of quantization target points with associated quantization configuration, - along with target points for scale unification. - """ - quantizer_setup = self._quantizer.get_quantization_setup(model, nncf_graph) - self._unified_scale_groups = self._collect_unified_groups(quantizer_setup, nncf_graph) - quantization_points = list(quantizer_setup.quantization_points.values()) - quantization_points = self._topological_sort_quantization_points(quantization_points, nncf_graph) - for quantization_point in quantization_points: - if quantization_point.is_weight_quantization_point(): - self._add_weight_quantization_target_point(quantization_point, nncf_graph) - elif quantization_point.is_activation_quantization_point(): - self._add_activation_quantization_target_point(quantization_point, nncf_graph) - else: - raise nncf.InternalError("Incorrect quantization point") - return self._quantization_target_points_to_qconfig, self._unified_scale_groups - - def _get_quantization_target_points( - self, model: TModel, nncf_graph: NNCFGraph - ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: - """ - Returns Quantization Target Points. - Returns a cache with target points if exists. Otherwise, initiates a procedure of finding them. - - :param model: Backend-specific model, for which Quantization Target Points are being seek. - :param nncf_graph: NNCFGraph instance. - :return: Mapping of quantization target points with associated quantization configuration, - along with target points for scale unification. - """ - if self._quantization_target_points_to_qconfig is not None: - return self._quantization_target_points_to_qconfig, self._unified_scale_groups - self._init_cache() - return self._find_quantization_target_points(model, nncf_graph) - - def _collect_unified_groups( - self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph - ) -> List[List[TargetPoint]]: - """ - Collects the group of quantizers for unification. - - :param quantizer_setup: SingleConfigQuantizerSetup instance. - :param nncf_graph: NNCFGraph instance. - :return: List with the groups of the TargetPoints. - """ - unified_scale_groups = [] - for quantizer_ids in quantizer_setup.unified_scale_groups.values(): - unified_scale_group = [] - for quantizer_id in quantizer_ids: - quantization_point = quantizer_setup.quantization_points[quantizer_id] - - # Only activation quantizers can be unified - if quantization_point.is_activation_quantization_point(): - activation_target_point = self._get_activation_quantization_target_point( - quantization_point, nncf_graph - ) - unified_scale_group.append(activation_target_point) - else: - weight_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph) - for weight_target_point in weight_target_points: - unified_scale_group.append(weight_target_point) - unified_scale_groups.append(unified_scale_group) - return unified_scale_groups - - def _topological_sort_quantization_points( - self, quantization_points: List[SingleConfigQuantizationPoint], nncf_graph: NNCFGraph - ) -> List[SingleConfigQuantizationPoint]: - """ - Sorts quantization_points based on the topological order of nodes obtained form nncf_graph. - - :param quantization_points: Quantization points. - :param nncf_graph: Instance of NNCFgraph used to get topological sort. - :return: Sorted quantization_points. - """ - node_names_to_pos = {node.node_name: i for i, node in enumerate(nncf_graph.topological_sort())} - quantization_points.sort(key=lambda point: node_names_to_pos[point.insertion_point.target_node_name]) - return quantization_points - def apply( self, model: TModel, @@ -385,111 +66,16 @@ def apply( statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: - transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model) - quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph) - weight_layer_names = set() - - def filter_func(point: StatisticPoint) -> bool: - return ( - self._algorithm_key in point.algorithm_to_tensor_collectors - and point.target_point == quantization_target_point - ) - - unified_ops_list = set() - for unified_scale_group in unified_scale_groups: - group_statistics = [] - for quantization_target_point in unified_scale_group: - target_node_name = quantization_target_point.target_node_name - for tensor_collector in statistic_points.get_algo_statistics_for_node( - target_node_name, filter_func, self._algorithm_key - ): - statistics = tensor_collector.get_statistics() - if statistics.min_values is None or statistics.max_values is None: - raise nncf.InternalError(f"Statistics were not collected for the node {target_node_name}") - group_statistics.append(statistics) - - unified_values = self._unify_statistics(group_statistics) - qconfigs = [quantization_target_points[qtp] for qtp in unified_scale_group] - if any(qconfigs[0] != qconfig for qconfig in qconfigs[1:]): - raise nncf.InternalError(f"QConfigs for unified scale group {unified_scale_group} are not equal") - qconfig = qconfigs[0] - q_group = QuantizerGroup.ACTIVATIONS - narrow_range = get_quantizer_narrow_range(qconfig, q_group) - parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range) - commands = self._backend_entity.create_unified_scales_quantizers_insertion_commands( - graph, unified_scale_group, qconfig, parameters + if self._min_max_algo._quantization_target_points_to_qconfig is None: + raise RuntimeError( + "Static points are not available." + " Please call `get_statistic_points` before calling the `apply` method." ) - for command in commands: - transformation_layout.register(command) - unified_ops_list.update(unified_scale_group) - - for quantization_target_point, qconfig in quantization_target_points.items(): - if quantization_target_point in unified_ops_list: - continue - target_node_name = quantization_target_point.target_node_name - for tensor_collector in statistic_points.get_algo_statistics_for_node( - target_node_name, filter_func, self._algorithm_key - ): - if quantization_target_point.is_weight_target_point(): - weights_name = self._backend_entity.get_weight_name(graph, quantization_target_point) - if not self._backend_entity.should_quantize_weight(weights_name, weight_layer_names): - continue - weight_layer_names.add(weights_name) - quant_group = QuantizerGroup.WEIGHTS - else: - quant_group = QuantizerGroup.ACTIVATIONS - - half_range = False - narrow_range = get_quantizer_narrow_range(qconfig, quant_group) - statistics = tensor_collector.get_statistics() - if statistics.min_values is None or statistics.max_values is None: - raise nncf.InternalError(f"Statistics were not collected for the node {target_node_name}") - parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, narrow_range, half_range) - command = self._backend_entity.create_quantizer_insertion_command( - graph, quantization_target_point, qconfig, parameters - ) - transformation_layout.register(command) - if not transformation_layout.transformations: - nncf_logger.info("The model has no operations to apply quantization.") - quantized_model = model_transformer.transform(transformation_layout) - return quantized_model + return self._min_max_algo.apply(model=model, graph=graph, statistic_points=statistic_points) def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) - self._reset_cache() - quantization_target_points, _ = self._get_quantization_target_points(model, graph) - output = StatisticPointsContainer() - for quantization_target_point, qconfig in quantization_target_points.items(): - nncf_logger.debug( - f"Adding target point {quantization_target_point.target_node_name}" - f" with type {quantization_target_point.type} for statistics collection" - ) - stat_collector = self._get_stat_collector( - graph, quantization_target_point, qconfig, self._batchwise_statistics - ) - output.add_statistic_point( - StatisticPoint( - target_point=quantization_target_point, - tensor_collector=stat_collector, - algorithm=self._algorithm_key, - ) - ) - return output - - @staticmethod - def _unify_statistics(statistics: List[MinMaxTensorStatistic]) -> MinMaxTensorStatistic: - """ - Returns backend-specific unified statistics. - - :param statistics: List of MinMaxTensorStatistic instances. - :return: Unified MinMaxTensorStatistic value. - """ - - max_values, min_values = [], [] - for statistic in statistics: - max_values.append(statistic.max_values.flatten()) - min_values.append(statistic.min_values.flatten()) - max_values = fns.max(fns.stack(max_values), axis=0) - min_values = fns.min(fns.stack(min_values), axis=0) - return MinMaxTensorStatistic(min_values=min_values, max_values=max_values) + quantizer_setup = self._quantizer.get_quantization_setup(model, graph) + self._min_max_algo._set_backend_entity(model) + self._min_max_algo._init_cache() + self._min_max_algo.fill_quantization_target_points(quantizer_setup, graph) + return self._min_max_algo.get_cached_statistic_points(model, graph) diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py b/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py deleted file mode 100644 index 0e30e70ae57..00000000000 --- a/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Set, Tuple - -import torch -from torch.quantization.fake_quantize import FakeQuantize - -import nncf -from nncf.common.graph.graph import NNCFGraph -from nncf.common.graph.graph import NNCFNode -from nncf.common.graph.transformations.commands import TargetType -from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode -from nncf.common.quantization.structs import QuantizerConfig -from nncf.experimental.common.quantization.algorithms.range_estimator.backend import RangeEstimatorAlgoBackend -from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand -from nncf.experimental.torch.fx.model_utils import get_target_point -from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder -from nncf.quantization.advanced_parameters import StatisticsType -from nncf.quantization.fake_quantize import FakeQuantizeParameters -from nncf.quantization.range_estimator import AggregatorType -from nncf.quantization.range_estimator import RangeEstimatorParameters -from nncf.torch.graph.graph import PTNNCFGraph -from nncf.torch.graph.graph import PTTargetPoint -from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.model_graph_manager import get_weight_tensor_port_ids -from nncf.torch.quantization.layers import QUANTIZATION_MODULES -from nncf.torch.quantization.layers import AsymmetricQuantizer -from nncf.torch.quantization.layers import BaseQuantizer -from nncf.torch.quantization.layers import PTQuantizerSpec -from nncf.torch.quantization.layers import get_scale_shape -from nncf.torch.quantization.strip import convert_to_torch_fakequantizer -from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP - - -class FXRangeEstimatorAlgoBackend(RangeEstimatorAlgoBackend): - @staticmethod - def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: - return get_target_point(target_type, target_node_name, port_id) - - @staticmethod - def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: - return nncf_graph.get_input_shape_for_insertion_point(target_point) - - @staticmethod - def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: - # TODO(dlyakhov): support transpose conv and other cases - return (0,) - - @staticmethod - def get_statistic_collector( - range_estimator_params: RangeEstimatorParameters, - use_abs_max: bool, - reduction_axes: Optional[Tuple[int, ...]], - aggregation_axes: Optional[Tuple[int, ...]], - inplace: bool, - num_samples: Optional[int] = None, - ) -> TensorCollector: - collector = TensorCollector(MinMaxTensorStatistic) - for params, container_key in zip( - [range_estimator_params.min, range_estimator_params.max], - [MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT], - ): - if params.statistics_type not in PT_REDUCERS_MAP: - raise nncf.InternalError( - f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." - ) - - if params.aggregator_type not in AGGREGATORS_MAP: - raise nncf.InternalError( - f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." - ) - - statistic_type = params.statistics_type - if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: - # TODO(dlyakhov): merge two quantile aggregators in one - if container_key == MinMaxTensorStatistic.MIN_STAT: - quantile = params.quantile_outlier_prob - else: - quantile = 1 - params.quantile_outlier_prob - reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) - else: - if use_abs_max and statistic_type == StatisticsType.MAX: - statistic_type = StatisticsType.ABS_MAX - reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) - - kwargs = { - "num_samples": num_samples, - "aggregation_axes": aggregation_axes, - } - if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: - kwargs.update({"quantile": params.quantile_outlier_prob}) - aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) - - collector.register_statistic_branch(container_key, reducer, aggregator) - return collector - - @staticmethod - def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: - return get_weight_tensor_port_ids(node, graph) - - @staticmethod - def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: - weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) - weight_edge = nncf_graph.get_input_edge_by_port_id(weighted_node, target_point.input_port_id) - weight = weight_edge.from_node - return weight.node_name - - @staticmethod - def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: - # If the nodes share one weight tensor, we should have only one quantizer on that - return weight_name not in quantized_weight_names - - @staticmethod - def _get_input_scale_shape( - nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool - ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: - is_weights = target_point.is_weight_target_point() - if is_weights: - # TODO(dlyakhov): support transpose conv/ make channel_idx common - channel_idx = 0 - else: - channel_idx = 1 # channel dim for activations - - input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) - scale_shape = tuple( - get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) - ) - - return input_shape, scale_shape, channel_idx - - @staticmethod - def _create_quantizer( - quantizer_config: QuantizerConfig, - scale_shape: Tuple, - parameters: FakeQuantizeParameters, - target_type: TargetType, - ) -> FakeQuantize: - mode = quantizer_config.mode - quantizer_cls = QUANTIZATION_MODULES.get(mode) - narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC - quantizer_spec = PTQuantizerSpec.from_config( - quantizer_config, - narrow_range=narrow_range, - scale_shape=scale_shape, - half_range=False, - logarithm_scale=False, - is_quantized_on_export=False, - compression_lr_multiplier=None, - ) - quantizer = quantizer_cls(quantizer_spec) - - # Fill it with minmax - # TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer. - FXRangeEstimatorAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) - # Convert to the torch fake quantizer - torch_fq = convert_to_torch_fakequantizer(quantizer) - return torch_fq - - @staticmethod - def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: - if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) - input_range = parameters.input_high - parameters.input_low - # Subtract eps from the input_range to make quantizer parameters equal to - # original parameters on the forward call. - quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) - else: - quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) - # Subtract eps from the scale to make quantizer parameters equal to - # original parameters on the forward call. - quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) - - @staticmethod - def create_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: PTTargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> FXApplyTransformationCommand: - _, scale_shape, _ = FXRangeEstimatorAlgoBackend._get_input_scale_shape( - nncf_graph, target_point, quantizer_config.per_channel - ) - - quantizer = FXRangeEstimatorAlgoBackend._create_quantizer( - quantizer_config, scale_shape, parameters, target_point.target_type - ) - transformation = qdq_insertion_transformation_builder(quantizer, [target_point]) - return FXApplyTransformationCommand(transformation) - - @staticmethod - def create_unified_scales_quantizers_insertion_commands( - nncf_graph: NNCFGraph, - target_points: List[PTTargetPoint], - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> List[PTSharedFnInsertionCommand]: - _, scale_shape, _ = FXRangeEstimatorAlgoBackend._get_input_scale_shape( - nncf_graph, target_points[0], quantizer_config.per_channel - ) - - quantizer = FXRangeEstimatorAlgoBackend._create_quantizer( - quantizer_config, scale_shape, parameters, target_points[0].target_type - ) - - transformations = [] - for tp in target_points: - transformation = qdq_insertion_transformation_builder(quantizer, [tp]) - transformations.append(FXApplyTransformationCommand(transformation)) - return transformations diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index dea9211b734..b6728b292bf 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -723,9 +723,7 @@ def _get_activation_quantization_target_point( ) return activation_quantization_target_point - def _find_quantization_target_points( - self, model: TModel, nncf_graph: NNCFGraph - ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: + def find_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: """ Initializes a cache, finds quantization target points and them puts in the cache. @@ -753,6 +751,19 @@ def _find_quantization_target_points( quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) self._apply_model_type_pass(self._model_type, quantizer_setup, nncf_graph) self._apply_device_pass(self._target_device, quantizer_setup, inference_nncf_graph) + return quantizer_setup + + def fill_quantization_target_points( + self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph + ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: + """ + Initializes a cache, finds quantization target points and them puts in the cache. + + :param model: Backend-specific model, for which Quantization Target Points are being seek. + :param nncf_graph: NNCFGraph instance. + :return: Mapping of quantization target points with associated quantization configuration, + along with target points for scale unification. + """ self._unified_scale_groups = self._collect_unified_groups(quantizer_setup, nncf_graph) quantization_points = list(quantizer_setup.quantization_points.values()) quantization_points = self._topological_sort_quantization_points(quantization_points, nncf_graph) @@ -780,7 +791,8 @@ def _get_quantization_target_points( if self._quantization_target_points_to_qconfig is not None: return self._quantization_target_points_to_qconfig, self._unified_scale_groups self._init_cache() - return self._find_quantization_target_points(model, nncf_graph) + quantizer_setup = self.find_quantization_setup(model, nncf_graph) + return self.fill_quantization_target_points(quantizer_setup, nncf_graph) def _collect_unified_groups( self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph @@ -989,10 +1001,30 @@ def filter_func(point: StatisticPoint) -> bool: quantized_model = model_transformer.transform(transformation_layout) return quantized_model + def get_cached_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + """ + Build statistic point container using already cached target points vs qconfigs cache. + + :param model: Model instance. + :param graph: NNCFGraph instance corespondent to the passed model. + :return: Filled statistic point container. + """ + if self._quantization_target_points_to_qconfig is None: + raise RuntimeError("get_cached_statistic_points is called before statistic caching.") + self._set_backend_entity(model) + return self._get_statistic_point_container(self._quantization_target_points_to_qconfig, graph) + def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: self._set_backend_entity(model) self._reset_cache() quantization_target_points, _ = self._get_quantization_target_points(model, graph) + return self._get_statistic_point_container(quantization_target_points, graph) + + def _get_statistic_point_container( + self, + quantization_target_points: Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]], + graph: NNCFGraph, + ) -> StatisticPointsContainer: output = StatisticPointsContainer() for quantization_target_point, qconfig in quantization_target_points.items(): nncf_logger.debug( diff --git a/tests/common/quantization/test_minmax.py b/tests/common/quantization/test_minmax.py index 2a63d3c7ac8..2f780ea652f 100644 --- a/tests/common/quantization/test_minmax.py +++ b/tests/common/quantization/test_minmax.py @@ -215,14 +215,20 @@ def test_min_max_caching(): Checks that the _get_quantization_target_points(...) of MinMaxQuantization called once utilizing the cache. Checks that after _reset_cache() it called one more time. """ - called = 0 + find_called = 0 + fill_called = 0 - def foo(self, *args): + def find_qsetup_mock(self, *args): + nonlocal find_called + find_called += 1 + return None + + def fill_qsetup_mock(self, *args): """ Mocked _find_quantization_target_points. """ - nonlocal called - called += 1 + nonlocal fill_called + fill_called += 1 # Set up cache self._quantization_target_points_to_qconfig = collections.OrderedDict() self._unified_scale_groups = [] @@ -230,11 +236,12 @@ def foo(self, *args): run_nums = 2 algo = MinMaxQuantization() - algo._find_quantization_target_points = types.MethodType(foo, algo) + algo.find_quantization_setup = types.MethodType(find_qsetup_mock, algo) + algo.fill_quantization_target_points = types.MethodType(fill_qsetup_mock, algo) for _ in range(run_nums): algo._get_quantization_target_points(None, None) - assert called == 1 + assert find_called == fill_called == 1 algo._reset_cache() for _ in range(run_nums): algo._get_quantization_target_points(None, None) - assert called == 2 + assert find_called == fill_called == 2