diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index 3bc826478f..445540fbae 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -35,7 +35,6 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.unify import OpExpressionTuple - from aesara.graph.utils import AssocList, InconsistencyError from aesara.misc.ordered_set import OrderedSet from aesara.utils import flatten diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py index 65dea4cfd8..c7393c0c00 100644 --- a/aesara/graph/rewriting/unify.py +++ b/aesara/graph/rewriting/unify.py @@ -23,7 +23,7 @@ from unification.utils import transitive_get as walk from unification.variable import Var, isvar, var -from aesara.graph.basic import Constant, Variable +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op from aesara.graph.type import Type @@ -89,7 +89,7 @@ def _eval_apply_fn(self, op: Op) -> Callable: def eval_op(*inputs, **kwargs): node = op.make_node(*inputs, **kwargs) - return node.outputs + return node return eval_op @@ -110,9 +110,14 @@ def etuplize_fn_Op(_: Op): return etuple_Op +def nth(idx, node): + return node.outputs[idx] + + +# A Variable is the `nth` output of an Apply node def car_Variable(x): if x.owner: - return x.owner.op + return nth else: raise ConsError("Not a cons pair.") @@ -122,7 +127,7 @@ def car_Variable(x): def cdr_Variable(x): if x.owner: - x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + x_e = etuple(_car(x), x.index, x.owner, evaled_obj=x) else: raise ConsError("Not a cons pair.") @@ -132,6 +137,21 @@ def cdr_Variable(x): _cdr.add((Variable,), cdr_Variable) +def car_Apply(x): + return x.op + + +_car.add((Apply,), car_Apply) + + +def cdr_Apply(x): + x_e = etuple(_car(x), *x.inputs, evaled_obj=x.outputs) + return x_e[1:] + + +_cdr.add((Apply,), cdr_Apply) + + def car_Op(x): if hasattr(x, "__props__"): return type(x) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index f57e55c6ad..f96933ea8d 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -114,7 +114,7 @@ def test_etuples(): assert res.owner.inputs == [x_at, y_at] w_at = etuple(at.add, x_at, y_at) - assert isinstance(w_at, ExpressionTuple) + assert isinstance(w_at, OpExpressionTuple) res = w_at.evaled_obj assert len(res) == 1