Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatch etuple #21

Merged
merged 2 commits into from
Sep 6, 2022
Merged

Dispatch etuple #21

merged 2 commits into from
Sep 6, 2022

Conversation

rlouf
Copy link
Contributor

@rlouf rlouf commented Sep 5, 2022

As a result of #19 / #20 we can customize the way expression tuples are evaluated by subclassing ExpressionTuple. We would also like to get an instance of these subclasses when etuplizing the corresponding evaluated operators. To do so and keep a consistent API in this PR we allow the function etuple to dispatch, and choose which registered version to use within etuplize.

Since this work was motivated by aesara-devs/aesara#1036, here is an example of how this works in practice:

import aesara.tensor as at
from aesara.graph.op import Op

from etuples.core import etuple, ExpressionTuple
from etuples.dispatch import etuplize, etuplize_fn


class OpExpressionTuple(ExpressionTuple):

    def _eval_apply_fn(self, op):
        return op.make_node

    def __repr__(self):
        return "Op" + super().__repr__()

    def __str__(self):
        return "o" + super().__str__()

@etuple.register(Op, [object])
def etuple_Op(*args, **kwargs):
    return OpExpressionTuple(args, **kwargs)

@etuplize_fn.register(Op)
def etuplize_fn_Op(op):
    return etuple_Op

print(etuple(at.random.normal, 0, 1))
# oe(normal_rv{0, (0, 0), floatX, False}, 0, 1)
print(etuple(at.add, 1, 1))
# oe(Elemwise{add,no_inplace}, 1, 1)
print(etuple(lambda x: x, 1))
# e(<function <lambda> at 0x7f77b894fd90>, 1)

norm_etz = etuplize(at.random.normal(0, 1))
norm_et = etuple(at.random.normal, *norm_etz[1:])
print(norm_et.evaled_obj)
# normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F772600FE60>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{0}, TensorConstant{1})

norm_etz._evaled_obj = norm_etz.null
print(norm_etz.evaled_obj)
# normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F772600FE60>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{0}, TensorConstant{1})

We cannot avoid dispatching etuple if we want to be able to unify etuple(at.normal.random, ...) with the result of etuplize(at.normal.random(x, y)). We could avoid (the imho awkward) etuplize_fn by allowing to dispatch etuplize_step instead (why dispatch etuplize by the way?).

The `etuple` function returns an instance of the `ExpressionTuple`
class. If we wish to return an instance of a subclass of
`ExpressionTuple` we need to implement a new function. This allows one
to dispatch the `etuple` function to keep a uniform API.
The first argument passed to `etuple` determines which of the registered
functions is called. The first argument is generally an
`ExpressionTuple`, the etuplized version of the operator. However, when
using subclasses of `ExpressionTuple` to customize evaluation, we want
`etuple` to depend on the current operator and not its etuplized
version.
@rlouf rlouf marked this pull request as ready for review September 5, 2022 16:01
@brandonwillard brandonwillard merged commit 6010add into pythological:main Sep 6, 2022
@rlouf rlouf deleted the dispatch-etuple branch September 7, 2022 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants