diff --git a/examples/post_training_quantization/onnx/mobilenet_v2/main.py b/examples/post_training_quantization/onnx/mobilenet_v2/main.py index 3e4687b8081..53b37da7755 100755 --- a/examples/post_training_quantization/onnx/mobilenet_v2/main.py +++ b/examples/post_training_quantization/onnx/mobilenet_v2/main.py @@ -104,6 +104,7 @@ def run_benchmark(path_to_model: str, shape: Optional[List[int]] = None, verbose # >> output_names = [output.name for output in sess.get_outputs()] # >> for data_item in val_loader: # >> sess.run(output_names, input_feed=transform_fn(data_item)) + input_name = model.graph.input[0].name diff --git a/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json b/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json new file mode 100644 index 00000000000..cdf14378be3 --- /dev/null +++ b/examples/post_training_quantization/openvino/mozilla-deepspeech/accuracy_checker.json @@ -0,0 +1,147 @@ +{ + "compression": { + "algorithms": [ + { + "name": "DefaultQuantization", + "params": { + "preset": "performance", + "stat_subset_size": 3 + } + } + ], + "dump_intermediate_model": true + }, + "engine": { + "datasets": [ + { + "metrics": [ + { + "type": "wer" + } + ], + "name": "LibriSpeech_test_clean_wav", + "data_source": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav", + + "annotation_conversion": { + "converter": "librispeech", + "data_dir": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav" + }, + "preprocessing": [ + { + "int16mode": true, + "type": "audio_normalization" + }, + { + "duration": "512 samples", + "overlap": "192 samples", + "type": "clip_audio" + }, + { + "base": 512, + "type": "hanning_window" + }, + { + "fftbase": 512, + "magnitude_squared": true, + "skip_channels": true, + "type": "audio_spectrogram" + }, + { + "base": 257, + "filterbank_channel_count": 40, + "lower_frequency_limit": 20, + "sample_rate": 16000, + "type": "audio_triangle_filtering", + "upper_frequency_limit": 4000 + }, + { + "filterbank_channel_count": 40, + "numceps": 26, + "type": "audio_dct" + }, + { + "context": 9, + "numceps": 26, + "type": "clip_cepstrum" + }, + { + "step": 16, + "type": "pack_cepstrum" + } + ], + "reader": "wav_reader" + } + ], + "launchers": [ + { + "adapter": { + "beam_size": 32, + "lm_alpha": 0.75, + "lm_beta": 1.05, + "lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary", + "lm_oov_score": -1000, + "lm_vocabulary_length": 4463723, + "lm_vocabulary_offset": 941235601, + "logarithmic_prob": false, + "probability_out": "logits", + "type": "ctc_beam_search_decoder_with_lm" + }, + "framework": "dlsdk", + "inputs": [ + { + "layout": "NHWC", + "name": "input_node", + "type": "INPUT" + }, + { + "name": "previous_state_c", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.2" + }, + { + "name": "previous_state_h", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.1" + } + ] + }, + { + "adapter": { + "beam_size": 32, + "lm_alpha": 0.75, + "lm_beta": 1.05, + "lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary", + "lm_oov_score": -1000, + "lm_vocabulary_length": 4463723, + "lm_vocabulary_offset": 941235601, + "logarithmic_prob": false, + "probability_out": "logits", + "type": "ctc_beam_search_decoder_with_lm" + }, + "framework": "openvino", + "inputs": [ + { + "layout": "NHWC", + "name": "input_node", + "type": "INPUT" + }, + { + "name": "previous_state_c", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd:0" + }, + { + "name": "previous_state_h", + "type": "LSTM_INPUT", + "value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1:0" + } + ] + } + ] + }, + "model": { + "model": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.xml", + "model_name": "mozilla-deepspeech-0.6.1", + "weights": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.bin" + } +} diff --git a/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py b/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py new file mode 100644 index 00000000000..627518098b8 --- /dev/null +++ b/examples/post_training_quantization/openvino/mozilla-deepspeech/main.py @@ -0,0 +1,69 @@ +import os +import subprocess + +import openvino.runtime as ov +from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator +from openvino.tools.pot.configs.config import Config + +import nncf + +model_name = "mozilla-deepspeech-0.6.1" +cache_dir = os.path.dirname(__file__) +dataset_config = os.path.join(cache_dir, "accuracy_checker.json") + +command = f"omz_downloader --name {model_name} --cache_dir {cache_dir}" +cmd_output = subprocess.call(command, shell=True) # nosec + +model_dir = os.path.join(cache_dir, model_name) +if not os.path.exists(model_dir): + command = f"omz_converter --name {model_name} -o {os.path.join(cache_dir, model_name)}" + cmd_output = subprocess.call(command, shell=True) # nosec + +xml_path = os.path.join(model_dir, f"public/{model_name}/FP16/{model_name}.xml") +ov_model = ov.Core().read_model(xml_path) + +config = Config.read_config(dataset_config) +config.configure_params() +accuracy_checker_config = config.engine + +model_evaluator = create_model_evaluator(accuracy_checker_config) +model_evaluator.load_network([{"model": ov_model}]) +model_evaluator.select_dataset("") + + +def sequence_transform_fn(data_item): + """ + Quantization transform function. Extracts and preprocesses sequential inputs data from dataloader + for quantization, returns iterable on preprocessed elements of feeded data item. + + :param data_item: Data item produced by DataLoader during iteration + :return: Iterable object on preprocessed elements of feeded data item. + """ + return data_item + + +def get_custom_forward(model, callback): + def custom_forward(data_item): + def iter_through_sequence(): + _, batch_annotation, batch_input, _ = data_item + filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation) + for filled_input in filled_inputs: + input_data = {} + for name, value in filled_input.items(): + input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value + yield input_data + + model_outputs = None + for model_inputs in iter_through_sequence(): + state_inputs = model_evaluator.launcher._fill_lstm_inputs(model_outputs) + model_inputs.update(state_inputs) + model_outputs = model(model_inputs) + callback(model_outputs) + + return custom_forward + + +dataset = nncf.CustomInferenceDataset(model_evaluator.dataset, sequence_transform_fn, get_custom_forward) + + +quantized_model = nncf.quantize(ov_model, dataset, subset_size=3) diff --git a/examples/post_training_quantization/openvino/tiny_gpt2/__init__.py b/examples/post_training_quantization/openvino/tiny_gpt2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/post_training_quantization/openvino/tiny_gpt2/main.py b/examples/post_training_quantization/openvino/tiny_gpt2/main.py new file mode 100644 index 00000000000..c38b2e21f73 --- /dev/null +++ b/examples/post_training_quantization/openvino/tiny_gpt2/main.py @@ -0,0 +1,64 @@ +from optimum.intel.openvino import OVModelForCausalLM +from transformers import AutoTokenizer + +import nncf + +GENERATION_LENGTH = 20 + + +def fix_ov_model_names_duplicates(ov_model): + names = set() + for op in ov_model.get_ops(): + friendly_name = op.get_friendly_name() + while True: + if friendly_name not in names: + break + friendly_name += "_" + names.add(friendly_name) + op.set_friendly_name(friendly_name) + + +model_id = "hf-internal-testing/tiny-random-gpt2" +# model_id = "hf-internal-testing/tiny-random-GPTNeoModel" +# model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" + +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokens = tokenizer("This is a sample input", return_tensors="pt") + +model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True) + + +def set_ov_model_in_hf_model(hf_model, ov_model): + hf_model.model = ov_model + hf_model.request = ov_model.create_infer_request() + + +def get_custom_forward(ov_model, callback_fn): + hf_model = model_with_pkv + set_ov_model_in_hf_model(hf_model, ov_model) + + def _callback_fn(info): + outputs = {k: v for k, v in zip(info["infer_request"].model_outputs, info["infer_request"].outputs)} + callback_fn(outputs) + + hf_model.request.set_callback(_callback_fn, {"infer_request": hf_model.request}) + + def custom_forward(dataitem): + hf_model.generate(**dataitem, min_length=GENERATION_LENGTH, max_length=GENERATION_LENGTH, num_beams=1) + + return custom_forward + + +def transform_fn(data_item): + return data_item + + +dataset = nncf.CustomInferenceDataset([tokens] * 10, transform_fn, get_custom_forward) + + +# Fix ov model duplicated names: +fix_ov_model_names_duplicates(model_with_pkv.model) +quantized_model = quantized_model = nncf.quantize(model_with_pkv.model, dataset, subset_size=3) + +model_with_pkv.model = quantized_model +model_with_pkv.request = None diff --git a/nncf/__init__.py b/nncf/__init__.py index 65a47864066..31f8bd6683c 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -14,6 +14,7 @@ from nncf.common.logging.logger import disable_logging from nncf.common.logging.logger import set_log_level from nncf.config import NNCFConfig +from nncf.data import CustomInferenceDataset from nncf.data import Dataset from nncf.parameters import DropType from nncf.parameters import ModelType diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 063b9734f3a..2df04b34272 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -35,7 +35,7 @@ def create(model: TModel) -> NNCFGraph: from nncf.onnx.graph.nncf_graph_builder import GraphConverter return GraphConverter.create_nncf_graph(model) - if model_backend == BackendType.OPENVINO: + if model_backend in [BackendType.OPENVINO]: from nncf.openvino.graph.nncf_graph_builder import GraphConverter return GraphConverter.create_nncf_graph(model) diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 9afc180aeda..41bd54c0442 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -10,9 +10,11 @@ # limitations under the License. from abc import ABC from abc import abstractmethod +from collections import defaultdict from itertools import islice from typing import Any, Dict, TypeVar +import numpy as np from tqdm import tqdm from nncf.common.factory import EngineFactory @@ -20,6 +22,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.data.dataset import CustomInferenceDataset from nncf.data.dataset import Dataset TensorType = TypeVar("TensorType") @@ -31,10 +34,13 @@ class StatisticsAggregator(ABC): Base class for statistics collection. """ + STACK_AXIS = 0 + def __init__(self, dataset: Dataset): self.dataset = dataset self.stat_subset_size = None self.statistic_points = StatisticPointsContainer() + self._is_custom_inference = isinstance(dataset, CustomInferenceDataset) def collect_statistics(self, model: TModel) -> None: """ @@ -46,19 +52,37 @@ def collect_statistics(self, model: TModel) -> None: model_transformer = ModelTransformerFactory.create(model) merged_statistics = self._get_merged_statistic_points(self.statistic_points, model) + if self._is_custom_inference: + merged_statistics = self._adapt_collectors(merged_statistics, self.STACK_AXIS) + transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) model_with_outputs = model_transformer.transform(transformation_layout) engine = EngineFactory.create(model_with_outputs) + if self._is_custom_inference: + sequence_container = defaultdict(list) + custom_forward = self.dataset.get_custom_forward( + engine.compiled_model, self._get_callback(model, sequence_container) + ) for input_data in tqdm( islice(self.dataset.get_inference_data(), self.stat_subset_size), total=self.stat_subset_size, desc="Statistics collection", ): - outputs = engine.infer(input_data) - processed_outputs = self._process_outputs(outputs) + if self._is_custom_inference: + custom_forward(input_data) + processed_outputs = {} + for friendly_name, values in sequence_container.items(): + processed_outputs[friendly_name] = self._get_tensor_processor().stack(values, axis=self.STACK_AXIS) + else: + processed_outputs = engine.infer(input_data) + processed_outputs = self._process_outputs(processed_outputs) self._register_statistics(processed_outputs, merged_statistics) + @staticmethod + def _get_callback(model, sequence_container: StatisticPointsContainer): + pass + def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None: """ Register statistic points for statistics collection and recalculates the maximum number samples @@ -115,6 +139,10 @@ def _get_merged_statistic_points( :return: Merged statistic points container bounded with given statistic point container. """ + @staticmethod + def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int): + return statistic_points + @staticmethod @abstractmethod def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]: @@ -124,3 +152,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]: :param outputs: raw model outputs :return: processed model outputs in Dict[str, NNCFTensor] format """ + + @staticmethod + @abstractmethod + def _get_tensor_processor(): + pass diff --git a/nncf/common/tensor_statistics/statistics.py b/nncf/common/tensor_statistics/statistics.py index 194c69ba56e..78964f14285 100644 --- a/nncf/common/tensor_statistics/statistics.py +++ b/nncf/common/tensor_statistics/statistics.py @@ -17,7 +17,7 @@ TensorType = TypeVar("TensorType") -class TensorStatistic(ABC): +class TensorStatisticBase(ABC): """Base class that stores statistic data""" @staticmethod @@ -30,7 +30,7 @@ def __eq__(self, other): pass -class MinMaxTensorStatistic(TensorStatistic): +class MinMaxTensorStatistic(TensorStatisticBase): MIN_STAT = "min_values" MAX_STAT = "max_values" @@ -42,7 +42,7 @@ def __eq__(self, other: "MinMaxTensorStatistic") -> bool: return self.tensor_eq(self.min_values, other.min_values) and self.tensor_eq(self.max_values, other.max_values) -class MeanTensorStatistic(TensorStatistic): +class MeanTensorStatistic(TensorStatisticBase): MEAN_STAT = "mean_values" SHAPE_STAT = "shape" @@ -62,7 +62,7 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool: return self.tensor_eq(self.mean_values, other.mean_values) and self.tensor_eq(self.shape, other.shape) -class MedianMADTensorStatistic(TensorStatistic): +class MedianMADTensorStatistic(TensorStatisticBase): def __init__(self, median_values, mad_values): self.median_values = median_values self.mad_values = mad_values @@ -73,7 +73,7 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool: ) -class PercentileTensorStatistic(TensorStatistic): +class PercentileTensorStatistic(TensorStatisticBase): def __init__(self, percentile_vs_values_dict): self.percentile_vs_values_dict = percentile_vs_values_dict @@ -86,7 +86,7 @@ def __eq__(self, other: "PercentileTensorStatistic", rtol=1e-9) -> bool: return True -class BatchTensorStatistic(TensorStatistic): +class BatchTensorStatistic(TensorStatisticBase): VALUES_STATS = "values" """ diff --git a/nncf/data/__init__.py b/nncf/data/__init__.py index 97fb8e39dfd..94ef9ee282b 100644 --- a/nncf/data/__init__.py +++ b/nncf/data/__init__.py @@ -9,4 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nncf.data.dataset import CustomInferenceDataset from nncf.data.dataset import Dataset diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index 6bf1322c2ec..7ed84942af2 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Generic, Iterable, List, Optional, TypeVar +from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar from nncf.common.utils.api_marker import api @@ -115,3 +115,10 @@ def _get_iterator_for_iter( if idx == indices[pos]: pos = pos + 1 yield transform_func(data_item) + + +@api(canonical_alias="nncf.RecurentDataset") +class CustomInferenceDataset(Dataset): + def __init__(self, data_source: Iterable, transform_fn, get_custom_forward_fn): + super().__init__(data_source, transform_fn) + self.get_custom_forward = get_custom_forward_fn diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index e175d550b99..98d708a8891 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -11,6 +11,7 @@ from abc import ABC from abc import abstractmethod +from abc import abstractstaticmethod from collections import defaultdict from collections import deque from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union @@ -19,28 +20,15 @@ from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor from nncf.common.tensor_statistics.collectors import NNCFTensor from nncf.common.tensor_statistics.collectors import ReductionShape -from nncf.common.tensor_statistics.statistics import TensorStatistic from nncf.quantization.advanced_parameters import AggregatorType InplaceInsertionFNType = TypeVar("InplaceInsertionFNType") -class TensorReducerBase(ABC): - """ - Tensor reducer is a callable object that reduces tensors according to - the specified rule. Could handle tensors inplace or out of place. - """ - - def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False): - """ - :param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape))) - if empty. - :param: Wheather should be calculated inplace or out of place. - - """ - self._reduction_shape = reduction_shape - self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor() +class TensorStatisticBase(ABC): + def __init__(self, inplace: bool): self._inplace = inplace + self._tensor_processor = self._get_processor() @property def inplace(self): @@ -54,19 +42,10 @@ def output_port_id(self) -> int: def name(self): return self.__class__.__name__ + str(self.__hash__()) - @staticmethod - @abstractmethod + @abstractstaticmethod def _get_processor() -> NNCFCollectorTensorProcessor: pass - @abstractmethod - def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]: - """ - Specifies the reduction rule in terms of NNCFCollectorTensorProcessor. - - :param x: Tensor to register. - """ - @abstractmethod def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: """ @@ -86,15 +65,74 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: :return: Inplace operation builder if possible else None. """ - def __call__(self, x: List[NNCFTensor]): + @abstractstaticmethod + def __call__(self, x: List[NNCFTensor], adapter) -> Any: + pass + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.inplace)) + + def __eq__(self, __o: object) -> bool: + return isinstance(__o, self.__class__) and self._inplace == __o.inplace + + +class TensorReducerBase(TensorStatisticBase, ABC): + """ + Tensor reducer is a callable object that reduces tensors according to + the specified rule. Could handle tensors inplace or out of place. + """ + + def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True): + """ + :param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape))) + if empty. + :param inplace: Wheather should be calculated inplace or out of place. + """ + super().__init__(inplace) + self._reduction_shape = reduction_shape + self._keepdims = keepdims + + @abstractstaticmethod + def _reduce_out_of_place(x: List[NNCFTensor], **kwargs) -> List[NNCFTensor]: + """ + Specifies the reduction rule in terms of NNCFCollectorTensorProcessor. + + :param x: Tensor to register. + """ + + def _reduce_default(self, x: List[NNCFTensor]) -> List[NNCFTensor]: if self.inplace: return x - - return self._reduce_out_of_place(x) + kwargs = self._get_kwargs() + kwargs["reduction_shape"] = self._get_reduction_shape(x[0]) + return self._reduce_out_of_place(x, **kwargs) + + def _reduce_sequential(self, x: List[NNCFTensor], stack_axis: Optional[int] = None) -> List[NNCFTensor]: + kwargs = self._get_kwargs() + if not self.inplace: + kwargs["reduction_shape"] = self._get_reduction_shape(x[0]) + x = self._reduce_out_of_place(x, **kwargs) + kwargs.update({"reduction_shape": stack_axis, "keepdims": False}) + return self._reduce_out_of_place(x, **kwargs) + + def __call__(self, x: List[NNCFTensor], adapter) -> Any: + # TODO: remove isinstance, use something else + # like registry + if isinstance(adapter, SequentialTensorCollectorAdapter): + return self._reduce_sequential(x, adapter.stack_axis) + return self._reduce_default(x) + + def _get_kwargs(self): + return { + "reduction_shape": self._reduction_shape, + "tensor_processor": self._tensor_processor, + "keepdims": self._keepdims, + } def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) + and self._keepdims == __o._keepdims and self._reduction_shape == __o._reduction_shape and self._inplace == __o.inplace ) @@ -108,6 +146,55 @@ def _get_reduction_shape(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...] return tuple(range(len(tensor.shape))) +class TensorStatisticsSequence(TensorStatisticBase): + def __init__(self, *args): + if any(statistic.inplace for statistic in args[1:]): + raise RuntimeError(f"Only first statistic of sequential tensor statistic could not be inplace.") + self._statistics = args + + @property + def inplace(self): + return self._statistics[0].inplace + + @property + def output_port_id(self) -> int: + return self._statistics[0].output_port_id + + @property + def name(self): + name = "" + for i, statistic in enumerate(self._statistics): + name += f"{i}_{statistic.name}" + return name + + def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: + return self._statistics[0].get_output_names(target_node_name, port_id) + + def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: + return self._statistics[0].get_inplace_fn() + + def __call__(self, x: List[NNCFTensor], adapter): + # Ignore feeded adapter + # Could be configurable + adapter = DefaultTensorCollectorAdapter() + if not self._statistics[0].inplace: + x = self._statistics[0](x, adapter) + + for statistic in self._statistics[1:]: + x = statistic(x, adapter) + return x + + def __eq__(self, __o: object) -> bool: + return ( + isinstance(__o, self.__class__) + and len(self._statistics) == len(__o._statistics) + and all(self_r == o_r for self_r, o_r in zip(self._statistics, __o._statistics)) + ) + + def __hash__(self) -> int: + return hash(tuple(hash(statistic) for statistic in self._statistics)) + + class TensorAggregatorBase: """ Tensor aggregator is designed to recieve (register) calculated statistics and @@ -131,14 +218,14 @@ def __init__(self, tensor_processor: NNCFCollectorTensorProcessor, num_samples: def num_samples(self) -> int: return self._num_samples - def register_reduced_input(self, x: TensorType): + def register_reduced_input(self, x: NNCFTensor, adapter): if self._num_samples is not None and self._collected_samples >= self._num_samples: return - self._register_reduced_input_impl(x) + self._register_reduced_input_impl(x, adapter) self._collected_samples += 1 @abstractmethod - def _register_reduced_input_impl(self, x: TensorType) -> None: + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: """ Registers incoming tensor in tensor aggregator. @@ -164,6 +251,20 @@ def __hash__(self) -> int: return hash(self.__class__.__name__) +class BaseTensorCollectorAdapter: + pass + + +class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter): + pass + + +class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter): + def __init__(self, stack_axis: int) -> None: + super().__init__() + self.stack_axis = stack_axis + + class TensorCollector: """ Calculates statistics at given tensors according to registered statistic branches. @@ -174,12 +275,13 @@ class TensorCollector: a dict could be collected by `get_statistics` call. """ - def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> None: + def __init__(self, statistic_container: Optional[TensorStatisticBase] = None) -> None: self._reducers: Set[TensorReducerBase] = set() self._aggregators: Dict[Tuple[int, int], TensorAggregatorBase] = {} self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {} self._stat_container = statistic_container self._enabled = True + self._adapter = DefaultTensorCollectorAdapter() @property def num_samples(self) -> Optional[int]: @@ -203,6 +305,9 @@ def reducers(self): def aggregators(self): return self._aggregators.copy() + def set_adapter(self, adapter: "BaseTensorCollectorAdapter"): + self._adapter = adapter + def enable(self): self._enabled = True @@ -270,14 +375,14 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None: input_ = inputs[reducer_hash] if any([tensor.is_empty() for tensor in input_]): continue - reduced_inputs[reducer_hash] = reducer(input_) + reduced_inputs[reducer_hash] = reducer(input_, self._adapter) for ( (reducer_hash, reducer_port_id, _), aggregator, ) in self._aggregators.items(): if reducer_hash in reduced_inputs: - aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id]) + aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], self._adapter) def _aggregate(self) -> None: result = {} @@ -289,7 +394,7 @@ def _aggregate(self) -> None: result[key] = val return result - def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]: + def get_statistics(self) -> Union[TensorStatisticBase, Dict[str, Any]]: """ Returns aggregated values in format of a TensorStatistic instance or a dict. @@ -395,7 +500,7 @@ def __init__(self, tensor_collectors: List[TensorCollector]) -> None: ##################################################Reducers################################################## -class NoopReducer(TensorReducerBase): +class NoopStatistic(TensorStatisticBase): def __init__(self): super().__init__(inplace=False) @@ -406,36 +511,53 @@ def _get_processor() -> NNCFCollectorTensorProcessor: def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: return None - def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]: + def __call__(self, x: List[NNCFTensor], adapter) -> Any: return x class MinReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_min(x, reduction_shape, keepdims=True)] + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.reduce_min(x[0], reduction_shape, keepdims=keepdims)] class MaxReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.reduce_max(x[0], reduction_shape, keepdims=keepdims)] class AbsMaxReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = self._tensor_processor.abs(x[0]) - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + x = tensor_processor.abs(x[0]) + return [tensor_processor.reduce_max(x, reduction_shape, keepdims=keepdims)] class MeanReducer(TensorReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.mean(x, reduction_shape, keepdims=True)] + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + keepdims: bool, + ) -> List[NNCFTensor]: + return [tensor_processor.mean(x[0], reduction_shape, keepdims=keepdims)] class QuantileReducerBase(TensorReducerBase): @@ -444,22 +566,41 @@ def __init__( reduction_shape: Optional[ReductionShape] = None, quantile: Union[float, List[float]] = [0.01, 0.99], inplace: bool = False, + keepdims: bool = True, ): - super().__init__(reduction_shape, False) + super().__init__(reduction_shape, False, keepdims) self._quantile = quantile + def _get_kwargs(self): + kwargs = super()._get_kwargs() + kwargs["quantile"] = self._quantile + return kwargs + def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._quantile == __o._quantile def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape, tuple(self._quantile))) + return hash( + ( + self.__class__.__name__, + self.inplace, + self._reduction_shape, + self._keepdims, + tuple(self._quantile), + ) + ) class QuantileReducer(QuantileReducerBase): - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = x[0] - reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=True) + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + quantile: Union[float, List[float]], + keepdims: bool, + ) -> List[NNCFTensor]: + return tensor_processor.quantile(x[0], quantile, reduction_shape, keepdims=keepdims) class AbsQuantileReducer(QuantileReducerBase): @@ -471,26 +612,39 @@ def __init__( ): super().__init__(reduction_shape, quantile, False) - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - x = self._tensor_processor.abs(x[0]) - reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, [self._quantile], reduction_shape, keepdims=True) - - -class BatchMeanReducer(TensorReducerBase): + @staticmethod + def _reduce_out_of_place( + x: List[NNCFTensor], + tensor_processor: NNCFCollectorTensorProcessor, + reduction_shape: Tuple[int, ...], + quantile: Union[float, List[float]], + keepdims: bool, + ) -> List[NNCFTensor]: + x = tensor_processor.abs(x[0]) + return tensor_processor.quantile(x, [quantile], reduction_shape, keepdims=keepdims) + + +class BatchMeanStatistic(TensorStatisticBase): def __init__(self, inplace: bool = False): - super().__init__(None, inplace) + super().__init__(inplace) - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: + def __call__(self, x: List[NNCFTensor], adapter) -> List[NNCFTensor]: return [self._tensor_processor.batch_mean(x[0])] -class MeanPerChReducer(TensorReducerBase): +class MeanPerChReducer(TensorStatisticBase): def __init__(self, channel_dim: int = 1, inplace: bool = False): - super().__init__(channel_dim, inplace) + super().__init__(inplace) + self._channel_dim = channel_dim + + def __call__(self, x: List[NNCFTensor], adapter) -> List[NNCFTensor]: + return [self._tensor_processor.mean_per_channel(x[0], self._channel_dim)] - def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - return [self._tensor_processor.mean_per_channel(x[0], self._reduction_shape)] + def __hash__(self) -> int: + return hash((self.__class__.__name__, self._inplace, self._channel_dim)) + + def __eq__(self, __o: object) -> bool: + return super().__eq__(__o) and self._channel_dim == __o._channel_dim ##################################################Aggregators################################################## @@ -500,7 +654,7 @@ class NoopAggregator(TensorAggregatorBase): def __init__(self, num_samples: Optional[int]): super().__init__(None, num_samples) - def _register_reduced_input_impl(self, x: TensorType) -> None: + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: self._container.append(x.tensor) def aggregate(self): @@ -508,18 +662,23 @@ def aggregate(self): class ShapeAggregator(TensorAggregatorBase): - def __init__(self): - super().__init__(None, 1) + def __init__(self, tensor_processor: NNCFCollectorTensorProcessor): + super().__init__(tensor_processor, 1) - def _register_reduced_input_impl(self, x: TensorType) -> None: - self._container = x + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: + if isinstance(adapter, SequentialTensorCollectorAdapter): + self._container = self._tensor_processor.unstack(x)[0] + elif isinstance(adapter, DefaultTensorCollectorAdapter): + self._container = x + else: + raise RuntimeError() def aggregate(self): return self._container.shape class MinAggregator(TensorAggregatorBase): - def _register_reduced_input_impl(self, x: TensorType) -> None: + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: if not self._container: self._container = x else: @@ -530,7 +689,7 @@ def aggregate(self): class MaxAggregator(TensorAggregatorBase): - def _register_reduced_input_impl(self, x: TensorType) -> None: + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: if not self._container: self._container = x else: @@ -549,7 +708,7 @@ def __init__( self._container = deque(maxlen=window_size) self._use_per_sample_stats = use_per_sample_stats - def _register_reduced_input_impl(self, x: TensorType) -> None: + def _register_reduced_input_impl(self, x: TensorType, adapter) -> None: if self._use_per_sample_stats: self._container.extend(self._tensor_processor.unstack(x)) else: diff --git a/nncf/onnx/statistics/aggregator.py b/nncf/onnx/statistics/aggregator.py index e48267ac136..3253dd3ec2b 100644 --- a/nncf/onnx/statistics/aggregator.py +++ b/nncf/onnx/statistics/aggregator.py @@ -79,3 +79,8 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, ONNXNNCFTensor]: return {n: ONNXNNCFTensor(v) for n, v in outputs.items()} + + def _get_tensor_processor(): + from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor + + return ONNXNNCFCollectorTensorProcessor diff --git a/nncf/openvino/statistics/aggregator.py b/nncf/openvino/statistics/aggregator.py index 60c7bab3e80..f31aead486e 100644 --- a/nncf/openvino/statistics/aggregator.py +++ b/nncf/openvino/statistics/aggregator.py @@ -21,10 +21,12 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector +from nncf.experimental.common.tensor_statistics.collectors import SequentialTensorCollectorAdapter from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand +from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.tensor import OVNNCFTensor @@ -72,6 +74,14 @@ def _get_transformation_layout_extra_outputs( return transformation_layout + @staticmethod + # TODO(dlyakhov) Move this to common part + def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int): + for _, _, tensor_collector in statistic_points.get_tensor_collectors(): + tensor_collector.set_adapter(SequentialTensorCollectorAdapter(stack_axis)) + + return statistic_points + @staticmethod # TODO(dlyakhov) Move this to common part def _get_merged_statistic_points( @@ -109,6 +119,24 @@ def _get_merged_statistic_points( merged_statistic_points.add_statistic_point(stat_point) return merged_statistic_points + @staticmethod + def _get_callback(model, sequence_container): + original_model_outputs_names = {op.node.friendly_name for op in model.outputs} + + def complition_callback(outputs): + for op, value in outputs.items(): + if op.node.friendly_name in original_model_outputs_names: + continue + if not isinstance(value, np.ndarray): + value = value.data + sequence_container[op.node.friendly_name].append(OVNNCFTensor(value)) + + return complition_callback + @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, OVNNCFTensor]: return {n: OVNNCFTensor(v) for n, v in outputs.items()} + + @staticmethod + def _get_tensor_processor(): + return OVNNCFCollectorTensorProcessor diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 99c48ea483e..fee5c7ea8fe 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -18,7 +18,7 @@ from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer -from nncf.experimental.common.tensor_statistics.collectors import BatchMeanReducer +from nncf.experimental.common.tensor_statistics.collectors import BatchMeanStatistic from nncf.experimental.common.tensor_statistics.collectors import InplaceInsertionFNType from nncf.experimental.common.tensor_statistics.collectors import MaxReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator @@ -26,7 +26,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanReducer from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer +from nncf.experimental.common.tensor_statistics.collectors import NoopStatistic from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector @@ -147,7 +147,7 @@ def quantile( return [OVNNCFTensor(x) for x in result] -class OVNoopReducer(NoopReducer): +class OVNoopStatistic(NoopStatistic): def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return [get_result_node_name(target_node_name, port_id)] @@ -196,7 +196,7 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) -class OVBatchMeanReducer(BatchMeanReducer): +class OVBatchMeanReducer(BatchMeanStatistic): def _get_processor(self): return OVNNCFCollectorTensorProcessor @@ -248,7 +248,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace reducer = OVBatchMeanReducer(inplace) else: reducer = OVMeanPerChanelReducer(channel_axis, inplace) - noop_reducer = OVNoopReducer() + noop_reducer = OVNoopStatistic() kwargs = { "tensor_processor": OVNNCFCollectorTensorProcessor, @@ -257,7 +257,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace "window_size": window_size, } aggregate_mean = MeanAggregator(**kwargs) - aggregate_shape = ShapeAggregator() + aggregate_shape = ShapeAggregator(tensor_processor=OVNNCFCollectorTensorProcessor) collector = TensorCollector(OVMeanTensorStatistic) collector.register_statistic_branch(OVMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 953fda6d579..e4d8001c62d 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -212,7 +212,7 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend self._backend_entity = ONNXMinMaxAlgoBackend() - elif model_backend == BackendType.OPENVINO: + elif model_backend in [BackendType.OPENVINO]: from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend self._backend_entity = OVMinMaxAlgoBackend() diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 60ab4b02edf..efddabdf117 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -156,7 +156,7 @@ def _create_statistics_aggregator(self, dataset: Dataset, backend: BackendType) from nncf.onnx.statistics.aggregator import ONNXStatisticsAggregator return ONNXStatisticsAggregator(dataset) - if backend == BackendType.OPENVINO: + if backend in [BackendType.OPENVINO]: from nncf.openvino.statistics.aggregator import OVStatisticsAggregator return OVStatisticsAggregator(dataset) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 44d8ba84bfe..b09519254a9 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -72,7 +72,7 @@ def quantize( :rtype: TModel """ backend = get_backend(model) - if backend == BackendType.OPENVINO: + if backend in [BackendType.OPENVINO]: from nncf.openvino.quantization.quantize_model import quantize_impl return quantize_impl( diff --git a/nncf/tensorflow/tensor_statistics/statistics.py b/nncf/tensorflow/tensor_statistics/statistics.py index 724c71019e4..f26fed7e1d9 100644 --- a/nncf/tensorflow/tensor_statistics/statistics.py +++ b/nncf/tensorflow/tensor_statistics/statistics.py @@ -14,7 +14,7 @@ from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic -from nncf.common.tensor_statistics.statistics import TensorStatistic +from nncf.common.tensor_statistics.statistics import TensorStatisticBase class TFMinMaxTensorStatistic(MinMaxTensorStatistic): @@ -35,7 +35,7 @@ def tensor_eq(tensor1: tf.Tensor, tensor2: tf.Tensor, rtol=1e-6) -> bool: return bool(tf.experimental.numpy.allclose(tensor1, tensor2, rtol=rtol)) -def tf_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> TFMinMaxTensorStatistic: +def tf_convert_stat_to_min_max_tensor_stat(statistic: TensorStatisticBase) -> TFMinMaxTensorStatistic: if isinstance(statistic, TFMinMaxTensorStatistic): return statistic if isinstance(statistic, TFMedianMADTensorStatistic): diff --git a/nncf/torch/quantization/algo.py b/nncf/torch/quantization/algo.py index 64b1966503e..f96f9ee4854 100644 --- a/nncf/torch/quantization/algo.py +++ b/nncf/torch/quantization/algo.py @@ -141,7 +141,7 @@ from nncf.torch.tensor_statistics.algo import TensorStatisticsCollectionBuilder from nncf.torch.tensor_statistics.collectors import ReductionShape from nncf.torch.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.torch.tensor_statistics.statistics import TensorStatistic +from nncf.torch.tensor_statistics.statistics import TensorStatisticBase from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat from nncf.torch.utils import get_model_device from nncf.torch.utils import get_model_dtype @@ -583,7 +583,7 @@ def _parse_precision_init_params(self, initializer_config: Dict) -> Tuple[str, B def _get_minmax_values_for_quantizer_locations( self, quantizer_setup: SingleConfigQuantizerSetup, - tensor_statistics: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]], + tensor_statistics: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]], target_model_graph: PTNNCFGraph, ) -> Dict[QuantizationPointId, MinMaxTensorStatistic]: retval = {} @@ -663,7 +663,7 @@ def _get_transformation_layout(self, target_model: NNCFNetwork) -> PTTransformat @staticmethod def get_statistics_for_quantizer_setup( target_model: NNCFNetwork, quantizer_setup: QuantizerSetupBase, range_init_params: PTRangeInitParams - ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]]: + ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]]: if range_init_params is None: return {} observation_points_vs_collectors_dict = ( @@ -687,7 +687,7 @@ def get_statistics_for_quantizer_setup( def _get_statistics_for_final_range_init( self, target_model: NNCFNetwork, quantizer_setup: QuantizerSetupBase, range_init_params: PTRangeInitParams - ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]]: + ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]]: return self.get_statistics_for_quantizer_setup(target_model, quantizer_setup, range_init_params) def _get_single_config_quantizer_setup(self, target_model) -> SingleConfigQuantizerSetup: @@ -1670,7 +1670,7 @@ def __init__( self, quantizer_setup: MultiConfigQuantizerSetup, initial_quantizer_setup: SingleConfigQuantizerSetup, - tensor_stats_for_all_setup_variations: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]], + tensor_stats_for_all_setup_variations: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]], hw_config: HWConfig = None, ): should_init = bool(tensor_stats_for_all_setup_variations) @@ -1689,7 +1689,7 @@ def _get_single_config_quantizer_setup(self, target_model) -> SingleConfigQuanti def _get_statistics_for_final_range_init( self, target_model: NNCFNetwork, quantizer_setup: QuantizerSetupBase, range_init_params: PTRangeInitParams - ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]]: + ) -> Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]]: return self._tensor_stats def _build_controller(self, model: NNCFNetwork) -> "ExperimentalQuantizationController": @@ -1745,7 +1745,7 @@ def __init__( quantizer_setup: MultiConfigQuantizerSetup, initial_quantizer_setup: SingleConfigQuantizerSetup, setup_to_module_id_translation_dict: Dict[QuantizationPointId, QuantizerId], - tensor_stats: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatistic]], + tensor_stats: Dict[PTTargetPoint, Dict[ReductionShape, TensorStatisticBase]], build_time_metric_info: QuantizationShareBuildTimeInfo, should_setup_adjust_pad_ops=False, hw_config: HWConfig = None, diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 4c57699cd3c..a3eb21cb0d4 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -67,3 +67,8 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, PTNNCFTensor]: return outputs + + def _get_tensor_processor(): + from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor + + return PTNNCFCollectorTensorProcessor diff --git a/nncf/torch/tensor_statistics/statistics.py b/nncf/torch/tensor_statistics/statistics.py index 325e1b27e81..aa37377ca35 100644 --- a/nncf/torch/tensor_statistics/statistics.py +++ b/nncf/torch/tensor_statistics/statistics.py @@ -14,7 +14,7 @@ from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic -from nncf.common.tensor_statistics.statistics import TensorStatistic +from nncf.common.tensor_statistics.statistics import TensorStatisticBase class PTMinMaxTensorStatistic(MinMaxTensorStatistic): @@ -35,7 +35,7 @@ def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) -def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> PTMinMaxTensorStatistic: +def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatisticBase) -> PTMinMaxTensorStatistic: if isinstance(statistic, PTMinMaxTensorStatistic): return statistic if isinstance(statistic, PTMedianMADTensorStatistic): diff --git a/tests/tensorflow/tensor_statistics/test_tensor_statistics.py b/tests/tensorflow/tensor_statistics/test_tensor_statistics.py index e99487c69d9..913a988460f 100644 --- a/tests/tensorflow/tensor_statistics/test_tensor_statistics.py +++ b/tests/tensorflow/tensor_statistics/test_tensor_statistics.py @@ -19,7 +19,7 @@ from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.tensor_statistics.collectors import StatisticsNotCollectedError from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase -from nncf.common.tensor_statistics.statistics import TensorStatistic +from nncf.common.tensor_statistics.statistics import TensorStatisticBase from nncf.tensorflow.tensor import TFNNCFTensor from nncf.tensorflow.tensor_statistics.collectors import TFMeanMinMaxStatisticCollector from nncf.tensorflow.tensor_statistics.collectors import TFMeanPercentileStatisticCollector @@ -101,7 +101,7 @@ class TestCollectedStatistics: def test_collected_statistics_with_shape_convert( self, collector: Type[TensorStatisticCollectorBase], - reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatistic], + reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatisticBase], ): for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): collector_obj = collector(use_abs_max=True, reduction_shape=reduction_shape) @@ -179,7 +179,7 @@ def test_collected_statistics_with_shape_convert( def test_collected_statistics( self, collector: Type[TensorStatisticCollectorBase], - reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatistic], + reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatisticBase], ): for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): collector_obj = collector(reduction_shape=reduction_shape) diff --git a/tests/torch/tensor_statistics/test_tensor_statistics.py b/tests/torch/tensor_statistics/test_tensor_statistics.py index 5c5cccc0220..cc5fc7463e1 100644 --- a/tests/torch/tensor_statistics/test_tensor_statistics.py +++ b/tests/torch/tensor_statistics/test_tensor_statistics.py @@ -19,7 +19,7 @@ from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.tensor_statistics.collectors import StatisticsNotCollectedError from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase -from nncf.common.tensor_statistics.statistics import TensorStatistic +from nncf.common.tensor_statistics.statistics import TensorStatisticBase from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector @@ -109,7 +109,7 @@ class TestCollectedStatistics: def test_collected_statistics_with_shape_convert( self, collector: Type[TensorStatisticCollectorBase], - reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatistic], + reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatisticBase], ): for shapes in reduction_shapes_vs_ref_statistic.keys(): output_shape, reduction_shape = shapes @@ -189,7 +189,7 @@ def test_collected_statistics_with_shape_convert( def test_collected_statistics( self, collector: Type[TensorStatisticCollectorBase], - reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatistic], + reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatisticBase], ): for shapes in reduction_shapes_vs_ref_statistic.keys(): reduction_shape = shapes