Skip to content

Commit

Permalink
Test eval of variables created by multiple output and default output …
Browse files Browse the repository at this point in the history
…`Op`s
  • Loading branch information
rlouf committed Sep 29, 2022
1 parent 5efbbde commit 1be3080
Showing 1 changed file with 60 additions and 20 deletions.
80 changes: 60 additions & 20 deletions tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_cons():
assert cdr_res == [atype_at.dtype, atype_at.shape]


def test_etuples():
def test_etuples_Op():
x_at = at.vector("x")
y_at = at.vector("y")

Expand All @@ -116,25 +116,39 @@ def test_etuples():
w_at = etuple(at.add, x_at, y_at)
assert isinstance(w_at, OpExpressionTuple)

w_at._evaled_obj = w_at.null
res = w_at.evaled_obj
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]


def test_etuples_atomic_Op():
x_at = at.vector("x")
y_at = at.vector("y")

z_at = etuple(x_at, y_at)

# This `Op` doesn't expand into an `etuple` (i.e. it's "atomic")
op1_np = CustomOpNoProps(1)

res = apply(op1_np, z_at)
assert res.owner.op == op1_np

q_at = op1_np(x_at, y_at)
res = etuplize(q_at)
assert isinstance(res, ExpressionTuple)
assert isinstance(res[0], CustomOpNoProps)
assert res[0] == op1_np
q_et = etuplize(q_at)
assert isinstance(q_et, ExpressionTuple)
assert isinstance(q_et[0], CustomOpNoProps)
assert q_et[0] == op1_np

q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert isinstance(res.owner.op, CustomOpNoProps)

with pytest.raises(TypeError):
etuplize(op1_np)


def test_etuples_multioutput_Op():
class MyMultiOutOp(Op):
def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
Expand All @@ -151,25 +165,51 @@ def perform(self, node, inputs, outputs):
assert res[0].owner.op == op1_np
assert res[1].owner.op == op1_np

mu_at = at.scalar("mu")
sigma_at = at.scalar("sigma")
# If we etuplize one of the outputs we should recover this output when
# evaluating
q_at, _ = op1_np(x_at)
q_et = etuplize(q_at)
q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert res == q_at

w_rv = at.random.normal(mu_at, sigma_at)
w_at = etuplize(w_rv)
assert isinstance(w_at, ExpressionTuple)
assert w_at[1] == 1
assert isinstance(w_at[2], OpExpressionTuple)
# If the caller etuplizes the output list, it should recover the
# list when evaluating.
q_at = op1_np(x_at)
q_et = etuplize(q_at)
q_et._evaled_obj = q_et.null
res = q_et.evaled_obj

z_at = etuple(at.random.normal, mu_at, sigma_at)
assert isinstance(z_at, OpExpressionTuple)

z_at = etuple(at.random.normal, *w_at[2][1:])
assert isinstance(z_at, OpExpressionTuple)
def test_etuples_default_output_op():
class MyDefaultOutputOp(Op):
default_output = 1

res = z_at.evaled_obj
assert len(res) == 2
assert res[1].owner.op == at.random.normal
assert res[1].owner.inputs[-2:] == [mu_at, sigma_at]
def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)

def perform(self, node, inputs, outputs):
outputs[0] = np.array(np.array(inputs[0]))
outputs[1] = np.array(np.array(inputs[1]))

x_at = at.vector("x")
y_at = at.vector("y")
op1_np = MyDefaultOutputOp()
res = apply(op1_np, etuple(x_at, y_at))
assert res.owner.op == op1_np
assert res.owner.inputs[0] == x_at
assert res.owner.inputs[1] == y_at

# We should recover the default output when evaluting its etuplized
# counterpart.
q_at = op1_np(x_at, y_at)
q_et = etuplize(q_at)
q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert isinstance(res.owner.op, MyDefaultOutputOp)
assert res.owner.inputs[0] == x_at
assert res.owner.inputs[1] == y_at


def test_unify_Variable():
Expand Down

0 comments on commit 1be3080

Please sign in to comment.