diff --git a/Project.toml b/Project.toml index 77f514046..e523f8b8a 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.17" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 28cc11d19..423b9e7dc 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -7,6 +7,14 @@ @scalar_rule zero(x) ZeroTangent() @scalar_rule transpose(x) true + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x) + ȯbj::MutableTangent + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9a5278747..ef7c1ab7b 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -8,6 +8,15 @@ end end end + + @testset "setfield!" begin + mutable struct MDemo + x::Float64 + end + + test_frule(setfield!, MDemo(3.5)⊢MutableTangent{MDemo}(x=2.0), :x, 5.0) + test_frule(setfield!, MDemo(3.5)⊢MutableTangent{MDemo}(x=2.0), 1, 5.0) + end @testset "Trig" begin @testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))