Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add polish to Bundle Types #225

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
partial(x::UniformTangent, i) = getfield(x, :val)
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
partial(x::AbstractZero, i) = x


Expand Down
85 changes: 57 additions & 28 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@
i::Int
end

function Base.getindex(a::AbstractTangentBundle, b::TaylorTangentIndex)
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
end

abstract type AbstractTangentSpace; end
Base.:(==)(x::AbstractTangentSpace, y::AbstractTangentSpace) = ==(promote(x, y)...)

"""
struct ExplicitTangent{P}
Expand All @@ -89,13 +86,23 @@
struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
partials::P
end
Base.:(==)(a::ExplicitTangent, b::ExplicitTangent) = a.partials == b.partials
Base.hash(tt::ExplicitTangent, h::UInt64) = hash(tt.partials, h)

Check warning on line 90 in src/tangent.jl

View check run for this annotation

Codecov / codecov/patch

src/tangent.jl#L90

Added line #L90 was not covered by tests

Base.getindex(tangent::ExplicitTangent, b::CanonicalTangentIndex) = tangent.partials[b.i]
function Base.getindex(tangent::ExplicitTangent, b::TaylorTangentIndex)
if lastindex(tangent.partials) == exp2(b.i) - 1
return tangent.partials[end]
end
# TODO: should we also allow other indexes if all the partials at that level are equal up regardless of order?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment wasn't in there before.
But I don't actually see why we don't do this.

We just need to check all things at the indexes j that have count_ones(j) = b.i are equal. And if so return that thing.
if not give error about ambig

throw(DomainError(b, "$(typeof(tangent)) is not taylor-like. Taylor indexing is ambiguous"))
end


@eval struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
coeffs::C
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
end
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)

"""
struct TaylorTangent{C}
Expand All @@ -122,15 +129,13 @@
"""
TaylorTangent

"""
struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)

Check warning on line 133 in src/tangent.jl

View check run for this annotation

Codecov / codecov/patch

src/tangent.jl#L133

Added line #L133 was not covered by tests


Base.getindex(tangent::TaylorTangent, tti::TaylorTangentIndex) = tangent.coeffs[tti.i]
Base.getindex(tangent::TaylorTangent, tti::CanonicalTangentIndex) = tangent.coeffs[count_ones(tti.i)]

Represents the product space of the given representations of the
tangent space.
"""
struct ProductTangent{T <: Tuple} <: AbstractTangentSpace
factors::T
end

"""
struct UniformTangent
Expand All @@ -141,6 +146,28 @@
struct UniformTangent{U} <: AbstractTangentSpace
val::U
end
Base.hash(t::UniformTangent, h::UInt64) = hash(t.val, h)

Check warning on line 149 in src/tangent.jl

View check run for this annotation

Codecov / codecov/patch

src/tangent.jl#L149

Added line #L149 was not covered by tests
Base.:(==)(t1::UniformTangent, t2::UniformTangent) = t1.val == t2.val

Base.getindex(tangent::UniformTangent, ::Any) = tangent.val

# Conversion and promotion
Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:AbstractTangentSpace}) = et
Base.promote_rule(tt::Type{<:TaylorTangent}, ::Type{<:AbstractTangentSpace}) = tt
Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:TaylorTangent}) = et
Base.promote_rule(::Type{<:TaylorTangent}, et::Type{<:ExplicitTangent}) = et

num_partials(::Type{TaylorTangent{P}}) where P = fieldcount(P)
num_partials(::Type{ExplicitTangent{P}}) where P = fieldcount(P)
Base.eltype(::Type{TaylorTangent{P}}) where P = eltype(P)
Base.eltype(::Type{ExplicitTangent{P}}) where P = eltype(P)
function Base.convert(::Type{T}, ut::UniformTangent) where {T<:Union{TaylorTangent, ExplicitTangent}}
# can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
T_wrapper = T<:TaylorTangent ? TaylorTangent : ExplicitTangent
T_wrapper(ntuple(_->convert(eltype(T), ut.val), num_partials(T)))
end
Base.convert(T::Type{<:ExplicitTangent}, tt::TaylorTangent) = ExplicitTangent(ntuple(i->tt[CanonicalTangentIndex(i)], num_partials(T)))
#TODO: Should we define the reverse: Explict->Taylor for the cases where that is actually defined?

function _TangentBundle end

Expand All @@ -154,15 +181,17 @@
struct TangentBundle{N, B, P}

