Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 25, 2024
1 parent 017661c commit fc653c1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _get_statistics_for_node(
statistics_for_node = []
for tensor_collector in statistic_points.get_algo_statistics_for_node(
node_name,
self._backend_entity.get_filter_fn_for_statistics(act_port),
self._backend_entity.get_filter_fn_for_statistics(act_port, self._algorithm_key),
self._algorithm_key,
):
statistic = tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY]
Expand Down
4 changes: 3 additions & 1 deletion nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,11 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:

@staticmethod
@abstractmethod
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
"""
Returns backend-specific callable to filter statistic containers according to its statistic point.
:param activation_port_id: Activation port id for the statistic collection target node.
:param algorithm_key: Current algorithm key.
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""
12 changes: 9 additions & 3 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend

OV_PRE_LAYER_TARGET_TYPE = TargetType.PRE_LAYER_OPERATION


class OVSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
Expand All @@ -54,7 +56,7 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:

@staticmethod
def pre_layer_target_type() -> TargetType:
return TargetType.PRE_LAYER_OPERATION
return OV_PRE_LAYER_TARGET_TYPE

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
Expand Down Expand Up @@ -155,8 +157,12 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return len(nncf_graph.get_next_nodes(weight_node)) > 1

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return point.target_point.port_id == activation_port_id
return (
algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == OV_PRE_LAYER_TARGET_TYPE
and point.target_point.port_id == activation_port_id
)

return filter_func
21 changes: 13 additions & 8 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def forward(self, x):
return torch.mul(x, self._scale_value)


PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK


class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
Expand All @@ -65,22 +68,22 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:

@staticmethod
def pre_layer_target_type() -> TargetType:
return TargetType.OPERATOR_PRE_HOOK
return PT_PRE_LAYER_TARGET_TYPE

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
# Metatypes of matmuls and convolutions guarantee
# Metatypes of linears and convolutions guarantee
# all nodes with the metatypes have weights, we can skip
# this check by returning True.
return True

@staticmethod
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
# Metatypes of matmuls and convolutions guarantee
# Metatypes of linears and convolutions guarantee
# all nodes with the metatypes have 0 activation port id
return 0

Expand Down Expand Up @@ -127,9 +130,7 @@ def scale_insertion_command(
input_port_id = 0
target_points = []
for node in nodes:
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node.node_name, input_port_id=input_port_id)
)
target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id))

return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)

Expand All @@ -150,8 +151,12 @@ def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.is_shared()

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return point.target_point.input_port_id == activation_port_id
return (
algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == PT_PRE_LAYER_TARGET_TYPE
and point.target_point.input_port_id == activation_port_id
)

return filter_func

0 comments on commit fc653c1

Please sign in to comment.