diff --git a/src/DifferentiationTest/DifferentiationTest.jl b/src/DifferentiationTest/DifferentiationTest.jl index 1fdc45393..928811b33 100644 --- a/src/DifferentiationTest/DifferentiationTest.jl +++ b/src/DifferentiationTest/DifferentiationTest.jl @@ -16,10 +16,12 @@ using ..DifferentiationInterface import ..DifferentiationInterface as DI using ..DifferentiationInterface: AutoTaped, + inner, mode, mysimilar, myzero, myzero!!, + outer, supports_mutation, supports_pushforward, supports_pullback diff --git a/src/DifferentiationTest/printing.jl b/src/DifferentiationTest/printing.jl index cbe65ff0a..db174181d 100644 --- a/src/DifferentiationTest/printing.jl +++ b/src/DifferentiationTest/printing.jl @@ -34,3 +34,7 @@ function backend_string(backend::AbstractADType) error("Unknown mode") end end + +function backend_string(backend::SecondOrder) + return "$(backend_string(outer(backend))) / $(backend_string(inner(backend)))" +end diff --git a/src/hvp.jl b/src/hvp.jl index fafe2ad94..213cc7056 100644 --- a/src/hvp.jl +++ b/src/hvp.jl @@ -13,12 +13,6 @@ By order of preference: """ hvp(f, backend, x, v, [extras]) -> p """ -function hvp( - f::F, backend::AbstractADType, x::Number, v, extras=prepare_hvp(f, backend, x) -) where {F} - return v * second_derivative(f, backend, x, extras) -end - function hvp( f::F, backend::AbstractADType, x, v, extras=prepare_hvp(f, backend, x) ) where {F} @@ -27,7 +21,13 @@ function hvp( return hvp(f, new_backend, x, v, new_extras) end -function hvp(f::F, backend::SecondOrder, x, v, extras=prepare_hvp(backend, f, x)) where {F} +function hvp( + f::F, backend::SecondOrder, x::Number, v::Number, extras=prepare_hvp(f, backend, x) +) where {F} + return v * second_derivative(f, backend, x, extras) +end + +function hvp(f::F, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend, x)) where {F} return hvp_aux(f, backend, x, v, extras, hvp_mode(backend)) end diff --git a/test/runtests.jl b/test/runtests.jl index 2f4d006a4..41bf8588f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,4 +98,8 @@ using Zygote: Zygote @time include("zygote.jl") end end + + @testset verbose = true "Second order" begin + include("second_order.jl") + end end; diff --git a/test/second_order.jl b/test/second_order.jl new file mode 100644 index 000000000..232f009f1 --- /dev/null +++ b/test/second_order.jl @@ -0,0 +1,31 @@ +using ADTypes +using DifferentiationInterface +using DifferentiationInterface.DifferentiationTest +using DifferentiationInterface.DifferentiationTest: backend_string + +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff +using Enzyme: Enzyme +using Zygote: Zygote + +using JET: JET +using Test + +SECOND_ORDER_BACKENDS = Dict( + "forward/forward" => [ + SecondOrder(AutoEnzyme(Enzyme.Forward), AutoForwardDiff()), + SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)), + ], + "forward/reverse" => [SecondOrder(AutoForwardDiff(), AutoZygote())], + "reverse/forward" => [], +) + +@testset verbose = true "Cross backends" begin + @testset verbose = true "$second_order_mode" for (second_order_mode, backends) in + pairs(SECOND_ORDER_BACKENDS) + @info "Testing $second_order_mode..." + @time @testset "$(backend_string(backend))" for backend in backends + test_operators(backend; first_order=false, type_stability=false) + end + end +end; diff --git a/test/zero.jl b/test/zero.jl index 39881e341..f05f5a814 100644 --- a/test/zero.jl +++ b/test/zero.jl @@ -34,9 +34,8 @@ test_operators( test_operators( [AutoZeroForward(), AutoZeroReverse()]; - allocating=false, correctness=false, - type_stability=true, + type_stability=false, allocations=true, );