Represents a tangent bundle as an explicit primal together
with some representation of (potentially a product of) the tangent space.
with some representation of the tangent space.
"""
TangentBundle

TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
_TangentBundle(Val{N}(), primal, tangent)

Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h)
Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent)
Base.:(==)(a::TangentBundle, b::TangentBundle) = false # different orders

Check warning on line 192 in src/tangent.jl

View check run for this annotation

Codecov / codecov/patch

src/tangent.jl#L192

Added line #L192 was not covered by tests
Base.:(==)(a::TangentBundle{N}, b::TangentBundle{N}) where {N} = (a.primal == b.primal) && (a.tangent == b.tangent)
Base.getindex(tbun::TangentBundle, x) = getindex(tbun.tangent, x)

const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}

Expand Down Expand Up @@ -197,12 +226,7 @@
length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃")
end

function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N}
if b.i === N
return a.tangent.partials[end]
end
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
end


const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}

Expand Down Expand Up @@ -233,11 +257,6 @@
print(io, x.coeffs[1], " ∂₁")
end

Base.getindex(tb::TaylorBundle, tti::TaylorTangentIndex) = tb.tangent.coeffs[tti.i]
function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
tb.tangent.coeffs[count_ones(tti.i)]
end

"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
return ntuple(fieldcount(B)) do field_ii
Expand Down Expand Up @@ -307,8 +326,18 @@
print(io, ")")
end

# Conversion and promotion
function Base.promote_rule(::Type{TangentBundle{N, B, P1}}, ::Type{TangentBundle{N, B, P2}}) where {N,B,P1,P2}
return TangentBundle{N, B, promote_type(P1, P2)}
end

function Base.convert(::Type{T}, tbun::TangentBundle{N, B}) where {N, B, P, T<:TangentBundle{N,B,P}}
the_primal = convert(B, primal(tbun))
the_partials = convert(P, tbun.tangent)
return _TangentBundle(Val{N}(), the_primal, the_partials)
end

Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val
# StructureArrays helpers

expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
expand_singleton_to_array(asize, a::AbstractArray) = a
Expand Down
70 changes: 65 additions & 5 deletions test/tangent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module tagent
module tangent
using Diffractor
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
using Diffractor: TaylorBundle, TaylorTangentIndex
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle, TaylorBundle, ExplicitTangentBundle
using Diffractor:TaylorTangentIndex, CanonicalTangentIndex
using Diffractor: ExplicitTangent, TaylorTangent, truncate
using ChainRulesCore
using Test
Expand Down Expand Up @@ -46,11 +46,71 @@ end
end
end

@testset "== and hash" begin
@testset "getindex" begin
tt = TaylorBundle{2}(1.5, (1.0, 2.0))
@test tt[TaylorTangentIndex(1)] == 1.0
@test tt[TaylorTangentIndex(2)] == 2.0
@test tt[CanonicalTangentIndex(1)] == 1.0
@test tt[CanonicalTangentIndex(2)] == 1.0
@test tt[CanonicalTangentIndex(3)] == 2.0

et = ExplicitTangentBundle{2}(1.5, (1.0, 2.0, 3.0))
@test_throws DomainError et[TaylorTangentIndex(1)] == 1.0
@test et[TaylorTangentIndex(2)] == 3.0
@test et[CanonicalTangentIndex(1)] == 1.0
@test et[CanonicalTangentIndex(2)] == 2.0
@test et[CanonicalTangentIndex(3)] == 3.0

zb = ZeroBundle{2}(1.5)
@test zb[TaylorTangentIndex(1)] == ZeroTangent()
@test zb[TaylorTangentIndex(2)] == ZeroTangent()
@test zb[CanonicalTangentIndex(1)] == ZeroTangent()
@test zb[CanonicalTangentIndex(2)] == ZeroTangent()
@test zb[CanonicalTangentIndex(3)] == ZeroTangent()
end

@testset "promote" begin
@test promote_type(
typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))),
typeof(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)))
) <: ExplicitTangentBundle{1, Vector{Float64}}

@test promote_type(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: TaylorBundle{1, Float64, Tuple{Float64}}
@test promote_type(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: ExplicitTangentBundle{1, Float64, Tuple{Float64}}
end
@testset "convert" begin
@test convert(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == TaylorBundle{1}(1.4, (0.0,))
@test convert(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == ExplicitTangentBundle{1}(1.4, (0.0,))

@test convert(
typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))),
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))

@test convert(
typeof(ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,))),
TaylorBundle{2}(1.5, (10.0, 20.0))
) === ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,))
end
@testset "==" begin
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
@test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0)
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))

@test ZeroBundle{3}(1.5) == ZeroBundle{3}(1.5)
@test ZeroBundle{3}(1.5) == TaylorBundle{3}(1.5, (0.0, 0.0, 0.0))
@test ZeroBundle{3}(1.5) == ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
end

@testset "hash" begin
@test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)))
@test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],)))

@test hash(ZeroBundle{3}(1.5)) == hash(ZeroBundle{3}(1.5))
@test hash(ZeroBundle{3}(1.5)) == hash(TaylorBundle{3}(1.5, (0.0, 0.0, 0.0)))
@test hash(ZeroBundle{3}(1.5)) == hash(ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)))
end


@testset "truncate" begin
tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
@test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))
Expand Down
Loading