Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReverseDiff doesn't like DynamicPPL.ReshapeTransform #698

Closed
penelopeysm opened this issue Oct 25, 2024 · 0 comments · Fixed by #699
Closed

ReverseDiff doesn't like DynamicPPL.ReshapeTransform #698

penelopeysm opened this issue Oct 25, 2024 · 0 comments · Fixed by #699

Comments

@penelopeysm
Copy link
Member

penelopeysm commented Oct 25, 2024

#555 introduced DynamicPPL.ReshapeTransform, which is very nice, but there's what seems to be a bug in ReverseDiff.jl which causes it to fail when ReshapeTransform is composed with a broadcasted function.

I reported the upstream bug at JuliaDiff/ReverseDiff.jl#265. In the context of DynamicPPL, this occurs when we have something like the following:

using DynamicPPL: invlink_transform, ReshapeTransform
using ReverseDiff

f(x) = invlink_transform(InverseGamma(2, 3))
g(x) = ReshapeTransform(())(x)
h = f  g
ReverseDiff.gradient(h, [1.0])

I suspect we should be able to change the implementation of ReshapeTransform though to try to circumvent this. I don't actually know all the possible shapes of stuff ReshapeTransform handles and whether different input/output shapes would give different ReverseDiff errors. However, I dug into a couple of the failing tests in Turing.jl, and it seems that both of them stem from ReshapeTransform being given singleton arrays (e.g. [1.0] above). Furthermore, the error message observed in all the other failing tests is the same (although I didn't verify that they ultimately stem from singleton arrays). So I think we could special-case this behaviour to keep ReverseDiff on our side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant