Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return primal value in pushforward! and pullback! #17

Merged
merged 10 commits into from
Feb 28, 2024
12 changes: 6 additions & 6 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,30 @@ SUITE = BenchmarkGroup()
for n in n_values
for backend in forward_backends
SUITE["forward"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, scalar_to_scalar, x, dx)
value_and_pushforward!(dy, $backend, scalar_to_scalar, x, dx)
end setup = (x = 1.0; dx = 1.0; dy = 0.0) evals = 1
if backend != EnzymeForwardBackend() # type instability?
SUITE["forward"]["scalar_to_vector"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, Fix2(scalar_to_vector, $n), x, dx)
value_and_pushforward!(dy, $backend, Fix2(scalar_to_vector, $n), x, dx)
end setup = (x = 1.0; dx = 1.0; dy = zeros($n)) evals = 1
end
SUITE["forward"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, vector_to_vector, x, dx)
value_and_pushforward!(dy, $backend, vector_to_vector, x, dx)
end setup = (x = randn($n); dx = randn($n); dy = zeros($n)) evals = 1
end

for backend in reverse_backends
if backend != ReverseDiffBackend()
SUITE["reverse"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, scalar_to_scalar, x, dy)
value_and_pullback!(dx, $backend, scalar_to_scalar, x, dy)
end setup = (x = 1.0; dy = 1.0; dx = 0.0) evals = 1
end
SUITE["reverse"]["vector_to_scalar"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, vector_to_scalar, x, dy)
value_and_pullback!(dx, $backend, vector_to_scalar, x, dy)
end setup = (x = randn($n); dy = 1.0; dx = zeros($n)) evals = 1
if backend != EnzymeReverseBackend()
SUITE["reverse"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, vector_to_vector, x, dy)
value_and_pullback!(dx, $backend, vector_to_vector, x, dy)
end setup = (x = randn($n); dy = randn($n); dx = zeros($n)) evals = 1
end
end
Expand Down
24 changes: 12 additions & 12 deletions ext/DifferentiationInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,40 @@ using LinearAlgebra
ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig
ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig

function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:Number}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return new_dy
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, new_dy
end

function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
dy .= new_dy
return dy
return y, dy
end

function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
_dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
return new_dx
return y, new_dx
end

function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
dx .= new_dx
return dx
return y, dx
end

end
24 changes: 14 additions & 10 deletions ext/DifferentiationInterfaceEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,47 @@ using Enzyme
"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:Real}
return only(autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, dx)))
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
dy .= only(autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, dx)))
return dy
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
end

## Reverse-mode

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
_dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y<:Union{Real,Nothing}}
return only(first(autodiff(Reverse, f, Active, Active(x)))) * dy
dydx, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(dydx)
return y, new_dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx))
dx .*= dy
return dx
return y, dx
end

end # module
59 changes: 46 additions & 13 deletions ext/DifferentiationInterfaceFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,80 @@ using DocStringExtensions
using FiniteDiff
using LinearAlgebra

const DEFAULT_FDTYPE = Val{:central}

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:Number}
new_dy = FiniteDiff.finite_difference_derivative(f, x) * dx
return new_dy
y = f(x)
der = FiniteDiff.finite_difference_derivative(
f,
x,
Val{:central}, # fdtype
eltype(dy), # returntype
y, # fx
)
new_dy = der * dx
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:AbstractArray}
new_dy = FiniteDiff.finite_difference_derivative(f, x)
dy .= new_dy .* dx
return dy
y = f(x)
FiniteDiff.finite_difference_gradient!(
dy,
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
Val{false}, # inplace
y, # fx
)
dy .*= dx
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Number}
g = FiniteDiff.finite_difference_gradient(f, x)
y = f(x)
g = FiniteDiff.finite_difference_gradient(
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
Val{false}, # inplace
y, # fx
)
new_dy = dot(g, dx)
return new_dy
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = FiniteDiff.finite_difference_jacobian(f, x)
y = f(x)
J = FiniteDiff.finite_difference_jacobian(
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
)
mul!(dy, J, dx)
return dy
return y, dy
end

end # module
51 changes: 34 additions & 17 deletions ext/DifferentiationInterfaceForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,66 @@ module DifferentiationInterfaceForwardDiffExt
using DifferentiationInterface
using DocStringExtensions
using ForwardDiff
using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative!
using LinearAlgebra

function extract_value(::Type{T}, ydual) where {T}
return value.(T, ydual)
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:Real}
new_dy = ForwardDiff.derivative(f, x) * dx
return new_dy
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
new_dy = extract_derivative(T, ydual)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:AbstractArray}
ForwardDiff.derivative!(dy, f, x)
dy .*= dx
return dy
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
dy = extract_derivative!(T, dy, ydual)
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Real}
g = ForwardDiff.gradient(f, x) # TODO: replace with duals, n times too slow
new_dy = dot(g, dx)
return new_dy
res = DiffResults.GradientResult(x)
ForwardDiff.gradient!(res, f, x)
y = DiffResults.value(res)
new_dy = dot(DiffResults.gradient(res), dx)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = ForwardDiff.jacobian(f, x) # TODO: replace with duals, n times too slow
res = DiffResults.JacobianResult(x)
ForwardDiff.jacobian!(res, f, x) # TODO: replace with duals, n times too slow
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dy, J, dx)
return dy
return y, dy
end

end
end # module
21 changes: 13 additions & 8 deletions ext/DifferentiationInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@ using LinearAlgebra
"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
ReverseDiff.gradient!(dx, f, x)
dx .*= dy
return dx
res = DiffResults.GradientResult(x)
ReverseDiff.gradient!(res, f, x)
y = DiffResults.value(res)
dx .= dy * DiffResults.gradient(res)
return y, dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
J = ReverseDiff.jacobian(f, x)
res = DiffResults.JacobianResult(x)
ReverseDiff.jacobian!(res, f, x)
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dx, transpose(J), dy)
return dx
return y, dx
end

end
end # module
Loading
Loading