Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AbstractComputation nest-able #1114

Open
jvmncs opened this issue Jun 6, 2022 · 0 comments
Open

Make AbstractComputation nest-able #1114

jvmncs opened this issue Jun 6, 2022 · 0 comments

Comments

@jvmncs
Copy link
Member

jvmncs commented Jun 6, 2022

Functions written with the eDSL should be callable from within other computations, regardless of whether they've been wrapped with the pm.computation decorator. For example,

@pm.computation
def plus1(x: pm.Argument(alice, dtype=pm.float64):
  with alice:
    one = pm.constant(1, dtype=pm.float64)
    return pm.add(x, one)

@pm.computation
def alice_add():
  with alice:
    x = pm.constant(3, dtype=pm.float64)
    x_plus_one = plus1(x)
  return x_plus_one

if __name__ == "__main__":
  [...]
  runtime.set_default()
  val = alice_add()  # <-- will fail during tracing

When alice_add is called, current behavior would be the following:

  • inside a runtime context, alice_add.__call__ invokes trace(alice_add)
  • trace(alice_add) will then invoke plus1.__call__. in order for this call to succeed, plus1 will need to return an Expression to be used to trace the rest of alice_add.
  • however, since the default runtime context is not None, plus1 will be executed against the default runtime's evaluate_computation with arguments of type Expression
  • the rust runtime bindings will try to interpret these Expression pyobj's as Moose Values, which will fail with a TypeError because these are not concrete values.

One solution for the user is to just drop the pm.computation decorator from plus1, so that it returns Expression no matter what runtime context is around. But this makes it hard for users to use "standard library" computations if they are already decorated with AbstractComputation (which would likely often be the case).

I think the simplest solution here would be to do the following:

  • Inside pm.trace, temporarily unset the default runtime context, so that get_current_runtime returns None.
  • If AbstractComputation.__call__ is invoked without a runtime context (i.e. get_current_runtime returns None), invoke AbstractComputation.func.__call__. This invocation maps Expressions to Expressions, so tracing can proceed normally.
  • If AbstractComputation.__call__ is invoked inside a runtime context, invoke get_current_runtime().evaluate_computation(...) with the computation as usual

Some other options:

  • Allow for nesting runtime contexts and create a new "dummy" Runtime class whoseevaluate_computation simply forwards to AbstractComputation.func.__call__
  • Something "moose-ier", e.g. accommodate Expression conversion in Moose bindings and in this case execute symbolically, i.e. run computation against a SymbolicSession instead of against the AsyncSession in AsyncTestRuntime
@jvmncs jvmncs added this to the Release milestone Jun 6, 2022
@jvmncs jvmncs changed the title Make AbstractComputation callable outside of a runtime context Make AbstractComputation nest-able Jun 6, 2022
@mortendahl mortendahl removed this from the Release milestone Jun 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants