From ee0dcb7916ea9fab352fefde78a90d2395f3a542 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 13:50:18 +0800 Subject: [PATCH 1/4] Add setfield frule formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/rulesets/Base/base.jl | 8 ++++++++ test/rulesets/Base/base.jl | 9 +++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d05e9a9c3..fa48b0ef6 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index bf97318c7..6c66d19ee 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -26,6 +26,14 @@ function rrule(::typeof(one), x) return (one(x), one_pullback) end + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x) + ȯbj::MutableTangent + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 80e418b82..5bda3072e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -18,6 +18,15 @@ end end end + + @testset "setfield!" begin + mutable struct MDemo + x::Float64 + end + + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), :x, 5.0) + test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), 1, 5.0) + end @testset "Trig" begin @testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2)) From 96db615d76f343b48ebb74ba16d926d419d36bf5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 16 Jan 2024 15:46:23 +0800 Subject: [PATCH 2/4] handle NoTangent in copyto! --- src/rulesets/Base/array.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b94bddd3e..83f61bfca 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -86,7 +86,14 @@ _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs ##### function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x) - return copyto!(y, x), copyto!(ẏ, ẋ) + if ẏ isa AbstractZero + # it's allowed to have an imutable zero tangent for ẏ as long as ẋ is zero + # TODO should this be handled here or in the AD? + @assert iszero(ẋ) + else + copyto!(ẏ, ẋ) + end + return copyto!(y, x), ẏ end function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) From 0ff449a9496e7c8533693a964b094485e0d39a1c Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 16 Jan 2024 16:55:47 +0800 Subject: [PATCH 3/4] move struct defn in tests to top level --- test/rulesets/Base/base.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 5bda3072e..18108f371 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -1,3 +1,7 @@ +mutable struct MDemo + x::Float64 +end + @testset "base.jl" begin @testset "zero/one" begin for f in [zero, one] @@ -20,10 +24,6 @@ end @testset "setfield!" begin - mutable struct MDemo - x::Float64 - end - test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), :x, 5.0) test_frule(setfield!, MDemo(3.5) ⊢ MutableTangent{MDemo}(; x=2.0), 1, 5.0) end From 810c6335a3abeaebee90eafa4180e9d875e52288 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 1 Feb 2024 15:31:31 +0800 Subject: [PATCH 4/4] remove resolved TODO comment --- src/rulesets/Base/array.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 83f61bfca..cd5f518b1 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -88,7 +88,6 @@ _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x) if ẏ isa AbstractZero # it's allowed to have an imutable zero tangent for ẏ as long as ẋ is zero - # TODO should this be handled here or in the AD? @assert iszero(ẋ) else copyto!(ẏ, ẋ)