Skip to content

Commit

Permalink
Wrap Ops during etuplization
Browse files Browse the repository at this point in the history
Evaluating etuplized objects fails for some `RandomVariable` ops whose
`__call__` function does not defer to `make_node`.

In this commit we wrap `RandomVariable` ops during etuplization with a
class that always defers `__call__` to `make_node`. We also add a
dispatch rule for `etuplize` so it also wraps `RandomVariable` with the
same class.
  • Loading branch information
rlouf committed Jul 5, 2022
1 parent 77bb152 commit d4d10d3
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion aesara/graph/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

from collections.abc import Mapping
from dataclasses import dataclass
from numbers import Number
from typing import Dict, Optional, Tuple, Union

Expand All @@ -25,6 +26,7 @@
from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.tensor.random.op import RandomVariable


def eval_if_etuple(x):
Expand Down Expand Up @@ -72,9 +74,50 @@ def __repr__(self):
return f"ConstrainedVar({repr(self.constraint)}, {self.token})"


@dataclass
class MakeNodeOp:
"""Wrapper around Ops.
Some RandomVariable Ops's `__call__` method does not defer to
`make_node`, and `eval_if_etuple` fails on their etuplized version. To
circumvent this limitation we wrap Ops with this object that always
defers `__call__` to `make_node`.
"""

op: Op

def __call__(self, *args):
return self.op.make_node(*args)


def car_MakeNodeOp(x):
return type(x)


_car.add((MakeNodeOp,), car_MakeNodeOp)


def cdr_MakeNodeOp(x):
x_e = etuple(_car(x), x.op, evaled_obj=x)
return x_e[1:]


_cdr.add((MakeNodeOp,), cdr_MakeNodeOp)


@etuplize.register(RandomVariable)
def etuplize_random(*args, **kwargs):
"""Wrap RandomVariable Ops with a MakeNodeOp object."""
return etuple(MakeNodeOp, etuplize.funcs[(object,)](*args, **kwargs))


def car_Variable(x):
if x.owner:
return x.owner.op
if issubclass(type(x.owner.op), RandomVariable):
return MakeNodeOp(x.owner.op)
else:
return x.owner.op
else:
raise ConsError("Not a cons pair.")

Expand Down

0 comments on commit d4d10d3

Please sign in to comment.