Skip to content

Commit

Permalink
TensorCollectorAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 16, 2023
1 parent 0b65576 commit 239474d
Show file tree
Hide file tree
Showing 19 changed files with 82 additions and 187 deletions.
211 changes: 73 additions & 138 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,52 +27,7 @@
InplaceInsertionFNType = TypeVar("InplaceInsertionFNType")


class TensorReducerInterface(ABC):
@abstractproperty
def inplace(self):
pass

@abstractproperty
def output_port_id(self) -> int:
pass

@abstractproperty
def name(self):
pass

@abstractmethod
def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
"""
Returns target output names from target model that is
modified for statistic collection.
:param target_node_name: Target node name for reducer.
:param port_id: Target port id for target node name for reducer.
:return: Target output names for reducer.
"""

@abstractmethod
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""
Returns correspondent inplace operation builder if inplace operations are available in backend.
:return: Inplace operation builder if possible else None.
"""

@abstractmethod
def __call__(self, x: List[NNCFTensor]):
pass

@abstractmethod
def __eq__(self, __o: object) -> bool:
pass

@abstractmethod
def __hash__(self) -> int:
pass


class TensorReducerBase(TensorReducerInterface, ABC):
class TensorReducerBase(ABC):
"""
Tensor reducer is a callable object that reduces tensors according to
the specified rule. Could handle tensors inplace or out of place.
Expand Down Expand Up @@ -106,6 +61,25 @@ def name(self):
def _get_processor() -> NNCFCollectorTensorProcessor:
pass

@abstractmethod
def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
"""
Returns target output names from target model that is
modified for statistic collection.
:param target_node_name: Target node name for reducer.
:param port_id: Target port id for target node name for reducer.
:return: Target output names for reducer.
"""

@abstractmethod
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""
Returns correspondent inplace operation builder if inplace operations are available in backend.
:return: Inplace operation builder if possible else None.
"""

@abstractstaticmethod
def reduce_out_of_place(x: List[TensorType], **kwargs) -> List[TensorType]:
"""
Expand All @@ -121,9 +95,6 @@ def get_kwargs(self):
"keepdims": self._keepdims,
}

def __call__(self, x: List[NNCFTensor]):
return self._adapter.reduce(x, self)

def __eq__(self, __o: object) -> bool:
return (
isinstance(__o, self.__class__)
Expand All @@ -136,86 +107,6 @@ def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._init_reduction_shape, self._keepdims))


class BaseReducerAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass


class DefaultReducerAdapter(BaseReducerAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)


class SequentialReducerAdapter(DefaultReducerAdapter):
def __init__(self, stack_axis: int) -> None:
super().__init__()
self._stack_axis = stack_axis

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)


class TensorReducersSequence(TensorReducerInterface):
def __init__(self, *args):
if any(reducer.inplace for reducer in args[1:]):
raise RuntimeError(f"Only first reducer of sequential tensor reducer could not be inplace.")
self._reducers = args

@property
def inplace(self):
return self._reducers[0].inplace

@property
def output_port_id(self) -> int:
return self._reducers[0].output_port_id

@property
def name(self):
name = ""
for i, reducer in enumerate(self._reducers):
name += f"{i}_{reducer.name}"
return name

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return self._reducers[0].get_output_names(target_node_name, port_id)

def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
return self._reducers[0].get_inplace_fn()

def __call__(self, x: List[NNCFTensor]):
if not self._reducers[0].inplace:
x = self._reducers[0](x)

for reducer in self._reducers[1:]:
x = reducer(x)
return x

def __eq__(self, __o: object) -> bool:
return (
isinstance(__o, self.__class__)
and len(self._reducers) == len(__o.reducers)
and all(self_r == o_r for self_r, o_r in zip(self._reducers, __o.reducers))
)

def __hash__(self) -> int:
return hash(tuple(hash(reducer) for reducer in self._reducers))


class TensorAggregatorBase:
"""
Expand Down Expand Up @@ -273,6 +164,54 @@ def __hash__(self) -> int:
return hash(self.__class__.__name__)


class BaseTensorCollectorAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass

@abstractmethod
def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
pass


class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
aggregator.register_reduced_input(x)


class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter):
def __init__(self, stack_axis: int, tensor_processor: NNCFCollectorTensorProcessor) -> None:
super().__init__()
self._stack_axis = stack_axis
self._tensor_processor = tensor_processor

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
if isinstance(aggregator, ShapeAggregator):
x = self._tensor_processor.unstack(x, axis=0)[0]
super().register_reduced_input(x, aggregator)


class TensorCollector:
"""
Calculates statistics at given tensors according to registered statistic branches.
Expand All @@ -289,7 +228,7 @@ def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> Non
self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {}
self._stat_container = statistic_container
self._enabled = True
self._adapter = DefaultReducerAdapter()
self._adapter = DefaultTensorCollectorAdapter()

