From ee0dcb7916ea9fab352fefde78a90d2395f3a542 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 13:50:18 +0800 Subject: [PATCH] Add setfield frule formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/rulesets/Base/base.jl | 8 ++++++++ test/rulesets/Base/base.jl | 9 +++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d05e9a9c3..fa48b0ef6 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index bf97318c7..6c66d19ee 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -26,6 +26,14 @@ function rrule(::typeof(one), x) return (one(x), one_pullback) end + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x) + ȯbj::MutableTangent + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 80e418b82..5bda3072e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -18,6 +18,15 @@ end end end + + @testset "setfield!" begin + mutable struct MDemo + x::Float64 + end + + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), :x, 5.0) + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), 1, 5.0) + end @testset "Trig" begin @testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))