Skip to content

Commit

Permalink
add missing TaylorBundle constructor, thus fixing map
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 10, 2023
1 parent 17c2264 commit 960e74f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
13 changes: 9 additions & 4 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,25 @@ end

const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}


function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end
function TaylorBundle{N}(primal, coeffs) where {N}
check_taylor_invariants(coeffs, primal, N)
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end

function check_taylor_invariants(coeffs, primal, N)
@assert length(coeffs) == N

end
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)

function TaylorBundle{N}(primal, coeffs) where {N}
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end

function Base.show(io::IO, x::TaylorBundle{1})
print(io, x.primal)
Expand Down
14 changes: 11 additions & 3 deletions test/forward.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module forward_tests
#module forward_tests
using Diffractor
using Diffractor: TaylorBundle, ZeroBundle
using Diffractor: TaylorBundle, ZeroBundle, ∂☆
using ChainRules
using ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
Expand Down Expand Up @@ -61,7 +61,7 @@ end
end

# Special case if there is no derivative information at all:
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
@test ∂☆{1}()(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
@test frule_calls[] == 0
@test primal_calls[] == 1
end
Expand All @@ -88,6 +88,14 @@ end
end


@testset "map" begin
@test ==(
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
)
end


@testset "structs" begin
struct IDemo
x::Float64
Expand Down

0 comments on commit 960e74f

Please sign in to comment.