@property
def num_samples(self) -> Optional[int]:
Expand All @@ -313,7 +252,7 @@ def reducers(self):
def aggregators(self):
return self._aggregators.copy()

def set_adapter(self, adapter: "BaseReducerAdapter"):
def set_adapter(self, adapter: "BaseTensorCollectorAdapter"):
self._adapter = adapter

def enable(self):
Expand Down Expand Up @@ -390,7 +329,7 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None:
aggregator,
) in self._aggregators.items():
if reducer_hash in reduced_inputs:
aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id])
self._adapter.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], aggregator)

def _aggregate(self) -> None:
result = {}
Expand Down Expand Up @@ -683,18 +622,14 @@ def aggregate(self):


class ShapeAggregator(TensorAggregatorBase):
def __init__(self, slice_=None):
def __init__(self):
super().__init__(None, 1)
self._slice = slice_

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def aggregate(self):
shape = self._container.shape
if self._slice is not None:
return shape[self._slice]
return shape
return self._container.shape


class MinAggregator(TensorAggregatorBase):
Expand Down
4 changes: 0 additions & 4 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.utils.backend import get_backend
from nncf.common.utils.timer import timer
from nncf.data import Dataset
from nncf.data import RecurentDataset
from nncf.openvino.quantization.backend_parameters import BackendParameters
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
from nncf.parameters import DropType
Expand Down Expand Up @@ -103,9 +102,6 @@ def native_quantize_impl(
"""
Implementation of the `quantize()` method for the OpenVINO backend via the OpenVINO Runtime API.
"""
if isinstance(calibration_dataset, RecurentDataset):
model_type = ModelType.SEQUENTIAL

quantization_algorithm = PostTrainingQuantization(
preset=preset,
target_device=target_device,
Expand Down
5 changes: 3 additions & 2 deletions nncf/openvino/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector
from nncf.experimental.common.tensor_statistics.collectors import SequentialReducerAdapter
from nncf.experimental.common.tensor_statistics.collectors import SequentialTensorCollectorAdapter
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor
from nncf.openvino.tensor import OVNNCFTensor


Expand Down Expand Up @@ -77,7 +78,7 @@ def _get_transformation_layout_extra_outputs(
# TODO(dlyakhov) Move this to common part
def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int):
for _, _, tensor_collector in statistic_points.get_tensor_collectors():
tensor_collector.set_adapter(SequentialReducerAdapter(stack_axis))
tensor_collector.set_adapter(SequentialTensorCollectorAdapter(stack_axis, OVNNCFCollectorTensorProcessor))

return statistic_points

Expand Down
5 changes: 2 additions & 3 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from nncf.openvino.statistics.statistics import OVBatchTensorStatistic
from nncf.openvino.statistics.statistics import OVMeanTensorStatistic
from nncf.openvino.tensor import OVNNCFTensor
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import StatisticsType


Expand Down Expand Up @@ -243,7 +242,7 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)


def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True, model_type=None):
def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True):
# TODO(dlyakhov): use inplace OVBatchMeanReducer and OVMeanPerChanelReducer
# after migration on openvino-dev=2023.0
inplace = False
Expand All @@ -260,7 +259,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator(slice(1, None) if model_type == ModelType.SEQUENTIAL else None)
aggregate_shape = ShapeAggregator()

collector = TensorCollector(OVMeanTensorStatistic)
collector.register_statistic_branch(OVMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
Expand Down
1 change: 0 additions & 1 deletion nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class ModelType(Enum):
"""

TRANSFORMER = "transformer"
SEQUENTIAL = "sequential"


@api(canonical_alias="nncf.DropType")
Expand Down
9 changes: 1 addition & 8 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import copy_model
from nncf.common.utils.backend import get_backend
from nncf.parameters import ModelType
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS

Expand Down Expand Up @@ -69,7 +68,6 @@ def __init__(
apply_for_all_nodes: bool = False,
inplace_statistics: bool = True,
backend_params: Optional[Dict[str, Any]] = None,
model_type: ModelType = None,
):
"""
:param subset_size: Size of a subset for the statistics collection,
Expand All @@ -93,8 +91,6 @@ def __init__(
self.apply_for_all_nodes = apply_for_all_nodes
self.inplace_statistics = inplace_statistics
self.backend_params = backend_params
self.model_type = model_type

self.nncf_graph = None
self._backend_entity = None
self._collected_stat_inputs = set()
Expand Down Expand Up @@ -480,10 +476,7 @@ def get_statistic_points(self, model: TModel) -> StatisticPointsContainer:
TargetType.POST_LAYER_OPERATION, node_name, output_port_id
)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_shape=channel_axis,
num_samples=self.subset_size,
inplace=self.inplace_statistics,
model_type=self.model_type,
reduction_shape=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(target_point=statistic_point, tensor_collector=stat_collector, algorithm=BiasCorrection)
Expand Down
Loading

0 comments on commit 239474d

Please sign in to comment.