Use callable for default_update
#1197
Replies: 1 comment 3 replies
-
It seems like you're talking about making Basically, anything that would (re)construct a graph can be loosely called "cloning", which means the callable would only encode the construction of the graph; however, I don't think the challenge lies there, since "proper" cloning handles all this and more just fine—when applied correctly. Likewise, this callable approach would need to determine the callable's inputs and orchestrate the use of the callables, and that's where one would likely end up effectively reproducing all the same logic as proper cloning. I also don't understand the connection with those links and the use of a callable and/or
This only applies to one type of update graph, so it's not a viable solution to the general
Note that #896 was not just about fixing the |
Beta Was this translation helpful? Give feedback.
-
The
default_update
is set by RandomStream so that users don't have to manually specify update expressions for RandomState/Generators in Aesara functions.However, this can be problematic when manipulating graphs, because they create a hidden dependency that is not revealed until the final function is actually compiled.
In that example the shared variable used by the cloned x has the wrong update expression, causing the bit generator to only advance 2 steps every call, instead of 4. It also produces a wasteful graph that takes 2 unused random draws in the background just to update the shared variable. There are other caveats when manipulating graphs with RandomVariables, but this one seems unnecessary.
One option would be to replace the stateful
default_update
by a callable that returns the correct update expression for a given node. This is the approach taken in PyMC for example (https://github.com/pymc-devs/pymc/blob/4c92adf9720f6578e51b9ef21e33c29871c67a83/pymc/aesaraf.py#L928-L933), and it is also similar to the approach taken by Aeppl transforms to avoid new "surprising dependencies" at runtime (https://github.com/aesara-devs/aeppl/blob/b4304fa60979625450b79fe4cfb88426bc3d73e8/aeppl/transforms.py#L426-L429).Instead of setting
x.owner.inputs[0].default_update = x.owner.outputs[0]
,RandomStream
would setx.owner.default_update = lambda node: (node.inputs[0]: node.outputs[0])
. The internal machinery would need to be updated to look for default_updates in nodes instead of SharedVariables, but otherwise I don't think it would be too much trouble.An altogether different solution was attempted in #896
Beta Was this translation helpful? Give feedback.
All reactions