diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index d8c16bb817d..f0fb907d244 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -243,7 +243,7 @@ def _get_statistics_for_node( statistics_for_node = [] for tensor_collector in statistic_points.get_algo_statistics_for_node( node_name, - self._backend_entity.get_filter_fn_for_statistics(act_port), + self._backend_entity.get_filter_fn_for_statistics(act_port, self._algorithm_key), self._algorithm_key, ): statistic = tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY] diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index 57440e1f371..e7ffec9c098 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -207,9 +207,11 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @staticmethod @abstractmethod - def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: + def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]: """ Returns backend-specific callable to filter statistic containers according to its statistic point. :param activation_port_id: Activation port id for the statistic collection target node. + :param algorithm_key: Current algorithm key. + :return: Backend-specific callable to filter statistic containers according to its statistic point. """ diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index a02d080a149..6ba53666a9b 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -38,6 +38,8 @@ from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend +OV_PRE_LAYER_TARGET_TYPE = TargetType.PRE_LAYER_OPERATION + class OVSmoothQuantAlgoBackend(SmoothQuantAlgoBackend): @property @@ -54,7 +56,7 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: @staticmethod def pre_layer_target_type() -> TargetType: - return TargetType.PRE_LAYER_OPERATION + return OV_PRE_LAYER_TARGET_TYPE @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: @@ -155,8 +157,12 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return len(nncf_graph.get_next_nodes(weight_node)) > 1 @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: + def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: - return point.target_point.port_id == activation_port_id + return ( + algorithm_key in point.algorithm_to_tensor_collectors + and point.target_point.type == OV_PRE_LAYER_TARGET_TYPE + and point.target_point.port_id == activation_port_id + ) return filter_func diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index bb4dc706d71..0d56a16cdda 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -46,6 +46,9 @@ def forward(self, x): return torch.mul(x, self._scale_value) +PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK + + class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend): @property def convolution_metatypes(self) -> List[OperatorMetatype]: @@ -65,7 +68,7 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: @staticmethod def pre_layer_target_type() -> TargetType: - return TargetType.OPERATOR_PRE_HOOK + return PT_PRE_LAYER_TARGET_TYPE @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: @@ -73,14 +76,14 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: - # Metatypes of matmuls and convolutions guarantee + # Metatypes of linears and convolutions guarantee # all nodes with the metatypes have weights, we can skip # this check by returning True. return True @staticmethod def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - # Metatypes of matmuls and convolutions guarantee + # Metatypes of linears and convolutions guarantee # all nodes with the metatypes have 0 activation port id return 0 @@ -127,9 +130,7 @@ def scale_insertion_command( input_port_id = 0 target_points = [] for node in nodes: - target_points.append( - PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node.node_name, input_port_id=input_port_id) - ) + target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id)) return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name) @@ -150,8 +151,12 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return node.is_shared() @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: + def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: - return point.target_point.input_port_id == activation_port_id + return ( + algorithm_key in point.algorithm_to_tensor_collectors + and point.target_point.type == PT_PRE_LAYER_TARGET_TYPE + and point.target_point.input_port_id == activation_port_id + ) return filter_func