-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle unreachable blocks in the adjoint CFG #1465
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems good to me, based on the new test. But I am not that familar with this bit of the code
I had to add the julia> using Zygote
julia> function f(x, cond)
if cond
return x
else
return 2x
end
return 3x
end
f (generic function with 1 method)
julia> back = Zygote.Pullback{Tuple{typeof(f), Float64, Bool}, Any}(([Returns([1,2,3])], [Returns([1,2,3])], 0x03))
∂(f)
julia> back(1.)
ERROR: "unreachable"
Stacktrace:
[1] f
@ ./REPL[2]:7 [inlined]
[2] (::Zygote.Pullback{Tuple{typeof(f), Float64, Bool}, Any})(Δ::Float64)
@ Zygote ~/Projects/Zygote.jl/src/compiler/interface2.jl:0
[3] top-level scope
@ REPL[5]:1 |
My understanding is that this may incur an allocation for the string "unreachable" in some methods. Is this reliably optimized out in practice? If not, it might make sense to define a custom immutable error type. That type could also track more information such as file and line number to help with troubleshooting. |
The string julia> using Zygote
julia> f(x) = @inbounds return x;
julia> invalid_pull = Zygote.Pullback{Tuple{typeof(f), Float64}, Any}(0x02,);
julia> m = only(methods(Zygote.adjointcfg));
root = m.roots[findfirst(==("unreachable"), m.roots)];
julia> unreachable_string = try invalid_pull(1.); catch exc; exc end
"unreachable"
julia> unreachable_string === root
true One problem is that it may taint the effects of the pullback compared to the case without unreachable blocks (which we could have if we removed unreachable blocks from the primal): julia> Core.Compiler.infer_effects(invalid_pull, (Float64,))
(!c,!e,!n,!t,!s,!m,+i)′
julia> g(x) = return x;
julia> _, pull = Zygote.pullback(g, 1.);
Core.Compiler.infer_effects(pull.back, (Float64,))
(+c,+e,+n,+t,+s,+m,+i) But any sufficiently complex pullbacks (involving |
A quick sanity check before I merge: does this pass tests locally on nightly for you? |
The added testset does, but the entire test suite fails with CUDA and a missing |
Could you try temporarily commenting out the CUDA test imports and the test block in Lines 8 to 15 in 49a1184
|
I had to cherry-pick #1462 but there is one |
The error described in #1118 and #1380, stopped happening in Julia 1.10-beta (maybe because of JuliaLang/julia#50943) but the gradient is now wrong.
This PR adds unreachable branches at the end of blocks in the adjoint when those blocks are unreachable in the primal which fixes the issue in both 1.9 and 1.10 because it avoids implicit branches.
I also though of removing unreachable blocks altogether since some operations are invalid for these blocks in IRTools (see
dominators
for example) but it seemed to do a lot more work than this fix.Fixes #1380
Fixes #1118
Note
It depends on FluxML/IRTools.jl#115.