diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index 591f80046..2e2faee7c 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -1,46 +1,29 @@ module DifferentiationInterfaceChainRulesCoreExt -using ChainRulesCore +using ChainRulesCore: NoTangent, frule_via_ad, rrule_via_ad using DifferentiationInterface -using LinearAlgebra ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig -function DifferentiationInterface.value_and_pushforward!( - _dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X -) where {X,Y<:Number} - rc = ruleconfig(backend) - y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) - return y, new_dy -end +update!(_old::Number, new::Number) = new +update!(old, new) = old .= new function DifferentiationInterface.value_and_pushforward!( dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X -) where {X,Y<:AbstractArray} +) where {X,Y} rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) - dy .= new_dy - return y, dy -end - -function DifferentiationInterface.value_and_pullback!( - _dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y -) where {X<:Number,Y} - rc = ruleconfig(backend) - y, pullback = rrule_via_ad(rc, f, x) - _, new_dx = pullback(dy) - return y, new_dx + return y, update!(dy, new_dy) end function DifferentiationInterface.value_and_pullback!( dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y -) where {X<:AbstractArray,Y} +) where {X,Y} rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) - dx .= new_dx - return y, dx + return y, update!(dx, new_dx) end end diff --git a/ext/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt.jl index 6e2c6dd96..6fdd1b3b0 100644 --- a/ext/DifferentiationInterfaceEnzymeExt.jl +++ b/ext/DifferentiationInterfaceEnzymeExt.jl @@ -1,14 +1,10 @@ module DifferentiationInterfaceEnzymeExt using DifferentiationInterface -using DocStringExtensions -using Enzyme +using Enzyme: Forward, ReverseWithPrimal, Active, Duplicated, autodiff -## Forward-mode +## Forward mode -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( _dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X ) where {X,Y<:Real} @@ -16,9 +12,6 @@ function DifferentiationInterface.value_and_pushforward!( return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X ) where {X,Y<:AbstractArray} @@ -27,11 +20,8 @@ function DifferentiationInterface.value_and_pushforward!( return y, dy end -## Reverse-mode +## Reverse mode -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pullback!( _dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y ) where {X<:Number,Y<:Union{Real,Nothing}} @@ -40,9 +30,6 @@ function DifferentiationInterface.value_and_pullback!( return y, new_dx end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pullback!( dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:Union{Real,Nothing}} diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index 83ea3e911..fb58ffd5f 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -1,20 +1,20 @@ module DifferentiationInterfaceFiniteDiffExt using DifferentiationInterface -using DocStringExtensions -using FiniteDiff -using LinearAlgebra +using FiniteDiff: + finite_difference_derivative, + finite_difference_gradient, + finite_difference_gradient!, + finite_difference_jacobian +using LinearAlgebra: dot, mul! const DEFAULT_FDTYPE = Val{:central} -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::FiniteDiffBackend, f, x::X, dx::X ) where {X<:Number,Y<:Number} y = f(x) - der = FiniteDiff.finite_difference_derivative( + der = finite_difference_derivative( f, x, DEFAULT_FDTYPE, # fdtype @@ -25,14 +25,11 @@ function DifferentiationInterface.value_and_pushforward!( return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::FiniteDiffBackend, f, x::X, dx::X ) where {X<:Number,Y<:AbstractArray} y = f(x) - FiniteDiff.finite_difference_gradient!( + finite_difference_gradient!( dy, f, x, @@ -45,14 +42,11 @@ function DifferentiationInterface.value_and_pushforward!( return y, dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::FiniteDiffBackend, f, x::X, dx::X ) where {X<:AbstractArray,Y<:Number} y = f(x) - g = FiniteDiff.finite_difference_gradient( + g = finite_difference_gradient( f, x, DEFAULT_FDTYPE, # fdtype @@ -64,14 +58,11 @@ function DifferentiationInterface.value_and_pushforward!( return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::FiniteDiffBackend, f, x::X, dx::X ) where {X<:AbstractArray,Y<:AbstractArray} y = f(x) - J = FiniteDiff.finite_difference_jacobian( + J = finite_difference_jacobian( f, x, DEFAULT_FDTYPE, # fdtype diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index 8ee9876d1..a9917b20a 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -1,68 +1,51 @@ module DifferentiationInterfaceForwardDiffExt using DifferentiationInterface -using DiffResults -using DocStringExtensions -using ForwardDiff +using DiffResults: DiffResults using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative! -using LinearAlgebra +using LinearAlgebra: mul! -function extract_value(::Type{T}, ydual) where {T} - return value.(T, ydual) -end - -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( _dy::Y, ::ForwardDiffBackend, f, x::X, dx::X ) where {X<:Real,Y<:Real} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) ydual = f(xdual) - y = extract_value(T, ydual) + y = value(T, ydual) new_dy = extract_derivative(T, ydual) return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::ForwardDiffBackend, f, x::X, dx::X ) where {X<:Real,Y<:AbstractArray} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) ydual = f(xdual) - y = extract_value(T, ydual) + y = value.(T, ydual) dy = extract_derivative!(T, dy, ydual) return y, dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( _dy::Y, ::ForwardDiffBackend, f, x::X, dx::X ) where {X<:AbstractArray,Y<:Real} - res = DiffResults.GradientResult(x) - ForwardDiff.gradient!(res, f, x) - y = DiffResults.value(res) - new_dy = dot(DiffResults.gradient(res), dx) + T = typeof(Tag(f, X)) # TODO: unsure + xdual = Dual{T}.(x, dx) # TODO: allocation + ydual = f(xdual) + y = value(T, ydual) + new_dy = extract_derivative(T, ydual) return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pushforward!( dy::Y, ::ForwardDiffBackend, f, x::X, dx::X ) where {X<:AbstractArray,Y<:AbstractArray} - res = DiffResults.JacobianResult(x) - ForwardDiff.jacobian!(res, f, x) # TODO: replace with duals, n times too slow - y = DiffResults.value(res) - J = DiffResults.jacobian(res) - mul!(dy, J, dx) + T = typeof(Tag(f, X)) # TODO: unsure + xdual = Dual{T}.(x, dx) # TODO: allocation + ydual = f(xdual) + y = value.(T, ydual) + dy = extract_derivative!(T, dy, ydual) return y, dy end diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index 80846434c..d9d0b4c1a 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -1,32 +1,25 @@ module DifferentiationInterfaceReverseDiffExt using DifferentiationInterface -using DiffResults -using DocStringExtensions -using ReverseDiff -using LinearAlgebra +using DiffResults: DiffResults +using ReverseDiff: gradient!, jacobian! +using LinearAlgebra: mul! -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pullback!( dx::X, ::ReverseDiffBackend, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:Real} - res = DiffResults.GradientResult(x) - ReverseDiff.gradient!(res, f, x) + res = DiffResults.DiffResult(zero(Y), dx) + res = gradient!(res, f, x) y = DiffResults.value(res) dx .= dy .* DiffResults.gradient(res) return y, dx end -""" -$(TYPEDSIGNATURES) -""" function DifferentiationInterface.value_and_pullback!( dx::X, ::ReverseDiffBackend, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:AbstractArray} - res = DiffResults.JacobianResult(x) - ReverseDiff.jacobian!(res, f, x) + res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) + res = jacobian!(res, f, x) y = DiffResults.value(res) J = DiffResults.jacobian(res) mul!(dx, transpose(J), dy) diff --git a/test/Project.toml b/test/Project.toml index a09df32e7..29703ebaf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,9 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/utils.jl b/test/utils.jl index 2e2402bfd..892ab2f9c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,11 +1,15 @@ using DifferentiationInterface using DifferentiationInterface: AbstractReverseBackend, AbstractForwardBackend +using ForwardDiff: ForwardDiff +using LinearAlgebra using JET +using Random: AbstractRNG, randn! +using StableRNGs using Test ## Test scenarios -@kwdef struct Scenario{F,X,Y} +struct Scenario{F,X,Y} "function" f::F "argument" @@ -14,63 +18,90 @@ using Test y::Y "pushforward seed" dx::X - "pushforward result" - dy_true::Y "pullback seed" dy::Y "pullback result" dx_true::X + "pushforward result" + dy_true::Y +end + +## Constructors + +function Scenario(rng::AbstractRNG, f, x) + y = f(x) + return Scenario(rng, f, x, y) +end + +function Scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:Number} + dx = randn(rng, X) + dy = randn(rng, Y) + der = ForwardDiff.derivative(f, x) + dx_true = der * dy + dy_true = der * dx + return Scenario(f, x, y, dx, dy, dx_true, dy_true) +end + +function Scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:AbstractArray} + dx = randn(rng, X) + dy = similar(y) + randn!(rng, dy) + der_array = ForwardDiff.derivative(f, x) + dx_true = dot(der_array, dy) + dy_true = der_array .* dx + return Scenario(f, x, y, dx, dy, dx_true, dy_true) end +function Scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:AbstractArray,Y<:Number} + dx = similar(x) + randn!(rng, dx) + dy = randn(rng, Y) + grad = ForwardDiff.gradient(f, x) + dx_true = grad .* dy + dy_true = dot(grad, dx) + return Scenario(f, x, y, dx, dy, dx_true, dy_true) +end + +function Scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:AbstractArray,Y<:AbstractArray} + dx = similar(x) + randn!(rng, dx) + dy = similar(y) + randn!(rng, dy) + jac = ForwardDiff.jacobian(f, x) + dx_true = transpose(jac) * dy + dy_true = jac * dx + return Scenario(f, x, y, dx, dy, dx_true, dy_true) +end + +## Access + get_input_type(::Scenario{F,X}) where {F,X} = X get_output_type(::Scenario{F,X,Y}) where {F,X,Y} = Y +## Seed + +rng = StableRNG(63) + ## Scalar input, scalar output -scenario1 = Scenario(; - f=(x::Real -> exp(2x)), - x=1.0, - y=exp(2), - dx=5.0, - dy_true=2exp(2) * 5, - dy=5.0, - dx_true=2exp(2) * 5, -) +scenario1 = Scenario(rng, (x::Real -> sin(2x)), 1.0) ## Scalar input, vector output -scenario2 = Scenario(; - f=(x::Real -> [exp(2x), exp(3x)]), - x=1.0, - y=[exp(2), exp(3)], - dx=5.0, - dy_true=[2exp(2), 3exp(3)] .* 5, - dy=[0.0, 5.0], - dx_true=3exp(3) * 5, -) +scenario2 = Scenario(rng, (x::Real -> [sin(2x), cos(3x)]), 1.0) ## Vector input, scalar output -scenario3 = Scenario(; - f=(x::AbstractVector -> exp(2x[1]) + exp(3x[2])), - x=[1.0, 2.0], - y=exp(2) + exp(6), - dx=[0.0, 5.0], - dy_true=3exp(6) * 5, - dy=5.0, - dx_true=[2exp(2), 3exp(6)] .* 5, -) +scenario3 = Scenario(rng, (x::AbstractVector -> sin(2x[1]) + cos(3x[2])), [1.0, 2.0]) ## Vector input, vector output -scenario4 = Scenario(; - f=(x::AbstractVector -> [exp(2x[1]), exp(3x[2])]), - x=[1.0, 2.0], - y=[exp(2), exp(6)], - dx=[0.0, 5.0], - dy_true=[0.0, 3exp(6)] .* 5, - dy=[0.0, 5.0], - dx_true=[0.0, 3exp(6)] .* 5, +scenario4 = Scenario( + rng, + (x::AbstractVector -> [sin(2x[1]), cos(3x[2]), tan(2x[1]) + tan(3x[2])]), + [1.0, 2.0], ) ## All