Skip to content

Commit

Permalink
Simplify tests, get correct complexity for ForwardDiff (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Feb 29, 2024
1 parent 31d1be9 commit c65fb27
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 142 deletions.
31 changes: 7 additions & 24 deletions ext/DifferentiationInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,29 @@
module DifferentiationInterfaceChainRulesCoreExt

using ChainRulesCore
using ChainRulesCore: NoTangent, frule_via_ad, rrule_via_ad
using DifferentiationInterface
using LinearAlgebra

ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig
ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig

function DifferentiationInterface.value_and_pushforward!(
_dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:Number}
rc = ruleconfig(backend)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, new_dy
end
update!(_old::Number, new::Number) = new
update!(old, new) = old .= new

function DifferentiationInterface.value_and_pushforward!(
dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
) where {X,Y}
rc = ruleconfig(backend)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
dy .= new_dy
return y, dy
end

function DifferentiationInterface.value_and_pullback!(
_dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y}
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
return y, new_dx
return y, update!(dy, new_dy)
end

function DifferentiationInterface.value_and_pullback!(
dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y}
) where {X,Y}
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
dx .= new_dx
return y, dx
return y, update!(dx, new_dx)
end

end
19 changes: 3 additions & 16 deletions ext/DifferentiationInterfaceEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
module DifferentiationInterfaceEnzymeExt

using DifferentiationInterface
using DocStringExtensions
using Enzyme
using Enzyme: Forward, ReverseWithPrimal, Active, Duplicated, autodiff

## Forward-mode
## Forward mode

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:Real}
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
Expand All @@ -27,11 +20,8 @@ function DifferentiationInterface.value_and_pushforward!(
return y, dy
end

## Reverse-mode
## Reverse mode

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pullback!(
_dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y<:Union{Real,Nothing}}
Expand All @@ -40,9 +30,6 @@ function DifferentiationInterface.value_and_pullback!(
return y, new_dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pullback!(
dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
Expand Down
29 changes: 10 additions & 19 deletions ext/DifferentiationInterfaceFiniteDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
module DifferentiationInterfaceFiniteDiffExt

using DifferentiationInterface
using DocStringExtensions
using FiniteDiff
using LinearAlgebra
using FiniteDiff:
finite_difference_derivative,
finite_difference_gradient,
finite_difference_gradient!,
finite_difference_jacobian
using LinearAlgebra: dot, mul!

const DEFAULT_FDTYPE = Val{:central}

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:Number}
y = f(x)
der = FiniteDiff.finite_difference_derivative(
der = finite_difference_derivative(
f,
x,
DEFAULT_FDTYPE, # fdtype
Expand All @@ -25,14 +25,11 @@ function DifferentiationInterface.value_and_pushforward!(
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:AbstractArray}
y = f(x)
FiniteDiff.finite_difference_gradient!(
finite_difference_gradient!(
dy,
f,
x,
Expand All @@ -45,14 +42,11 @@ function DifferentiationInterface.value_and_pushforward!(
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Number}
y = f(x)
g = FiniteDiff.finite_difference_gradient(
g = finite_difference_gradient(
f,
x,
DEFAULT_FDTYPE, # fdtype
Expand All @@ -64,14 +58,11 @@ function DifferentiationInterface.value_and_pushforward!(
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
y = f(x)
J = FiniteDiff.finite_difference_jacobian(
J = finite_difference_jacobian(
f,
x,
DEFAULT_FDTYPE, # fdtype
Expand Down
45 changes: 14 additions & 31 deletions ext/DifferentiationInterfaceForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,51 @@
module DifferentiationInterfaceForwardDiffExt

using DifferentiationInterface
using DiffResults
using DocStringExtensions
using ForwardDiff
using DiffResults: DiffResults
using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative!
using LinearAlgebra
using LinearAlgebra: mul!

function extract_value(::Type{T}, ydual) where {T}
return value.(T, ydual)
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:Real}
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
y = value(T, ydual)
new_dy = extract_derivative(T, ydual)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:AbstractArray}
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
y = value.(T, ydual)
dy = extract_derivative!(T, dy, ydual)
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Real}
res = DiffResults.GradientResult(x)
ForwardDiff.gradient!(res, f, x)
y = DiffResults.value(res)
new_dy = dot(DiffResults.gradient(res), dx)
T = typeof(Tag(f, X)) # TODO: unsure
xdual = Dual{T}.(x, dx) # TODO: allocation
ydual = f(xdual)
y = value(T, ydual)
new_dy = extract_derivative(T, ydual)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
res = DiffResults.JacobianResult(x)
ForwardDiff.jacobian!(res, f, x) # TODO: replace with duals, n times too slow
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dy, J, dx)
T = typeof(Tag(f, X)) # TODO: unsure
xdual = Dual{T}.(x, dx) # TODO: allocation
ydual = f(xdual)
y = value.(T, ydual)
dy = extract_derivative!(T, dy, ydual)
return y, dy
end

Expand Down
21 changes: 7 additions & 14 deletions ext/DifferentiationInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
module DifferentiationInterfaceReverseDiffExt

using DifferentiationInterface
using DiffResults
using DocStringExtensions
using ReverseDiff
using LinearAlgebra
using DiffResults: DiffResults
using ReverseDiff: gradient!, jacobian!
using LinearAlgebra: mul!

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
res = DiffResults.GradientResult(x)
ReverseDiff.gradient!(res, f, x)
res = DiffResults.DiffResult(zero(Y), dx)
res = gradient!(res, f, x)
y = DiffResults.value(res)
dx .= dy .* DiffResults.gradient(res)
return y, dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
res = DiffResults.JacobianResult(x)
ReverseDiff.jacobian!(res, f, x)
res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x)))
res = jacobian!(res, f, x)
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dx, transpose(J), dy)
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Loading

0 comments on commit c65fb27

Please sign in to comment.