From 0cce7afad49f899952bcbf249de9ca9418e46af4 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 11 May 2023 17:55:35 +0200 Subject: [PATCH] Sequential tensor reducer TensorReducerSequence Reducer adapter inside reducer TensorCollectorAdapter --- .../onnx/mobilenet_v2/main.py | 1 + .../mozilla-deepspeech/accuracy_checker.json | 147 ++++++++++++ .../openvino/mozilla-deepspeech/main.py | 55 +++++ nncf/__init__.py | 1 + nncf/common/tensor_statistics/aggregator.py | 43 +++- nncf/data/__init__.py | 2 + nncf/data/dataset.py | 29 ++- .../common/tensor_statistics/collectors.py | 216 ++++++++++++++---- nncf/onnx/statistics/aggregator.py | 5 + nncf/openvino/statistics/aggregator.py | 14 ++ nncf/torch/statistics/aggregator.py | 5 + 11 files changed, 465 insertions(+), 53 deletions(-) create mode 100644 examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json create mode 100644 examples/post_training_quantization/openvino/mozilla-deepspeech/main.py diff --git a/examples/post_training_quantization/onnx/mobilenet_v2/main.py b/examples/post_training_quantization/onnx/mobilenet_v2/main.py index 3e4687b8081..53b37da7755 100755 --- a/examples/post_training_quantization/onnx/mobilenet_v2/main.py +++ b/examples/post_training_quantization/onnx/mobilenet_v2/main.py @@ -104,6 +104,7 @@ def run_benchmark(path_to_model: str, shape: Optional[List[int]] = None, verbose # >> output_names = [output.name for output in sess.get_outputs()] # >> for data_item in val_loader: # >> sess.run(output_names, input_feed=transform_fn(data_item)) + input_name = model.graph.input[0].name diff --git a/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json b/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json new file mode 100644 index 00000000000..cdf14378be3 --- /dev/null +++ b/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json @@ -0,0 +1,147 @@ +{ + "compression": { + "algorithms": [ + { + "name": "DefaultQuantization", + "params": { + "preset": "performance", + "stat_subset_size": 3 + } + } + ], + "dump_intermediate_model": true + }, + "engine": { + "datasets": [ + { + "metrics": [ + { + "type": "wer" + } + ], + "name": "LibriSpeech_test_clean_wav", + "data_source": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav", + + "annotation_conversion": { + "converter": "librispeech", + "data_dir": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav" + }, + "preprocessing": [ + { + "int16mode": true, + "type": "audio_normalization" + }, + { + "duration": "512 samples", + "overlap": "192 samples", + "type": "clip_audio" + }, + { + "base": 512, + "type": "hanning_window" + }, + { + "fftbase": 512, + "magnitude_squared": true, + "skip_channels": true, + "type": "audio_spectrogram" + }, + { + "base": 257, + "filterbank_channel_count": 40, + "lower_frequency_limit": 20, + "sample_rate": 16000, + "type": "audio_triangle_filtering", + "upper_frequency_limit": 4000 + }, + { + "filterbank_channel_count": 40, + "numceps": 26, + "type": "audio_dct" + }, + { + "context": 9, + "numceps": 26, + "type": "clip_cepstrum" + }, + { + "step": 16, + "type": "pack_cepstrum" + } + ], + "reader": "wav_reader" + } + ], + "launchers": [ + { + "adapter": { + "beam_size": 32, + "lm_alpha": 0.75, + "lm_beta": 1.05, + "lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary", + "lm_oov_score": -1000, + "lm_vocabulary_length": 4463723, + "lm_vocabulary_offset": 941235601, + "logarithmic_prob": false, + "probability_out": "logits", + "type": "ctc_beam_search_decoder_with_lm" + }, + "framework": "dlsdk", + "inputs": [ + { + "layout": "NHWC", + "name": "input_node", + "type": "INPUT" + }, + { + "name": "previous_state_c", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.2" + }, + { + "name": "previous_state_h", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.1" + } + ] + }, + { + "adapter": { + "beam_size": 32, + "lm_alpha": 0.75, + "lm_beta": 1.05, + "lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary", + "lm_oov_score": -1000, + "lm_vocabulary_length": 4463723, + "lm_vocabulary_offset": 941235601, + "logarithmic_prob": false, + "probability_out": "logits", + "type": "ctc_beam_search_decoder_with_lm" + }, + "framework": "openvino", + "inputs": [ + { + "layout": "NHWC", + "name": "input_node", + "type": "INPUT" + }, + { + "name": "previous_state_c", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd:0" + }, + { + "name": "previous_state_h", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1:0" + } + ] + } + ] + }, + "model": { + "model": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.xml", + "model_name": "mozilla-deepspeech-0.6.1", + "weights": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.bin" + } +} diff --git a/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py b/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py new file mode 100644 index 00000000000..229d7a1229e --- /dev/null +++ b/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py @@ -0,0 +1,55 @@ +import json +import os +import subprocess + +import numpy as np +import openvino.runtime as ov +from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator +from openvino.tools.pot.configs.config import Config + +import nncf + +model_name = "mozilla-deepspeech-0.6.1" +cache_dir = os.path.dirname(__file__) +dataset_config = os.path.join(cache_dir, "accuracy_checker.json") + +command = f"omz_downloader --name {model_name} --cache_dir {cache_dir}" +cmd_output = subprocess.call(command, shell=True) # nosec + +model_dir = os.path.join(cache_dir, model_name) +if not os.path.exists(model_dir): + command = f"omz_converter --name {model_name} -o {os.path.join(cache_dir, model_name)}" + cmd_output = subprocess.call(command, shell=True) # nosec + +xml_path = os.path.join(model_dir, f"public/{model_name}/FP16/{model_name}.xml") +ov_model = ov.Core().read_model(xml_path) + +config = Config.read_config(dataset_config) +config.configure_params() +accuracy_checker_config = config.engine + +model_evaluator = create_model_evaluator(accuracy_checker_config) +model_evaluator.load_network([{"model": ov_model}]) +model_evaluator.select_dataset("") + + +def get_tokens_from_sequence_func(data_item): + _, batch_annotation, batch_input, _ = data_item + filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation) + for filled_input in filled_inputs: + input_data = {} + for name, value in filled_input.items(): + input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value + yield input_data + + +def fill_sequential_inputs_fn(model_inputs, model_outputs): + # Combine model inputs with state model outputs + # or fill state model outputs if model_outputs is None + state_inputs = model_evaluator.launcher._fill_lstm_inputs(model_outputs) + model_inputs.update(state_inputs) + return model_inputs + + +dataset = nncf.RecurentDataset(model_evaluator.dataset, get_tokens_from_sequence_func, fill_sequential_inputs_fn) +quantized_model = nncf.quantize(ov_model, dataset, subset_size=3) diff --git a/nncf/__init__.py b/nncf/__init__.py index 65a47864066..2b21702b727 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -15,6 +15,7 @@ from nncf.common.logging.logger import set_log_level from nncf.config import NNCFConfig from nncf.data import Dataset +from nncf.data import RecurentDataset from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import TargetDevice diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 9afc180aeda..b9dd193002e 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -10,6 +10,7 @@ # limitations under the License. from abc import ABC from abc import abstractmethod +from collections import defaultdict from itertools import islice from typing import Any, Dict, TypeVar @@ -21,6 +22,7 @@ from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.data.dataset import Dataset +from nncf.data.dataset import RecurentDataset TensorType = TypeVar("TensorType") TModel = TypeVar("TModel") @@ -31,10 +33,13 @@ class StatisticsAggregator(ABC): Base class for statistics collection. """ + STACK_AXIS = 0 + def __init__(self, dataset: Dataset): self.dataset = dataset self.stat_subset_size = None self.statistic_points = StatisticPointsContainer() + self._is_sequential = isinstance(dataset, RecurentDataset) def collect_statistics(self, model: TModel) -> None: """ @@ -46,19 +51,44 @@ def collect_statistics(self, model: TModel) -> None: model_transformer = ModelTransformerFactory.create(model) merged_statistics = self._get_merged_statistic_points(self.statistic_points, model) + if self._is_sequential: + merged_statistics = self._adapt_collectors(merged_statistics, self.STACK_AXIS) + transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) model_with_outputs = model_transformer.transform(transformation_layout) engine = EngineFactory.create(model_with_outputs) + infer_fn = self._infer_sequential if self._is_sequential else self._infer for input_data in tqdm( islice(self.dataset.get_inference_data(), self.stat_subset_size), total=self.stat_subset_size, desc="Statistics collection", ): - outputs = engine.infer(input_data) - processed_outputs = self._process_outputs(outputs) + processed_outputs = infer_fn(engine, input_data) self._register_statistics(processed_outputs, merged_statistics) + def _infer(self, engine, input_data): + outputs = engine.infer(input_data) + return self._process_outputs(outputs) + + def _infer_sequential(self, engine, sequence): + model_output = None + model_outputs = defaultdict(list) + for token in sequence.get_tokens_iter(): + filled_inputs = sequence.fill_inputs(token, model_output) + model_output = engine.infer(filled_inputs) + processed_output = self._process_outputs(model_output) + for output_name, output_value in processed_output.items(): + model_outputs[output_name].append(output_value) + + # Stack model outputs and return them + stacked_outputs = {} + tensor_processor = self._get_tensor_processor() + for output_name, output_values in model_outputs.items(): + stacked_outputs[output_name] = tensor_processor.stack(output_values, axis=self.STACK_AXIS) + + return stacked_outputs + def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None: """ Register statistic points for statistics collection and recalculates the maximum number samples @@ -115,6 +145,10 @@ def _get_merged_statistic_points( :return: Merged statistic points container bounded with given statistic point container. """ + @staticmethod + def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int): + return statistic_points + @staticmethod @abstractmethod def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]: @@ -124,3 +158,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]: :param outputs: raw model outputs :return: processed model outputs in Dict[str, NNCFTensor] format """ + + @staticmethod + @abstractmethod + def _get_tensor_processor(): + pass diff --git a/nncf/data/__init__.py b/nncf/data/__init__.py index 97fb8e39dfd..1f59096721e 100644 --- a/nncf/data/__init__.py +++ b/nncf/data/__init__.py @@ -10,3 +10,5 @@ # limitations under the License. from nncf.data.dataset import Dataset +from nncf.data.dataset import RecurentDataset +from nncf.data.dataset import Sequence diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index 6bf1322c2ec..a52b4265ed4 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Generic, Iterable, List, Optional, TypeVar +from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar from nncf.common.utils.api_marker import api @@ -115,3 +115,30 @@ def _get_iterator_for_iter( if idx == indices[pos]: pos = pos + 1 yield transform_func(data_item) + + +@api(canonical_alias="nncf.RecurentDataset") +class RecurentDataset(Dataset): + def __init__(self, data_source: Iterable, get_token_from_sequence_func, fill_sequential_inputs_fn): + def transform_fn_wrapper(data_item): + return Sequence(data_item, get_token_from_sequence_func, fill_sequential_inputs_fn) + + super().__init__(data_source, transform_fn_wrapper) + + +class Sequence: + def __init__( + self, + raw_sequence, + get_tokens_from_sequence_func: Callable[[DataItem], ModelInput], + fill_sequential_inputs_fn: Callable[[DataItem], ModelInput], + ): + self._raw_sequence = raw_sequence + self._get_tokens_from_sequence_func = get_tokens_from_sequence_func + self._fill_sequential_inputs_fn = fill_sequential_inputs_fn + + def get_tokens_iter(self): + return self._get_tokens_from_sequence_func(self._raw_sequence) + + def fill_inputs(self, token, model_outputs): + return self._fill_sequential_inputs_fn(token, model_outputs) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index e175d550b99..4d071f3066d 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -11,6 +11,8 @@ from abc import ABC from abc import abstractmethod +from abc import abstractproperty +from abc import abstractstaticmethod from collections import defaultdict from collections import deque from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union @@ -31,16 +33,16 @@ class TensorReducerBase(ABC): the specified rule. Could handle tensors inplace or out of place. """ - def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False): + def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True): """ :param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape))) if empty. - :param: Wheather should be calculated inplace or out of place. - + :param inplace: Wheather should be calculated inplace or out of place. """ self._reduction_shape = reduction_shape self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor() self._inplace = inplace + self._keepdims = keepdims @property def inplace(self): @@ -54,19 +56,10 @@ def output_port_id(self) -> int: def name(self): return self.__class__.__name__ + str(self.__hash__()) - @staticmethod - @abstractmethod + @abstractstaticmethod def _get_processor() -> NNCFCollectorTensorProcessor: pass - @abstractmethod - def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]: - """ - Specifies the reduction rule in terms of NNCFCollectorTensorProcessor. - - :param x: Tensor to register. - """ - @abstractmethod def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: """ @@ -86,15 +79,25 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: :return: Inplace operation builder if possible else None. """ - def __call__(self, x: List[NNCFTensor]): - if self.inplace: - return x + @abstractstaticmethod + def reduce_out_of_place(x: List[TensorType], **kwargs) -> List[TensorType]: + """ + Specifies the reduction rule in terms of NNCFCollectorTensorProcessor. - return self._reduce_out_of_place(x) + :param x: Tensor to register. + """ + + def get_kwargs(self): + return { + "reduction_shape": self._reduction_shape, + "tensor_processor": self._tensor_processor, + "keepdims": self._keepdims, + } def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) + and self._keepdims == __o._keepdims and self._reduction_shape == __o._reduction_shape and self._inplace == __o.inplace ) @@ -164,6 +167,54 @@ def __hash__(self) -> int: return hash(self.__class__.__name__) +class BaseTensorCollectorAdapter(ABC): + @abstractmethod + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + pass + + @abstractmethod + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + pass + + +class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter): + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + if reducer.inplace: + return x + + kwargs = reducer.get_kwargs() + reduction_shape_key = "reduction_shape" + if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None: + kwargs[reduction_shape_key] = tuple(range(len(x[0].shape))) + + return reducer.reduce_out_of_place(x, **kwargs) + + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + aggregator.register_reduced_input(x) + + +class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter): + def __init__(self, stack_axis: int, tensor_processor: NNCFCollectorTensorProcessor) -> None: + super().__init__() + self._stack_axis = stack_axis + self._tensor_processor = tensor_processor + + def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase): + per_element_result = super().reduce(x, reducer) + kwargs = reducer.get_kwargs() + new_params = {"reduction_shape": self._stack_axis, "keepdims": False} + if not all(k in kwargs for k in new_params): + return per_element_result + + kwargs.update(new_params) + return reducer.reduce_out_of_place(per_element_result, **kwargs) + + def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase): + if isinstance(aggregator, ShapeAggregator): + x = self._tensor_processor.unstack(x, axis=0)[0] + super().register_reduced_input(x, aggregator) + + class TensorCollector: """ Calculates statistics at given tensors according to registered statistic branches. @@ -180,6 +231,7 @@ def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> Non self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {} self._stat_container = statistic_container self._enabled = True + self._adapter = DefaultTensorCollectorAdapter() @property def num_samples(self) -> Optional[int]: @@ -203,6 +255,9 @@ def reducers(self): def aggregators(self): return self._aggregators.copy() + def set_adapter(self, adapter: "BaseTensorCollectorAdapter"): + self._adapter = adapter + def enable(self): self._enabled = True @@ -270,14 +325,14 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None: input_ = inputs[reducer_hash] if any([tensor.is_empty() for tensor in input_]): continue - reduced_inputs[reducer_hash] = reducer(input_) + reduced_inputs[reducer_hash] = self._adapter.reduce(input_, reducer) for ( (reducer_hash, reducer_port_id, _), aggregator, ) in self._aggregators.items(): if reducer_hash in reduced_inputs: - aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id]) + self._adapter.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], aggregator) def _aggregate(self) -> None: result = {} @@ -399,6 +454,9 @@ class NoopReducer(TensorReducerBase): def __init__(self): super().__init__(inplace=False) + def get_kwargs(self): + return {} + @staticmethod def _get_processor() -> NNCFCollectorTensorProcessor: return None @@ -406,36 +464,54 @@ def _get_processor() -> NNCFCollectorTensorProcessor: def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: return None - def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]: + @staticmethod + def reduce_out_of_place(x: List[TensorType]) -> List[TensorType]: return x class MinReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_min(x, reduction_shape, keepdims=True)] + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.reduce_min(x[0], reduction_shape, keepdims=keepdims)] class MaxReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.reduce_max(x[0], reduction_shape, keepdims=keepdims)] class AbsMaxReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = self._tensor_processor.abs(x[0]) - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + x = tensor_processor.abs(x[0]) + return [tensor_processor.reduce_max(x, reduction_shape, keepdims=keepdims)] class MeanReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.mean(x, reduction_shape, keepdims=True)] + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.mean(x[0], reduction_shape, keepdims=keepdims)] class QuantileReducerBase(TensorReducerBase): @@ -444,22 +520,41 @@ def __init__( reduction_shape: Optional[ReductionShape] = None, quantile: Union[float, List[float]] = [0.01, 0.99], inplace: bool = False, + keepdims: bool = True, ): - super().__init__(reduction_shape, False) + super().__init__(reduction_shape, False, keepdims) self._quantile = quantile + def get_kwargs(self): + kwargs = super().get_kwargs() + kwargs["quantile"] = self._quantile + return kwargs + def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._quantile == __o._quantile def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape, tuple(self._quantile))) + return hash( + ( + self.__class__.__name__, + self.inplace, + self._reduction_shape, + self._keepdims, + tuple(self._quantile), + ) + ) class QuantileReducer(QuantileReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=True) + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + quantile: Union[float, List[float]], + keepdims: bool, + ) -> List[NNCFTensor]: + return tensor_processor.quantile(x[0], quantile, reduction_shape, keepdims=keepdims) class AbsQuantileReducer(QuantileReducerBase): @@ -471,26 +566,47 @@ def __init__( ): super().__init__(reduction_shape, quantile, False) - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = self._tensor_processor.abs(x[0]) - reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, [self._quantile], reduction_shape, keepdims=True) + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + quantile: Union[float, List[float]], + keepdims: bool, + ) -> List[NNCFTensor]: + x = tensor_processor.abs(x[0]) + return tensor_processor.quantile(x, [quantile], reduction_shape, keepdims=keepdims) class BatchMeanReducer(TensorReducerBase): def __init__(self, inplace: bool = False): super().__init__(None, inplace) - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - return [self._tensor_processor.batch_mean(x[0])] + def get_kwargs(self): + return { + "tensor_processor": self._tensor_processor, + } + + @staticmethod + def reduce_out_of_place(x: List[NNCFTensor], tensor_processor: NNCFCollectorTensorProcessor) -> List[NNCFTensor]: + return [tensor_processor.batch_mean(x[0])] class MeanPerChReducer(TensorReducerBase): def __init__(self, channel_dim: int = 1, inplace: bool = False): super().__init__(channel_dim, inplace) - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - return [self._tensor_processor.mean_per_channel(x[0], self._reduction_shape)] + def get_kwargs(self): + return { + "tensor_processor": self._tensor_processor, + "channel_dim": self._reduction_shape, + } + + @staticmethod + def reduce_out_of_place( + x: List[NNCFTensor], tensor_processor: NNCFCollectorTensorProcessor, channel_dim: int + ) -> List[NNCFTensor]: + return [tensor_processor.mean_per_channel(x[0], channel_dim)] ##################################################Aggregators################################################## diff --git a/nncf/onnx/statistics/aggregator.py b/nncf/onnx/statistics/aggregator.py index e48267ac136..3253dd3ec2b 100644 --- a/nncf/onnx/statistics/aggregator.py +++ b/nncf/onnx/statistics/aggregator.py @@ -79,3 +79,8 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, ONNXNNCFTensor]: return {n: ONNXNNCFTensor(v) for n, v in outputs.items()} + + def _get_tensor_processor(): + from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor + + return ONNXNNCFCollectorTensorProcessor diff --git a/nncf/openvino/statistics/aggregator.py b/nncf/openvino/statistics/aggregator.py index 60c7bab3e80..d62af7cbdd1 100644 --- a/nncf/openvino/statistics/aggregator.py +++ b/nncf/openvino/statistics/aggregator.py @@ -21,10 +21,12 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector +from nncf.experimental.common.tensor_statistics.collectors import SequentialTensorCollectorAdapter from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand +from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.tensor import OVNNCFTensor @@ -72,6 +74,14 @@ def _get_transformation_layout_extra_outputs( return transformation_layout + @staticmethod + # TODO(dlyakhov) Move this to common part + def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int): + for _, _, tensor_collector in statistic_points.get_tensor_collectors(): + tensor_collector.set_adapter(SequentialTensorCollectorAdapter(stack_axis, OVNNCFCollectorTensorProcessor)) + + return statistic_points + @staticmethod # TODO(dlyakhov) Move this to common part def _get_merged_statistic_points( @@ -112,3 +122,7 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, OVNNCFTensor]: return {n: OVNNCFTensor(v) for n, v in outputs.items()} + + @staticmethod + def _get_tensor_processor(): + return OVNNCFCollectorTensorProcessor diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 4c57699cd3c..a3eb21cb0d4 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -67,3 +67,8 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, PTNNCFTensor]: return outputs + + def _get_tensor_processor(): + from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor + + return PTNNCFCollectorTensorProcessor