Skip to content

Commit

Permalink
define construstinstance over UniformBundles so differentiating map o…
Browse files Browse the repository at this point in the history
…ver closure wrt closed variable works
  • Loading branch information
oxinabox committed Oct 10, 2023
1 parent 960e74f commit 084212a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
unbundle(atb), Δ->throw(Δ)
end

function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
function StructArrays.createinstance(T::Type{<:UniformBundle}, args...)
T(args[1], args[2])
end

Expand Down
10 changes: 10 additions & 0 deletions test/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ end
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
)


# map over all closure, wrt the closed variable
mulby(x) = y->x*y
🐇 = ∂☆{1}()(
ZeroBundle{1}(x->(map(mulby(x), [2.0, 4.0]))),
TaylorBundle{1}(2.0, (10.0,))
)
@test 🐇 == TaylorBundle{1}([4.0, 8.0], ([20.0, 40.0],))

end


Expand Down

0 comments on commit 084212a

Please sign in to comment.