Skip to content

Commit

Permalink
Unify Variables that are default outputs with ExpressionTuples
Browse files Browse the repository at this point in the history
The difficulty comes from the fact that the `Variable` is etuplized as
`ExpressionTuple(nth, default_output, oExpressionTuple(...))` but we
cannot expect the caller to use `nth` here since this `default_output`
mechanism is hidden. We expand the latter before unifying.
  • Loading branch information
rlouf committed Oct 5, 2022
1 parent 33d9511 commit 6f67809
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
21 changes: 18 additions & 3 deletions aesara/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,27 @@ def _unify_Constant_Constant(u, v, s):


def _unify_Variable_ExpressionTuple(u, v, s):
# `Constant`s are "atomic"
"""Unify a `Variable` with an `ExpressionTuple`.
If the `Variable`'s owner only has one output we can etuplize the `Variable`
and unify both expression tuple.
If the owner has multiple outputs, but the `Op`'s `default_output` is not
`None` we unify the etuplized version of the `Variable` with an expanded
expression tuple that account for the variable selection. We only do this
for nodes with a default output (we otherwise expect the caller to use
the `nth` operator in the expression tuple).
"""
if not u.owner:
yield False
return

yield _unify(etuplize(u, shallow=True), v, s)
if u.owner.nout == 1:
yield _unify(etuplize(u, shallow=True), v, s)
elif u.owner.nout == 2 and u.owner.op.default_output is not None:
u_et = etuplize(u)
v_et = etuple(nth, u_et[1], v)
yield _unify(u_et, v_et, s)


_unify.add(
Expand Down
27 changes: 27 additions & 0 deletions tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,33 @@ def test_unify_Variable():
assert s[b_lv] == y_at


def test_unify_default_output_Variable():
"""Make sure that we can unify with the default output of an Apply node."""

class MyDefaultOutputOp(Op):
default_output = 1

def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)

def perform(self, node, inputs, outputs):
outputs[0] = np.array(np.array(inputs[0]))
outputs[1] = np.array(np.array(inputs[1]))

x_at = at.vector("x")
y_at = at.vector("y")
op1_np = MyDefaultOutputOp()
q_at = op1_np(x_at, y_at)

x_lv, y_lv = var("x"), var("y")
q_et = etuple(op1_np, x_lv, y_lv)

s = unify(q_et, q_at)
assert s[x_lv] == x_at
assert s[y_lv] == y_at


def test_unify_Op():
# These `Op`s expand into `ExpressionTuple`s
op1 = CustomOp(1)
Expand Down

0 comments on commit 6f67809

Please sign in to comment.