From 239474da960d691e33b481f5b47aa0783ff556b6 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 16 May 2023 13:45:13 +0200 Subject: [PATCH] TensorCollectorAdapter --- .../common/tensor_statistics/collectors.py | 211 ++++++------------ nncf/openvino/quantization/quantize_model.py | 4 - nncf/openvino/statistics/aggregator.py | 5 +- nncf/openvino/statistics/collectors.py | 5 +- nncf/parameters.py | 1 - .../algorithms/bias_correction/algorithm.py | 9 +- .../algorithms/bias_correction/backend.py | 2 - .../bias_correction/onnx_backend.py | 2 - .../bias_correction/openvino_backend.py | 4 +- .../fast_bias_correction/algorithm.py | 8 +- .../fast_bias_correction/backend.py | 2 - .../fast_bias_correction/onnx_backend.py | 2 - .../fast_bias_correction/openvino_backend.py | 4 +- .../algorithms/min_max/algorithm.py | 1 - .../algorithms/min_max/backend.py | 1 - .../algorithms/min_max/onnx_backend.py | 1 - .../algorithms/min_max/openvino_backend.py | 4 - .../algorithms/min_max/torch_backend.py | 1 - .../algorithms/post_training/algorithm.py | 2 - 19 files changed, 82 insertions(+), 187 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index dcb83949bc6..1b9ef5db7ef 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -27,52 +27,7 @@ InplaceInsertionFNType = TypeVar("InplaceInsertionFNType") -class TensorReducerInterface(ABC): - @abstractproperty - def inplace(self): - pass - - @abstractproperty - def output_port_id(self) -> int: - pass - - @abstractproperty - def name(self): - pass - - @abstractmethod - def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: - """ - Returns target output names from target model that is - modified for statistic collection. - - :param target_node_name: Target node name for reducer. - :param port_id: Target port id for target node name for reducer. - :return: Target output names for reducer. - """ - - @abstractmethod - def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: - """ - Returns correspondent inplace operation builder if inplace operations are available in backend. - - :return: Inplace operation builder if possible else None. - """ - - @abstractmethod - def __call__(self, x: List[NNCFTensor]): - pass - - @abstractmethod - def __eq__(self, __o: object) -> bool: - pass - - @abstractmethod - def __hash__(self) -> int: - pass - - -class TensorReducerBase(TensorReducerInterface, ABC): +class TensorReducerBase(ABC): """ Tensor reducer is a callable object that reduces tensors according to the specified rule. Could handle tensors inplace or out of place. @@ -106,6 +61,25 @@ def name(self): def _get_processor() -> NNCFCollectorTensorProcessor: pass + @abstractmethod + def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: + """ + Returns target output names from target model that is + modified for statistic collection. + + :param target_node_name: Target node name for reducer. + :param port_id: Target port id for target node name for reducer. + :return: Target output names for reducer. + """ + + @abstractmethod + def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: + """ + Returns correspondent inplace operation builder if inplace operations are available in backend. + + :return: Inplace operation builder if possible else None. + """ + @abstractstaticmethod def reduce_out_of_place(x: List[TensorType], **kwargs) -> List[TensorType]: """ @@ -121,9 +95,6 @@ def get_kwargs(self): "keepdims": self._keepdims, } - def __call__(self, x: List[NNCFTensor]): - return self._adapter.reduce(x, self) - def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) @@ -136,86 +107,6 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, self.inplace, self._init_reduction_shape, self._keepdims)) -class BaseReducerAdapter(ABC): - @abstractmethod - def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): - pass - - -class DefaultReducerAdapter(BaseReducerAdapter): - def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): - if reducer.inplace: - return x - - kwargs = reducer.get_kwargs() - reduction_shape_key = "reduction_shape" - if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None: - kwargs[reduction_shape_key] = tuple(range(len(x[0].shape))) - - return reducer.reduce_out_of_place(x, **kwargs) - - -class SequentialReducerAdapter(DefaultReducerAdapter): - def __init__(self, stack_axis: int) -> None: - super().__init__() - self._stack_axis = stack_axis - - def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): - per_element_result = super().reduce(x, reducer) - kwargs = reducer.get_kwargs() - new_params = {"reduction_shape": self._stack_axis, "keepdims": False} - if not all(k in kwargs for k in new_params): - return per_element_result - - kwargs.update(new_params) - return reducer.reduce_out_of_place(per_element_result, **kwargs) - - -class TensorReducersSequence(TensorReducerInterface): - def __init__(self, *args): - if any(reducer.inplace for reducer in args[1:]): - raise RuntimeError(f"Only first reducer of sequential tensor reducer could not be inplace.") - self._reducers = args - - @property - def inplace(self): - return self._reducers[0].inplace - - @property - def output_port_id(self) -> int: - return self._reducers[0].output_port_id - - @property - def name(self): - name = "" - for i, reducer in enumerate(self._reducers): - name += f"{i}_{reducer.name}" - return name - - def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: - return self._reducers[0].get_output_names(target_node_name, port_id) - - def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: - return self._reducers[0].get_inplace_fn() - - def __call__(self, x: List[NNCFTensor]): - if not self._reducers[0].inplace: - x = self._reducers[0](x) - - for reducer in self._reducers[1:]: - x = reducer(x) - return x - - def __eq__(self, __o: object) -> bool: - return ( - isinstance(__o, self.__class__) - and len(self._reducers) == len(__o.reducers) - and all(self_r == o_r for self_r, o_r in zip(self._reducers, __o.reducers)) - ) - - def __hash__(self) -> int: - return hash(tuple(hash(reducer) for reducer in self._reducers)) - class TensorAggregatorBase: """ @@ -273,6 +164,54 @@ def __hash__(self) -> int: return hash(self.__class__.__name__) +class BaseTensorCollectorAdapter(ABC): + @abstractmethod + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + pass + + @abstractmethod + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + pass + + +class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter): + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + if reducer.inplace: + return x + + kwargs = reducer.get_kwargs() + reduction_shape_key = "reduction_shape" + if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None: + kwargs[reduction_shape_key] = tuple(range(len(x[0].shape))) + + return reducer.reduce_out_of_place(x, **kwargs) + + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + aggregator.register_reduced_input(x) + + +class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter): + def __init__(self, stack_axis: int, tensor_processor: NNCFCollectorTensorProcessor) -> None: + super().__init__() + self._stack_axis = stack_axis + self._tensor_processor = tensor_processor + + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + per_element_result = super().reduce(x, reducer) + kwargs = reducer.get_kwargs() + new_params = {"reduction_shape": self._stack_axis, "keepdims": False} + if not all(k in kwargs for k in new_params): + return per_element_result + + kwargs.update(new_params) + return reducer.reduce_out_of_place(per_element_result, **kwargs) + + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + if isinstance(aggregator, ShapeAggregator): + x = self._tensor_processor.unstack(x, axis=0)[0] + super().register_reduced_input(x, aggregator) + + class TensorCollector: """ Calculates statistics at given tensors according to registered statistic branches. @@ -289,7 +228,7 @@ def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> Non self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {} self._stat_container = statistic_container self._enabled = True - self._adapter = DefaultReducerAdapter() + self._adapter = DefaultTensorCollectorAdapter() @property def num_samples(self) -> Optional[int]: @@ -313,7 +252,7 @@ def reducers(self): def aggregators(self): return self._aggregators.copy() - def set_adapter(self, adapter: "BaseReducerAdapter"): + def set_adapter(self, adapter: "BaseTensorCollectorAdapter"): self._adapter = adapter def enable(self): @@ -390,7 +329,7 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None: aggregator, ) in self._aggregators.items(): if reducer_hash in reduced_inputs: - aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id]) + self._adapter.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], aggregator) def _aggregate(self) -> None: result = {} @@ -683,18 +622,14 @@ def aggregate(self): class ShapeAggregator(TensorAggregatorBase): - def __init__(self, slice_=None): + def __init__(self): super().__init__(None, 1) - self._slice = slice_ def _register_reduced_input_impl(self, x: TensorType) -> None: self._container = x def aggregate(self): - shape = self._container.shape - if self._slice is not None: - return shape[self._slice] - return shape + return self._container.shape class MinAggregator(TensorAggregatorBase): diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 61ba10d07a4..aa344e95f4d 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -21,7 +21,6 @@ from nncf.common.utils.backend import get_backend from nncf.common.utils.timer import timer from nncf.data import Dataset -from nncf.data import RecurentDataset from nncf.openvino.quantization.backend_parameters import BackendParameters from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed from nncf.parameters import DropType @@ -103,9 +102,6 @@ def native_quantize_impl( """ Implementation of the `quantize()` method for the OpenVINO backend via the OpenVINO Runtime API. """ - if isinstance(calibration_dataset, RecurentDataset): - model_type = ModelType.SEQUENTIAL - quantization_algorithm = PostTrainingQuantization( preset=preset, target_device=target_device, diff --git a/nncf/openvino/statistics/aggregator.py b/nncf/openvino/statistics/aggregator.py index a8d15e9f7b9..44a364e47a2 100644 --- a/nncf/openvino/statistics/aggregator.py +++ b/nncf/openvino/statistics/aggregator.py @@ -21,11 +21,12 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector -from nncf.experimental.common.tensor_statistics.collectors import SequentialReducerAdapter +from nncf.experimental.common.tensor_statistics.collectors import SequentialTensorCollectorAdapter from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand +from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.tensor import OVNNCFTensor @@ -77,7 +78,7 @@ def _get_transformation_layout_extra_outputs( # TODO(dlyakhov) Move this to common part def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int): for _, _, tensor_collector in statistic_points.get_tensor_collectors(): - tensor_collector.set_adapter(SequentialReducerAdapter(stack_axis)) + tensor_collector.set_adapter(SequentialTensorCollectorAdapter(stack_axis, OVNNCFCollectorTensorProcessor)) return statistic_points diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 5f8254b6a46..3898f0436d3 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -40,7 +40,6 @@ from nncf.openvino.statistics.statistics import OVBatchTensorStatistic from nncf.openvino.statistics.statistics import OVMeanTensorStatistic from nncf.openvino.tensor import OVNNCFTensor -from nncf.parameters import ModelType from nncf.quantization.advanced_parameters import StatisticsType @@ -243,7 +242,7 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) -def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True, model_type=None): +def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True): # TODO(dlyakhov): use inplace OVBatchMeanReducer and OVMeanPerChanelReducer # after migration on openvino-dev=2023.0 inplace = False @@ -260,7 +259,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace "window_size": window_size, } aggregate_mean = MeanAggregator(**kwargs) - aggregate_shape = ShapeAggregator(slice(1, None) if model_type == ModelType.SEQUENTIAL else None) + aggregate_shape = ShapeAggregator() collector = TensorCollector(OVMeanTensorStatistic) collector.register_statistic_branch(OVMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) diff --git a/nncf/parameters.py b/nncf/parameters.py index 4c1dc2b61b2..28ae264834e 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -40,7 +40,6 @@ class ModelType(Enum): """ TRANSFORMER = "transformer" - SEQUENTIAL = "sequential" @api(canonical_alias="nncf.DropType") diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index b042db80c77..d66175e3fce 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -29,7 +29,6 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import copy_model from nncf.common.utils.backend import get_backend -from nncf.parameters import ModelType from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS @@ -69,7 +68,6 @@ def __init__( apply_for_all_nodes: bool = False, inplace_statistics: bool = True, backend_params: Optional[Dict[str, Any]] = None, - model_type: ModelType = None, ): """ :param subset_size: Size of a subset for the statistics collection, @@ -93,8 +91,6 @@ def __init__( self.apply_for_all_nodes = apply_for_all_nodes self.inplace_statistics = inplace_statistics self.backend_params = backend_params - self.model_type = model_type - self.nncf_graph = None self._backend_entity = None self._collected_stat_inputs = set() @@ -480,10 +476,7 @@ def get_statistic_points(self, model: TModel) -> StatisticPointsContainer: TargetType.POST_LAYER_OPERATION, node_name, output_port_id ) stat_collector = self._backend_entity.mean_statistic_collector( - reduction_shape=channel_axis, - num_samples=self.subset_size, - inplace=self.inplace_statistics, - model_type=self.model_type, + reduction_shape=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics ) statistic_container.add_statistic_point( StatisticPoint(target_point=statistic_point, tensor_collector=stat_collector, algorithm=BiasCorrection) diff --git a/nncf/quantization/algorithms/bias_correction/backend.py b/nncf/quantization/algorithms/bias_correction/backend.py index 9c36e94675c..c189bd2bf20 100644 --- a/nncf/quantization/algorithms/bias_correction/backend.py +++ b/nncf/quantization/algorithms/bias_correction/backend.py @@ -24,7 +24,6 @@ from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.utils.registry import Registry -from nncf.parameters import ModelType TModel = TypeVar("TModel") OutputType = TypeVar("OutputType") @@ -109,7 +108,6 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> TensorStatisticCollectorBase: """ Returns backend-specific mean statistic collector. diff --git a/nncf/quantization/algorithms/bias_correction/onnx_backend.py b/nncf/quantization/algorithms/bias_correction/onnx_backend.py index 1cdbf16ef24..e4a92fc43f4 100644 --- a/nncf/quantization/algorithms/bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/bias_correction/onnx_backend.py @@ -35,7 +35,6 @@ from nncf.onnx.statistics.collectors import ONNXMeanStatisticCollector from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor from nncf.onnx.tensor import ONNXNNCFTensor -from nncf.parameters import ModelType from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend @@ -83,7 +82,6 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> ONNXMeanStatisticCollector: return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size) diff --git a/nncf/quantization/algorithms/bias_correction/openvino_backend.py b/nncf/quantization/algorithms/bias_correction/openvino_backend.py index c88aac12835..1a7d10b887b 100644 --- a/nncf/quantization/algorithms/bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/bias_correction/openvino_backend.py @@ -34,7 +34,6 @@ from nncf.openvino.statistics.collectors import get_mean_batch_stat_collector from nncf.openvino.statistics.collectors import get_mean_stat_collector from nncf.openvino.tensor import OVNNCFTensor -from nncf.parameters import ModelType from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend @@ -78,9 +77,8 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> TensorCollector: - return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace, model_type) + return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace) @staticmethod def batch_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector: diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 0492df2dd98..3320c1bfe0a 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -27,7 +27,6 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend -from nncf.parameters import ModelType from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS @@ -60,7 +59,6 @@ def __init__( apply_for_all_nodes: bool = False, inplace_statistics: bool = True, backend_params: Optional[Dict[str, Any]] = None, - model_type: ModelType = None, ): """ :param subset_size: Size of a subset for the statistics collection, @@ -84,7 +82,6 @@ def __init__( self.apply_for_all_nodes = apply_for_all_nodes self.inplace_statistics = inplace_statistics self.backend_params = backend_params - self.model_type = model_type self.nncf_graph = None self._backend_entity = None @@ -255,10 +252,7 @@ def _add_statistic_point(self, container: StatisticPointsContainer, point: Targe :param axis: Channel axis for the statistics calculation. """ stat_collector = self._backend_entity.mean_statistic_collector( - reduction_shape=axis, - num_samples=self.subset_size, - inplace=self.inplace_statistics, - model_type=self.model_type, + reduction_shape=axis, num_samples=self.subset_size, inplace=self.inplace_statistics ) container.add_statistic_point( StatisticPoint(target_point=point, tensor_collector=stat_collector, algorithm=FastBiasCorrection) diff --git a/nncf/quantization/algorithms/fast_bias_correction/backend.py b/nncf/quantization/algorithms/fast_bias_correction/backend.py index 86748218b75..136c6f4f800 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/backend.py @@ -24,7 +24,6 @@ from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.utils.registry import Registry -from nncf.parameters import ModelType TModel = TypeVar("TModel") OutputType = TypeVar("OutputType") @@ -88,7 +87,6 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> TensorStatisticCollectorBase: """ Returns backend-specific mean statistic collector. diff --git a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py index 47d1daee8c4..71ac7468d4f 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py @@ -31,7 +31,6 @@ from nncf.onnx.statistics.collectors import ONNXMeanStatisticCollector from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor from nncf.onnx.tensor import ONNXNNCFTensor -from nncf.parameters import ModelType from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @@ -66,7 +65,6 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> ONNXMeanStatisticCollector: return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size) diff --git a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index c8be93ff4cc..1057b7cbc17 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -32,7 +32,6 @@ from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.statistics.collectors import get_mean_stat_collector from nncf.openvino.tensor import OVNNCFTensor -from nncf.parameters import ModelType from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @@ -67,9 +66,8 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - model_type: Optional[ModelType] = None, ) -> TensorCollector: - return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace, model_type) + return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace) @staticmethod def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]: diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index a3da13fe905..953fda6d579 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -275,7 +275,6 @@ def _get_stat_collector( target_point, quantizer_config, inplace=self._inplace_statistics, - model_type=self._model_type, num_samples=self._subset_size, ) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 2c063c644db..07cd9990b62 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -157,7 +157,6 @@ def get_statistic_collector( target_point: TargetPoint, quantizer_config: QuantizerConfig, inplace: bool, - model_type: ModelType, num_samples: int = None, ) -> TensorStatisticCollectorBase: """ diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index f908115d3fd..c04b8cd4038 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -178,7 +178,6 @@ def get_statistic_collector( target_point: ONNXTargetPoint, quantizer_config: QuantizerConfig, inplace: bool, - model_type: ModelType, num_samples: int = None, ) -> Union[ONNXMinMaxStatisticCollector, ONNXMeanMinMaxStatisticCollector]: reduction_shape, use_abs_max = ONNXMinMaxAlgoBackend._get_reduction_shape_and_use_abs_max( diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 360b5f902c0..50973520947 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -24,8 +24,6 @@ from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.common.tensor_statistics.collectors import TensorReducersSequence -from nncf.openvino.engine import SEQUENTIAL_SAMPLE_STACK_AXIS from nncf.openvino.graph.metatypes.openvino_metatypes import GENERAL_WEIGHT_LAYER_METATYPES from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype @@ -181,7 +179,6 @@ def get_statistic_collector( target_point: OVTargetPoint, quantizer_config: QuantizerConfig, inplace: bool, - model_type: ModelType, num_samples: int = None, ) -> TensorCollector: reduction_shape, use_abs_max = OVMinMaxAlgoBackend._get_reduction_shape_and_use_abs_max( @@ -215,7 +212,6 @@ def get_statistic_collector( statistic_type = params.statistics_type if use_abs_max and statistic_type == StatisticsType.MAX: statistic_type = StatisticsType.ABS_MAX - reducer = OV_REDUCERS_MAP[statistic_type](**kwargs) kwargs = {"num_samples": _num_samples, "tensor_processor": OVNNCFCollectorTensorProcessor} diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 0875941e2ce..7315574fe96 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -154,7 +154,6 @@ def get_statistic_collector( target_point: PTTargetPoint, quantizer_config: QuantizerConfig, inplace: bool, - model_type: ModelType, num_samples: int = None, ) -> Union[PTMinMaxStatisticCollector, PTMeanMinMaxStatisticCollector]: if ( diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 0d1e9c625f0..60ab4b02edf 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -114,7 +114,6 @@ def __init__( apply_for_all_nodes=bias_correction_params.apply_for_all_nodes, inplace_statistics=advanced_parameters.inplace_statistics, backend_params=advanced_parameters.backend_params, - model_type=model_type, ) else: threshold = BIAS_CORRECTION_THRESHOLD @@ -127,7 +126,6 @@ def __init__( apply_for_all_nodes=bias_correction_params.apply_for_all_nodes, inplace_statistics=advanced_parameters.inplace_statistics, backend_params=advanced_parameters.backend_params, - model_type=model_type, ) self.algorithms.append(bias_correction)