-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Common interface for rule definition #644
Comments
@antoine-levitt feel free to add comments here @willtebbutt I'd love to hear your thoughts |
I think we ought to be able to do a reasonable job of translating between Mooncake tangents / fdata / rdata and Enzyme's Duplicated / Active system, because they're both quite precisely specified. There are some differences, for example, if the primal value is a Re ChainRules, we'd have to make some choices, because the conversion is fundamentally ambiguous (ChainRules permits you to represent the tangent of anything with anything). That being said, I have to attempt to do this for ChainRules integration in Mooncake -- see here -- it's currently rather unsatisfying and incomplete. This is all to say that I doubt a truly universal translator is possible, but we ought to be able to identify a set of types for which the conversion is possible, and provide reasonably informative error messages if someone attempts to do something we don't know how to handle. |
ReverseDiff and ForwardDiff are likely the most used autograd backends for Turing.jl, so we would like to keep supporting them even after Mooncake / Enzyme becomes more stable. Meanwhile, we likely will depreciate Zygote / Tracker in favour of Mooncake. Unfortunately, supporting ReverseDiff and ForwardDiff means we must maintain and add many extra rules (see, e.g. DistributionsAD.jl, Bijectors.jl, and DynamicPPL.jl). I'd like to see an option to use Mooncake / Enzyme to define rules for ReverseDiff / ForwardDiff straightforwardly in the near future. EDIT: I just noticed it is already possible to define ForwardDiff rules via Enzyme / Mooncake using EDIT 2: it is often okay to manually replace |
Yeah, from a user point of view it doesn't really matter what rule system we use, but we'd like to use just one. That unfortunately looks quite tricky. In my application I want to differentiate wrt things that are hidden in structs, eg |
You can just do Enzyme.@import_rrule / @import_frule and it'll import whatever chainrule is defined for that. So your "universal_rrule" macro for the moment would be something like
That said I really don't think this should go here, probably in Chainrules or something |
I have added this page to the docs in order to explain rule definition in more detail: https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/faq/differentiability/ |
It would be nice to have a universal translator between:
At the moment,
DI.DifferentiateWith
is a partial answer, butf
withDifferentiateWith(f, other_backend)
The text was updated successfully, but these errors were encountered: