Skip to content

Commit

Permalink
TensorCollectorAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 16, 2023
1 parent 0b65576 commit 43eb6bc
Show file tree
Hide file tree
Showing 23 changed files with 126 additions and 218 deletions.
32 changes: 30 additions & 2 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from itertools import islice
from typing import Any, Dict, TypeVar

Expand Down Expand Up @@ -56,14 +57,36 @@ def collect_statistics(self, model: TModel) -> None:
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = EngineFactory.create(model_with_outputs)
infer_fn = self._infer_sequential if self._is_sequential else self._infer

for input_data in tqdm(
islice(self.dataset.get_inference_data(), self.stat_subset_size), total=self.stat_subset_size
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
processed_outputs = infer_fn(engine, input_data)
self._register_statistics(processed_outputs, merged_statistics)

def _infer(self, engine, input_data):
outputs = engine.infer(input_data)
return self._process_outputs(outputs)

def _infer_sequential(self, engine, sequence):
model_output = None
model_outputs = defaultdict(list)
for token in sequence.get_tokens_iter():
filled_inputs = sequence.fill_inputs(token, model_output)
model_output = engine.infer(filled_inputs)
processed_output = self._process_outputs(model_output)
for output_name, output_value in processed_output.items():
model_outputs[output_name].append(output_value)

# Stack model outputs and return them
stacked_outputs = {}
tensor_processor = self._get_tensor_processor()
for output_name, output_values in model_outputs.items():
stacked_outputs[output_name] = tensor_processor.stack(output_values, axis=self.STACK_AXIS)

return stacked_outputs

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Register statistic points for statistics collection and recalculates the maximum number samples
Expand Down Expand Up @@ -133,3 +156,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
:param outputs: raw model outputs
:return: processed model outputs in Dict[str, NNCFTensor] format
"""

@staticmethod
@abstractmethod
def _get_tensor_processor():
pass
213 changes: 73 additions & 140 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,52 +27,7 @@
InplaceInsertionFNType = TypeVar("InplaceInsertionFNType")


class TensorReducerInterface(ABC):
@abstractproperty
def inplace(self):
pass

@abstractproperty
def output_port_id(self) -> int:
pass

@abstractproperty
def name(self):
pass

@abstractmethod
def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
"""
Returns target output names from target model that is
modified for statistic collection.
:param target_node_name: Target node name for reducer.
:param port_id: Target port id for target node name for reducer.
:return: Target output names for reducer.
"""

@abstractmethod
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""
Returns correspondent inplace operation builder if inplace operations are available in backend.
:return: Inplace operation builder if possible else None.
"""

@abstractmethod
def __call__(self, x: List[NNCFTensor]):
pass

@abstractmethod
def __eq__(self, __o: object) -> bool:
pass

@abstractmethod
def __hash__(self) -> int:
pass


class TensorReducerBase(TensorReducerInterface, ABC):
class TensorReducerBase(ABC):
"""
Tensor reducer is a callable object that reduces tensors according to
the specified rule. Could handle tensors inplace or out of place.
Expand Down Expand Up @@ -106,6 +61,25 @@ def name(self):
def _get_processor() -> NNCFCollectorTensorProcessor:
pass

@abstractmethod
def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
"""
Returns target output names from target model that is
modified for statistic collection.
:param target_node_name: Target node name for reducer.
:param port_id: Target port id for target node name for reducer.
:return: Target output names for reducer.
"""

@abstractmethod
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""
Returns correspondent inplace operation builder if inplace operations are available in backend.
:return: Inplace operation builder if possible else None.
"""

@abstractstaticmethod
def reduce_out_of_place(x: List[TensorType], **kwargs) -> List[TensorType]:
"""
Expand All @@ -121,9 +95,6 @@ def get_kwargs(self):
"keepdims": self._keepdims,
}

def __call__(self, x: List[NNCFTensor]):
return self._adapter.reduce(x, self)

def __eq__(self, __o: object) -> bool:
return (
isinstance(__o, self.__class__)
Expand All @@ -136,87 +107,6 @@ def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._init_reduction_shape, self._keepdims))


class BaseReducerAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass


class DefaultReducerAdapter(BaseReducerAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)


class SequentialReducerAdapter(DefaultReducerAdapter):
def __init__(self, stack_axis: int) -> None:
super().__init__()
self._stack_axis = stack_axis

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)


class TensorReducersSequence(TensorReducerInterface):
def __init__(self, *args):
if any(reducer.inplace for reducer in args[1:]):
raise RuntimeError(f"Only first reducer of sequential tensor reducer could not be inplace.")
self._reducers = args

@property
def inplace(self):
return self._reducers[0].inplace

@property
def output_port_id(self) -> int:
return self._reducers[0].output_port_id

@property
def name(self):
name = ""
for i, reducer in enumerate(self._reducers):
name += f"{i}_{reducer.name}"
return name

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return self._reducers[0].get_output_names(target_node_name, port_id)

def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
return self._reducers[0].get_inplace_fn()

def __call__(self, x: List[NNCFTensor]):
if not self._reducers[0].inplace:
x = self._reducers[0](x)

for reducer in self._reducers[1:]:
x = reducer(x)
return x

def __eq__(self, __o: object) -> bool:
return (
isinstance(__o, self.__class__)
and len(self._reducers) == len(__o.reducers)
and all(self_r == o_r for self_r, o_r in zip(self._reducers, __o.reducers))
)

def __hash__(self) -> int:
return hash(tuple(hash(reducer) for reducer in self._reducers))


class TensorAggregatorBase:
"""
Tensor aggregator is designed to recieve (register) calculated statistics and
Expand Down Expand Up @@ -273,6 +163,54 @@ def __hash__(self) -> int:
return hash(self.__class__.__name__)


class BaseTensorCollectorAdapter(ABC):
@abstractmethod
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
pass

@abstractmethod
def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
pass


class DefaultTensorCollectorAdapter(BaseTensorCollectorAdapter):
def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
if reducer.inplace:
return x

kwargs = reducer.get_kwargs()
reduction_shape_key = "reduction_shape"
if reduction_shape_key in kwargs and kwargs[reduction_shape_key] is None:
kwargs[reduction_shape_key] = tuple(range(len(x[0].shape)))

return reducer.reduce_out_of_place(x, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
aggregator.register_reduced_input(x)


class SequentialTensorCollectorAdapter(DefaultTensorCollectorAdapter):
def __init__(self, stack_axis: int, tensor_processor: NNCFCollectorTensorProcessor) -> None:
super().__init__()
self._stack_axis = stack_axis
self._tensor_processor = tensor_processor

def reduce(self, x: List[NNCFTensor], reducer: TensorReducerBase):
per_element_result = super().reduce(x, reducer)
kwargs = reducer.get_kwargs()
new_params = {"reduction_shape": self._stack_axis, "keepdims": False}
if not all(k in kwargs for k in new_params):
return per_element_result

kwargs.update(new_params)
return reducer.reduce_out_of_place(per_element_result, **kwargs)

def register_reduced_input(self, x: NNCFTensor, aggregator: TensorAggregatorBase):
if isinstance(aggregator, ShapeAggregator):
x = self._tensor_processor.unstack(x, axis=0)[0]
super().register_reduced_input(x, aggregator)


class TensorCollector:
"""
Calculates statistics at given tensors according to registered statistic branches.
Expand All @@ -289,7 +227,7 @@ def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> Non
self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {}
self._stat_container = statistic_container
self._enabled = True
self._adapter = DefaultReducerAdapter()
self._adapter = DefaultTensorCollectorAdapter()

@property
def num_samples(self) -> Optional[int]:
Expand All @@ -313,7 +251,7 @@ def reducers(self):
def aggregators(self):
return self._aggregators.copy()

def set_adapter(self, adapter: "BaseReducerAdapter"):
def set_adapter(self, adapter: "BaseTensorCollectorAdapter"):
self._adapter = adapter

def enable(self):
Expand Down Expand Up @@ -390,7 +328,7 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None:
aggregator,
) in self._aggregators.items():
if reducer_hash in reduced_inputs:
aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id])
self._adapter.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id], aggregator)

def _aggregate(self) -> None:
result = {}
Expand Down Expand Up @@ -604,7 +542,6 @@ def __hash__(self) -> int:


class QuantileReducer(QuantileReducerBase):
@staticmethod
@staticmethod
def reduce_out_of_place(
x: List[NNCFTensor],
Expand Down Expand Up @@ -683,18 +620,14 @@ def aggregate(self):


class ShapeAggregator(TensorAggregatorBase):
def __init__(self, slice_=None):
def __init__(self):
super().__init__(None, 1)
self._slice = slice_

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def aggregate(self):
shape = self._container.shape
if self._slice is not None:
return shape[self._slice]
return shape
return self._container.shape


class MinAggregator(TensorAggregatorBase):
Expand Down
5 changes: 5 additions & 0 deletions nncf/onnx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def _get_merged_statistic_points(
@staticmethod
def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, ONNXNNCFTensor]:
return {n: ONNXNNCFTensor(v) for n, v in outputs.items()}

def _get_tensor_processor():
from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor

return ONNXNNCFCollectorTensorProcessor
Loading

0 comments on commit 43eb6bc

Please sign in to comment.