Skip to content

Commit

Permalink
Allow etuplization of RandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 7, 2022
1 parent e40c827 commit 53931d3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
41 changes: 39 additions & 2 deletions aesara/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@

from collections.abc import Mapping
from numbers import Number
from typing import Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
from cons.core import ConsError, _car, _cdr
from etuples import apply, etuple, etuplize
from etuples.core import ExpressionTuple
from etuples.core import ExpressionTuple, etuple
from etuples.dispatch import etuplize_fn
from unification.core import _unify, assoc
from unification.utils import transitive_get as walk
from unification.variable import Var, isvar, var

from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.tensor.random.op import RandomVariable


def eval_if_etuple(x):
Expand Down Expand Up @@ -72,6 +74,41 @@ def __repr__(self):
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"


class RVExpressionTuple(ExpressionTuple):
r"""Etuple form for `RandomVariables`s.
Some `RandomVariable.__call__` signatures do not match their
`RandomVariable.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to
fail. To circumvent this limitation we subclass `ExpressionTuple`, and
overload the `_eval_apply` method to use `Op.make_node` instead of
`Op.__call__`.
"""

def _eval_apply_fn(self, op: RandomVariable) -> Callable:
def eval_fn(*inputs, **kwargs):
node = op.make_node(*inputs, **kwargs)
return node.outputs[1]

return eval_fn

def __repr__(self):
return "RV" + super().__repr__()

def __str__(self):
return "rv" + super().__repr__()


@etuple.register(RandomVariable, [object])
def etuple_RandomVariable(*args, **kwargs) -> RVExpressionTuple:
return RVExpressionTuple(args, **kwargs)


@etuplize_fn.register(RandomVariable)
def etuplize_fn_RandomVariable(_: RandomVariable):
return etuple_RandomVariable


def car_Variable(x):
if x.owner:
return x.owner.op
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.17.0
- scipy>=0.14
- filelock
- etuples
- etuples>=0.3.8
- logical-unification
- miniKanren
- cons
Expand Down
27 changes: 26 additions & 1 deletion tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import aesara.tensor as at
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.op import Op
from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars
from aesara.graph.rewriting.unify import (
ConstrainedVar,
RVExpressionTuple,
convert_strs_to_vars,
)
from aesara.tensor.type import TensorType
from tests.graph.utils import MyType

Expand Down Expand Up @@ -110,6 +114,7 @@ def test_etuples():
assert res.owner.inputs == [x_at, y_at]

w_at = etuple(at.add, x_at, y_at)
assert isinstance(w_at, ExpressionTuple)

res = w_at.evaled_obj
assert res.owner.op == at.add
Expand All @@ -123,6 +128,8 @@ def test_etuples():

q_at = op1_np(x_at, y_at)
res = etuplize(q_at)
assert isinstance(res, ExpressionTuple)
assert isinstance(res[0], CustomOpNoProps)
assert res[0] == op1_np

with pytest.raises(TypeError):
Expand All @@ -144,6 +151,24 @@ def perform(self, node, inputs, outputs):
assert res[0].owner.op == op1_np
assert res[1].owner.op == op1_np

mu_at = at.scalar("mu")
sigma_at = at.scalar("sigma")

w_rv = at.random.normal(mu_at, sigma_at)
w_at = etuplize(w_rv)
assert isinstance(w_at, RVExpressionTuple)
assert isinstance(w_at[0], ExpressionTuple)

z_at = etuple(at.random.normal, mu_at, sigma_at)
assert isinstance(z_at, RVExpressionTuple)

z_at = etuple(at.random.normal, *w_at[1:])
assert isinstance(z_at, RVExpressionTuple)

res = z_at.evaled_obj
assert res.owner.op == at.random.normal
assert res.owner.inputs[-2:] == [mu_at, sigma_at]


def test_unify_Variable():
x_at = at.vector("x")
Expand Down

0 comments on commit 53931d3

Please sign in to comment.