diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 13a5f7a8..45d99531 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -2,7 +2,6 @@ partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i) partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i) partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) partial(x::UniformTangent, i) = getfield(x, :val) -partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) partial(x::AbstractZero, i) = x diff --git a/src/tangent.jl b/src/tangent.jl index 8d30e4f3..025b958b 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex i::Int end -function Base.getindex(a::AbstractTangentBundle, b::TaylorTangentIndex) - error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous") -end - abstract type AbstractTangentSpace; end +Base.:(==)(x::AbstractTangentSpace, y::AbstractTangentSpace) = ==(promote(x, y)...) """ struct ExplicitTangent{P} @@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials. struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace partials::P end +Base.:(==)(a::ExplicitTangent, b::ExplicitTangent) = a.partials == b.partials +Base.hash(tt::ExplicitTangent, h::UInt64) = hash(tt.partials, h) + +Base.getindex(tangent::ExplicitTangent, b::CanonicalTangentIndex) = tangent.partials[b.i] +function Base.getindex(tangent::ExplicitTangent, b::TaylorTangentIndex) + if lastindex(tangent.partials) == exp2(b.i) - 1 + return tangent.partials[end] + end + # TODO: should we also allow other indexes if all the partials at that level are equal up regardless of order? + throw(DomainError(b, "$(typeof(tangent)) is not taylor-like. Taylor indexing is ambiguous")) +end + @eval struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace coeffs::C TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs)) end -Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs -Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h) """ struct TaylorTangent{C} @@ -122,15 +129,13 @@ by analogy with the (truncated) Taylor series """ TaylorTangent -""" - struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}} +Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs +Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h) + + +Base.getindex(tangent::TaylorTangent, tti::TaylorTangentIndex) = tangent.coeffs[tti.i] +Base.getindex(tangent::TaylorTangent, tti::CanonicalTangentIndex) = tangent.coeffs[count_ones(tti.i)] -Represents the product space of the given representations of the -tangent space. -""" -struct ProductTangent{T <: Tuple} <: AbstractTangentSpace - factors::T -end """ struct UniformTangent @@ -141,6 +146,28 @@ useful for representing singleton values. struct UniformTangent{U} <: AbstractTangentSpace val::U end +Base.hash(t::UniformTangent, h::UInt64) = hash(t.val, h) +Base.:(==)(t1::UniformTangent, t2::UniformTangent) = t1.val == t2.val + +Base.getindex(tangent::UniformTangent, ::Any) = tangent.val + +# Conversion and promotion +Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:AbstractTangentSpace}) = et +Base.promote_rule(tt::Type{<:TaylorTangent}, ::Type{<:AbstractTangentSpace}) = tt +Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:TaylorTangent}) = et +Base.promote_rule(::Type{<:TaylorTangent}, et::Type{<:ExplicitTangent}) = et + +num_partials(::Type{TaylorTangent{P}}) where P = fieldcount(P) +num_partials(::Type{ExplicitTangent{P}}) where P = fieldcount(P) +Base.eltype(::Type{TaylorTangent{P}}) where P = eltype(P) +Base.eltype(::Type{ExplicitTangent{P}}) where P = eltype(P) +function Base.convert(::Type{T}, ut::UniformTangent) where {T<:Union{TaylorTangent, ExplicitTangent}} + # can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper + T_wrapper = T<:TaylorTangent ? TaylorTangent : ExplicitTangent + T_wrapper(ntuple(_->convert(eltype(T), ut.val), num_partials(T))) +end +Base.convert(T::Type{<:ExplicitTangent}, tt::TaylorTangent) = ExplicitTangent(ntuple(i->tt[CanonicalTangentIndex(i)], num_partials(T))) +#TODO: Should we define the reverse: Explict->Taylor for the cases where that is actually defined? function _TangentBundle end @@ -154,7 +181,7 @@ end struct TangentBundle{N, B, P} Represents a tangent bundle as an explicit primal together -with some representation of (potentially a product of) the tangent space. +with some representation of the tangent space. """ TangentBundle @@ -162,7 +189,9 @@ TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} = _TangentBundle(Val{N}(), primal, tangent) Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h) -Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent) +Base.:(==)(a::TangentBundle, b::TangentBundle) = false # different orders +Base.:(==)(a::TangentBundle{N}, b::TangentBundle{N}) where {N} = (a.primal == b.primal) && (a.tangent == b.tangent) +Base.getindex(tbun::TangentBundle, x) = getindex(tbun.tangent, x) const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}} @@ -197,12 +226,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle) length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃") end -function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N} - if b.i === N - return a.tangent.partials[end] - end - error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous") -end + const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} @@ -233,11 +257,6 @@ function Base.show(io::IO, x::TaylorBundle{1}) print(io, x.coeffs[1], " ∂₁") end -Base.getindex(tb::TaylorBundle, tti::TaylorTangentIndex) = tb.tangent.coeffs[tti.i] -function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex) - tb.tangent.coeffs[count_ones(tti.i)] -end - "for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple" function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple} return ntuple(fieldcount(B)) do field_ii @@ -307,8 +326,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N print(io, ")") end +# Conversion and promotion +function Base.promote_rule(::Type{TangentBundle{N, B, P1}}, ::Type{TangentBundle{N, B, P2}}) where {N,B,P1,P2} + return TangentBundle{N, B, promote_type(P1, P2)} +end + +function Base.convert(::Type{T}, tbun::TangentBundle{N, B}) where {N, B, P, T<:TangentBundle{N,B,P}} + the_primal = convert(B, primal(tbun)) + the_partials = convert(P, tbun.tangent) + return _TangentBundle(Val{N}(), the_primal, the_partials) +end -Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val +# StructureArrays helpers expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) expand_singleton_to_array(asize, a::AbstractArray) = a diff --git a/test/tangent.jl b/test/tangent.jl index 95b4e22d..baef0ed6 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -1,7 +1,7 @@ -module tagent +module tangent using Diffractor -using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle -using Diffractor: TaylorBundle, TaylorTangentIndex +using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle, TaylorBundle, ExplicitTangentBundle +using Diffractor:TaylorTangentIndex, CanonicalTangentIndex using Diffractor: ExplicitTangent, TaylorTangent, truncate using ChainRulesCore using Test @@ -46,11 +46,71 @@ end end end -@testset "== and hash" begin +@testset "getindex" begin + tt = TaylorBundle{2}(1.5, (1.0, 2.0)) + @test tt[TaylorTangentIndex(1)] == 1.0 + @test tt[TaylorTangentIndex(2)] == 2.0 + @test tt[CanonicalTangentIndex(1)] == 1.0 + @test tt[CanonicalTangentIndex(2)] == 1.0 + @test tt[CanonicalTangentIndex(3)] == 2.0 + + et = ExplicitTangentBundle{2}(1.5, (1.0, 2.0, 3.0)) + @test_throws DomainError et[TaylorTangentIndex(1)] == 1.0 + @test et[TaylorTangentIndex(2)] == 3.0 + @test et[CanonicalTangentIndex(1)] == 1.0 + @test et[CanonicalTangentIndex(2)] == 2.0 + @test et[CanonicalTangentIndex(3)] == 3.0 + + zb = ZeroBundle{2}(1.5) + @test zb[TaylorTangentIndex(1)] == ZeroTangent() + @test zb[TaylorTangentIndex(2)] == ZeroTangent() + @test zb[CanonicalTangentIndex(1)] == ZeroTangent() + @test zb[CanonicalTangentIndex(2)] == ZeroTangent() + @test zb[CanonicalTangentIndex(3)] == ZeroTangent() +end + +@testset "promote" begin + @test promote_type( + typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))), + typeof(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) + ) <: ExplicitTangentBundle{1, Vector{Float64}} + + @test promote_type(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: TaylorBundle{1, Float64, Tuple{Float64}} + @test promote_type(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: ExplicitTangentBundle{1, Float64, Tuple{Float64}} +end +@testset "convert" begin + @test convert(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == TaylorBundle{1}(1.4, (0.0,)) + @test convert(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == ExplicitTangentBundle{1}(1.4, (0.0,)) + + @test convert( + typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))), + TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) + ) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],)) + + @test convert( + typeof(ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,))), + TaylorBundle{2}(1.5, (10.0, 20.0)) + ) === ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,)) +end +@testset "==" begin @test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) - @test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0) + @test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],)) + + @test ZeroBundle{3}(1.5) == ZeroBundle{3}(1.5) + @test ZeroBundle{3}(1.5) == TaylorBundle{3}(1.5, (0.0, 0.0, 0.0)) + @test ZeroBundle{3}(1.5) == ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) end +@testset "hash" begin + @test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) + @test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))) + + @test hash(ZeroBundle{3}(1.5)) == hash(ZeroBundle{3}(1.5)) + @test hash(ZeroBundle{3}(1.5)) == hash(TaylorBundle{3}(1.5, (0.0, 0.0, 0.0))) + @test hash(ZeroBundle{3}(1.5)) == hash(ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))) +end + + @testset "truncate" begin tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0)) @test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))