Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vjp and jtvp helper functions (Jacobian vector products) #155

Closed
ChrisRackauckas opened this issue Apr 14, 2019 · 10 comments
Closed

vjp and jtvp helper functions (Jacobian vector products) #155

ChrisRackauckas opened this issue Apr 14, 2019 · 10 comments

Comments

@ChrisRackauckas
Copy link
Member

It is a really nice result that you can use forward mode to generate expressions for v*J and reverse mode to generate expressions for J'*v without explicitly building J. There are a lot of use cases for this, so it would be nice if Zygote had a helper function for calling these.

@MikeInnes
Copy link
Member

The VJP is the same thing that Zygote calls a "pullback", i.e. the back function in y, back = Zygote.forward(f, x).

You can also get a JVP by differentiating that pullback (since it's linear and just gets transposed) but nested AD is not as robust or well supported yet.

@ChrisRackauckas
Copy link
Member Author

I thought vjp was just seeded forward mode and jtvp was just seeded reverse mode?

@MikeInnes
Copy link
Member

MikeInnes commented Apr 14, 2019

VJP is reverse mode ((v'*J)', or J'*v, if we don't store v transposed), JVP is forward (J*v).

FWIW, VJPs are not so much a trick you can do with reverse mode as what AD fundamentally does. For each function we provide or derive a VJP (or "adjoint", or "pullback") and compose these together. In the (very) special case that the function output is scalar we can just "seed" the sensitivity 1 and get gradients; but Zygote doesn't directly care about this because it's not actually a gradient-calculator, it's a VJP builder.

@ChrisRackauckas
Copy link
Member Author

Oh, I didn't realize that Zygote.forward(f, x) is supposed to be reverse mode? I assumed that meant it was forward mode.

@MikeInnes
Copy link
Member

That just refers to the "forward pass" (I wouldn't mind finding a better name but have no better ideas so far). back carries out the backwards pass. You might not have seen but Zygote now has a fair amount of docs going over this stuff.

@MikeInnes
Copy link
Member

Closing for now as I don't think there's an action item here.

@tomhaber
Copy link

tomhaber commented Aug 21, 2024

What was the conclusion on the jacobian-vector-product part? I was hoping to achieve this in Zygote.jl without resorting to ForwardDiff.jl.

The following attempt didn't do what I was expecting:

function jvp(f, u, v)
   _, back = pullback(u) do u
       Zygote.forwarddiff(u) do u
           # call f(u)
       end
   end
  back(v)[1]
end

(Also a little confused as to what it actually means, reverse-mode AD using foward-mode?)

I agree that it would be very nice to have helpers for jvp and vjp

@ToucheSir
Copy link
Member

I don't understand why nested differentiation (i.e. second derivatives) came up in the discussion above. These days, I'd recommend using a library like https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/operators/#Low-level-operators with an AD of your choice.

@tomhaber
Copy link

I suppose I will just use both Zygote and ForwardDiff, with DifferentiationInterface I'd also need to use two backends: one for forward-mode and one for reverse-mode. I was just hoping that Zygote would have provided the pushforward method as well, since it already uses ForwardDiff internally.

Seems like with so many AD packages, there's still a need for (n+1) :D

Thanks for answering

@ToucheSir
Copy link
Member

Yes, by all means use both ADs directly if that works best for you. There are now ADs which support both forward and reverse mode (e.g. https://github.com/EnzymeAD/Enzyme.jl), but evaluating those tools is beyond the scope of this issue :).

I was just hoping that Zygote would have provided the pushforward method as well, since it already uses ForwardDiff internally.

Maybe this will help clarify things, but Zygote doesn't provide a pushforward method because it doesn't use ForwardDiff internally**. Zygote.forwarddiff is purely user-facing and does not do what people think it does. Arguably it should not exist because of this confusion.

** except in a single reverse-mode rule. But that usage is completely opaque to the end user. Zygote technically has its own source-to-source forward mode separate from ForwardDiff, but it was never developed enough to be usable and is essentially vestigial at this point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants