From 1be30802b30abe9e0ce4365cc8768280110a4b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 29 Sep 2022 15:11:56 +0200 Subject: [PATCH] Test eval of variables created by multiple output and default output `Op`s --- tests/graph/rewriting/test_unify.py | 80 +++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 5debd7ec30..2d5dce0e29 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -102,7 +102,7 @@ def test_cons(): assert cdr_res == [atype_at.dtype, atype_at.shape] -def test_etuples(): +def test_etuples_Op(): x_at = at.vector("x") y_at = at.vector("y") @@ -116,10 +116,18 @@ def test_etuples(): w_at = etuple(at.add, x_at, y_at) assert isinstance(w_at, OpExpressionTuple) + w_at._evaled_obj = w_at.null res = w_at.evaled_obj assert res.owner.op == at.add assert res.owner.inputs == [x_at, y_at] + +def test_etuples_atomic_Op(): + x_at = at.vector("x") + y_at = at.vector("y") + + z_at = etuple(x_at, y_at) + # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") op1_np = CustomOpNoProps(1) @@ -127,14 +135,20 @@ def test_etuples(): assert res.owner.op == op1_np q_at = op1_np(x_at, y_at) - res = etuplize(q_at) - assert isinstance(res, ExpressionTuple) - assert isinstance(res[0], CustomOpNoProps) - assert res[0] == op1_np + q_et = etuplize(q_at) + assert isinstance(q_et, ExpressionTuple) + assert isinstance(q_et[0], CustomOpNoProps) + assert q_et[0] == op1_np + + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, CustomOpNoProps) with pytest.raises(TypeError): etuplize(op1_np) + +def test_etuples_multioutput_Op(): class MyMultiOutOp(Op): def make_node(self, *inputs): outputs = [MyType()(), MyType()()] @@ -151,25 +165,51 @@ def perform(self, node, inputs, outputs): assert res[0].owner.op == op1_np assert res[1].owner.op == op1_np - mu_at = at.scalar("mu") - sigma_at = at.scalar("sigma") + # If we etuplize one of the outputs we should recover this output when + # evaluating + q_at, _ = op1_np(x_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert res == q_at - w_rv = at.random.normal(mu_at, sigma_at) - w_at = etuplize(w_rv) - assert isinstance(w_at, ExpressionTuple) - assert w_at[1] == 1 - assert isinstance(w_at[2], OpExpressionTuple) + # If the caller etuplizes the output list, it should recover the + # list when evaluating. + q_at = op1_np(x_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj - z_at = etuple(at.random.normal, mu_at, sigma_at) - assert isinstance(z_at, OpExpressionTuple) - z_at = etuple(at.random.normal, *w_at[2][1:]) - assert isinstance(z_at, OpExpressionTuple) +def test_etuples_default_output_op(): + class MyDefaultOutputOp(Op): + default_output = 1 - res = z_at.evaled_obj - assert len(res) == 2 - assert res[1].owner.op == at.random.normal - assert res[1].owner.inputs[-2:] == [mu_at, sigma_at] + def make_node(self, *inputs): + outputs = [MyType()(), MyType()()] + return Apply(self, list(inputs), outputs) + + def perform(self, node, inputs, outputs): + outputs[0] = np.array(np.array(inputs[0])) + outputs[1] = np.array(np.array(inputs[1])) + + x_at = at.vector("x") + y_at = at.vector("y") + op1_np = MyDefaultOutputOp() + res = apply(op1_np, etuple(x_at, y_at)) + assert res.owner.op == op1_np + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at + + # We should recover the default output when evaluting its etuplized + # counterpart. + q_at = op1_np(x_at, y_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, MyDefaultOutputOp) + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at def test_unify_Variable():