Skip to content

Commit

Permalink
[MultiThreshold] Replace default data_layout by fallback in execute_node
Browse files Browse the repository at this point in the history
Note: Only covers data layouts for tensors with less than five axes
  • Loading branch information
iksnagreb committed Oct 24, 2024
1 parent 9d73e16 commit c0b4534
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def get_nodeattr_types(self):
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
# fmt: off
"data_layout": ("s", False, "NCHW")
# fmt: on
"data_layout": ("s", False, ""),
}

def make_shape_compatible_op(self, model):
Expand Down Expand Up @@ -129,6 +127,13 @@ def execute_node(self, context, graph):
# accepted by the multithreshold function above, i.e, the channel
# dimension is along the axis with index 1.
data_layout = self.get_nodeattr("data_layout")
# If there is no layout annotation, guess based on rank of the
# tensor
if not data_layout and len(v.shape) < 5:
# Maps tensor rank to layout annotation
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
# Lookup the layout required by this input shape
data_layout = rank_to_layout[len(v.shape)]
# Lookup the index of the channel dimension in the data layout
# Note: Assumes there is at most one "C" which denotes the channel
# dimension
Expand Down

0 comments on commit c0b4534

Please sign in to comment.