Skip to content

Commit

Permalink
Variables are the nth output of an Apply node
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 12, 2022
1 parent 36ec81c commit 3d6969c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
1 change: 0 additions & 1 deletion aesara/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.unify import OpExpressionTuple

from aesara.graph.utils import AssocList, InconsistencyError
from aesara.misc.ordered_set import OrderedSet
from aesara.utils import flatten
Expand Down
28 changes: 24 additions & 4 deletions aesara/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from unification.utils import transitive_get as walk
from unification.variable import Var, isvar, var

from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Type

Expand Down Expand Up @@ -89,7 +89,7 @@ def _eval_apply_fn(self, op: Op) -> Callable:

def eval_op(*inputs, **kwargs):
node = op.make_node(*inputs, **kwargs)
return node.outputs
return node

return eval_op

Expand All @@ -110,9 +110,14 @@ def etuplize_fn_Op(_: Op):
return etuple_Op


def nth(idx, node):
return node.outputs[idx]


# A Variable is the `nth` output of an Apply node
def car_Variable(x):
if x.owner:
return x.owner.op
return nth
else:
raise ConsError("Not a cons pair.")

Expand All @@ -122,7 +127,7 @@ def car_Variable(x):

def cdr_Variable(x):
if x.owner:
x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x)
x_e = etuple(_car(x), x.index, x.owner, evaled_obj=x)
else:
raise ConsError("Not a cons pair.")

Expand All @@ -132,6 +137,21 @@ def cdr_Variable(x):
_cdr.add((Variable,), cdr_Variable)


def car_Apply(x):
return x.op


_car.add((Apply,), car_Apply)


def cdr_Apply(x):
x_e = etuple(_car(x), *x.inputs, evaled_obj=x.outputs)
return x_e[1:]


_cdr.add((Apply,), cdr_Apply)


def car_Op(x):
if hasattr(x, "__props__"):
return type(x)
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_etuples():
assert res.owner.inputs == [x_at, y_at]

w_at = etuple(at.add, x_at, y_at)
assert isinstance(w_at, ExpressionTuple)
assert isinstance(w_at, OpExpressionTuple)

res = w_at.evaled_obj
assert len(res) == 1
Expand Down

0 comments on commit 3d6969c

Please sign in to comment.