Skip to content

Commit

Permalink
Merge pull request #224 from JuliaDiff/ox/mapfix
Browse files Browse the repository at this point in the history
Fix map in forwards mode
  • Loading branch information
oxinabox authored Oct 11, 2023
2 parents b3e4ee0 + fb1dd92 commit e7c8abd
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 8 deletions.
20 changes: 15 additions & 5 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ end
coeffs::C
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
end
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)

"""
struct TaylorTangent{C}
Expand Down Expand Up @@ -159,6 +161,9 @@ TangentBundle
TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
_TangentBundle(Val{N}(), primal, tangent)

Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h)
Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent)

const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}

check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
Expand Down Expand Up @@ -201,20 +206,25 @@ end

const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}


function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end
function TaylorBundle{N}(primal, coeffs) where {N}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end

function check_taylor_invariants(coeffs, primal, N)
@assert length(coeffs) == N

end
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)

function TaylorBundle{N}(primal, coeffs) where {N}
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end

function Base.show(io::IO, x::TaylorBundle{1})
print(io, x.primal)
Expand Down Expand Up @@ -350,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
unbundle(atb), Δ->throw(Δ)
end

function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
function StructArrays.createinstance(T::Type{<:UniformBundle}, args...)
T(args[1], args[2])
end

Expand Down
24 changes: 21 additions & 3 deletions test/forward.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module forward_tests
using Diffractor
using Diffractor: TaylorBundle, ZeroBundle
using Diffractor: TaylorBundle, ZeroBundle, ∂☆
using ChainRules
using ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
Expand Down Expand Up @@ -61,7 +61,7 @@ end
end

# Special case if there is no derivative information at all:
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
@test ∂☆{1}()(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
@test frule_calls[] == 0
@test primal_calls[] == 1
end
Expand All @@ -88,6 +88,24 @@ end
end


@testset "map" begin
@test ==(
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
)


# map over all closure, wrt the closed variable
mulby(x) = y->x*y
🐇 = ∂☆{1}()(
ZeroBundle{1}(x->(map(mulby(x), [2.0, 4.0]))),
TaylorBundle{1}(2.0, (10.0,))
)
@test 🐇 == TaylorBundle{1}([4.0, 8.0], ([20.0, 40.0],))

end


@testset "structs" begin
struct IDemo
x::Float64
Expand Down Expand Up @@ -166,4 +184,4 @@ end
)
end

end
end # module
5 changes: 5 additions & 0 deletions test/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ end
end
end

@testset "== and hash" begin
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
@test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0)
end

@testset "truncate" begin
tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
@test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))
Expand Down

0 comments on commit e7c8abd

Please sign in to comment.