diff --git a/Project.toml b/Project.toml index d05e9a9c3..fa48b0ef6 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b94bddd3e..cd5f518b1 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -86,7 +86,13 @@ _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs ##### function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x) - return copyto!(y, x), copyto!(ẏ, ẋ) + if ẏ isa AbstractZero + # it's allowed to have an imutable zero tangent for ẏ as long as ẋ is zero + @assert iszero(ẋ) + else + copyto!(ẏ, ẋ) + end + return copyto!(y, x), ẏ end function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index bf97318c7..6c66d19ee 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -26,6 +26,14 @@ function rrule(::typeof(one), x) return (one(x), one_pullback) end + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x) + ȯbj::MutableTangent + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 80e418b82..18108f371 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -1,3 +1,7 @@ +mutable struct MDemo + x::Float64 +end + @testset "base.jl" begin @testset "zero/one" begin for f in [zero, one] @@ -18,6 +22,11 @@ end end end + + @testset "setfield!" begin + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), :x, 5.0) + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), 1, 5.0) + end @testset "Trig" begin @testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))