diff --git a/src/tangent.jl b/src/tangent.jl index a46dd031..8d30e4f3 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -94,6 +94,8 @@ end coeffs::C TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs)) end +Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs +Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h) """ struct TaylorTangent{C} @@ -159,6 +161,9 @@ TangentBundle TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} = _TangentBundle(Val{N}(), primal, tangent) +Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h) +Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent) + const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}} check_tangent_invariant(lp, N) = @assert lp == 2^N - 1 @@ -201,20 +206,25 @@ end const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} + +function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P} + check_taylor_invariants(coeffs, primal, N) + _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) +end function TaylorBundle{N, B}(primal::B, coeffs) where {N, B} check_taylor_invariants(coeffs, primal, N) _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) end +function TaylorBundle{N}(primal, coeffs) where {N} + check_taylor_invariants(coeffs, primal, N) + _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) +end function check_taylor_invariants(coeffs, primal, N) @assert length(coeffs) == N - end @ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) -function TaylorBundle{N}(primal, coeffs) where {N} - _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) -end function Base.show(io::IO, x::TaylorBundle{1}) print(io, x.primal) @@ -350,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle) unbundle(atb), Δ->throw(Δ) end -function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...) +function StructArrays.createinstance(T::Type{<:UniformBundle}, args...) T(args[1], args[2]) end diff --git a/test/forward.jl b/test/forward.jl index 53b65059..1e4b7142 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,6 +1,6 @@ module forward_tests using Diffractor -using Diffractor: TaylorBundle, ZeroBundle +using Diffractor: TaylorBundle, ZeroBundle, ∂☆ using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -61,7 +61,7 @@ end end # Special case if there is no derivative information at all: - @test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0) + @test ∂☆{1}()(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0) @test frule_calls[] == 0 @test primal_calls[] == 1 end @@ -88,6 +88,24 @@ end end +@testset "map" begin + @test ==( + ∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))), + TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) + ) + + + # map over all closure, wrt the closed variable + mulby(x) = y->x*y + 🐇 = ∂☆{1}()( + ZeroBundle{1}(x->(map(mulby(x), [2.0, 4.0]))), + TaylorBundle{1}(2.0, (10.0,)) + ) + @test 🐇 == TaylorBundle{1}([4.0, 8.0], ([20.0, 40.0],)) + +end + + @testset "structs" begin struct IDemo x::Float64 @@ -166,4 +184,4 @@ end ) end -end +end # module diff --git a/test/tangent.jl b/test/tangent.jl index 01a54607..95b4e22d 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -46,6 +46,11 @@ end end end +@testset "== and hash" begin + @test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) + @test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0) +end + @testset "truncate" begin tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0)) @test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))