From d4d10d37b8d89ed25f519dc8df002e3fdb3ac7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 4 Jul 2022 17:42:11 +0200 Subject: [PATCH] Wrap Ops during etuplization Evaluating etuplized objects fails for some `RandomVariable` ops whose `__call__` function does not defer to `make_node`. In this commit we wrap `RandomVariable` ops during etuplization with a class that always defers `__call__` to `make_node`. We also add a dispatch rule for `etuplize` so it also wraps `RandomVariable` with the same class. --- aesara/graph/unify.py | 45 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/aesara/graph/unify.py b/aesara/graph/unify.py index ffa7502a18..1ac5fe30f5 100644 --- a/aesara/graph/unify.py +++ b/aesara/graph/unify.py @@ -11,6 +11,7 @@ """ from collections.abc import Mapping +from dataclasses import dataclass from numbers import Number from typing import Dict, Optional, Tuple, Union @@ -25,6 +26,7 @@ from aesara.graph.basic import Constant, Variable from aesara.graph.op import Op from aesara.graph.type import Type +from aesara.tensor.random.op import RandomVariable def eval_if_etuple(x): @@ -72,9 +74,50 @@ def __repr__(self): return f"ConstrainedVar({repr(self.constraint)}, {self.token})" +@dataclass +class MakeNodeOp: + """Wrapper around Ops. + + Some RandomVariable Ops's `__call__` method does not defer to + `make_node`, and `eval_if_etuple` fails on their etuplized version. To + circumvent this limitation we wrap Ops with this object that always + defers `__call__` to `make_node`. + + """ + + op: Op + + def __call__(self, *args): + return self.op.make_node(*args) + + +def car_MakeNodeOp(x): + return type(x) + + +_car.add((MakeNodeOp,), car_MakeNodeOp) + + +def cdr_MakeNodeOp(x): + x_e = etuple(_car(x), x.op, evaled_obj=x) + return x_e[1:] + + +_cdr.add((MakeNodeOp,), cdr_MakeNodeOp) + + +@etuplize.register(RandomVariable) +def etuplize_random(*args, **kwargs): + """Wrap RandomVariable Ops with a MakeNodeOp object.""" + return etuple(MakeNodeOp, etuplize.funcs[(object,)](*args, **kwargs)) + + def car_Variable(x): if x.owner: - return x.owner.op + if issubclass(type(x.owner.op), RandomVariable): + return MakeNodeOp(x.owner.op) + else: + return x.owner.op else: raise ConsError("Not a cons pair.")