You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#555 introduced DynamicPPL.ReshapeTransform, which is very nice, but there's what seems to be a bug in ReverseDiff.jl which causes it to fail when ReshapeTransform is composed with a broadcasted function.
I reported the upstream bug at JuliaDiff/ReverseDiff.jl#265. In the context of DynamicPPL, this occurs when we have something like the following:
using DynamicPPL: invlink_transform, ReshapeTransform
using ReverseDiff
f(x) =invlink_transform(InverseGamma(2, 3))
g(x) =ReshapeTransform(())(x)
h = f ∘ g
ReverseDiff.gradient(h, [1.0])
I suspect we should be able to change the implementation of ReshapeTransform though to try to circumvent this. I don't actually know all the possible shapes of stuff ReshapeTransform handles and whether different input/output shapes would give different ReverseDiff errors. However, I dug into a couple of the failing tests in Turing.jl, and it seems that both of them stem from ReshapeTransform being given singleton arrays (e.g. [1.0] above). Furthermore, the error message observed in all the other failing tests is the same (although I didn't verify that they ultimately stem from singleton arrays). So I think we could special-case this behaviour to keep ReverseDiff on our side.
The text was updated successfully, but these errors were encountered:
#555 introduced
DynamicPPL.ReshapeTransform
, which is very nice, but there's what seems to be a bug in ReverseDiff.jl which causes it to fail when ReshapeTransform is composed with a broadcasted function.I reported the upstream bug at JuliaDiff/ReverseDiff.jl#265. In the context of DynamicPPL, this occurs when we have something like the following:
I suspect we should be able to change the implementation of
ReshapeTransform
though to try to circumvent this. I don't actually know all the possible shapes of stuffReshapeTransform
handles and whether different input/output shapes would give different ReverseDiff errors. However, I dug into a couple of the failing tests in Turing.jl, and it seems that both of them stem from ReshapeTransform being given singleton arrays (e.g.[1.0]
above). Furthermore, the error message observed in all the other failing tests is the same (although I didn't verify that they ultimately stem from singleton arrays). So I think we could special-case this behaviour to keep ReverseDiff on our side.The text was updated successfully, but these errors were encountered: