Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 31, 2023
1 parent 262ffe9 commit 9dd3f85
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 45 deletions.
46 changes: 33 additions & 13 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,30 +565,34 @@ def __init__(

class OnlineAggregatorBase(OnlineOfflineAggregatorBase, ABC):
def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None:
if len(self._aggregation_axes) > 1:
raise NotImplemented()
unstacked_tensors = self._tensor_processor.unstack(x)
online_aggregation_axes = [dim - 1 for dim in self._aggregation_axes if dim != 0]
if online_aggregation_axes:
reduced = fn(x, axis=online_aggregation_axes, keepdims=self._keepdims)
else:
reduced = x
if 0 in self._aggregation_axes:
if self._container:
unstacked_tensors.append(self._container)
self._container = fn(*unstacked_tensors)
reduced = fn(self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False)
self._container = reduced
else:
if not self._container:
self._container = x
else:
self._container = fn(x, self._container)
self._container.append(reduced)

def _aggregate_impl(self):
return self._container.tensor
if 0 in self._aggregation_axes:
if self._keepdims:
return self._tensor_processor.stack([self._container]).tensor
return self._container.tensor
return self._tensor_processor.stack(self._container).tensor


class MinAggregator(OnlineAggregatorBase):
def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._online_register_reduced_input_impl(x, self._tensor_processor.min)
return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_min)


class MaxAggregator(OnlineAggregatorBase):
def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._online_register_reduced_input_impl(x, self._tensor_processor.max)
return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_max)


class OfflineAggregatorBase(OnlineOfflineAggregatorBase, ABC):
Expand Down Expand Up @@ -629,7 +633,9 @@ def __init__(

def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
result = self._tensor_processor.no_outliers_map(stacked_val, fn, axis=self._axis, alpha=self._quantile)
result = self._tensor_processor.no_outliers_map(
stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile
)
return result.tensor

def __eq__(self, __o: object) -> bool:
Expand Down Expand Up @@ -701,6 +707,20 @@ def _aggregate_impl(self) -> Any:
return retval


class PostAggregateAggregatorHook(TensorAggregatorBase, ABC):
def __init__(self, aggregator: TensorAggregatorBase, post_aggregation_hook):
super().__init__(None)
self._aggregator = aggregator
self._post_aggregation_hook = post_aggregation_hook

def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._aggregator.register_reduced_input(x)

def _aggregate_impl(self) -> Any:
retval = self._aggregator.aggregate()
return self._post_aggregation_hook(retval)


AGGREGATORS_MAP = {
AggregatorType.MIN: MinAggregator,
AggregatorType.MAX: MaxAggregator,
Expand Down
7 changes: 3 additions & 4 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def generate_stat_collector_for_range_init_config(
if init_config.init_type not in RANGE_INIT_TYPES_VS_DESCRIPTIONS:
raise RuntimeError("Unknown range init type: {}".format(init_config.init_type))

keepdims = collector_params._per_channel
use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max")
reduction_axes = collector_params.convert_reduction_axes(use_per_sample_stats)
collector_kwargs = collector_params.convert_statistic_params(use_per_sample_stats)

if init_config.init_type == "min_max":
Expand Down Expand Up @@ -238,15 +236,16 @@ def generate_stat_collector_for_range_init_config(
num_samples=num_samples,
**collector_kwargs,
)
return PTPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples)
# return PTPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples)
if init_config.init_type == "mean_percentile":
min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1)
max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9)
return get_mean_percentile_statistic_collector(
percentiles_to_collect=[min_percentile, max_percentile],
num_samples=num_samples,
**collector_kwargs,
)
return PTMeanPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples)
# return PTMeanPercentileStatisticCollector([min_percentile, max_percentile], scale_shape, num_samples)
raise ValueError("Range init type not handled!")

@classmethod
Expand Down
109 changes: 85 additions & 24 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Any, Callable, Deque, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -41,6 +42,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import PostAggregateAggregatorHook
from nncf.experimental.common.tensor_statistics.collectors import PrecentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
Expand Down Expand Up @@ -234,6 +236,16 @@ class PTMeanPerChReducer(PTReducerMixIn, MeanPerChReducer):
pass


def maybe_add_squeeze(aggregator, squeeze_dims):
if not squeeze_dims:
return aggregator

def post_aggregation_hook(aggregated_value):
return PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(aggregated_value), dim=squeeze_dims).tensor

return PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook)


def get_min_max_statistic_collector(
use_abs_max,
reducers_axes,
Expand All @@ -253,11 +265,13 @@ def get_min_max_statistic_collector(
}
min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims)
min_aggregator = MinAggregator(**aggregator_kwargs)
min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims)
max_aggregator = MaxAggregator(**aggregator_kwargs)
max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator)
return tensor_collector

Expand Down Expand Up @@ -296,7 +310,6 @@ def get_mixed_min_max_statistic_collector(
window_size: int = None,
):
tensor_collector = TensorCollector(PTMinMaxTensorStatistic)

min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims)

kwargs = {
Expand All @@ -308,13 +321,16 @@ def get_mixed_min_max_statistic_collector(
}
min_aggregator_cls = MeanAggregator if use_means_of_mins else MinAggregator
min_aggregator = min_aggregator_cls(**kwargs)
min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims)
max_aggregator_cls = MeanAggregator if use_means_of_maxs else MinAggregator
max_aggregator = max_aggregator_cls(**kwargs)
max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator)

return tensor_collector


Expand Down Expand Up @@ -382,7 +398,9 @@ def _get_statistics(self) -> PTMinMaxTensorStatistic:
return PTMinMaxTensorStatistic(min_values, max_values)


def get_median_mad_statistic_collector(
def _get_collection_without_reduction(
aggregator_cls,
statistic_class,
reducers_axes,
reducers_keepdims: bool,
aggregators_axes,
Expand All @@ -391,21 +409,57 @@ def get_median_mad_statistic_collector(
squeeze_dims,
window_size: int = None,
):
tensor_collector = TensorCollector(PTMedianMADTensorStatistic)
tensor_collector = TensorCollector(statistic_class)
reducer = PTNoopReducer()
aggregator = MedianAbsoluteDeviationAggregator(
aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes]))
aggregator = aggregator_cls(
PTNNCFCollectorTensorProcessor,
aggregation_axes=aggregators_axes + [dim + 1 for dim in reducers_axes],
aggregation_axes=aggregation_axes,
window_size=window_size,
num_samples=num_samples,
keepdims=aggregators_keepdims,
keepdims=True,
)
dims_to_squeeze = [0] if squeeze_dims else []
dims_to_squeeze += [axis + 1 for axis in reducers_axes] if not reducers_keepdims else []
dims_to_squeeze += aggregators_axes if not aggregators_keepdims else []
if dims_to_squeeze:

def post_aggregation_hook(aggregated_value):
retval = {}
for key, value in aggregated_value.items():
retval[key] = PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(value), dim=dims_to_squeeze).tensor
return retval

aggregator = PostAggregateAggregatorHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook)

tensor_collector.register_statistic_branch(
PTMedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator
)
return tensor_collector


def get_median_mad_statistic_collector(
reducers_axes,
reducers_keepdims: bool,
aggregators_axes,
aggregators_keepdims,
num_samples: int,
squeeze_dims,
window_size: int = None,
):
return _get_collection_without_reduction(
MedianAbsoluteDeviationAggregator,
PTMedianMADTensorStatistic,
reducers_axes=reducers_axes,
reducers_keepdims=reducers_keepdims,
aggregators_axes=aggregators_axes,
aggregators_keepdims=aggregators_keepdims,
num_samples=num_samples,
squeeze_dims=squeeze_dims,
window_size=window_size,
)


class PTMedianMADStatisticCollector(MedianMADStatisticCollector):
def _register_input(self, x: torch.Tensor):
with no_nncf_trace():
Expand All @@ -432,20 +486,17 @@ def get_precentile_tensor_collector(
squeeze_dims,
window_size: int = None,
):
tensor_collector = TensorCollector(PTPercentileTensorStatistic)
reducer = PTNoopReducer()
aggregator = PrecentileAggregator(
PTNNCFCollectorTensorProcessor,
percentiles_to_collect=percentiles_to_collect,
reduction_shape=aggregators_axes + [dim + 1 for dim in aggregators_axes],
window_size=window_size,
return _get_collection_without_reduction(
partial(PrecentileAggregator, percentiles_to_collect=percentiles_to_collect),
PTPercentileTensorStatistic,
reducers_axes=reducers_axes,
reducers_keepdims=reducers_keepdims,
aggregators_axes=aggregators_axes,
aggregators_keepdims=aggregators_keepdims,
num_samples=num_samples,
keepdims=reducers_keepdims,
)
tensor_collector.register_statistic_branch(
PTPercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator
squeeze_dims=squeeze_dims,
window_size=window_size,
)
return tensor_collector


class PTPercentileStatisticCollector(PercentileStatisticCollector):
Expand All @@ -462,19 +513,29 @@ def _get_statistics(self) -> PTPercentileTensorStatistic:


def get_mean_percentile_statistic_collector(
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
keepdims: bool = True,
num_samples: int = None,
percentiles_to_collect,
reducers_axes,
reducers_keepdims: bool,
aggregators_axes,
aggregators_keepdims,
num_samples: int,
squeeze_dims,
window_size: int = None,
):
tensor_collector = TensorCollector(PTPercentileTensorStatistic)
quantiles_to_collect = np.true_divide(percentiles_to_collect, 100)
reducer = PTQuantileReducer(reduction_shape, quantiles_to_collect, keepdims=keepdims)
reducer = PTQuantileReducer(
reduction_shape=reducers_axes, quantile=quantiles_to_collect, keepdims=reducers_keepdims
)
for output_port_id, p in enumerate(percentiles_to_collect):
aggregator = MeanAggregator(
PTNNCFCollectorTensorProcessor, use_per_sample_stats=False, num_samples=num_samples, window_size=window_size
PTNNCFCollectorTensorProcessor,
aggregation_axes=aggregators_axes,
keepdims=aggregators_keepdims,
num_samples=num_samples,
window_size=window_size,
)
aggregator = maybe_add_squeeze(aggregator, squeeze_dims)
tensor_collector.register_statistic_branch(
(PTPercentileTensorStatistic.PRECENTILE_VS_VALUE_DICT, p), reducer, aggregator, output_port_id
)
Expand Down
12 changes: 10 additions & 2 deletions tests/torch/quantization/test_range_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ def init_idfn(val):
("mean_min_max", 9999, 0, 9999),
("threesigma", 16119.5, -6119.5, 22239),
("percentile", 6789, 3210, 3578),
]
("mean_percentile", 6789, 9.9990, 9979.0020),
][-1:]
),
ids=init_idfn,
)
Expand Down Expand Up @@ -722,7 +723,14 @@ def test_per_layer_range_init_collectors_are_called_the_required_number_of_times
)


QUANTIZER_RANGE_INITIALIZERS = ["min_max", "threesigma", "mean_min_max", "percentile", "mixed_min_max"]
QUANTIZER_RANGE_INITIALIZERS = [
"min_max",
"threesigma",
"mean_min_max",
"percentile",
"mixed_min_max",
"mean_percentile",
][-1:]


class QuantizeRangeInitScaleShapeTestStruct:
Expand Down
4 changes: 2 additions & 2 deletions tests/torch/tensor_statistics/test_tensor_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def test_collected_statistics(
channel_dim = None
if len(reduction_shape) > 1:
channel_dim = [dim for dim, val in enumerate(reduction_shape) if val == 1][0]
collector_obj = collector(keepdims=len(reduction_shape) > 1, channel_dim=channel_dim)
# collector_obj = collector(reduction_shape=channel_dim, keepdims=len(reduction_shape) > 1)

collector_obj = collector()
for input_ in TestCollectedStatistics.REF_INPUTS:
if hasattr(collector_obj, "register_unnamed_inputs"):
collector_obj.register_unnamed_inputs(PTNNCFTensor(input_))
Expand Down

0 comments on commit 9dd3f85

Please sign in to comment.