From 0de05027d2ba4e4d01e79774642c3a6c396e8b6a Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 6 Sep 2023 18:40:34 +0200 Subject: [PATCH] WIP --- .../common/tensor_statistics/collectors.py | 17 +- nncf/openvino/statistics/collectors.py | 10 +- .../fast_bias_correction/torch_backend.py | 7 +- .../algorithms/min_max/torch_backend.py | 124 ++++---- nncf/torch/quantization/init_range.py | 3 - nncf/torch/statistics/aggregator.py | 10 +- nncf/torch/tensor_statistics/collectors.py | 288 +++++++++++------- tests/common/test_statistics_aggregator.py | 3 + .../test_templates/test_channel_alignment.py | 2 +- .../test_templates/test_quantizer_config.py | 4 +- .../test_templates/test_smooth_quant.py | 2 +- tests/torch/ptq/test_ptq_params.py | 22 +- tests/torch/ptq/test_quantizer_config.py | 20 +- tests/torch/quantization/test_range_init.py | 19 +- .../test_tensor_statistics.py | 119 ++------ tests/torch/test_statistics_aggregator.py | 2 +- 16 files changed, 341 insertions(+), 311 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index cf76f6b532a..5a51691669d 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -32,7 +32,7 @@ class TensorReducerBase(ABC): the specified rule. Could handle tensors inplace or out of place. """ - def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True): + def __init__(self, reduction_axes: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True): """ :param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape))) if empty. @@ -40,7 +40,7 @@ def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bo :param keepdims: Should the axes which are reduced are left in the result as dimensions with size one or not. """ - self._reduction_shape = reduction_shape + self._reduction_shape = reduction_axes self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor() self._inplace = inplace self._keepdims = keepdims @@ -469,12 +469,12 @@ def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: class QuantileReducerBase(TensorReducerBase): def __init__( self, - reduction_shape: Optional[ReductionShape] = None, + reduction_axes: Optional[ReductionShape] = None, quantile: Optional[Union[float, Tuple[float]]] = None, inplace: bool = False, keepdims: bool = True, ): - super().__init__(reduction_shape, False, keepdims) + super().__init__(reduction_axes, False, keepdims) self._quantile = (0.01, 0.99) if quantile is None else quantile def __eq__(self, __o: object) -> bool: @@ -494,11 +494,11 @@ def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: class AbsQuantileReducer(QuantileReducerBase): def __init__( self, - reduction_shape: Optional[ReductionShape] = None, + reduction_axes: Optional[ReductionShape] = None, quantile: Union[float, List[float]] = 0.99, inplace: bool = False, ): - super().__init__(reduction_shape, quantile, False) + super().__init__(reduction_axes, quantile, False) def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = self._tensor_processor.abs(x[0]) @@ -624,8 +624,9 @@ def __init__( window_size=None, quantile: float = 0.01, ): + assert len(aggregation_axes) == 1 super().__init__( - tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples + tensor_processor, aggregation_axes=aggregation_axes[0], keepdims=keepdims, num_samples=num_samples ) self._window_size = window_size self._container = deque(maxlen=window_size) @@ -707,7 +708,7 @@ def _aggregate_impl(self) -> Any: return retval -class PostAggregateAggregatorHook(TensorAggregatorBase, ABC): +class PostAggregateHook(TensorAggregatorBase, ABC): def __init__(self, aggregator: TensorAggregatorBase, post_aggregation_hook): super().__init__(None) self._aggregator = aggregator diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 61ef776fab7..4704a12a0fb 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -157,7 +157,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_min_op(self.name, self._reduction_shape) + return get_inplace_min_op(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -168,7 +168,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_max_op(self.name, self._reduction_shape, False) + return get_inplace_max_op(self.name, self._reduction_axes, False) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -179,7 +179,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_max_op(self.name, self._reduction_shape, True) + return get_inplace_max_op(self.name, self._reduction_axes, True) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -190,7 +190,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_mean_op(self.name, self._reduction_shape) + return get_inplace_mean_op(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -212,7 +212,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_mean_per_ch(self.name, self._reduction_shape) + return get_inplace_mean_per_ch(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index b7316724db0..173acaf05e7 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.torch.graph.transformations.command_creation import create_bias_correction_command @@ -32,8 +33,8 @@ from nncf.torch.model_analyzer import is_quantized_weights from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTMeanStatisticCollector from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +from nncf.torch.tensor_statistics.collectors import get_mean_stat_collector @ALGO_BACKENDS.register(BackendType.TORCH) @@ -71,8 +72,8 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - ) -> PTMeanStatisticCollector: - return PTMeanStatisticCollector(reduction_shape, num_samples, window_size) + ) -> TensorCollector: + return get_mean_stat_collector(num_samples, reduction_shape, window_size) @staticmethod def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 82c858b2f28..c7d06ee0e00 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -28,6 +28,9 @@ from nncf.common.quantization.structs import QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AggregatorType @@ -49,8 +52,9 @@ from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import get_scale_shape -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector +from nncf.torch.tensor import PTNNCFTensor +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic @@ -155,32 +159,61 @@ def get_statistic_collector( quantizer_config: QuantizerConfig, inplace: bool, num_samples: int = None, - ) -> Union[PTMinMaxStatisticCollector, PTMeanMinMaxStatisticCollector]: - if ( - range_estimator_params.min.statistics_type == StatisticsType.MIN - and range_estimator_params.min.aggregator_type == AggregatorType.MIN - and range_estimator_params.max.statistics_type == StatisticsType.MAX - and range_estimator_params.max.aggregator_type == AggregatorType.MAX + ) -> TensorCollector: + collector_params = PTMinMaxAlgoBackend._default_collector_params(nncf_graph, target_point, quantizer_config) + collector_kwargs = collector_params.convert_statistic_params(per_sample_stats=False) + + collector = TensorCollector(PTMinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [PTMinMaxTensorStatistic.MIN_STAT, PTMinMaxTensorStatistic.MAX_STAT], ): - collector_name = "min_max" - - elif ( - range_estimator_params.min.statistics_type == StatisticsType.MIN - and range_estimator_params.min.aggregator_type == AggregatorType.MEAN - and range_estimator_params.max.statistics_type == StatisticsType.MAX - and range_estimator_params.max.aggregator_type == AggregatorType.MEAN - ): - collector_name = "mean_min_max" - - else: - raise RuntimeError( - "The following range estimator parameters are not supported by PyTorch backend by now: " - f"{str(range_estimator_params)}" - ) - - return PTMinMaxAlgoBackend._statistic_collector_builder( - collector_name, nncf_graph, target_point, quantizer_config, num_samples - ) + if not params.statistics_type in PT_REDUCERS_MAP: + raise RuntimeError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if not params.aggregator_type in AGGREGATORS_MAP: + raise RuntimeError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + kwargs = { + "reduction_axes": collector_kwargs["reducers_axes"], + "keepdims": collector_kwargs["reducers_keepdims"], + } + if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + if container_key == PTMinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + kwargs.update({"quantile": [quantile]}) + # TODO(dlyakhov): merge two quantile aggregators in one + + statistic_type = params.statistics_type + if collector_params.use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](**kwargs) + + kwargs = { + "aggregation_axes": collector_kwargs["aggregators_axes"], + "keepdims": collector_kwargs["aggregators_keepdims"], + "num_samples": num_samples, + "tensor_processor": PTNNCFCollectorTensorProcessor, + } + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + if collector_kwargs["squeeze_dims"] is not None: + + def post_aggregation_hook(aggregated_value): + return PTNNCFCollectorTensorProcessor.squeeze( + PTNNCFTensor(aggregated_value), dim=collector_kwargs["squeeze_dims"] + ).tensor + + aggregator = PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector @staticmethod def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: @@ -223,37 +256,18 @@ def _get_input_scale_shape( return input_shape, scale_shape, channel_idx @staticmethod - def _default_collector_params_and_scale_shape( + def _default_collector_params( nncf_graph: NNCFGraph, target_point: PTTargetPoint, quantizer_config: QuantizerConfig - ) -> Tuple[PTRangeInitCollectorParams, Tuple[int, ...]]: - input_shape, scale_shape, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape( - nncf_graph, target_point, quantizer_config - ) - return ( - PTRangeInitCollectorParams( - is_weights=target_point.is_weight_target_point(), - mode=quantizer_config.mode, - per_channel=quantizer_config.per_channel, - input_shape=input_shape, - channel_idx=channel_idx, - ), - scale_shape, - ) - - @staticmethod - def _statistic_collector_builder( - collector_name: str, - nncf_graph: NNCFGraph, - target_point: PTTargetPoint, - quantizer_config: QuantizerConfig, - num_samples: int = None, - ) -> PTMeanMinMaxStatisticCollector: - collector_params, scale_shape = PTMinMaxAlgoBackend._default_collector_params_and_scale_shape( + ) -> PTRangeInitCollectorParams: + input_shape, _, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape( nncf_graph, target_point, quantizer_config ) - init_config = RangeInitConfig(collector_name, num_samples) - return StatCollectorGenerator.generate_stat_collector_for_range_init_config( - init_config, scale_shape, collector_params, num_samples + return PTRangeInitCollectorParams( + is_weights=target_point.is_weight_target_point(), + mode=quantizer_config.mode, + per_channel=quantizer_config.per_channel, + input_shape=input_shape, + channel_idx=channel_idx, ) @staticmethod diff --git a/nncf/torch/quantization/init_range.py b/nncf/torch/quantization/init_range.py index 0d1a5de2ab4..36428502eaa 100644 --- a/nncf/torch/quantization/init_range.py +++ b/nncf/torch/quantization/init_range.py @@ -39,9 +39,6 @@ from nncf.torch.quantization.translator import PTTargetPointTranslator from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint -from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTPercentileStatisticCollector from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 6c2c48256c6..84cb2b63a73 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -41,6 +41,14 @@ def _get_transformation_layout_extra_outputs( ) -> TransformationLayout: transformation_layout = TransformationLayout() transformation_commands = [] + + def register_inputs_fn(fn): + def register_inputs(input_: torch.Tensor): + fn(PTNNCFTensor(input_)) + return input_ + + return register_inputs + for _statistic_points in statistic_points.values(): for _statistic_point in _statistic_points: for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): @@ -48,7 +56,7 @@ def _get_transformation_layout_extra_outputs( transformation_commands.append( PTInsertionCommand( _statistic_point.target_point, - collector.register_input, + register_inputs_fn(collector.register_unnamed_inputs), TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, ) ) diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 4baedc594a5..f82877da854 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -42,10 +42,12 @@ from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopReducer -from nncf.experimental.common.tensor_statistics.collectors import PostAggregateAggregatorHook +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook from nncf.experimental.common.tensor_statistics.collectors import PrecentileAggregator from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer +from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.quantization.advanced_parameters import StatisticsType from nncf.torch.dynamic_graph.context import no_nncf_trace from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.reduction import expand_like @@ -93,9 +95,15 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF return PTNNCFTensor(torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))) return PTNNCFTensor(x.tensor.median(dim=axis, keepdim=keepdims).values) - @staticmethod - def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: - raise NotImplementedError() + @classmethod + def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: + if mask is None: + return cls.mean(x, axis=axis, keepdims=keepdims) + masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor) + result = np.ma.mean(masked_x, axis=axis, keepdims=False) + if len(result) == 1: + return PTNNCFTensor(torch.tensor(result)) + return PTNNCFTensor(torch.tensor(result.data)) @classmethod def masked_median( @@ -170,9 +178,23 @@ def precentile( @classmethod def no_outliers_map( - cls, x: NNCFTensor, fn: Callable[[NNCFTensor, Optional[int]], Any], axis: int = 0, alpha: float = 0.01 + cls, + x: NNCFTensor, + fn: Callable[[NNCFTensor, int, NNCFTensor], Any], + axis: int = 0, + alpha: float = 0.01, + keepdims: bool = False, ): - raise NotImplementedError() + if len(x.shape) == 1: + return fn(x, axis=None, mask=None, keepdims=keepdims) + + x = x.tensor + if axis: + x = torch.moveaxis(x, [axis] if isinstance(axis, int) else axis, 0) + + low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], 0) + outliers_mask = np.logical_or(x < low_values, high_values < x) + return fn(x, axis=0, mask=PTNNCFTensor(outliers_mask), keepdims=keepdims) @classmethod def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor: @@ -243,7 +265,7 @@ def maybe_add_squeeze(aggregator, squeeze_dims): def post_aggregation_hook(aggregated_value): return PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(aggregated_value), dim=squeeze_dims).tensor - return PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + return PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) def get_min_max_statistic_collector( @@ -276,27 +298,6 @@ def get_min_max_statistic_collector( return tensor_collector -class PTMinMaxStatisticCollector(MinMaxStatisticCollector): - def __init__( - self, use_abs_max: bool, reduction_shape: ReductionShape, output_shape: ReductionShape, num_samples: int = None - ): - super().__init__(use_abs_max, reduction_shape, num_samples) - self._output_shape = output_shape - - @staticmethod - def _get_processor() -> NNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor() - - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._register_input_common(PTNNCFTensor(x)) - - def _get_statistics(self) -> PTMinMaxTensorStatistic: - min_values = self._min_values.tensor.view(self._output_shape) - max_values = self._max_values.tensor.view(self._output_shape) - return PTMinMaxTensorStatistic(min_values, max_values) - - def get_mixed_min_max_statistic_collector( reducers_axes, reducers_keepdims: bool, @@ -334,6 +335,136 @@ def get_mixed_min_max_statistic_collector( return tensor_collector +def get_median_mad_statistic_collector( + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + return _get_collection_without_reduction( + MedianAbsoluteDeviationAggregator, + PTMedianMADTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, + num_samples=num_samples, + squeeze_dims=squeeze_dims, + window_size=window_size, + ) + + +def get_precentile_tensor_collector( + percentiles_to_collect, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + return _get_collection_without_reduction( + partial(PrecentileAggregator, percentiles_to_collect=percentiles_to_collect), + PTPercentileTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, + num_samples=num_samples, + squeeze_dims=squeeze_dims, + window_size=window_size, + ) + + +def _get_collection_without_reduction( + aggregator_cls, + statistic_class, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + tensor_collector = TensorCollector(statistic_class) + reducer = PTNoopReducer() + aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes])) + aggregator = aggregator_cls( + PTNNCFCollectorTensorProcessor, + aggregation_axes=aggregation_axes, + window_size=window_size, + num_samples=num_samples, + keepdims=True, + ) + dims_to_squeeze = [0] if squeeze_dims else [] + dims_to_squeeze += [axis + 1 for axis in reducers_axes] if not reducers_keepdims else [] + dims_to_squeeze += aggregators_axes if not aggregators_keepdims else [] + if dims_to_squeeze: + + def post_aggregation_hook(aggregated_value): + retval = {} + for key, value in aggregated_value.items(): + retval[key] = PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(value), dim=dims_to_squeeze).tensor + return retval + + aggregator = PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + tensor_collector.register_statistic_branch( + PTMedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator + ) + return tensor_collector + + +def get_mean_stat_collector(num_samples, channel_axis, window_size=None): + if channel_axis == 0: + reducer = PTBatchMeanReducer() + else: + reducer = PTMeanPerChReducer(channel_axis) + noop_reducer = PTNoopReducer() + + kwargs = { + "tensor_processor": PTNNCFCollectorTensorProcessor, + "num_samples": num_samples, + "window_size": window_size, + } + aggregate_mean = MeanAggregator(**kwargs) + aggregate_shape = ShapeAggregator() + + collector = TensorCollector(PTMeanTensorStatistic) + collector.register_statistic_branch(PTMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) + collector.register_statistic_branch(PTMeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape) + return collector + + +#################################################################################################### + + +class PTMinMaxStatisticCollector(MinMaxStatisticCollector): + def __init__( + self, use_abs_max: bool, reduction_shape: ReductionShape, output_shape: ReductionShape, num_samples: int = None + ): + super().__init__(use_abs_max, reduction_shape, num_samples) + self._output_shape = output_shape + + @staticmethod + def _get_processor() -> NNCFCollectorTensorProcessor: + return PTNNCFCollectorTensorProcessor() + + def _register_input(self, x: torch.Tensor): + with no_nncf_trace(): + self._register_input_common(PTNNCFTensor(x)) + + def _get_statistics(self) -> PTMinMaxTensorStatistic: + min_values = self._min_values.tensor.view(self._output_shape) + max_values = self._max_values.tensor.view(self._output_shape) + return PTMinMaxTensorStatistic(min_values, max_values) + + class PTMixedMinMaxStatisticCollector(MixedMinMaxStatisticCollector): def __init__( self, @@ -398,68 +529,6 @@ def _get_statistics(self) -> PTMinMaxTensorStatistic: return PTMinMaxTensorStatistic(min_values, max_values) -def _get_collection_without_reduction( - aggregator_cls, - statistic_class, - reducers_axes, - reducers_keepdims: bool, - aggregators_axes, - aggregators_keepdims, - num_samples: int, - squeeze_dims, - window_size: int = None, -): - tensor_collector = TensorCollector(statistic_class) - reducer = PTNoopReducer() - aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes])) - aggregator = aggregator_cls( - PTNNCFCollectorTensorProcessor, - aggregation_axes=aggregation_axes, - window_size=window_size, - num_samples=num_samples, - keepdims=True, - ) - dims_to_squeeze = [0] if squeeze_dims else [] - dims_to_squeeze += [axis + 1 for axis in reducers_axes] if not reducers_keepdims else [] - dims_to_squeeze += aggregators_axes if not aggregators_keepdims else [] - if dims_to_squeeze: - - def post_aggregation_hook(aggregated_value): - retval = {} - for key, value in aggregated_value.items(): - retval[key] = PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(value), dim=dims_to_squeeze).tensor - return retval - - aggregator = PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) - - tensor_collector.register_statistic_branch( - PTMedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator - ) - return tensor_collector - - -def get_median_mad_statistic_collector( - reducers_axes, - reducers_keepdims: bool, - aggregators_axes, - aggregators_keepdims, - num_samples: int, - squeeze_dims, - window_size: int = None, -): - return _get_collection_without_reduction( - MedianAbsoluteDeviationAggregator, - PTMedianMADTensorStatistic, - reducers_axes=reducers_axes, - reducers_keepdims=reducers_keepdims, - aggregators_axes=aggregators_axes, - aggregators_keepdims=aggregators_keepdims, - num_samples=num_samples, - squeeze_dims=squeeze_dims, - window_size=window_size, - ) - - class PTMedianMADStatisticCollector(MedianMADStatisticCollector): def _register_input(self, x: torch.Tensor): with no_nncf_trace(): @@ -476,29 +545,6 @@ def _get_statistics(self) -> PTMedianMADTensorStatistic: return PTMedianMADTensorStatistic(median_tensor, mad_tensor) -def get_precentile_tensor_collector( - percentiles_to_collect, - reducers_axes, - reducers_keepdims: bool, - aggregators_axes, - aggregators_keepdims, - num_samples: int, - squeeze_dims, - window_size: int = None, -): - return _get_collection_without_reduction( - partial(PrecentileAggregator, percentiles_to_collect=percentiles_to_collect), - PTPercentileTensorStatistic, - reducers_axes=reducers_axes, - reducers_keepdims=reducers_keepdims, - aggregators_axes=aggregators_axes, - aggregators_keepdims=aggregators_keepdims, - num_samples=num_samples, - squeeze_dims=squeeze_dims, - window_size=window_size, - ) - - class PTPercentileStatisticCollector(PercentileStatisticCollector): def _register_input(self, x: torch.Tensor): with no_nncf_trace(): @@ -524,9 +570,7 @@ def get_mean_percentile_statistic_collector( ): tensor_collector = TensorCollector(PTPercentileTensorStatistic) quantiles_to_collect = np.true_divide(percentiles_to_collect, 100) - reducer = PTQuantileReducer( - reduction_shape=reducers_axes, quantile=quantiles_to_collect, keepdims=reducers_keepdims - ) + reducer = PTQuantileReducer(reduction_axes=reducers_axes, quantile=quantiles_to_collect, keepdims=reducers_keepdims) for output_port_id, p in enumerate(percentiles_to_collect): aggregator = MeanAggregator( PTNNCFCollectorTensorProcessor, @@ -569,3 +613,15 @@ def _register_input(self, x: torch.Tensor): def _get_statistics(self) -> PTMeanTensorStatistic: return PTMeanTensorStatistic(self._mean_aggregate().tensor, self._shape()) + + +#################################################################################################### + +PT_REDUCERS_MAP = { + StatisticsType.MIN: PTMinReducer, + StatisticsType.MAX: PTMaxReducer, + StatisticsType.ABS_MAX: PTAbsMaxReducer, + StatisticsType.MEAN: PTMeanReducer, + StatisticsType.QUANTILE: PTQuantileReducer, + StatisticsType.ABS_QUANTILE: PTAbsQuantileReducer, +} diff --git a/tests/common/test_statistics_aggregator.py b/tests/common/test_statistics_aggregator.py index cd0545a4580..5988ff9e114 100644 --- a/tests/common/test_statistics_aggregator.py +++ b/tests/common/test_statistics_aggregator.py @@ -428,6 +428,9 @@ def filter_func(point): shape = (3, 1, 1, 1) ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val)) + if not np.allclose(stat.min_values, ref_min_val): + breakpoint() + stat = tensor_collector.get_statistics() assert np.allclose(stat.min_values, ref_min_val) assert np.allclose(stat.max_values, ref_max_val) if isinstance(ref_min_val, np.ndarray): diff --git a/tests/post_training/test_templates/test_channel_alignment.py b/tests/post_training/test_templates/test_channel_alignment.py index 2b063e7d90f..41b792ca1e8 100644 --- a/tests/post_training/test_templates/test_channel_alignment.py +++ b/tests/post_training/test_templates/test_channel_alignment.py @@ -492,7 +492,7 @@ def test_statistic_collectors(self, inplace_ref, q_ref): assert len(statistic_collector.reducers) == 1 reducer = statistic_collector.reducers.pop() assert isinstance(reducer, QuantileReducer) - assert reducer._reduction_shape == reduction_shape_ref + assert reducer._reduction_axes == reduction_shape_ref assert np.allclose(reducer._quantile, (q_ref, 1 - q_ref)) assert len(statistic_collector.aggregators) == 2 diff --git a/tests/post_training/test_templates/test_quantizer_config.py b/tests/post_training/test_templates/test_quantizer_config.py index e614138d0a9..afd2accc285 100644 --- a/tests/post_training/test_templates/test_quantizer_config.py +++ b/tests/post_training/test_templates/test_quantizer_config.py @@ -278,8 +278,8 @@ def test_get_stat_collector( for reducer in reducers: if q_config_per_channel: - assert reducer._reduction_shape == params.ref_per_ch_reduction_shape + assert reducer._reduction_axes == params.ref_per_ch_reduction_shape else: - assert reducer._reduction_shape == params.ref_per_tensor_reduction_shape + assert reducer._reduction_axes == params.ref_per_tensor_reduction_shape assert tensor_collector.num_samples == num_samples diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index 42fe17e01b0..4292e228657 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -145,7 +145,7 @@ def test_get_abs_max_channel_collector(self): for reducer in backend_tensor_collector.reducers: assert isinstance(reducer, AbsMaxReducer) assert reducer.inplace == inplace_type - assert reducer._reduction_shape == reduction_shape + assert reducer._reduction_axes == reduction_shape @pytest.mark.parametrize( "model_cls, references", diff --git a/tests/torch/ptq/test_ptq_params.py b/tests/torch/ptq/test_ptq_params.py index c174ec8b322..35ddfe3128e 100644 --- a/tests/torch/ptq/test_ptq_params.py +++ b/tests/torch/ptq/test_ptq_params.py @@ -18,6 +18,10 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.structs import QuantizationPreset from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -33,8 +37,6 @@ from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype from nncf.torch.graph.operator_metatypes import PTSoftmaxMetatype from nncf.torch.quantization.quantize_model import _create_nncf_config -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import SoftmaxTestMetatype @@ -104,11 +106,17 @@ class TestPTQParams(TemplateTestPTQParams): def get_algo_backend(self): return PTMinMaxAlgoBackend() - def check_is_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMinMaxStatisticCollector) - - def check_is_mean_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMeanMinMaxStatisticCollector) + def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MinAggregator in aggrs + assert MaxAggregator in aggrs + + def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MeanAggregator in aggrs + assert aggrs[0].__class__ == aggrs[1].__class__ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_q): if quantize_outputs: diff --git a/tests/torch/ptq/test_quantizer_config.py b/tests/torch/ptq/test_quantizer_config.py index 41cab6438b5..152503c802b 100644 --- a/tests/torch/ptq/test_quantizer_config.py +++ b/tests/torch/ptq/test_quantizer_config.py @@ -12,9 +12,11 @@ import pytest from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector from tests.post_training.test_templates.models import NNCFGraphToTest from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.post_training.test_templates.models import NNCFGraphToTestSumAggregation @@ -30,11 +32,17 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return PTMinMaxAlgoBackend() - def check_is_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMinMaxStatisticCollector) + def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MinAggregator in aggrs + assert MaxAggregator in aggrs - def check_is_mean_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMeanMinMaxStatisticCollector) + def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MeanAggregator in aggrs + assert aggrs[0].__class__ == aggrs[1].__class__ @pytest.fixture( params=[ diff --git a/tests/torch/quantization/test_range_init.py b/tests/torch/quantization/test_range_init.py index 000c78fb3eb..f3be4ef2c0d 100644 --- a/tests/torch/quantization/test_range_init.py +++ b/tests/torch/quantization/test_range_init.py @@ -34,6 +34,7 @@ from nncf.common.quantization.structs import QuantizerGroup from nncf.config import NNCFConfig from nncf.config.structures import QuantizationRangeInitArgs +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.torch import utils from nncf.torch.checkpoint_loading import load_state @@ -49,9 +50,6 @@ from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import SymmetricQuantizer from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat from nncf.torch.utils import get_all_modules_by_type from nncf.torch.utils import safe_thread_call @@ -716,11 +714,18 @@ def test_per_layer_range_init_collectors_are_called_the_required_number_of_times ("three_sigma", range_threesigma_init_create_spy), ]: assert spy.call_count == range_init_call_count_test_struct.expected_call_count_initializer_create[stat_type] + collected_samples = 0 for tensor_collector in spy.return_values_list: - assert ( - tensor_collector.num_samples - == range_init_call_count_test_struct.expected_call_count_register_input[stat_type] - ) + cur_values = set() + for aggr in tensor_collector.aggregators.values(): + if isinstance(aggr, PostAggregateHook): + cur_values.add(aggr._aggregator._collected_samples) + else: + cur_values.add(aggr._collected_samples) + assert len(cur_values) == 1 + collected_samples += cur_values.pop() + + assert collected_samples == range_init_call_count_test_struct.expected_call_count_register_input[stat_type] QUANTIZER_RANGE_INITIALIZERS = [ diff --git a/tests/torch/tensor_statistics/test_tensor_statistics.py b/tests/torch/tensor_statistics/test_tensor_statistics.py index 8655f08a925..7981bd96461 100644 --- a/tests/torch/tensor_statistics/test_tensor_statistics.py +++ b/tests/torch/tensor_statistics/test_tensor_statistics.py @@ -15,19 +15,11 @@ import pytest import torch -from nncf.common.tensor_statistics.collectors import OfflineTensorStatisticCollector from nncf.common.tensor_statistics.collectors import ReductionShape -from nncf.common.tensor_statistics.collectors import StatisticsNotCollectedError from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.tensor_statistics.statistics import TensorStatistic from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMixedMinMaxStatisticCollector from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -from nncf.torch.tensor_statistics.collectors import PTPercentileStatisticCollector from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector @@ -83,7 +75,6 @@ class TestCollectedStatistics: ( partial( get_mixed_min_max_statistic_collector, - use_per_sample_stats=False, use_means_of_mins=True, use_means_of_maxs=True, ), @@ -108,7 +99,6 @@ class TestCollectedStatistics: ( partial( get_mixed_min_max_statistic_collector, - use_per_sample_stats=False, use_means_of_mins=False, use_means_of_maxs=True, ), @@ -138,9 +128,15 @@ def test_collected_statistics_with_shape_convert( reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatistic], ): for shapes in reduction_shapes_vs_ref_statistic.keys(): - output_shape, reduction_shape = shapes + output_shape, reducer_axes = shapes collector_obj = collector( - use_abs_max=True, reduction_shape=reduction_shape, keepdims=len(output_shape) > 1, num_samples=None + use_abs_max=True, + reducers_axes=reducer_axes, + reducers_keepdims=len(output_shape) > 1, + aggregators_axes=(0,), + aggregators_keepdims=False, + squeeze_dims=None, + num_samples=None, ) for input_ in TestCollectedStatistics.REF_INPUTS: collector_obj.register_unnamed_inputs(PTNNCFTensor(input_)) @@ -239,103 +235,36 @@ def test_collected_statistics_with_shape_convert( # ), }, ), - ][:], + ], ) def test_collected_statistics( self, collector: Type[TensorStatisticCollectorBase], reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatistic], ): - for shapes in reduction_shapes_vs_ref_statistic.keys(): - reduction_shape = shapes - - channel_dim = None + for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): if len(reduction_shape) > 1: - channel_dim = [dim for dim, val in enumerate(reduction_shape) if val == 1][0] + reducer_axes = ([dim for dim, val in enumerate(reduction_shape) if val == 1][0],) + aggregator_keep_dims = False + else: + reducer_axes = (0, 1) + aggregator_keep_dims = True - collector_obj = collector() + collector_obj = collector( + reducers_axes=reducer_axes, + reducers_keepdims=len(reduction_shape) > 1, + aggregators_axes=(0,), + aggregators_keepdims=aggregator_keep_dims, + num_samples=None, + squeeze_dims=None, + ) for input_ in TestCollectedStatistics.REF_INPUTS: if hasattr(collector_obj, "register_unnamed_inputs"): collector_obj.register_unnamed_inputs(PTNNCFTensor(input_)) else: collector_obj.register_inputs(input_) test_stats = collector_obj.get_statistics() - assert reduction_shapes_vs_ref_statistic[shapes] == test_stats - - COLLECTORS = [ - partial(PTMinMaxStatisticCollector, use_abs_max=False, output_shape=(1,)), - partial( - PTMixedMinMaxStatisticCollector, - use_per_sample_stats=False, - use_abs_max=False, - use_means_of_mins=False, - use_means_of_maxs=False, - output_shape=(1,), - ), - partial(PTMeanMinMaxStatisticCollector, use_per_sample_stats=False, use_abs_max=False, output_shape=(1,)), - PTMedianMADStatisticCollector, - partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), - partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), - ] - - @pytest.fixture(params=COLLECTORS) - def collector_for_interface_test(self, request): - collector_type = request.param - return collector_type(reduction_shape=(1,)) - - def test_collected_samples(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_inputs(input_) - assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) - - def test_reset(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_inputs(input_) - collector_for_interface_test.reset() - assert collector_for_interface_test.collected_samples() == 0 - with pytest.raises(StatisticsNotCollectedError): - collector_for_interface_test.get_statistics() - - def test_enable_disable(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_inputs(input_) - - collector_for_interface_test.disable() - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_inputs(input_) - assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) - - collector_for_interface_test.enable() - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_inputs(input_) - assert collector_for_interface_test.collected_samples() == 2 * len(TestCollectedStatistics.REF_INPUTS) - - OFFLINE_COLLECTORS = [ - partial( - PTMixedMinMaxStatisticCollector, - use_per_sample_stats=False, - use_abs_max=False, - use_means_of_mins=False, - use_means_of_maxs=False, - output_shape=(1,), - ), - partial(PTMeanMinMaxStatisticCollector, use_per_sample_stats=False, use_abs_max=False, output_shape=(1,)), - PTMedianMADStatisticCollector, - partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), - partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), - ] - - REF_NUM_SAMPLES = 3 - - @pytest.fixture(params=OFFLINE_COLLECTORS) - def collector_for_num_samples_test(self, request): - collector_type = request.param - return collector_type(reduction_shape=(1,), num_samples=TestCollectedStatistics.REF_NUM_SAMPLES) - - def test_num_samples(self, collector_for_num_samples_test: OfflineTensorStatisticCollector): - for input_ in TestCollectedStatistics.REF_INPUTS * 10: - collector_for_num_samples_test.register_inputs(input_) - assert collector_for_num_samples_test.collected_samples() == TestCollectedStatistics.REF_NUM_SAMPLES + assert reduction_shapes_vs_ref_statistic[reduction_shape] == test_stats class TestCollectorTensorProcessor: diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 4a4b8f48914..60bfb99015a 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -62,7 +62,7 @@ def get_backend_model(self, dataset_samples): @pytest.fixture def is_backend_support_custom_estimators(self) -> bool: - return False + return True @pytest.fixture(scope="session") def test_params(self):