Skip to content

Commit

Permalink
TensorCollectorAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 16, 2023
1 parent 0b65576 commit 54732d0
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 66 deletions.
97 changes: 53 additions & 44 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,41 +136,6 @@ def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._init_reduction_shape, self._keepdims))


class BaseReducerAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass


class DefaultReducerAdapter(BaseReducerAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)


class SequentialReducerAdapter(DefaultReducerAdapter):
def __init__(self, stack_axis: int) -> None:
super().__init__()
self._stack_axis = stack_axis

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)


class TensorReducersSequence(TensorReducerInterface):
def __init__(self, *args):
if any(reducer.inplace for reducer in args[1:]):
Expand Down Expand Up @@ -273,6 +238,54 @@ def __hash__(self) -> int:
return hash(self.__class__.__name__)


class BaseTensorCollectorAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass

@abstractmethod
def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
pass


class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
aggregator.register_reduced_input(x)


class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter):
def __init__(self, stack_axis: int, tensor_processor: NNCFCollectorTensorProcessor) -> None:
super().__init__()
self._stack_axis = stack_axis
self._tensor_processor = tensor_processor

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
if isinstance(aggregator, ShapeAggregator):
x = self._tensor_processor.unstack(x, axis=0)[0]
super().register_reduced_input(x, aggregator)


class TensorCollector:
"""
Calculates statistics at given tensors according to registered statistic branches.
Expand All @@ -289,7 +302,7 @@ def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> Non
self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {}
self._stat_container = statistic_container
self._enabled = True
self._adapter = DefaultReducerAdapter()
self._adapter = DefaultTensorCollectorAdapter()

@property
def num_samples(self) -> Optional[int]:
Expand All @@ -313,7 +326,7 @@ def reducers(self):
def aggregators(self):
return self._aggregators.copy()

def set_adapter(self, adapter: "BaseReducerAdapter"):
def set_adapter(self, adapter: "BaseTensorCollectorAdapter"):
self._adapter = adapter

def enable(self):
Expand Down Expand Up @@ -390,7 +403,7 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None:
aggregator,
) in self._aggregators.items():
if reducer_hash in reduced_inputs:
aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id])
self._adapter.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], aggregator)

def _aggregate(self) -> None:
result = {}
Expand Down Expand Up @@ -683,18 +696,14 @@ def aggregate(self):


class ShapeAggregator(TensorAggregatorBase):
def __init__(self, slice_=None):
def __init__(self):
super().__init__(None, 1)
self._slice = slice_

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def aggregate(self):
shape = self._container.shape
if self._slice is not None:
return shape[self._slice]
return shape
return self._container.shape


class MinAggregator(TensorAggregatorBase):
Expand Down
2 changes: 0 additions & 2 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def native_quantize_impl(
"""
Implementation of the `quantize()` method for the OpenVINO backend via the OpenVINO Runtime API.
"""
if isinstance(calibration_dataset, RecurentDataset):
model_type = ModelType.SEQUENTIAL

quantization_algorithm = PostTrainingQuantization(
preset=preset,
Expand Down
5 changes: 3 additions & 2 deletions nncf/openvino/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector
from nncf.experimental.common.tensor_statistics.collectors import SequentialReducerAdapter
from nncf.experimental.common.tensor_statistics.collectors import SequentialTensorCollectorAdapter
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor
from nncf.openvino.tensor import OVNNCFTensor


Expand Down Expand Up @@ -77,7 +78,7 @@ def _get_transformation_layout_extra_outputs(
# TODO(dlyakhov) Move this to common part
def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int):
for _, _, tensor_collector in statistic_points.get_tensor_collectors():
tensor_collector.set_adapter(SequentialReducerAdapter(stack_axis))
tensor_collector.set_adapter(SequentialTensorCollectorAdapter(stack_axis, OVNNCFCollectorTensorProcessor))

return statistic_points

Expand Down
4 changes: 2 additions & 2 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)


def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True, model_type=None):
def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True):
# TODO(dlyakhov): use inplace OVBatchMeanReducer and OVMeanPerChanelReducer
# after migration on openvino-dev=2023.0
inplace = False
Expand All @@ -260,7 +260,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator(slice(1, None) if model_type == ModelType.SEQUENTIAL else None)
aggregate_shape = ShapeAggregator()

collector = TensorCollector(OVMeanTensorStatistic)
collector.register_statistic_branch(OVMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
Expand Down
3 changes: 0 additions & 3 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(
apply_for_all_nodes: bool = False,
inplace_statistics: bool = True,
backend_params: Optional[Dict[str, Any]] = None,
model_type: ModelType = None,
):
"""
:param subset_size: Size of a subset for the statistics collection,
Expand All @@ -93,7 +92,6 @@ def __init__(
self.apply_for_all_nodes = apply_for_all_nodes
self.inplace_statistics = inplace_statistics
self.backend_params = backend_params
self.model_type = model_type

self.nncf_graph = None
self._backend_entity = None
Expand Down Expand Up @@ -483,7 +481,6 @@ def get_statistic_points(self, model: TModel) -> StatisticPointsContainer:
reduction_shape=channel_axis,
num_samples=self.subset_size,
inplace=self.inplace_statistics,
model_type=self.model_type,
)
statistic_container.add_statistic_point(
StatisticPoint(target_point=statistic_point, tensor_collector=stat_collector, algorithm=BiasCorrection)
Expand Down
1 change: 0 additions & 1 deletion nncf/quantization/algorithms/bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace, model_type)
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace)

@staticmethod
def batch_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
apply_for_all_nodes: bool = False,
inplace_statistics: bool = True,
backend_params: Optional[Dict[str, Any]] = None,
model_type: ModelType = None,
):
"""
:param subset_size: Size of a subset for the statistics collection,
Expand All @@ -84,7 +83,6 @@ def __init__(
self.apply_for_all_nodes = apply_for_all_nodes
self.inplace_statistics = inplace_statistics
self.backend_params = backend_params
self.model_type = model_type
self.nncf_graph = None
self._backend_entity = None

Expand Down Expand Up @@ -258,7 +256,6 @@ def _add_statistic_point(self, container: StatisticPointsContainer, point: Targe
reduction_shape=axis,
num_samples=self.subset_size,
inplace=self.inplace_statistics,
model_type=self.model_type,
)
container.add_statistic_point(
StatisticPoint(target_point=point, tensor_collector=stat_collector, algorithm=FastBiasCorrection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
model_type: Optional[ModelType] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace, model_type)
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace)

@staticmethod
def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]:
Expand Down
2 changes: 0 additions & 2 deletions nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(
apply_for_all_nodes=bias_correction_params.apply_for_all_nodes,
inplace_statistics=advanced_parameters.inplace_statistics,
backend_params=advanced_parameters.backend_params,
model_type=model_type,
)
else:
threshold = BIAS_CORRECTION_THRESHOLD
Expand All @@ -127,7 +126,6 @@ def __init__(
apply_for_all_nodes=bias_correction_params.apply_for_all_nodes,
inplace_statistics=advanced_parameters.inplace_statistics,
backend_params=advanced_parameters.backend_params,
model_type=model_type,
)

self.algorithms.append(bias_correction)
Expand Down

0 comments on commit 54732d0

Please sign in to comment.