diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 82c5385e174..cf76f6b532a 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -565,30 +565,34 @@ def __init__( class OnlineAggregatorBase(OnlineOfflineAggregatorBase, ABC): def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None: - if len(self._aggregation_axes) > 1: - raise NotImplemented() - unstacked_tensors = self._tensor_processor.unstack(x) + online_aggregation_axes = [dim - 1 for dim in self._aggregation_axes if dim != 0] + if online_aggregation_axes: + reduced = fn(x, axis=online_aggregation_axes, keepdims=self._keepdims) + else: + reduced = x + if 0 in self._aggregation_axes: if self._container: - unstacked_tensors.append(self._container) - self._container = fn(*unstacked_tensors) + reduced = fn(self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False) + self._container = reduced else: - if not self._container: - self._container = x - else: - self._container = fn(x, self._container) + self._container.append(reduced) def _aggregate_impl(self): - return self._container.tensor + if 0 in self._aggregation_axes: + if self._keepdims: + return self._tensor_processor.stack([self._container]).tensor + return self._container.tensor + return self._tensor_processor.stack(self._container).tensor class MinAggregator(OnlineAggregatorBase): def _register_reduced_input_impl(self, x: TensorType) -> None: - return self._online_register_reduced_input_impl(x, self._tensor_processor.min) + return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_min) class MaxAggregator(OnlineAggregatorBase): def _register_reduced_input_impl(self, x: TensorType) -> None: - return self._online_register_reduced_input_impl(x, self._tensor_processor.max) + return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_max) class OfflineAggregatorBase(OnlineOfflineAggregatorBase, ABC): @@ -629,7 +633,9 @@ def __init__( def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]: stacked_val = self._tensor_processor.stack(self._container) - result = self._tensor_processor.no_outliers_map(stacked_val, fn, axis=self._axis, alpha=self._quantile) + result = self._tensor_processor.no_outliers_map( + stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile + ) return result.tensor def __eq__(self, __o: object) -> bool: @@ -701,6 +707,20 @@ def _aggregate_impl(self) -> Any: return retval +class PostAggregateAggregatorHook(TensorAggregatorBase, ABC): + def __init__(self, aggregator: TensorAggregatorBase, post_aggregation_hook): + super().__init__(None) + self._aggregator = aggregator + self._post_aggregation_hook = post_aggregation_hook + + def _register_reduced_input_impl(self, x: TensorType) -> None: + return self._aggregator.register_reduced_input(x) + + def _aggregate_impl(self) -> Any: + retval = self._aggregator.aggregate() + return self._post_aggregation_hook(retval) + + AGGREGATORS_MAP = { AggregatorType.MIN: MinAggregator, AggregatorType.MAX: MaxAggregator, diff --git a/nncf/torch/quantization/init_range.py b/nncf/torch/quantization/init_range.py index 7328b785f85..0d1a5de2ab4 100644 --- a/nncf/torch/quantization/init_range.py +++ b/nncf/torch/quantization/init_range.py @@ -187,9 +187,7 @@ def generate_stat_collector_for_range_init_config( if init_config.init_type not in RANGE_INIT_TYPES_VS_DESCRIPTIONS: raise RuntimeError("Unknown range init type: {}".format(init_config.init_type)) - keepdims = collector_params._per_channel use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max") - reduction_axes = collector_params.convert_reduction_axes(use_per_sample_stats) collector_kwargs = collector_params.convert_statistic_params(use_per_sample_stats) if init_config.init_type == "min_max": @@ -238,15 +236,16 @@ def generate_stat_collector_for_range_init_config( num_samples=num_samples, **collector_kwargs, ) - return PTPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples) + # return PTPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples) if init_config.init_type == "mean_percentile": min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1) max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9) return get_mean_percentile_statistic_collector( percentiles_to_collect=[min_percentile, max_percentile], + num_samples=num_samples, **collector_kwargs, ) - return PTMeanPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples) + # return PTMeanPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples) raise ValueError("Range init type not handled!") @classmethod diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 29474764db5..4baedc594a5 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Any, Callable, Deque, List, Optional, Tuple, Union import numpy as np @@ -41,6 +42,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopReducer +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateAggregatorHook from nncf.experimental.common.tensor_statistics.collectors import PrecentileAggregator from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector @@ -234,6 +236,16 @@ class PTMeanPerChReducer(PTReducerMixIn, MeanPerChReducer): pass +def maybe_add_squeeze(aggregator, squeeze_dims): + if not squeeze_dims: + return aggregator + + def post_aggregation_hook(aggregated_value): + return PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(aggregated_value), dim=squeeze_dims).tensor + + return PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + def get_min_max_statistic_collector( use_abs_max, reducers_axes, @@ -253,11 +265,13 @@ def get_min_max_statistic_collector( } min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims) min_aggregator = MinAggregator(**aggregator_kwargs) + min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims) tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator) max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims) max_aggregator = MaxAggregator(**aggregator_kwargs) + max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims) tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator) return tensor_collector @@ -296,7 +310,6 @@ def get_mixed_min_max_statistic_collector( window_size: int = None, ): tensor_collector = TensorCollector(PTMinMaxTensorStatistic) - min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims) kwargs = { @@ -308,13 +321,16 @@ def get_mixed_min_max_statistic_collector( } min_aggregator_cls = MeanAggregator if use_means_of_mins else MinAggregator min_aggregator = min_aggregator_cls(**kwargs) + min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims) tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator) max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims) max_aggregator_cls = MeanAggregator if use_means_of_maxs else MinAggregator max_aggregator = max_aggregator_cls(**kwargs) + max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims) tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator) + return tensor_collector @@ -382,7 +398,9 @@ def _get_statistics(self) -> PTMinMaxTensorStatistic: return PTMinMaxTensorStatistic(min_values, max_values) -def get_median_mad_statistic_collector( +def _get_collection_without_reduction( + aggregator_cls, + statistic_class, reducers_axes, reducers_keepdims: bool, aggregators_axes, @@ -391,21 +409,57 @@ def get_median_mad_statistic_collector( squeeze_dims, window_size: int = None, ): - tensor_collector = TensorCollector(PTMedianMADTensorStatistic) + tensor_collector = TensorCollector(statistic_class) reducer = PTNoopReducer() - aggregator = MedianAbsoluteDeviationAggregator( + aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes])) + aggregator = aggregator_cls( PTNNCFCollectorTensorProcessor, - aggregation_axes=aggregators_axes + [dim + 1 for dim in reducers_axes], + aggregation_axes=aggregation_axes, window_size=window_size, num_samples=num_samples, - keepdims=aggregators_keepdims, + keepdims=True, ) + dims_to_squeeze = [0] if squeeze_dims else [] + dims_to_squeeze += [axis + 1 for axis in reducers_axes] if not reducers_keepdims else [] + dims_to_squeeze += aggregators_axes if not aggregators_keepdims else [] + if dims_to_squeeze: + + def post_aggregation_hook(aggregated_value): + retval = {} + for key, value in aggregated_value.items(): + retval[key] = PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(value), dim=dims_to_squeeze).tensor + return retval + + aggregator = PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + tensor_collector.register_statistic_branch( PTMedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator ) return tensor_collector +def get_median_mad_statistic_collector( + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + return _get_collection_without_reduction( + MedianAbsoluteDeviationAggregator, + PTMedianMADTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, + num_samples=num_samples, + squeeze_dims=squeeze_dims, + window_size=window_size, + ) + + class PTMedianMADStatisticCollector(MedianMADStatisticCollector): def _register_input(self, x: torch.Tensor): with no_nncf_trace(): @@ -432,20 +486,17 @@ def get_precentile_tensor_collector( squeeze_dims, window_size: int = None, ): - tensor_collector = TensorCollector(PTPercentileTensorStatistic) - reducer = PTNoopReducer() - aggregator = PrecentileAggregator( - PTNNCFCollectorTensorProcessor, - percentiles_to_collect=percentiles_to_collect, - reduction_shape=aggregators_axes + [dim + 1 for dim in aggregators_axes], - window_size=window_size, + return _get_collection_without_reduction( + partial(PrecentileAggregator, percentiles_to_collect=percentiles_to_collect), + PTPercentileTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, num_samples=num_samples, - keepdims=reducers_keepdims, - ) - tensor_collector.register_statistic_branch( - PTPercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator + squeeze_dims=squeeze_dims, + window_size=window_size, ) - return tensor_collector class PTPercentileStatisticCollector(PercentileStatisticCollector): @@ -462,19 +513,29 @@ def _get_statistics(self) -> PTPercentileTensorStatistic: def get_mean_percentile_statistic_collector( - percentiles_to_collect: List[float], - reduction_shape: Optional[ReductionShape] = None, - keepdims: bool = True, - num_samples: int = None, + percentiles_to_collect, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, window_size: int = None, ): tensor_collector = TensorCollector(PTPercentileTensorStatistic) quantiles_to_collect = np.true_divide(percentiles_to_collect, 100) - reducer = PTQuantileReducer(reduction_shape, quantiles_to_collect, keepdims=keepdims) + reducer = PTQuantileReducer( + reduction_shape=reducers_axes, quantile=quantiles_to_collect, keepdims=reducers_keepdims + ) for output_port_id, p in enumerate(percentiles_to_collect): aggregator = MeanAggregator( - PTNNCFCollectorTensorProcessor, use_per_sample_stats=False, num_samples=num_samples, window_size=window_size + PTNNCFCollectorTensorProcessor, + aggregation_axes=aggregators_axes, + keepdims=aggregators_keepdims, + num_samples=num_samples, + window_size=window_size, ) + aggregator = maybe_add_squeeze(aggregator, squeeze_dims) tensor_collector.register_statistic_branch( (PTPercentileTensorStatistic.PRECENTILE_VS_VALUE_DICT, p), reducer, aggregator, output_port_id ) diff --git a/tests/torch/quantization/test_range_init.py b/tests/torch/quantization/test_range_init.py index 455df062384..fbeabb40215 100644 --- a/tests/torch/quantization/test_range_init.py +++ b/tests/torch/quantization/test_range_init.py @@ -531,7 +531,8 @@ def init_idfn(val): ("mean_min_max", 9999, 0, 9999), ("threesigma", 16119.5, -6119.5, 22239), ("percentile", 6789, 3210, 3578), - ] + ("mean_percentile", 6789, 9.9990, 9979.0020), + ][-1:] ), ids=init_idfn, ) @@ -722,7 +723,14 @@ def test_per_layer_range_init_collectors_are_called_the_required_number_of_times ) -QUANTIZER_RANGE_INITIALIZERS = ["min_max", "threesigma", "mean_min_max", "percentile", "mixed_min_max"] +QUANTIZER_RANGE_INITIALIZERS = [ + "min_max", + "threesigma", + "mean_min_max", + "percentile", + "mixed_min_max", + "mean_percentile", +][-1:] class QuantizeRangeInitScaleShapeTestStruct: diff --git a/tests/torch/tensor_statistics/test_tensor_statistics.py b/tests/torch/tensor_statistics/test_tensor_statistics.py index dfb974d43f6..8655f08a925 100644 --- a/tests/torch/tensor_statistics/test_tensor_statistics.py +++ b/tests/torch/tensor_statistics/test_tensor_statistics.py @@ -252,8 +252,8 @@ def test_collected_statistics( channel_dim = None if len(reduction_shape) > 1: channel_dim = [dim for dim, val in enumerate(reduction_shape) if val == 1][0] - collector_obj = collector(keepdims=len(reduction_shape) > 1, channel_dim=channel_dim) - # collector_obj = collector(reduction_shape=channel_dim, keepdims=len(reduction_shape) > 1) + + collector_obj = collector() for input_ in TestCollectedStatistics.REF_INPUTS: if hasattr(collector_obj, "register_unnamed_inputs"): collector_obj.register_unnamed_inputs(PTNNCFTensor(input_))