Skip to content

Commit

Permalink
Constant filtering and shape of removal passes are updated
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 16, 2023
1 parent 6c4d598 commit 63e668d
Show file tree
Hide file tree
Showing 19 changed files with 334 additions and 83 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OperatorMetatype:
hw_config_names: List[str] = []
output_channel_axis: Optional[int] = None
ignored_input_ports: List[int] = []
input_edges_num_expected = None

@classmethod
def get_all_aliases(cls) -> List[str]:
Expand Down
2 changes: 2 additions & 0 deletions nncf/common/utils/dot_file_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def read_dot_graph(path: pathlib.Path) -> nx.DiGraph:
def _maybe_escape_colons_in_attrs(data: Dict):
for attr_name in data:
attr_val = data[attr_name]
if not isinstance(attr_val, str):
continue
if RESERVED_CHAR in attr_val and not (attr_val[0] == '"' or attr_val[-1] == '"'):
data[attr_name] = '"' + data[attr_name] + '"' # escaped colons are allowed

Expand Down
14 changes: 6 additions & 8 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,23 +266,21 @@ class ONNXConcatMetatype(ONNXOpMetatype):
class ONNXBatchNormMetatype(ONNXOpMetatype):
name = "BatchNormalizationOp"
op_names = ["BatchNormalization"]
input_edges_num_expected = 2
input_edges_num_expected = 5


@ONNX_OPERATION_METATYPES.register()
class ONNXResizeMetatype(ONNXOpMetatype):
name = "ResizeOp"
op_names = ["Resize"]
hw_config_names = [HWConfigOpName.INTERPOLATE]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXCenterCropPadMetatype(ONNXOpMetatype):
name = "CenterCropPadOp"
op_names = ["CenterCropPad"]
hw_config_names = [HWConfigOpName.CROP]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -335,7 +333,6 @@ class ONNXSplitMetatype(ONNXOpMetatype):
name = "SplitOp"
op_names = ["Split"]
hw_config_names = [HWConfigOpName.SPLIT]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -367,7 +364,6 @@ class ONNXNotMetatype(ONNXOpMetatype):
name = "NotOp"
op_names = ["Not"]
hw_config_names = [HWConfigOpName.LOGICALNOT]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -407,7 +403,6 @@ class ONNXFloorMetatype(ONNXOpMetatype):
name = "FloorOp"
op_names = ["Floor"]
hw_config_names = [HWConfigOpName.FLOORMOD]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -423,7 +418,6 @@ class ONNXSqrtMetatype(ONNXOpMetatype):
name = "SqrtOp"
op_names = ["Sqrt"]
hw_config_names = [HWConfigOpName.POWER]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -457,7 +451,6 @@ class ONNXLogMetatype(ONNXOpMetatype):
class ONNXAbsMetatype(ONNXOpMetatype):
name = "AbsOp"
op_names = ["Abs"]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -482,25 +475,29 @@ class ONNXScatterNDMetatype(ONNXOpMetatype):
class ONNXRoiAlignMetatype(ONNXOpMetatype):
name = "RoiAlignOp"
op_names = ["RoiAlign"]
input_edges_num_expected = 3


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherMetatype(ONNXOpMetatype):
name = "GatherOp"
op_names = ["Gather"]
subtypes = [ONNXEmbeddingMetatype]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherNDMetatype(ONNXOpMetatype):
name = "GatherNDOp"
op_names = ["GatherND"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherElementsMetatype(ONNXOpMetatype):
name = "GatherElementsOp"
op_names = ["GatherElements"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -521,6 +518,7 @@ class ONNXSqueezeMetatype(ONNXOpMetatype):
class ONNXNonMaxSuppressionMetatype(ONNXOpMetatype):
name = "NonMaxSuppressionOp"
op_names = ["NonMaxSuppression"]
# input_edges_num_expected = from 2 to 5


@ONNX_OPERATION_METATYPES.register()
Expand Down
5 changes: 4 additions & 1 deletion nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.weight_attrs = weight_attrs if weight_attrs is not None else {}
self.bias_attrs = bias_attrs if bias_attrs is not None else {}
self.node_attrs = node_attrs if node_attrs is not None else {}
self.layer_attributes = layer_attributes
self._layer_attributes = layer_attributes

def has_weight(self) -> bool:
return bool(self.weight_attrs)
Expand All @@ -78,6 +78,9 @@ def has_bias(self) -> bool:
def has_node_attrs(self) -> bool:
return bool(self.node_attrs)

def get_backend_agnostic_attributes(self) -> BaseLayerAttributes:
return self._layer_attributes


def get_constant_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]:
"""
Expand Down
Loading

0 comments on commit 63e668d

Please sign in to comment.