You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The AD part of the package has not undergone any major overhaul since I first implemented it around the start of the project. Back then I relied on Flux/Zygote, because the package was tailored to Flux models anyway (at the time), I was entirely new to AD and Julia; and I could use Zygote to differentiate through structs, like so:
""" ∂ℓ( generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation, )The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators.It assumes that `Zygote.jl` has gradient access."""function∂ℓ(
generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation
)
return Flux.gradient(ce ->ℓ(generator, ce), ce)[1][:counterfactual_state]
end
Pain Points
The current implementation is less than ideal for various reasons:
Zygote cannot handle nested AD, which is necessary for some counterfactual generators (see Sort out PROBE #376).
Gradients are still taken implicitly, which is not in line with where the broader ecosystem is headed, I believe.
The previous point also makes it difficult to implement forward-over-reverse to solve the nested AD issue.
The AD implementation has never been optimized for performance, so I guess there's a lot of room for improvement here.
That would be amazing, of course, if you could help out, but only if it's not too much trouble for you. I want to look at this soon, but I might look at #495 first, because I think I may need this for a research project I'm currently working on.
I've updated the description a little bit. If you have any pointers, I'd much appreciate if you could share them here.
Current Status
The AD part of the package has not undergone any major overhaul since I first implemented it around the start of the project. Back then I relied on Flux/Zygote, because the package was tailored to Flux models anyway (at the time), I was entirely new to AD and Julia; and I could use Zygote to differentiate through structs, like so:
Pain Points
The current implementation is less than ideal for various reasons:
To Do
The text was updated successfully, but these errors were encountered: