Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@grad_from_chainrules macro fails when using multi-output functions #221

Open
ThummeTo opened this issue Mar 16, 2023 · 2 comments
Open

Comments

@ThummeTo
Copy link

Dear team,

first: Thanks for developing this nice package :-)

I think there is an error with the macro @grad_from_chainrules when using it on multi-output functions (for example a function that outputs a tuple of two vectors). Note, that gradient/jacobian determination is not part of the current Github-tests, only the rrules are evaluated directly, but no gradient/jacobian is built for testing ReverseDiff with the corresponding rrule. However this works fine for single-output functions together with ReverseDiff.gradient.

See the following MWE:

using ForwardDiff, Zygote, ReverseDiff, ChainRulesCore

# SINGLE OUTPUT FUNCTION 

f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
    r = f(x)
    function back(d)
        return ChainRulesCore.NoTangent(), fill(3, size(x))
    end
    return r, back
end

ReverseDiff.@grad_from_chainrules f(x::AbstractVector{<:ReverseDiff.TrackedReal})

seed = rand(3)

# Everything ok, ForwardDiff computes the correct derivatives (no frule defined),
# ReverseDiff and Zygote use the new rrule as to expect
ForwardDiff.gradient(f, seed)
Zygote.gradient(f, seed)[1]
ReverseDiff.gradient(f, seed)

# MULTI OUTPUT FUNCTION 

f_multi(x, y) = (4x .+ 1, 3x .+ 1 .+ y)

function ChainRulesCore.rrule(::typeof(f_multi), x, y)
    r = f_multi(x, y)
    function back(d)
        y1, y2 = d
        return ChainRulesCore.NoTangent(), fill(2 , size(x)), fill(3 , size(y))
    end
    return r, back
end

ReverseDiff.@grad_from_chainrules f_multi(x::AbstractVector{<:ReverseDiff.TrackedReal}, y::AbstractVector{<:Real})

# ForwardDiff computes the correct derivatives (no frule defined),
# Zygote use the new rrule as to expect, ReverseDiff fails!
ForwardDiff.jacobian(x -> f_multi(x, ones(3))[1], seed)
Zygote.jacobian(x -> f_multi(x, ones(3))[1], seed)[1]
ReverseDiff.jacobian(x -> f_multi(x, ones(3))[1], seed) # this errors!

Tested in Julia 1.8.5, all used libraries up-to-date.

Thanks in advance & best regards!

@ThummeTo
Copy link
Author

Forgot to post the error message:

ERROR: MethodError: no method matching track(::Tuple{Vector{Float64}, Vector{Float64}}, ::Vector{ReverseDiff.AbstractInstruction})
Closest candidates are:
  track(::AbstractArray, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:469
  track(::Real, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:467
  track(::typeof(vcat), ::Union{Number, AbstractVecOrMat}...) at ...\ReverseDiff.jl\src\macros.jl:190       
  ...
Stacktrace:
 [1] track(#unused#::typeof(f_multi), x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
   @ Main ...\ReverseDiff.jl\src\macros.jl:329
 [2] f_multi(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
   @ Main ...\ReverseDiff.jl\src\macros.jl:324
 [3] (::var"#17#18")(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
   @ Main ...\MWE_multi_reversediff.jl:44
 [4] ReverseDiff.JacobianTape(f::var"#17#18", input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing})
   @ ReverseDiff ...\ReverseDiff.jl\src\api\tape.jl:229
 [5] jacobian(f::Function, input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing}) (repeats 2 times)
   @ ReverseDiff ...\src\api\jacobians.jl:23
 [6] top-level scope
   @ ...\MWE_multi_reversediff.jl:44

@cortner
Copy link

cortner commented Jul 9, 2023

+1 --- I've run into the same problem and my MWE is almost identical to the one above.

For me this is a huge problem, because I am hoping to use RevDiff over Zygote to get second derivatives. But when you implement a pullback of a pullback then you will typically have multiple outputs to take care of.

If anybody can suggest how to fix this or work around it, I'd be very grateful.

CC @tjjarvinen

sethaxen added a commit to TuringLang/Bijectors.jl that referenced this issue May 27, 2024
`@grad_from_chainrules` can't handle multi-output functions, see JuliaDiff/ReverseDiff.jl#221. In this case it can AD through the primal just fine.
torfjelde added a commit to TuringLang/Bijectors.jl that referenced this issue Jun 5, 2024
* Rename VecCholeskyBijector to VecCorrCholeskyBijector

* Compute corr logdetjac during transform

* Enforce one-based indexing

* Add with_logabsdet_jacobian for correlation transforms

* Add rrule for non-mutating ADs

* Update ChainRules to use manual rrule

* Update Tracker to use manual rrule

* Remove rrule for ReverseDiff

`@grad_from_chainrules` can't handle multi-output functions, see JuliaDiff/ReverseDiff.jl#221. In this case it can AD through the primal just fine.

* Add module

* Make CorrBijector more numerically stable

Also use consistent notation with inverse transform

* Increment patch number

* Revert "Rename VecCholeskyBijector to VecCorrCholeskyBijector"

This reverts commit bd6ff3d.

* Update src/bijectors/corr.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Apply suggestions from code review

* Work around issues with Tracker

* import `stack` from Compat.jl (#314)

* import `stack` in tests too

* disable certain tests for ProductBijector on Julia versions with older
`eachslice` impls

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants