diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index ed34ffe..d9e92c7 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx): val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: val = np.asarray([one_val], dtype=np.float32) + elif op in ["Identity"]: + val = None else: return graph = model.graph + if val is None: + inplist = ["inp" if as_first_node else "div_out"] + else: + model.set_initializer("value", val) + inplist = ["inp" if as_first_node else "div_out", "value"] + identity_node = helper.make_node(op, inplist, ["ident_out"]) if as_first_node: - identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"]) graph.node.insert(0, identity_node) graph.node[1].input[0] = "ident_out" else: - identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) graph.node.insert(3, identity_node) graph.node[-1].input[0] = "ident_out" - model.set_initializer("value", val) return model # identity operations to be inserted -@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) def test_remove_identity_ops(op, as_first_node, approx):