diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 708840e..bcf0731 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -92,9 +92,7 @@ def get_nodeattr_types(self): "out_dtype": ("s", True, ""), "out_scale": ("f", False, 1.0), "out_bias": ("f", False, 0.0), - # fmt: off - "data_layout": ("s", False, "NCHW") - # fmt: on + "data_layout": ("s", False, ""), } def make_shape_compatible_op(self, model): @@ -129,6 +127,13 @@ def execute_node(self, context, graph): # accepted by the multithreshold function above, i.e, the channel # dimension is along the axis with index 1. data_layout = self.get_nodeattr("data_layout") + # If there is no layout annotation, guess based on rank of the + # tensor + if not data_layout and len(v.shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + data_layout = rank_to_layout[len(v.shape)] # Lookup the index of the channel dimension in the data layout # Note: Assumes there is at most one "C" which denotes the channel # dimension