Skip to content

Commit

Permalink
[Test] add Identity op case to test_remove_identity_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar committed Sep 12, 2024
1 parent 8bad7e7 commit 71ee780
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/transformation/test_remove_identity_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx):
val = np.asarray([zero_val], dtype=np.float32)
elif op in ["Mul", "Div"]:
val = np.asarray([one_val], dtype=np.float32)
elif op in ["Identity"]:
val = None
else:
return

graph = model.graph
if val is None:
inplist = ["inp" if as_first_node else "div_out"]
else:
model.set_initializer("value", val)
inplist = ["inp" if as_first_node else "div_out", "value"]
identity_node = helper.make_node(op, inplist, ["ident_out"])
if as_first_node:
identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
graph.node.insert(0, identity_node)
graph.node[1].input[0] = "ident_out"
else:
identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
graph.node.insert(3, identity_node)
graph.node[-1].input[0] = "ident_out"
model.set_initializer("value", val)

return model


# identity operations to be inserted
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
def test_remove_identity_ops(op, as_first_node, approx):
Expand Down

0 comments on commit 71ee780

Please sign in to comment.