Skip to content

Commit

Permalink
reduction_axes -> channel_axis for bc/fbc
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 26, 2023
1 parent ef99ed4 commit b3fb6b2
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 25 deletions.
4 changes: 2 additions & 2 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import abstractmethod
from collections import defaultdict
from collections import deque
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union

from nncf.common.tensor import TensorType
from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor
Expand Down Expand Up @@ -398,7 +398,7 @@ def get_tensor_collector_inputs(
return target_inputs

@staticmethod
def _build_statistic_container(statistic_container_cls: TensorStatistic, kwargs: Dict[Any, Any]):
def _build_statistic_container(statistic_container_cls: Type[TensorStatistic], kwargs: Dict[Any, Any]):
if issubclass(statistic_container_cls, MinMaxTensorStatistic):
return statistic_container_cls(
min_values=kwargs[MinMaxTensorStatistic.MIN_STAT], max_values=kwargs[MinMaxTensorStatistic.MAX_STAT]
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
TargetType.POST_LAYER_OPERATION, node_name, port_id=OUTPUT_PORT_OF_NODE
)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
channel_axis=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -87,15 +86,15 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: TargetPoint) -
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_axes: Channel axis for the statistics aggregation.
:param channel_axis: Channel axis for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/bias_correction/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.model_utils import remove_fq_from_inputs
from nncf.onnx.graph.node_utils import get_bias_value
Expand Down Expand Up @@ -77,12 +76,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: ONNXTargetPoin

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)
return ONNXMeanStatisticCollector(channel_axis, num_samples, window_size)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> ONNXMeanStatisticCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
Expand Down Expand Up @@ -65,12 +64,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: OVTargetPoint)

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)
return get_mean_stat_collector(num_samples, channel_axis, window_size, inplace)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _add_statistic_point(self, container: StatisticPointsContainer, point: Targe
:param axis: Channel axis for the statistics calculation.
"""
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
channel_axis=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
container.add_statistic_point(
StatisticPoint(target_point=point, tensor_collector=stat_collector, algorithm=self._algorithm_key)
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/fast_bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -79,15 +78,15 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> Transform
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_axes: Channel axes for the statistics aggregation.
:param channel_axis: Channel axes for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.node_utils import get_bias_value
from nncf.onnx.graph.node_utils import is_any_weight_quantized
Expand Down Expand Up @@ -64,12 +63,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> ONNXModel

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)
return ONNXMeanStatisticCollector(channel_axis, num_samples, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: onnx.ModelProto) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
Expand Down Expand Up @@ -56,12 +55,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelEx

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)
return get_mean_stat_collector(num_samples, channel_axis, window_size, inplace)

@staticmethod
def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS
Expand Down Expand Up @@ -68,12 +67,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> PTModelEx

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statisitic_collector(num_samples, reduction_axes, window_size)
return get_mean_statisitic_collector(num_samples, channel_axis, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
Expand Down

0 comments on commit b3fb6b2

Please sign in to comment.