From 92f6a039daf4a2def6278a1cb2178875ee2764bc Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 04:11:29 -0400 Subject: [PATCH 01/36] rename files --- src/ChainRulesCore.jl | 2 +- src/tangent_types/{tangent.jl => structural_tangent.jl} | 0 test/runtests.jl | 2 +- test/tangent_types/{tangent.jl => structural_tangent.jl} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename src/tangent_types/{tangent.jl => structural_tangent.jl} (100%) rename test/tangent_types/{tangent.jl => structural_tangent.jl} (100%) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 94e8242b1..f943c50fa 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -21,7 +21,7 @@ include("debug_mode.jl") include("tangent_types/abstract_tangent.jl") include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") -include("tangent_types/tangent.jl") +include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") include("tangent_arithmetic.jl") diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/structural_tangent.jl similarity index 100% rename from src/tangent_types/tangent.jl rename to src/tangent_types/structural_tangent.jl diff --git a/test/runtests.jl b/test/runtests.jl index 6a4684d03..a3b0971a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using Test @testset "differentials" begin include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") - include("tangent_types/tangent.jl") + include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") end diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/structural_tangent.jl similarity index 100% rename from test/tangent_types/tangent.jl rename to test/tangent_types/structural_tangent.jl From f9b5a2491f403d21e3a4006c22c497570fcef789 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 06:35:40 -0400 Subject: [PATCH 02/36] move functionality up to StructuralTangent --- src/tangent_arithmetic.jl | 22 +- src/tangent_types/structural_tangent.jl | 369 +++++++++++++----------- 2 files changed, 211 insertions(+), 180 deletions(-) diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 439f0ac8f..18ae7b3ad 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x Base.:*(x::NotImplemented, ::NotImplemented) = x LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x # `NotImplemented` always "wins" + -for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x end @@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent) @eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T() end # `NotImplemented` "wins" * and dot for other types -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:*(x::NotImplemented, ::$T) = x @eval Base.:*(::$T, x::NotImplemented) = x @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x @@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent() Base.:-(::NoTangent) = NoTangent() Base.:*(::NoTangent, ::NoTangent) = NoTangent() LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::NoTangent, b::$T) = b @eval Base.:+(a::$T, ::NoTangent) = a @eval Base.:-(::NoTangent, b::$T) = -b @@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.:-(::ZeroTangent) = ZeroTangent() Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent() LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::ZeroTangent, b::$T) = b @eval Base.:+(a::$T, ::ZeroTangent) = a @eval Base.:-(::ZeroTangent, b::$T) = -b @@ -126,11 +126,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} +function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P,typeof(data)}(data) + return StructuralTangent{P}(data) end -function Base.:+(a::P, d::Tangent{P}) where {P} +function Base.:+(a::P, d::StructuralTangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P} end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where {P} = b + a +Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a -Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent) +Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 tangents # Only of a tangent and a scaling factor (generally `Real`) for T in (:Number,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) + @eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 6af968c53..a8fd4ac1b 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -1,5 +1,201 @@ """ - Tangent{P, T} <: AbstractTangent + StructuralTangent{P} <: AbstractTangent + +Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`). +as an object with mirroring fields. +""" +abstract type StructuralTangent{P} <: AbstractTangent end + +function StructuralTangent{P}(nt::NamedTuple) where P + return Tangent{P, typeof(nt)}(nt) +end + +StructuralTangent{P}(tup::Tuple) where P = Tangent{P, typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) + + +Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) +Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) + +Base.haskey(tangent::StructuralTangent, key) = haskey(backing(tangent), key) +if isdefined(Base, :hasproperty) + Base.hasproperty(tangent::StructuralTangent, key::Symbol) = hasproperty(backing(tangent), key) +end + +Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) + +function Base.map(f, tangent::StructuralTangent{P}) where {P} + L = propertynames(backing(tangent)) + vals = map(f, Tuple(backing(tangent))) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return if tangent isa Tangent + Tangent{P, typeof(named_vals)}(named_vals) + else + # Handle MutableTangent + end +end + + +""" + backing(x) + +Accesses the backing field of a `Tangent`, +or destructures any other struct type into a `NamedTuple`. +Identity function on `Tuple`s and `NamedTuple`s. + +This is an internal function used to simplify operations between `Tangent`s and the +primal types. +""" +backing(x::Tuple) = x +backing(x::NamedTuple) = x +backing(x::Dict) = x +backing(x::StructuralTangent) = getfield(x, :backing) + +# For generic structs +function backing(x::T)::NamedTuple where {T} + # note: all computation outside the if @generated happens at runtime. + # so the first 4 lines of the branchs look the same, but can not be moved out. + # see https://github.com/JuliaLang/julia/issues/34283 + if @generated + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) + nfields = fieldcount(T) + names = fieldnames(T) + types = fieldtypes(T) + + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) + else + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) + nfields = fieldcount(T) + names = fieldnames(T) + types = fieldtypes(T) + + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) + end +end + + +""" + _zeroed_backing(P) + +Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. +""" +@generated function _zeroed_backing(::Type{P}) where {P} + nil_base = ntuple(fieldcount(P)) do i + (fieldname(P, i), ZeroTangent()) + end + return (; nil_base...) +end + +""" + construct(::Type{T}, fields::[NamedTuple|Tuple]) + +Constructs an object of type `T`, with the given fields. +Fields must be correct in name and type, and `T` must have a default constructor. + +This internally is called to construct structs of the primal type `T`, +after an operation such as the addition of a primal to a tangent + +It should be overloaded, if `T` does not have a default constructor, +or if `T` needs to maintain some invarients between its fields. +""" +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} + # Tested and verified that that this avoids a ton of allocations + if length(L) !== fieldcount(T) + # if length is equal but names differ then we will catch that below anyway. + throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L")) + end + + if @generated + vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T)) + return :(T($(vals...))) + else + return T((getproperty(fields, fname) for fname in fieldnames(T))...) + end +end + +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields + +elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) + +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} + # Rule of Tangent addition: any fields not present are implict hard Zeros + + # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. + # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231 + if @generated + names = Base.merge_names(an, bn) + + vals = map(names) do field + a_field = :(getproperty(a, $(QuoteNode(field)))) + b_field = :(getproperty(b, $(QuoteNode(field)))) + value_expr = if Base.sym_in(field, an) + if Base.sym_in(field, bn) + # in both + :($a_field + $b_field) + else + # only in `an` + a_field + end + else # must be in `b` only + b_field + end + Expr(:kw, field, value_expr) + end + return Expr(:tuple, Expr(:parameters, vals...)) + else + names = Base.merge_names(an, bn) + vals = map(names) do field + value = if Base.sym_in(field, an) + a_field = getproperty(a, field) + if Base.sym_in(field, bn) + # in both + b_field = getproperty(b, field) + a_field + b_field + else + # only in `an` + a_field + end + else # must be in `b` only + getproperty(b, field) + end + field => value + end + return (; vals...) + end +end + +elementwise_add(a::Dict, b::Dict) = merge(+, a, b) + +struct PrimalAdditionFailedException{P} <: Exception + primal::P + tangent + original::Exception +end + +function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} + println(io, "Could not construct $P after addition.") + println(io, "This probably means no default constructor is defined.") + println(io, "Either define a default constructor") + printstyled(io, "$P(", join(propertynames(err.tangent), ", "), ")"; color=:blue) + println(io, "\nor overload") + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.tangent)))"; color=:blue + ) + println(io, "\nor overload") + printstyled(io, "Base.:+(::$P, ::$(typeof(err.tangent)))"; color=:blue) + println(io, "\nOriginal Exception:") + printstyled(io, err.original; color=:yellow) + return println(io) +end + + +""" + Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. `P` is the the corresponding primal type that this is a tangent for. @@ -21,7 +217,7 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P,T} <: AbstractTangent +struct Tangent{P,T} <: StructuralTangent{P} # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain tangents) backing::T @@ -62,6 +258,7 @@ function _backing_error(P, G, E) return throw(ArgumentError(msg)) end + function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end @@ -98,7 +295,7 @@ end Base.iszero(::Tangent{<:,NamedTuple{}}) = true Base.iszero(::Tangent{<:,Tuple{}}) = true -Base.iszero(t::Tangent) = all(iszero, backing(t)) + Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) @@ -134,13 +331,6 @@ function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedT return unthunk(getfield(backing(tangent), idx)) end -Base.keys(tangent::Tangent) = keys(backing(tangent)) -Base.propertynames(tangent::Tangent) = propertynames(backing(tangent)) - -Base.haskey(tangent::Tangent, key) = haskey(backing(tangent), key) -if isdefined(Base, :hasproperty) - Base.hasproperty(tangent::Tangent, key::Symbol) = hasproperty(backing(tangent), key) -end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) @@ -159,57 +349,13 @@ function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} - vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L,typeof(vals)}(vals) - return Tangent{P,typeof(named_vals)}(named_vals) -end function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end Base.conj(tangent::Tangent) = map(conj, tangent) -""" - backing(x) - -Accesses the backing field of a `Tangent`, -or destructures any other struct type into a `NamedTuple`. -Identity function on `Tuple`s and `NamedTuple`s. - -This is an internal function used to simplify operations between `Tangent`s and the -primal types. -""" -backing(x::Tuple) = x -backing(x::NamedTuple) = x -backing(x::Dict) = x -backing(x::Tangent) = getfield(x, :backing) - -# For generic structs -function backing(x::T)::NamedTuple where {T} - # note: all computation outside the if @generated happens at runtime. - # so the first 4 lines of the branchs look the same, but can not be moved out. - # see https://github.com/JuliaLang/julia/issues/34283 - if @generated - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) - nfields = fieldcount(T) - names = fieldnames(T) - types = fieldtypes(T) - - vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) - else - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) - nfields = fieldcount(T) - names = fieldnames(T) - types = fieldtypes(T) - vals = ntuple(ii -> getfield(x, ii), nfields) - return NamedTuple{names,Tuple{types...}}(vals) - end -end """ canonicalize(tangent::Tangent{P}) -> Tangent{P} @@ -243,118 +389,3 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent - -""" - _zeroed_backing(P) - -Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. -""" -@generated function _zeroed_backing(::Type{P}) where {P} - nil_base = ntuple(fieldcount(P)) do i - (fieldname(P, i), ZeroTangent()) - end - return (; nil_base...) -end - -""" - construct(::Type{T}, fields::[NamedTuple|Tuple]) - -Constructs an object of type `T`, with the given fields. -Fields must be correct in name and type, and `T` must have a default constructor. - -This internally is called to construct structs of the primal type `T`, -after an operation such as the addition of a primal to a tangent - -It should be overloaded, if `T` does not have a default constructor, -or if `T` needs to maintain some invarients between its fields. -""" -function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} - # Tested and verified that that this avoids a ton of allocations - if length(L) !== fieldcount(T) - # if length is equal but names differ then we will catch that below anyway. - throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L")) - end - - if @generated - vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T)) - return :(T($(vals...))) - else - return T((getproperty(fields, fname) for fname in fieldnames(T))...) - end -end - -construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields -construct(::Type{T}, fields::T) where {T<:Tuple} = fields - -elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) - -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} - # Rule of Tangent addition: any fields not present are implict hard Zeros - - # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. - # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231 - if @generated - names = Base.merge_names(an, bn) - - vals = map(names) do field - a_field = :(getproperty(a, $(QuoteNode(field)))) - b_field = :(getproperty(b, $(QuoteNode(field)))) - value_expr = if Base.sym_in(field, an) - if Base.sym_in(field, bn) - # in both - :($a_field + $b_field) - else - # only in `an` - a_field - end - else # must be in `b` only - b_field - end - Expr(:kw, field, value_expr) - end - return Expr(:tuple, Expr(:parameters, vals...)) - else - names = Base.merge_names(an, bn) - vals = map(names) do field - value = if Base.sym_in(field, an) - a_field = getproperty(a, field) - if Base.sym_in(field, bn) - # in both - b_field = getproperty(b, field) - a_field + b_field - else - # only in `an` - a_field - end - else # must be in `b` only - getproperty(b, field) - end - field => value - end - return (; vals...) - end -end - -elementwise_add(a::Dict, b::Dict) = merge(+, a, b) - -struct PrimalAdditionFailedException{P} <: Exception - primal::P - tangent::Tangent{P} - original::Exception -end - -function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} - println(io, "Could not construct $P after addition.") - println(io, "This probably means no default constructor is defined.") - println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.tangent), ", "), ")"; color=:blue) - println(io, "\nor overload") - printstyled( - io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.tangent)))"; color=:blue - ) - println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.tangent)))"; color=:blue) - println(io, "\nOriginal Exception:") - printstyled(io, err.original; color=:yellow) - return println(io) -end From 35aff309aeb252cbaa9eb0aecd022caa043ba362 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 07:00:51 -0400 Subject: [PATCH 03/36] Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index a8fd4ac1b..9ea665735 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -6,20 +6,21 @@ as an object with mirroring fields. """ abstract type StructuralTangent{P} <: AbstractTangent end -function StructuralTangent{P}(nt::NamedTuple) where P - return Tangent{P, typeof(nt)}(nt) +function StructuralTangent{P}(nt::NamedTuple) where {P} + return Tangent{P,typeof(nt)}(nt) end -StructuralTangent{P}(tup::Tuple) where P = Tangent{P, typeof(tup)}(tup) -StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) - +StructuralTangent{P}(tup::Tuple) where {P} = Tangent{P,typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where {P} = Tangent{P}(dict) Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) Base.haskey(tangent::StructuralTangent, key) = haskey(backing(tangent), key) if isdefined(Base, :hasproperty) - Base.hasproperty(tangent::StructuralTangent, key::Symbol) = hasproperty(backing(tangent), key) + function Base.hasproperty(tangent::StructuralTangent, key::Symbol) + return hasproperty(backing(tangent), key) + end end Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) @@ -29,13 +30,12 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) return if tangent isa Tangent - Tangent{P, typeof(named_vals)}(named_vals) + Tangent{P,typeof(named_vals)}(named_vals) else # Handle MutableTangent end end - """ backing(x) @@ -77,7 +77,6 @@ function backing(x::T)::NamedTuple where {T} end end - """ _zeroed_backing(P) @@ -193,7 +192,6 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} return println(io) end - """ Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent @@ -258,7 +256,6 @@ function _backing_error(P, G, E) return throw(ArgumentError(msg)) end - function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end From c7932f1f4af5e1917ef498ec1d3bbc37246fd167 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 21 Aug 2023 17:46:22 +0800 Subject: [PATCH 04/36] WIP mutable Tangent (squash me) --- src/tangent_types/structural_tangent.jl | 49 ++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 9ea665735..9410e50af 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -3,15 +3,31 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`). as an object with mirroring fields. + +!!!!!! warning Exprimental + The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs. + `MutableTangent` is an experimental feature. + Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + """ abstract type StructuralTangent{P} <: AbstractTangent end function StructuralTangent{P}(nt::NamedTuple) where {P} - return Tangent{P,typeof(nt)}(nt) + if ismutabletype(P) + return MutableTangent{P}(nt) + else + return Tangent{P,typeof(nt)}(nt) + end end -StructuralTangent{P}(tup::Tuple) where {P} = Tangent{P,typeof(tup)}(tup) -StructuralTangent{P}(dict::Dict) where {P} = Tangent{P}(dict) +ismutabletype(::Type{P}) where P = ismutable(P) +ismutabletype(::Type{String}) = false +ismutabletype(::Type{Symbol}) = false + + +StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) @@ -29,10 +45,10 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} L = propertynames(backing(tangent)) vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) - return if tangent isa Tangent - Tangent{P,typeof(named_vals)}(named_vals) + return if tangent isa MutableTangent + MutableTangent{P}(named_vals) else - # Handle MutableTangent + Tangent{P,typeof(named_vals)}(named_vals) end end @@ -386,3 +402,24 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent + + +""" + MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent to a mutable struct. +It itself is also mutable. + +!!!!!! warning Exprimental + MutableTangent is an experimental feature. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) +""" +mutable struct MutableTangent{P} + # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict + # (but potentially a different one, as it doesn't contain tangents) + backing::NamedTuple +end + +Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) +Base.setproperty! \ No newline at end of file From 4b50fd2f385b7e7f0741f912f401cf424cd96844 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 15 Sep 2023 17:18:21 +0800 Subject: [PATCH 05/36] wip --- src/tangent_types/structural_tangent.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 9410e50af..7df7008db 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -14,16 +14,14 @@ as an object with mirroring fields. abstract type StructuralTangent{P} <: AbstractTangent end function StructuralTangent{P}(nt::NamedTuple) where {P} - if ismutabletype(P) + if has_mutable_tangent(P) return MutableTangent{P}(nt) else return Tangent{P,typeof(nt)}(nt) end end -ismutabletype(::Type{P}) where P = ismutable(P) -ismutabletype(::Type{String}) = false -ismutabletype(::Type{Symbol}) = false +has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0) StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) @@ -410,16 +408,22 @@ canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent This type represents the tangent to a mutable struct. It itself is also mutable. -!!!!!! warning Exprimental +!!! warning Exprimental MutableTangent is an experimental feature. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. Exactly how it should be used (e.g. is it forward-mode only?) + +!!! warning Do not directly mess with the tangent backing data + It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. + However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). + If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. """ mutable struct MutableTangent{P} - # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict - # (but potentially a different one, as it doesn't contain tangents) - backing::NamedTuple + #TODO: we may want to absolutely lock the type of this down + backing::NamedTuple end Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) -Base.setproperty! \ No newline at end of file +function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) + new_backing = Base.setindex(backing(tangent), x, name) +end \ No newline at end of file From 87ceddf1af98f64ea4a3a165545dd0c9bb76f283 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 15 Sep 2023 20:23:27 +0800 Subject: [PATCH 06/36] First pass at something that maybe works --- src/ChainRulesCore.jl | 2 +- src/tangent_types/structural_tangent.jl | 7 +++++-- test/tangent_types/structural_tangent.jl | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f943c50fa..4b86570cd 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -14,7 +14,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("debug_mode.jl") diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7df7008db..18e40fea9 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -418,12 +418,15 @@ It itself is also mutable. However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. """ -mutable struct MutableTangent{P} +mutable struct MutableTangent{P} <: StructuralTangent{P} #TODO: we may want to absolutely lock the type of this down backing::NamedTuple end -Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) +MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) +Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) new_backing = Base.setindex(backing(tangent), x, name) + setfield!(tangent, :backing, new_backing) + return x end \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index b0cb5577e..671d2fad0 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -425,3 +425,21 @@ end @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole end end + +@testset "MutableTangent" begin + mutable struct MDemo + x::Float64 + end + function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x) + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ + end + + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(;x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + + @test ∂obj.x == 10.0 + @test obj.x == 95.0 +end \ No newline at end of file From e75a364102a3661259ea0fed69a2797b34177773 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 18 Sep 2023 13:38:22 +0800 Subject: [PATCH 07/36] accept int index --- src/tangent_types/structural_tangent.jl | 13 ++++++++++++- test/tangent_types/structural_tangent.jl | 6 +++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 18e40fea9..f26c39874 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -424,9 +424,20 @@ mutable struct MutableTangent{P} <: StructuralTangent{P} end MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) + Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) +Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(backing(tangent), idx) # break ambig + function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) new_backing = Base.setindex(backing(tangent), x, name) setfield!(tangent, :backing, new_backing) return x -end \ No newline at end of file +end + +function Base.setproperty!(tangent::MutableTangent, idx::Int, x) + # needed due to https://github.com/JuliaLang/julia/issues/43155 + name = idx2sym(backing(tangent), idx) + return setproperty!(tangent, name, x) +end + +idx2sym(::NamedTuple{names}, idx) where names = names[idx] \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 671d2fad0..03e4db45e 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -439,7 +439,11 @@ end obj = MDemo(99.0) ∂obj = MutableTangent{MDemo}(;x=1.5) frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 end \ No newline at end of file From b56658169fffc11c1382aee58faee29064b9ceed Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 20:20:56 +0800 Subject: [PATCH 08/36] add == and hash for MutableTangent --- src/tangent_types/structural_tangent.jl | 11 ++++++-- test/tangent_types/structural_tangent.jl | 34 +++++++++++++++++------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index f26c39874..dd7a53ec6 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -5,6 +5,7 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple as an object with mirroring fields. !!!!!! warning Exprimental + `StructuralTangent` is an experimental feature, and is part of the mutation support featureset. The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs. `MutableTangent` is an experimental feature. Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. @@ -409,7 +410,7 @@ This type represents the tangent to a mutable struct. It itself is also mutable. !!! warning Exprimental - MutableTangent is an experimental feature. + MutableTangent is an experimental feature, and is part of the mutation support featureset. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. Exactly how it should be used (e.g. is it forward-mode only?) @@ -440,4 +441,10 @@ function Base.setproperty!(tangent::MutableTangent, idx::Int, x) return setproperty!(tangent, name, x) end -idx2sym(::NamedTuple{names}, idx) where names = names[idx] \ No newline at end of file +idx2sym(::NamedTuple{names}, idx) where names = names[idx] + +Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) +function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} + typeintersect(T1, T2) == Union{} && return false + backing(t1)==backing(t2) +end \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 03e4db45e..5736b2f1d 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -436,14 +436,28 @@ end return y, ẏ end - obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(;x=1.5) - frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 - @test obj.x == 95.0 - - frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) - @test ∂obj.x == 20.0 - @test getproperty(∂obj, 1) == 20.0 - @test obj.x == 96.0 + @testset "usecase" begin + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(;x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + @test ∂obj.x == 10.0 + @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 + end + + @testset "== and hash" begin + @test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0) + @test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0) + @test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0) + @test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0) + + nt = (;x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0) + + @test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0)) + end end \ No newline at end of file From f8900c47e56a2337d87ab9f7f333074b86d6c822 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 21:44:01 +0800 Subject: [PATCH 09/36] add and test zero_tangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/abstract_zero.jl | 29 +++++++++++++++++++++++++ src/tangent_types/structural_tangent.jl | 2 +- test/tangent_types/abstract_zero.jl | 14 ++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 4b86570cd..51db59b64 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # tangent operations +export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 77c455c04..dd9068b74 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -91,3 +91,32 @@ arguments. ``` """ struct NoTangent <: AbstractZero end + +""" + zero_tangent(primal) + +This returns an appropriate zero tangent suitable for accumulating tangents of the primal. +For mutable composites types this is a structural []`MutableTangent`](@ref) +For `Array`s, it is applied recursively for each element. +For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply. +(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything) + +!!! warning Exprimental + `zero_tangent`is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) +""" +function zero_tangent end +zero_tangent(::AbstractString) = ZeroTangent() +# zero_tangent(::Number) = zero(x) # TODO: do we want this? +zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this? +zero_tangent(primal::Array) = map(zero_tangent, primal) +@generated function zero_tangent(primal) + has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples + zfield_exprs = map(fieldnames(primal)) do fname + fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + Expr(:kw, fname, fval) + end + backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) + return :($MutableTangent{$primal}($backing_expr)) +end \ No newline at end of file diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index dd7a53ec6..864029572 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -22,7 +22,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P} end end -has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0) +has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0) StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 028d942ea..be2c74241 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -160,3 +160,17 @@ @test isempty(detect_ambiguities(M)) end end + +@testset "zero_tangent" begin + mutable struct MutDemo + x::Float64 + end + @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} + @test iszero(zero_tangent(MutDemo(1.5))) + + @test zero_tangent((;a=1)) isa ZeroTangent + + @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] + @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] +end + From bcb558766cf1228c9c602417b91e537f148e7e51 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 17:31:37 +0800 Subject: [PATCH 10/36] export StructuralTangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/structural_tangent.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 51db59b64..2a2f93c64 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -14,7 +14,7 @@ export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("debug_mode.jl") diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 864029572..25aec08b5 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -10,7 +10,6 @@ as an object with mirroring fields. `MutableTangent` is an experimental feature. Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. - """ abstract type StructuralTangent{P} <: AbstractTangent end @@ -447,4 +446,4 @@ Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} typeintersect(T1, T2) == Union{} && return false backing(t1)==backing(t2) -end \ No newline at end of file +end From 63c450bbf7a3175086f29142db5c07bcdffde26b Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 2 Oct 2023 12:02:19 +0800 Subject: [PATCH 11/36] Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/tangent_types/abstract_zero.jl | 3 +-- test/tangent_types/structural_tangent.jl | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index be2c74241..8193ac28f 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -168,9 +168,8 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((;a=1)) isa ZeroTangent + @test zero_tangent((; a=1)) isa ZeroTangent @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] end - diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 5736b2f1d..f4f753f47 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -430,7 +430,9 @@ end mutable struct MDemo x::Float64 end - function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x) + function ChainRulesCore.frule( + (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x + ) y = setfield!(obj, field, x) ẏ = setproperty!(ȯbj, field, ẋ) return y, ẏ @@ -438,7 +440,7 @@ end @testset "usecase" begin obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(;x=1.5) + ∂obj = MutableTangent{MDemo}(; x=1.5) frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) @test ∂obj.x == 10.0 @test obj.x == 95.0 @@ -450,14 +452,14 @@ end end @testset "== and hash" begin - @test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0) - @test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0) - @test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0) - @test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0) + @test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0) + @test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0) - nt = (;x=1.0) - @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0) + nt = (; x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0)) + @test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0)) end end \ No newline at end of file From 53e8f0dc48f85d2ed6fdfb863f20ca923175a956 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 15:45:22 +0800 Subject: [PATCH 12/36] handle unassigned a bit more --- src/tangent_types/abstract_zero.jl | 28 +++++++++++++++++++++++----- test/tangent_types/abstract_zero.jl | 25 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index dd9068b74..e167d6375 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -107,10 +107,9 @@ For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is d Exactly how it should be used (e.g. is it forward-mode only?) """ function zero_tangent end -zero_tangent(::AbstractString) = ZeroTangent() -# zero_tangent(::Number) = zero(x) # TODO: do we want this? -zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this? -zero_tangent(primal::Array) = map(zero_tangent, primal) + +zero_tangent(x::Number) = zero(x) + @generated function zero_tangent(primal) has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname @@ -119,4 +118,23 @@ zero_tangent(primal::Array) = map(zero_tangent, primal) end backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) return :($MutableTangent{$primal}($backing_expr)) -end \ No newline at end of file +end + +function zero_tangent(x::Array{P, N}) where {P, N} + (isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x) + + # Now we need to handle nonfully assigned arrays + # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 + y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...) + @inbounds for n in eachindex(y) + if isassigned(x, n) + y[n] = zero_tangent(x[n]) + end + end + return y +end + +guess_zero_tangent_type(::Type{T}) where {T<:Number} = T +guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N} +guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 + # TODO: we might be able to do better than this. even without. \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 8193ac28f..7060575f1 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -172,4 +172,29 @@ end @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + + @testset "undef elements" begin + x = Vector{Vector{Float64}}(undef, 3) + x[2] = [1.0,2.0] + dx = zero_tangent(x) + @test dx isa Vector{Vector{Float64}} + @test length(dx) == 3 + @test !isassigned(dx, 1) + @test dx[2] == [0.0, 0.0] + @test !isassigned(dx, 3) + + + a = Vector{MutDemo}(undef, 3) + a[2] = MutDemo(1.5) + da = zero_tangent(a) + @test !isassigned(da, 1) + @test iszero(da[2]) + @test !isassigned(da, 3) + + + db = zero_tangent(Vector{MutDemo}(undef, 3)) + @test all(ii->!isassigned(db,ii), eachindex(db)) + @test length(db)==3 + @test db isa Vector + end end From 17064c2ec055b96ec965d1375ede0e6e29b12c53 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 15:58:59 +0800 Subject: [PATCH 13/36] add some more test cases to zero_tangent --- test/tangent_types/abstract_zero.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7060575f1..1f6d857d9 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -165,10 +165,17 @@ end mutable struct MutDemo x::Float64 end + struct Demo + x::Float64 + end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) @test zero_tangent((; a=1)) isa ZeroTangent + @test zero_tangent(Demo(1.2)) isa ZeroTangent + + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] From 5a19913f874e951882de870902a34ec084cfa623 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 23:32:59 +0800 Subject: [PATCH 14/36] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 15 +++++++++------ test/tangent_types/abstract_zero.jl | 10 ++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index e167d6375..f79526cf9 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -120,12 +120,13 @@ zero_tangent(x::Number) = zero(x) return :($MutableTangent{$primal}($backing_expr)) end -function zero_tangent(x::Array{P, N}) where {P, N} - (isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x) - +function zero_tangent(x::Array{P,N}) where {P,N} + (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) && + return map(zero_tangent, x) + # Now we need to handle nonfully assigned arrays # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 - y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...) + y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) @inbounds for n in eachindex(y) if isassigned(x, n) y[n] = zero_tangent(x[n]) @@ -135,6 +136,8 @@ function zero_tangent(x::Array{P, N}) where {P, N} end guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N} +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 - # TODO: we might be able to do better than this. even without. \ No newline at end of file +# TODO: we might be able to do better than this. even without. \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 1f6d857d9..eb9b757a1 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -182,7 +182,7 @@ end @testset "undef elements" begin x = Vector{Vector{Float64}}(undef, 3) - x[2] = [1.0,2.0] + x[2] = [1.0, 2.0] dx = zero_tangent(x) @test dx isa Vector{Vector{Float64}} @test length(dx) == 3 @@ -190,7 +190,6 @@ end @test dx[2] == [0.0, 0.0] @test !isassigned(dx, 3) - a = Vector{MutDemo}(undef, 3) a[2] = MutDemo(1.5) da = zero_tangent(a) @@ -198,10 +197,9 @@ end @test iszero(da[2]) @test !isassigned(da, 3) - db = zero_tangent(Vector{MutDemo}(undef, 3)) - @test all(ii->!isassigned(db,ii), eachindex(db)) - @test length(db)==3 + @test all(ii -> !isassigned(db, ii), eachindex(db)) + @test length(db) == 3 @test db isa Vector - end + end end From dac92bd209ce2da67c0fc137123b0ad64e5b67c5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 6 Oct 2023 23:31:15 +0800 Subject: [PATCH 15/36] Handle Structs with undef fields --- src/tangent_types/abstract_zero.jl | 6 +++++- test/tangent_types/abstract_zero.jl | 18 +++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f79526cf9..7f3cd0944 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,7 +113,11 @@ zero_tangent(x::Number) = zero(x) @generated function zero_tangent(primal) has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname - fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + fval = if isdefined(primal, fname) + Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + else + ZeroTangent() + end Expr(:kw, fname, fval) end backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index eb9b757a1..1e1b2f28a 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -180,7 +180,7 @@ end @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] - @testset "undef elements" begin + @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] dx = zero_tangent(x) @@ -202,4 +202,20 @@ end @test length(db) == 3 @test db isa Vector end + + @testset "undef fields struct" begin + dx = zero_tangent(Core.Box()) + @test dx.contents isa ZeroTangent + @test (dx.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStruct + intro::Float64 + contents::Number + MyPartiallyDefinedStruct(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStruct(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test (dy.contents = 2.0) == 2.0 # should be assignable + end end From 98a7e39bd8a6ea3a89fd7fa927710ab6444a34ef Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 22 Dec 2023 19:20:48 +0800 Subject: [PATCH 16/36] overhaul zero_tangent and MutableTangent for type stability --- src/tangent_types/abstract_zero.jl | 45 ++++-- src/tangent_types/structural_tangent.jl | 174 ++++++++++++++---------- test/tangent_types/abstract_zero.jl | 57 ++++++-- 3 files changed, 177 insertions(+), 99 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 7f3cd0944..94a7bc084 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -111,22 +111,40 @@ function zero_tangent end zero_tangent(x::Number) = zero(x) @generated function zero_tangent(primal) - has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname - fval = if isdefined(primal, fname) - Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) - else - ZeroTangent() - end + fval = :( + if isdefined(primal, $(QuoteNode(fname))) + zero_tangent(getfield(primal, $(QuoteNode(fname)))) + else + # This is going to be potentially bad, but that's what they get for not giving us a primal + # This will never me mutated inplace, rather it will alway be replaced with an actual value first + ZeroTangent() + end + ) Expr(:kw, fname, fval) end - backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) - return :($MutableTangent{$primal}($backing_expr)) + + return if has_mutable_tangent(primal) + any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype + # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent + fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) + Expr(:kw, fname, fdef) + end + :($MutableTangent{$primal}( + $(Expr(:tuple, Expr(:parameters, any_mask...))), + $(Expr(:tuple, Expr(:parameters, zfield_exprs...))) + )) + else + :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) + end end +zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...) + function zero_tangent(x::Array{P,N}) where {P,N} - (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) && + if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) return map(zero_tangent, x) + end # Now we need to handle nonfully assigned arrays # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 @@ -139,9 +157,8 @@ function zero_tangent(x::Array{P,N}) where {P,N} return y end +# Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} - return Array{guess_zero_tangent_type(T),N} -end -guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 -# TODO: we might be able to do better than this. even without. \ No newline at end of file +guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) +guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = return Array{guess_zero_tangent_type(T),N} +guess_zero_tangent_type(T::Type)= Any \ No newline at end of file diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 25aec08b5..192d58f8c 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -13,6 +13,90 @@ as an object with mirroring fields. """ abstract type StructuralTangent{P} <: AbstractTangent end +""" + Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. +`P` is the the corresponding primal type that this is a tangent for. + +`Tangent{P}` should have fields (technically properties), that match to a subset of the +fields of the primal type; and each should be a tangent type matching to the primal +type of that field. +Fields of the P that are not present in the Tangent are treated as `Zero`. + +`T` is an implementation detail representing the backing data structure. +For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`. +It should not be passed in by user. + +For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly +to for a tuple. +For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values +via `tangent.fieldname`. +Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`. +To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) +function is provided. +""" +struct Tangent{P,T} <: StructuralTangent{P} + # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict + # (but potentially a different one, as it doesn't contain tangents) + backing::T + + function Tangent{P,T}(backing) where {P,T} + if P <: Tuple + T <: Tuple || _backing_error(P, T, Tuple) + elseif P <: AbstractDict + T <: AbstractDict || _backing_error(P, T, AbstractDict) + elseif P === Any # can be anything + else # Any other struct (including NamedTuple) + T <: NamedTuple || _backing_error(P, T, NamedTuple) + end + return new(backing) + end +end + +""" + MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent to a mutable struct. +It itself is also mutable. + +!!! warning Exprimental + MutableTangent is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) + +!!! warning Do not directly mess with the tangent backing data + It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. + However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). + If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. +""" +struct MutableTangent{P,F} <: StructuralTangent{P} + backing::F + + function MutableTangent{P}(fieldvals) where P + backing = map(Ref, fieldvals) + return new{P, typeof(backing)}(backing) + end + function MutableTangent{P}( + any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} + ) where {names, P} + + backing = map(any_mask, fvals) do isany, fval + ref = if isany + Ref{Any} + else + Ref + end + return ref(fval) + end + return new{P, typeof(backing)}(backing) + end +end + +#################################################################### +# StructuralTangent Common + + function StructuralTangent{P}(nt::NamedTuple) where {P} if has_mutable_tangent(P) return MutableTangent{P}(nt) @@ -21,6 +105,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P} end end + has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0) @@ -40,6 +125,9 @@ end Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) function Base.map(f, tangent::StructuralTangent{P}) where {P} + #TODO: is it even useful to support this on MutableTangents? + #TODO: we implictly assume only linear `f` are called and that it is safe to ignore noncanonical Zeros + # This feels like a fair assumption since all normal operations on tangents are linear L = propertynames(backing(tangent)) vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) @@ -63,7 +151,8 @@ primal types. backing(x::Tuple) = x backing(x::NamedTuple) = x backing(x::Dict) = x -backing(x::StructuralTangent) = getfield(x, :backing) +backing(x::Tangent) = getfield(x, :backing) +backing(x::MutableTangent) = map(getindex, getfield(x, :backing)) # For generic structs function backing(x::T)::NamedTuple where {T} @@ -206,46 +295,8 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} return println(io) end -""" - Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent - -This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. -`P` is the the corresponding primal type that this is a tangent for. - -`Tangent{P}` should have fields (technically properties), that match to a subset of the -fields of the primal type; and each should be a tangent type matching to the primal -type of that field. -Fields of the P that are not present in the Tangent are treated as `Zero`. - -`T` is an implementation detail representing the backing data structure. -For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`. -It should not be passed in by user. - -For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly -to for a tuple. -For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values -via `tangent.fieldname`. -Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`. -To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) -function is provided. -""" -struct Tangent{P,T} <: StructuralTangent{P} - # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict - # (but potentially a different one, as it doesn't contain tangents) - backing::T - - function Tangent{P,T}(backing) where {P,T} - if P <: Tuple - T <: Tuple || _backing_error(P, T, Tuple) - elseif P <: AbstractDict - T <: AbstractDict || _backing_error(P, T, AbstractDict) - elseif P === Any # can be anything - else # Any other struct (including NamedTuple) - T <: NamedTuple || _backing_error(P, T, NamedTuple) - end - return new(backing) - end -end +####################################### +# immutable Tangent function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple @@ -401,46 +452,19 @@ canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent - -""" - MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent - -This type represents the tangent to a mutable struct. -It itself is also mutable. - -!!! warning Exprimental - MutableTangent is an experimental feature, and is part of the mutation support featureset. - While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. - Exactly how it should be used (e.g. is it forward-mode only?) - -!!! warning Do not directly mess with the tangent backing data - It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. - However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). - If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. -""" -mutable struct MutableTangent{P} <: StructuralTangent{P} - #TODO: we may want to absolutely lock the type of this down - backing::NamedTuple -end +################################################### +# MutableTangent MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) -Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) -Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(backing(tangent), idx) # break ambig +ref_backing(t::MutableTangent) = getfield(t, :backing) -function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) - new_backing = Base.setindex(backing(tangent), x, name) - setfield!(tangent, :backing, new_backing) - return x -end +Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] +Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -function Base.setproperty!(tangent::MutableTangent, idx::Int, x) - # needed due to https://github.com/JuliaLang/julia/issues/43155 - name = idx2sym(backing(tangent), idx) - return setproperty!(tangent, name, x) -end +Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getproperty(ref_backing(tangent), name)[] = x +Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getproperty(ref_backing(tangent), idx)[] = x # break ambig -idx2sym(::NamedTuple{names}, idx) where names = names[idx] Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 1e1b2f28a..5e987d613 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -162,6 +162,8 @@ end @testset "zero_tangent" begin + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 mutable struct MutDemo x::Float64 end @@ -171,34 +173,34 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa ZeroTangent - @test zero_tangent(Demo(1.2)) isa ZeroTangent - - @test zero_tangent(1) === 0 - @test zero_tangent(1.0) === 0.0 + @test zero_tangent((; a=1)) isa Tangent{typeof((;a=1))} + @test zero_tangent(Demo(1.2)) isa Tangent{Demo} + @test zero_tangent(Demo(1.2)).x === 0.0 @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] dx = zero_tangent(x) @test dx isa Vector{Vector{Float64}} @test length(dx) == 3 - @test !isassigned(dx, 1) + @test !isassigned(dx, 1) # We may reconsider this later @test dx[2] == [0.0, 0.0] - @test !isassigned(dx, 3) + @test !isassigned(dx, 3) # We may reconsider this later a = Vector{MutDemo}(undef, 3) a[2] = MutDemo(1.5) da = zero_tangent(a) - @test !isassigned(da, 1) + @test !isassigned(da, 1) # We may reconsider this later @test iszero(da[2]) - @test !isassigned(da, 3) + @test !isassigned(da, 3) # We may reconsider this later db = zero_tangent(Vector{MutDemo}(undef, 3)) - @test all(ii -> !isassigned(db, ii), eachindex(db)) + @test all(ii -> !isassigned(db, ii), eachindex(db)) # We may reconsider this later @test length(db) == 3 @test db isa Vector end @@ -217,5 +219,40 @@ end @test iszero(dy.intro) @test iszero(dy.contents) @test (dy.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStructWithAnys + intro::Float64 + contents::Any + MyPartiallyDefinedStructWithAnys(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStructWithAnys(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test dy.contents === ZeroTangent() # we just don't know anything about this data + @test (dy.contents = 2.0) == 2.0 # should be assignable + @test (dy.contents = [2.0, 4.0]) == [2.0, 4.0] # should be assignable to different values + + mutable struct MyStructWithNonConcreteFields + x::Any + y::Union{Float64, Vector{Float64}} + z::AbstractVector + end + d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0])) + @test iszero(d.x) + d.x = Tangent{Base.RefValue{Float64}}(x=1.5) + @test d.x == Tangent{Base.RefValue{Float64}}(x=1.5) #should be assignable + d.x=2.4 + @test d.x == 2.4 #should be assignable + @test iszero(d.y) + d.y=2.4 + @test d.y == 2.4 #should be assignable + d.y=[2.4] + @test d.y == [2.4] #should be assignable + @test iszero(d.z) + d.z = [1.0, 2.0] + @test d.z == [1.0, 2.0] + d.z = @view [2.0,3.0,4.0][1:2] + @test d.z == [2.0, 3.0] + @test d.z isa SubArray end end From 7d88acd4a3c744faf87fa4d8396c881eefe145e3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 27 Dec 2023 11:37:48 +0800 Subject: [PATCH 17/36] set MutableTangent setproperty! on index --- src/tangent_types/structural_tangent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 192d58f8c..f68eefb76 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -462,8 +462,8 @@ ref_backing(t::MutableTangent) = getfield(t, :backing) Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getproperty(ref_backing(tangent), name)[] = x -Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getproperty(ref_backing(tangent), idx)[] = x # break ambig +Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getfield(ref_backing(tangent), name)[] = x +Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getfield(ref_backing(tangent), idx)[] = x # break ambig Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) From b3562c6603da408851276c7a148941c77960180d Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 27 Dec 2023 11:43:13 +0800 Subject: [PATCH 18/36] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index f68eefb76..7262d4da7 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -462,9 +462,12 @@ ref_backing(t::MutableTangent) = getfield(t, :backing) Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getfield(ref_backing(tangent), name)[] = x -Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getfield(ref_backing(tangent), idx)[] = x # break ambig - +function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) + return getfield(ref_backing(tangent), name)[] = x +end +function Base.setproperty!(tangent::MutableTangent, idx::Int, x) + return getfield(ref_backing(tangent), idx)[] = x +end # break ambig Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} From dd3f1ab17ac174f8e537e64ffb2ed786e0275648 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 16:09:15 +0800 Subject: [PATCH 19/36] handle abstract fields right in mutable tangents outside of zero tangent --- src/tangent_types/structural_tangent.jl | 10 +++--- test/tangent_types/structural_tangent.jl | 41 +++++++++++++++++++++--- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7262d4da7..c7ae1b1b5 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -73,10 +73,6 @@ It itself is also mutable. struct MutableTangent{P,F} <: StructuralTangent{P} backing::F - function MutableTangent{P}(fieldvals) where P - backing = map(Ref, fieldvals) - return new{P, typeof(backing)}(backing) - end function MutableTangent{P}( any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} ) where {names, P} @@ -91,8 +87,14 @@ struct MutableTangent{P,F} <: StructuralTangent{P} end return new{P, typeof(backing)}(backing) end + + function MutableTangent{P}(fvals) where P + any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P))) + return MutableTangent{P}(any_mask, fvals) + end end + #################################################################### # StructuralTangent Common diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index f4f753f47..8ab5a6bc6 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -4,6 +4,11 @@ struct Foo y::Float64 end +mutable struct MFoo + x::Float64 + y +end + # For testing Primal + Tangent performance struct Bar x::Float64 @@ -452,14 +457,40 @@ end end @testset "== and hash" begin - @test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0) - @test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0) + @test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0) + @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) nt = (; x=1.0) @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0)) + @test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + end + + @testset "Mutation" begin + v = MutableTangent{MFoo}(x=1.5, y=2.4) + v.x = 1.6 + @test v == MutableTangent{MFoo}(x=1.6, y=2.4) + v.y = [1.0, 2.0] # change type, because primal can change type + @test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0]) + end +end + +@testset "map" begin + @testset "Tangent" begin + ∂foo = Tangent{Foo}(x=1.5, y=2.4) + @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8) + + ∂foo = Tangent{Foo}(x=1.5) + @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0) + end + @testset "MutableTangent" begin + ∂foo = MutableTangent{MFoo}(x=1.5, y=2.4) + ∂foo2 = map(v->2*v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8) + # Check can still be mutated to new typ + ∂foo2.y=[1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0]) end end \ No newline at end of file From db45626b0d6dab704f3ed598f5dd829fbc8d6132 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 16:56:52 +0800 Subject: [PATCH 20/36] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 11 +++++---- test/tangent_types/abstract_zero.jl | 17 +++++++------- test/tangent_types/structural_tangent.jl | 30 ++++++++++++------------ 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 94a7bc084..a0f8d1f82 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -123,7 +123,6 @@ zero_tangent(x::Number) = zero(x) ) Expr(:kw, fname, fval) end - return if has_mutable_tangent(primal) any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent @@ -132,11 +131,11 @@ zero_tangent(x::Number) = zero(x) end :($MutableTangent{$primal}( $(Expr(:tuple, Expr(:parameters, any_mask...))), - $(Expr(:tuple, Expr(:parameters, zfield_exprs...))) + $(Expr(:tuple, Expr(:parameters, zfield_exprs...))), )) else :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) - end + end end zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...) @@ -160,5 +159,7 @@ end # Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) -guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = return Array{guess_zero_tangent_type(T),N} -guess_zero_tangent_type(T::Type)= Any \ No newline at end of file +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end +guess_zero_tangent_type(T::Type) = Any \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 5e987d613..81114511a 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -173,7 +173,7 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa Tangent{typeof((;a=1))} + @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} @test zero_tangent(Demo(1.2)) isa Tangent{Demo} @test zero_tangent(Demo(1.2)).x === 0.0 @@ -181,7 +181,6 @@ end @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) - @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] @@ -234,24 +233,24 @@ end mutable struct MyStructWithNonConcreteFields x::Any - y::Union{Float64, Vector{Float64}} + y::Union{Float64,Vector{Float64}} z::AbstractVector end d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0])) @test iszero(d.x) - d.x = Tangent{Base.RefValue{Float64}}(x=1.5) - @test d.x == Tangent{Base.RefValue{Float64}}(x=1.5) #should be assignable - d.x=2.4 + d.x = Tangent{Base.RefValue{Float64}}(; x=1.5) + @test d.x == Tangent{Base.RefValue{Float64}}(; x=1.5) #should be assignable + d.x = 2.4 @test d.x == 2.4 #should be assignable @test iszero(d.y) - d.y=2.4 + d.y = 2.4 @test d.y == 2.4 #should be assignable - d.y=[2.4] + d.y = [2.4] @test d.y == [2.4] #should be assignable @test iszero(d.z) d.z = [1.0, 2.0] @test d.z == [1.0, 2.0] - d.z = @view [2.0,3.0,4.0][1:2] + d.z = @view [2.0, 3.0, 4.0][1:2] @test d.z == [2.0, 3.0] @test d.z isa SubArray end diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 8ab5a6bc6..0982f97c0 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -457,40 +457,40 @@ end end @testset "== and hash" begin - @test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0) + @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) nt = (; x=1.0) @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) end @testset "Mutation" begin - v = MutableTangent{MFoo}(x=1.5, y=2.4) + v = MutableTangent{MFoo}(; x=1.5, y=2.4) v.x = 1.6 - @test v == MutableTangent{MFoo}(x=1.6, y=2.4) + @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) v.y = [1.0, 2.0] # change type, because primal can change type - @test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0]) + @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) end end @testset "map" begin @testset "Tangent" begin - ∂foo = Tangent{Foo}(x=1.5, y=2.4) - @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8) + ∂foo = Tangent{Foo}(; x=1.5, y=2.4) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) - ∂foo = Tangent{Foo}(x=1.5) - @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0) + ∂foo = Tangent{Foo}(; x=1.5) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) end @testset "MutableTangent" begin - ∂foo = MutableTangent{MFoo}(x=1.5, y=2.4) - ∂foo2 = map(v->2*v, ∂foo) - @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8) + ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) + ∂foo2 = map(v -> 2 * v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) # Check can still be mutated to new typ - ∂foo2.y=[1.0, 2.0] - @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0]) + ∂foo2.y = [1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) end end \ No newline at end of file From ad299716065b18af366355e52666a4734cd45e34 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 17:55:36 +0800 Subject: [PATCH 21/36] Add docs for forward mutation support --- docs/make.jl | 1 + docs/src/api.md | 2 +- .../superpowers/mutation_support.md | 73 +++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 docs/src/rule_author/superpowers/mutation_support.md diff --git a/docs/make.jl b/docs/make.jl index ad86a84ae..1666665fe 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -61,6 +61,7 @@ makedocs(; "`@opt_out`" => "rule_author/superpowers/opt_out.md", "`RuleConfig`" => "rule_author/superpowers/ruleconfig.md", "Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md", + "Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md", ], "Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md", "Tips for making your package work with AD" => "rule_author/tips_for_packages.md", diff --git a/docs/src/api.md b/docs/src/api.md index 5648058e0..57b7bf2ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,7 +20,7 @@ Modules = [ChainRulesCore] Pages = [ "tangent_types/abstract_zero.jl", "tangent_types/one.jl", - "tangent_types/tangent.jl", + "tangent_types/structural_tangent.jl", "tangent_types/thunks.jl", "tangent_types/abstract_tangent.jl", "tangent_types/notimplemented.jl", diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md new file mode 100644 index 000000000..55629166a --- /dev/null +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -0,0 +1,73 @@ +# Mutation Support + +ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD. +(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) + +!!! warning "Experimental" + This page documents an experimental feature. + Expect breaking changes in minor versions while this remains. + It is not suitable for general use unless you are prepared to modify how you are using it each minor release. + It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions. + + +## `MutableTangent` +The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place. +It is required to be a structural tangent, having one tangent for each field of the primal object. + +Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. +Just like not all `struct`s need to use `Tangent`s. +Common examples away from this are natural tangent types like for arrays. +However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance. + +## `zero_tangent` + +The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. +The [`ZeroTangent`](@ref) type also does this. +The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal. +For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. +To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). + +It is also useful for reasons of type stability, since it is always a structural tangent. +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not. + +## Writing a frule for a mutating function +It is relatively straight forward to write a frule for a mutating function. +There are a few key points to follow: + - There must be a mutable tangent input for every mutated primal input + - When the primal value is changed, the corresponding change must be made to its tangent partner + - When a value is returned, return its partnered tangent. + + +### Example +For example, consider the primal function with: +1. takes two `Ref`s +2. doubles the first one inplace +3. overwrites the second one's value with the literal 5.0 +4. returns the first one + + +```julia +function foo!(a::Base.RefValue, b::Base.RefValue) + a[] *= 2 + b[] = 5.0 + return a +end +``` + +The frule for this would be: +```julia +function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) + @assert ȧ isa MutableTangent{typeof(a)} + @assert ḃ isa MutableTangent{typeof(b)} + + a[] *= 2 + ȧ.x *= 2 # `.x` is the field that lives behind RefValues + + b[]=5.0 + ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` + + return a, ȧ +end +``` + +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. \ No newline at end of file From 8471f39e6eed6097768c0e52ad4623f4fc6697cc Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 12:40:07 +0800 Subject: [PATCH 22/36] use ismutabletype from Compat --- Project.toml | 14 +++++++------- src/ChainRulesCore.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 195cccac5..63fd50344 100644 --- a/Project.toml +++ b/Project.toml @@ -7,20 +7,17 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[weakdeps] -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[extensions] -ChainRulesCoreSparseArraysExt = "SparseArrays" - [compat] BenchmarkTools = "0.5" -Compat = "2, 3, 4" +Compat = "3.40, 4" FiniteDifferences = "0.10" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" julia = "1.6" +[extensions] +ChainRulesCoreSparseArraysExt = "SparseArrays" + [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -31,3 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"] + +[weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 2a2f93c64..bda392497 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using Compat: hasfield, hasproperty +using Compat: hasfield, hasproperty, ismutabletype export frule, rrule # core function # rule configurations From 9b6d6e52bb8b4c7b7a00742cf8512b7c8e869e4d Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 12:41:11 +0800 Subject: [PATCH 23/36] wrap structural tangent tests in a common testset --- test/tangent_types/structural_tangent.jl | 799 ++++++++++++----------- 1 file changed, 400 insertions(+), 399 deletions(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 0982f97c0..16d702c14 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -22,475 +22,476 @@ struct StructWithInvariant StructWithInvariant(x) = new(x, 2x) end +@testset "StructuralTangent" begin + @testset "Tangent" begin + @testset "empty types" begin + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} + end -@testset "Tangent" begin - @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} - end - - @testset "constructor" begin - t = (1.0, 2.0) - nt = (x=1, y=2.0) - d = Dict(:x => 1.0, :y => 2.0) - vals = [1, 2] - - @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - - @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) + @testset "constructor" begin + t = (1.0, 2.0) + nt = (x=1, y=2.0) + d = Dict(:x => 1.0, :y => 2.0) + vals = [1, 2] - @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) - @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) - @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) + @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) - end + @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) - @testset "==" begin - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) - @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) + @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) - @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) - @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) + @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) + end - tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) + @testset "==" begin + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) - @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) - end + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) + @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) - @testset "hash" begin - @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) - @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) - end + tup = (1.0, 2.0) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(; x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(; x=2.5), :x) == true - if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) end - @test Tangent{Foo}(; x=2.5).x == 2.5 - - tang1 = Tangent{Tuple{Float64}}(2.0) - @test keys(tang1) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test NoTangent() === @inferred Base.tail(tang1) - @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) - - tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) - @test @inferred(first(tang3)) === tang3[1] === 1.0 - @test @inferred(last(tang3)) isa Thunk - @test unthunk(last(tang3)) == [7.0] - @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() - @test Tuple(Base.tail(tang3))[end] isa Thunk - - NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk - @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 - @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent - - ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) - @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} - @test NoTangent() === @inferred Base.tail(ntang1) - - # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 - # if VERSION >= v"1.8-" - # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # else - # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # end - @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - - @test length(Tangent{Foo}(; x=2.5)) == 1 - @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - - @test eltype(Tangent{Foo}(; x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 - - # Testing iterate via collect - @test collect(Tangent{Foo}(; x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] - - # Test indexed_iterate - ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function (tangent) - a, b = tangent - return (a, b) + + @testset "hash" begin + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) end - @inferred _unpack2tuple(ctup) - @test _unpack2tuple(ctup) === (2.0, 3) - - # Test getproperty is inferrable - _unpacknamedtuple = tangent -> (tangent.x, tangent.y) - if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + + @testset "indexing, iterating, and properties" begin + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true + if isdefined(Base, :hasproperty) + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + end + @test Tangent{Foo}(; x=2.5).x == 2.5 + + tang1 = Tangent{Tuple{Float64}}(2.0) + @test keys(tang1) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test NoTangent() === @inferred Base.tail(tang1) + @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) + + tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) + @test @inferred(first(tang3)) === tang3[1] === 1.0 + @test @inferred(last(tang3)) isa Thunk + @test unthunk(last(tang3)) == [7.0] + @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() + @test Tuple(Base.tail(tang3))[end] isa Thunk + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk + @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 + @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent + + ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) + @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} + @test NoTangent() === @inferred Base.tail(ntang1) + + # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 + # if VERSION >= v"1.8-" + # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # else + # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # end + @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false + + @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 + + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 + + # Testing iterate via collect + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] + + # Test indexed_iterate + ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) + _unpack2tuple = function (tangent) + a, b = tangent + return (a, b) + end + @inferred _unpack2tuple(ctup) + @test _unpack2tuple(ctup) === (2.0, 3) + + # Test getproperty is inferrable + _unpacknamedtuple = tangent -> (tangent.x, tangent.y) + if VERSION ≥ v"1.2" + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + end end - end - @testset "reverse" begin - c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") - cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) - @test reverse(c) === cr + @testset "reverse" begin + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) + @test reverse(c) === cr - if VERSION < v"1.9-" - # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) + if VERSION < v"1.9-" + # can't reverse a named tuple or a dict + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) - d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{typeof(d),typeof(d)}(d) - @test_throws MethodError reverse(Tangent{Foo}()) - else - # These now work but do we care? + d = Dict(:x => 1, :y => 2.0) + cdict = Tangent{typeof(d),typeof(d)}(d) + @test_throws MethodError reverse(Tangent{Foo}()) + else + # These now work but do we care? + end end - end - @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() - end + @testset "unset properties" begin + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + end - @testset "conj" begin - @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) - @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) - ) - @test ==( - conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), - Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), - ) - end + @testset "conj" begin + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) + @test ==( + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) + ) + @test ==( + conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), + Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), + ) + end - @testset "canonicalize" begin - # Testing iterate via collect - @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) + @testset "canonicalize" begin + # Testing iterate via collect + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) - # For structure it needs to match order and ZeroTangent() fill to match primal - CFoo = Tangent{Foo} - @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) + # For structure it needs to match order and ZeroTangent() fill to match primal + CFoo = Tangent{Foo} + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) - @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a=1, b=2) - c2 = Tangent{Any}(1, 2) - c3 = Tangent{Any}(Dict(4 => 3)) + @testset "unspecified primal type" begin + c1 = Tangent{Any}(; a=1, b=2) + c2 = Tangent{Any}(1, 2) + c3 = Tangent{Any}(Dict(4 => 3)) - @test c1 == canonicalize(c1) - @test c2 == canonicalize(c2) - @test c3 == canonicalize(c3) + @test c1 == canonicalize(c1) + @test c2 == canonicalize(c2) + @test c3 == canonicalize(c3) + end end - end - @testset "+ with other composites" begin - @testset "Structs" begin - CFoo = Tangent{Foo} - @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) - @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) - @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) - end + @testset "+ with other composites" begin + @testset "Structs" begin + CFoo = Tangent{Foo} + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) + end - @testset "Tuples" begin - @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} - ) - @test ( - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64,Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) - end + @testset "Tuples" begin + @test ==( + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} + ) + @test ( + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + end - @testset "NamedTuples" begin - make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) - t1 = make_tangent((; a=1.5, b=0.0)) - t2 = make_tangent((; a=0.0, b=2.5)) - t_sum = make_tangent((a=1.5, b=2.5)) - @test t1 + t2 == t_sum - end + @testset "NamedTuples" begin + make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) + t1 = make_tangent((; a=1.5, b=0.0)) + t2 = make_tangent((; a=0.0, b=2.5)) + t_sum = make_tangent((a=1.5, b=2.5)) + @test t1 + t2 == t_sum + end - @testset "Dicts" begin - d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) - d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) - d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) - @test d1 + d2 == d_sum + @testset "Dicts" begin + d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) + d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) + d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) + @test d1 + d2 == d_sum + end + + @testset "Fields of type NotImplemented" begin + CFoo = Tangent{Foo} + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa CFoo + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Tuple}(1.5) + b = Tangent{Tuple}(@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Tuple} + @test first(z) isa ChainRulesCore.NotImplemented + end + + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{NamedTuple{(:x,)}} + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Dict}(Dict(:x => 1.5)) + b = Tangent{Dict}(Dict(:x => @not_implemented(""))) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Dict} + @test z[:x] isa ChainRulesCore.NotImplemented + end + end end - @testset "Fields of type NotImplemented" begin - CFoo = Tangent{Foo} - a = CFoo(; x=1.5) - b = CFoo(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa CFoo - @test z.x isa ChainRulesCore.NotImplemented + @testset "+ with Primals" begin + @testset "Structs" begin + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end - a = Tangent{Tuple}(1.5) - b = Tangent{Tuple}(@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Tuple} - @test first(z) isa ChainRulesCore.NotImplemented + @testset "Tuples" begin + @test Tangent{Tuple{}}() + () == () + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end - a = Tangent{NamedTuple{(:x,)}}(; x=1.5) - b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{NamedTuple{(:x,)}} - @test z.x isa ChainRulesCore.NotImplemented + @testset "NamedTuple" begin + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) end - a = Tangent{Dict}(Dict(:x => 1.5)) - b = Tangent{Dict}(Dict(:x => @not_implemented(""))) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Dict} - @test z[:x] isa ChainRulesCore.NotImplemented + @testset "Dicts" begin + d_primal = Dict(4 => 3.0, 3 => 2.0) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) + @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end - end - @testset "+ with Primals" begin - @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 - end + @testset "+ with Primals, with inner constructor" begin + value = StructWithInvariant(10.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) - @testset "Tuples" begin - @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) - end + @testset "with and without debug mode" begin + @assert ChainRulesCore.debug_mode() == false + @test_throws MethodError (value + diff) + @test_throws MethodError (diff + value) - @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + ChainRulesCore.debug_mode() = true # enable debug mode + @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) + @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) + ChainRulesCore.debug_mode() = false # disable it again + end - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + # Now we define constuction for ChainRulesCore.jl's purposes: + # It is going to determine the root quanity of the invarient + function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) + x = (nt.x + nt.x2 / 2) / 2 + return StructWithInvariant(x) + end + @test value + diff == StructWithInvariant(12.5) + @test diff + value == StructWithInvariant(12.5) end - @testset "Dicts" begin - d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) - @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) + @testset "differential arithmetic" begin + c = Tangent{Foo}(; y=1.5, x=2.5) + + @test NoTangent() * c == NoTangent() + @test c * NoTangent() == NoTangent() + @test dot(NoTangent(), c) == NoTangent() + @test dot(c, NoTangent()) == NoTangent() + @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y + @test norm(NoTangent(), Inf) == 0 + + @test ZeroTangent() * c == ZeroTangent() + @test c * ZeroTangent() == ZeroTangent() + @test dot(ZeroTangent(), c) == ZeroTangent() + @test dot(c, ZeroTangent()) == ZeroTangent() + @test norm(ZeroTangent()) == 0 + @test norm(ZeroTangent(), 0.4) == 0 + + @test true * c === c + @test c * true === c + + t = @thunk 2 + @test t * c == 2 * c + @test c * t == c * 2 end - end - @testset "+ with Primals, with inner constructor" begin - value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) + @testset "-Tangent" begin + t = Tangent{Foo}(; x=1.0, y=-2.0) + @test -t == Tangent{Foo}(; x=-1.0, y=2.0) + @test -1.0 * t == -t + end - @testset "with and without debug mode" begin - @assert ChainRulesCore.debug_mode() == false - @test_throws MethodError (value + diff) - @test_throws MethodError (diff + value) + @testset "scaling" begin + @test ( + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 + ) + @test ( + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 + ) + d = Tangent{Dict}(Dict(4 => 3.0)) + two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) + @test 2 * d == two_d == d * 2 - ChainRulesCore.debug_mode() = true # enable debug mode - @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) - @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) - ChainRulesCore.debug_mode() = false # disable it again + @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) + @test_throws MethodError [1, 2] * d + @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] end - # Now we define constuction for ChainRulesCore.jl's purposes: - # It is going to determine the root quanity of the invarient - function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2 / 2) / 2 - return StructWithInvariant(x) + @testset "iszero" begin + @test iszero(Tangent{Foo}()) + @test iszero(Tangent{Tuple{}}()) + @test iszero(Tangent{Foo}(; x=ZeroTangent())) + @test iszero(Tangent{Foo}(; y=0.0)) + @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) + + @test !iszero(Tangent{Foo}(; y=3.0)) end - @test value + diff == StructWithInvariant(12.5) - @test diff + value == StructWithInvariant(12.5) - end - @testset "differential arithmetic" begin - c = Tangent{Foo}(; y=1.5, x=2.5) - - @test NoTangent() * c == NoTangent() - @test c * NoTangent() == NoTangent() - @test dot(NoTangent(), c) == NoTangent() - @test dot(c, NoTangent()) == NoTangent() - @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y - @test norm(NoTangent(), Inf) == 0 - - @test ZeroTangent() * c == ZeroTangent() - @test c * ZeroTangent() == ZeroTangent() - @test dot(ZeroTangent(), c) == ZeroTangent() - @test dot(c, ZeroTangent()) == ZeroTangent() - @test norm(ZeroTangent()) == 0 - @test norm(ZeroTangent(), 0.4) == 0 - - @test true * c === c - @test c * true === c - - t = @thunk 2 - @test t * c == 2 * c - @test c * t == c * 2 - end + @testset "show" begin + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" + # check for exact regex match not occurence( `^...$`) + # and allowing optional whitespace (`\s?`) + @test occursin( + r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", + repr(Tangent{Tuple{Int64,Int64}}(1, 2)), + ) - @testset "-Tangent" begin - t = Tangent{Foo}(; x=1.0, y=-2.0) - @test -t == Tangent{Foo}(; x=-1.0, y=2.0) - @test -1.0 * t == -t - end + @test repr(Tangent{Foo}()) == "Tangent{Foo}()" + end - @testset "scaling" begin - @test ( - 2 * Tangent{Foo}(; y=1.5, x=2.5) == - Tangent{Foo}(; y=3.0, x=5.0) == - Tangent{Foo}(; y=1.5, x=2.5) * 2 - ) - @test ( - 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == - Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == - Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 - ) - d = Tangent{Dict}(Dict(4 => 3.0)) - two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) - @test 2 * d == two_d == d * 2 + @testset "internals" begin + @testset "Can't do backing on primative type" begin + @test_throws Exception ChainRulesCore.backing(1.4) + end - @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) - @test_throws MethodError [1, 2] * d - @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] - end + @testset "Internals don't allocate a ton" begin + bk = (; x=1.0, y=2.0) + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 - @testset "iszero" begin - @test iszero(Tangent{Foo}()) - @test iszero(Tangent{Tuple{}}()) - @test iszero(Tangent{Foo}(; x=ZeroTangent())) - @test iszero(Tangent{Foo}(; y=0.0)) - @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) + # weaker version of the above (which should pass on all versions) + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 + @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 + end + end - @test !iszero(Tangent{Foo}(; y=3.0)) + @testset "non-same-typed differential arithmetic" begin + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1) + end + + @testset "printing" begin + t5 = Tuple(rand(3)) + nt3 = (x=t5, y=t5, z=nothing) + tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent + @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened + @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole + end end - @testset "show" begin - @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" - # check for exact regex match not occurence( `^...$`) - # and allowing optional whitespace (`\s?`) - @test occursin( - r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", - repr(Tangent{Tuple{Int64,Int64}}(1, 2)), + @testset "MutableTangent" begin + mutable struct MDemo + x::Float64 + end + function ChainRulesCore.frule( + (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x ) - - @test repr(Tangent{Foo}()) == "Tangent{Foo}()" - end - - @testset "internals" begin - @testset "Can't do backing on primative type" begin - @test_throws Exception ChainRulesCore.backing(1.4) + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ end - @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 - - # weaker version of the above (which should pass on all versions) - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 - @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 + @testset "usecase" begin + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(; x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + @test ∂obj.x == 10.0 + @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 end - end - @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1) - end - - @testset "printing" begin - t5 = Tuple(rand(3)) - nt3 = (x=t5, y=t5, z=nothing) - tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent - @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened - @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole - end -end + @testset "== and hash" begin + @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) + @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) -@testset "MutableTangent" begin - mutable struct MDemo - x::Float64 - end - function ChainRulesCore.frule( - (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x - ) - y = setfield!(obj, field, x) - ẏ = setproperty!(ȯbj, field, ẋ) - return y, ẏ - end + nt = (; x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @testset "usecase" begin - obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(; x=1.5) - frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 - @test obj.x == 95.0 - - frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) - @test ∂obj.x == 20.0 - @test getproperty(∂obj, 1) == 20.0 - @test obj.x == 96.0 - end - - @testset "== and hash" begin - @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) - @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) - - nt = (; x=1.0) - @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - - @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) - end + @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + end - @testset "Mutation" begin - v = MutableTangent{MFoo}(; x=1.5, y=2.4) - v.x = 1.6 - @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) - v.y = [1.0, 2.0] # change type, because primal can change type - @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) + @testset "Mutation" begin + v = MutableTangent{MFoo}(; x=1.5, y=2.4) + v.x = 1.6 + @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) + v.y = [1.0, 2.0] # change type, because primal can change type + @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) + end end -end -@testset "map" begin - @testset "Tangent" begin - ∂foo = Tangent{Foo}(; x=1.5, y=2.4) - @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) + @testset "map" begin + @testset "Tangent" begin + ∂foo = Tangent{Foo}(; x=1.5, y=2.4) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) - ∂foo = Tangent{Foo}(; x=1.5) - @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) - end - @testset "MutableTangent" begin - ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) - ∂foo2 = map(v -> 2 * v, ∂foo) - @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) - # Check can still be mutated to new typ - ∂foo2.y = [1.0, 2.0] - @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) + ∂foo = Tangent{Foo}(; x=1.5) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) + end + @testset "MutableTangent" begin + ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) + ∂foo2 = map(v -> 2 * v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) + # Check can still be mutated to new typ + ∂foo2.y = [1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) + end end end \ No newline at end of file From f5efd7dbcb46970a7b0ffe1366834c8584191ee9 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 15:22:20 +0800 Subject: [PATCH 24/36] Support types that have no tangent space in zero_tangent --- src/tangent_types/abstract_zero.jl | 10 +++++-- test/tangent_types/abstract_zero.jl | 45 +++++++++++++++++++---------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index a0f8d1f82..8e31ae492 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -96,10 +96,11 @@ struct NoTangent <: AbstractZero end zero_tangent(primal) This returns an appropriate zero tangent suitable for accumulating tangents of the primal. -For mutable composites types this is a structural []`MutableTangent`](@ref) +For mutable composites types this is a structural [`MutableTangent`](@ref) For `Array`s, it is applied recursively for each element. -For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply. -(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything) +For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` +and suitable for accumulating against. +In general though, it is more likely to produce a structural tangent. !!! warning Exprimental `zero_tangent`is an experimental feature, and is part of the mutation support featureset. @@ -110,7 +111,10 @@ function zero_tangent end zero_tangent(x::Number) = zero(x) +zero_tangent(::Type) = NoTangent() + @generated function zero_tangent(primal) + fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. zfield_exprs = map(fieldnames(primal)) do fname fval = :( if isdefined(primal, $(QuoteNode(fname))) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 81114511a..960e88d99 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -162,25 +162,38 @@ end @testset "zero_tangent" begin - @test zero_tangent(1) === 0 - @test zero_tangent(1.0) === 0.0 - mutable struct MutDemo - x::Float64 - end - struct Demo - x::Float64 - end - @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} - @test iszero(zero_tangent(MutDemo(1.5))) + @testset "basics" begin + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 + mutable struct MutDemo + x::Float64 + end + struct Demo + x::Float64 + end + @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} + @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} - @test zero_tangent(Demo(1.2)) isa Tangent{Demo} - @test zero_tangent(Demo(1.2)).x === 0.0 + @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} + @test zero_tangent(Demo(1.2)) isa Tangent{Demo} + @test zero_tangent(Demo(1.2)).x === 0.0 - @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] - @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] + @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + + @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + end + + @testset "Weird types" begin + @test iszero(zero_tangent(typeof(Int))) # primative type + @test iszero(zero_tangent(typeof(Base.RefValue))) # struct + @test iszero(zero_tangent(Vector)) # UnionAll + @test iszero(zero_tangent(Union{Int, Float64})) # Union + @test iszero(zero_tangent(:abc)) + @test iszero(zero_tangent("abc")) + @test iszero(zero_tangent(sin)) + end - @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] From 8a54fae8e9cf179ab224bf744cff4a34c0966ef3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 16 Jan 2024 13:59:27 +0800 Subject: [PATCH 25/36] define zero_tangent for Tangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/abstract_zero.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index bda392497..286f71db2 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -19,9 +19,9 @@ export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, include("debug_mode.jl") include("tangent_types/abstract_tangent.jl") +include("tangent_types/structural_tangent.jl") include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") -include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") include("tangent_arithmetic.jl") diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 8e31ae492..86ed92523 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,6 +113,9 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() +zero_tangent(x::Tangent) = ZeroTangent() +# TODO: zero_tangent(x::MutableTangent) + @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. zfield_exprs = map(fieldnames(primal)) do fname From b67686da84adc2ffc040c20be68a6e3f64de7136 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 11:43:59 +0800 Subject: [PATCH 26/36] Add structural zero tangent code for higher order --- src/tangent_types/abstract_zero.jl | 11 +++++++++-- test/tangent_types/abstract_zero.jl | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 86ed92523..61fc05b6f 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,8 +113,15 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() -zero_tangent(x::Tangent) = ZeroTangent() -# TODO: zero_tangent(x::MutableTangent) +function zero_tangent(x::MutableTangent{P}) where P + zb = backing(zero_tangent(backing(x))) + return MutableTangent{P}(zb) +end + +function zero_tangent(x::Tangent{P}) where P + zb = backing(zero_tangent(backing(x))) + return Tangent{P, typeof(zb)}(zb) +end @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 960e88d99..a4df83ebf 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -182,6 +182,14 @@ end @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + + # Higher order + # StructuralTangents are valid tangents for themselves (just like Numbers) + # and indeed we prefer that, otherwise higher order structural tangents are kinda + # nightmarishly complex types. + @test zero_tangent(zero_tangent(Demo(1.5))) == zero_tangent(Demo(1.5)) + @test zero_tangent(zero_tangent((1.5, 2.5))) == Tangent{Tuple{Float64, Float64}}(0.0, 0.0) + @test zero_tangent(zero_tangent(MutDemo(1.5))) == zero_tangent(MutDemo(1.5)) end @testset "Weird types" begin From b3a4d57f5213395546b2fdde691e6ad153df7b9b Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 12:05:02 +0800 Subject: [PATCH 27/36] Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 61fc05b6f..d4f17d852 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,7 +113,7 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() -function zero_tangent(x::MutableTangent{P}) where P +function zero_tangent(x::MutableTangent{P}) where {P} zb = backing(zero_tangent(backing(x))) return MutableTangent{P}(zb) end From f481d05467aa9fe32f60ea14af1ec7096fb3b15d Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 13:32:57 +0800 Subject: [PATCH 28/36] overload show for mutable tangent --- src/tangent_types/structural_tangent.jl | 5 ++++- test/tangent_types/structural_tangent.jl | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index c7ae1b1b5..7730a6215 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -334,7 +334,10 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where {P} +function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} + if tangent isa MutableTangent + print(io, "Mutable") + end print(io, "Tangent{") str = sprint(show, P, context = io) i = findfirst('{', str) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 16d702c14..b902e89e4 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -398,6 +398,11 @@ end ) @test repr(Tangent{Foo}()) == "Tangent{Foo}()" + + @test ==( + repr(MutableTangent{MFoo}((;x=1.5, y=[1.0, 2.0]))), + "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])" + ) end @testset "internals" begin From da8c20434f38605471b0e33ccf2e9d77cc196d02 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 Jan 2024 12:09:12 +0800 Subject: [PATCH 29/36] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index d4f17d852..15aa00d42 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -118,9 +118,9 @@ function zero_tangent(x::MutableTangent{P}) where {P} return MutableTangent{P}(zb) end -function zero_tangent(x::Tangent{P}) where P +function zero_tangent(x::Tangent{P}) where {P} zb = backing(zero_tangent(backing(x))) - return Tangent{P, typeof(zb)}(zb) + return Tangent{P,typeof(zb)}(zb) end @generated function zero_tangent(primal) From 26138a9a1d1b8bc895f7012171e32b4148e0bed5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:08:56 +0800 Subject: [PATCH 30/36] move show code to `Common` area --- src/tangent_types/structural_tangent.jl | 53 +++++++++++++------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7730a6215..a469d9f1b 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -140,6 +140,33 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} end end + +function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} + if tangent isa MutableTangent + print(io, "Mutable") + end + print(io, "Tangent{") + str = sprint(show, P, context = io) + i = findfirst('{', str) + if isnothing(i) + print(io, str) + else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line: + print(io, str[1:prevind(str, i)]) + if length(str) < 80 + printstyled(io, str[i:end], color=:light_black) + else + printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black) + end + end + print(io, "}") + if isempty(backing(tangent)) + print(io, "()") # so it doesn't show `NamedTuple()` + else + # allow Tuple or NamedTuple `show` to do the rendering of brackets etc + show(io, backing(tangent)) + end +end + """ backing(x) @@ -334,32 +361,6 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} - if tangent isa MutableTangent - print(io, "Mutable") - end - print(io, "Tangent{") - str = sprint(show, P, context = io) - i = findfirst('{', str) - if isnothing(i) - print(io, str) - else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line: - print(io, str[1:prevind(str, i)]) - if length(str) < 80 - printstyled(io, str[i:end], color=:light_black) - else - printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black) - end - end - print(io, "}") - if isempty(backing(tangent)) - print(io, "()") # so it doesn't show `NamedTuple()` - else - # allow Tuple or NamedTuple `show` to do the rendering of brackets etc - show(io, backing(tangent)) - end -end - Base.iszero(::Tangent{<:,NamedTuple{}}) = true Base.iszero(::Tangent{<:,Tuple{}}) = true From e8865897c7cd8b9dad20866bce8d773105b8c951 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:22:36 +0800 Subject: [PATCH 31/36] docs more consistent --- .../superpowers/mutation_support.md | 23 +++++++++++++------ src/tangent_types/abstract_zero.jl | 3 ++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md index 55629166a..a4dec8ab8 100644 --- a/docs/src/rule_author/superpowers/mutation_support.md +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -1,6 +1,6 @@ # Mutation Support -ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD. +ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD. (Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) !!! warning "Experimental" @@ -17,18 +17,23 @@ It is required to be a structural tangent, having one tangent for each field of Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. Just like not all `struct`s need to use `Tangent`s. Common examples away from this are natural tangent types like for arrays. -However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance. +However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance. ## `zero_tangent` The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. The [`ZeroTangent`](@ref) type also does this. -The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal. +The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal. +To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation. +However, in practice[^1] this is achieved through in a structural tangent For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). -It is also useful for reasons of type stability, since it is always a structural tangent. -For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not. + + +It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type. +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not, +and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s. ## Writing a frule for a mutating function It is relatively straight forward to write a frule for a mutating function. @@ -41,7 +46,7 @@ There are a few key points to follow: ### Example For example, consider the primal function with: 1. takes two `Ref`s -2. doubles the first one inplace +2. doubles the first one in place 3. overwrites the second one's value with the literal 5.0 4. returns the first one @@ -70,4 +75,8 @@ function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::B end ``` -Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. \ No newline at end of file +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. + +[^1]: + Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent. + Except in the special case of where the struct is not mutable and has no nested fields that are mutable. \ No newline at end of file diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 15aa00d42..f921db29d 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -100,7 +100,8 @@ For mutable composites types this is a structural [`MutableTangent`](@ref) For `Array`s, it is applied recursively for each element. For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` and suitable for accumulating against. -In general though, it is more likely to produce a structural tangent. +For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`. +In general, it is more likely to produce a structural tangent. !!! warning Exprimental `zero_tangent`is an experimental feature, and is part of the mutation support featureset. From d3380bca2656fce41cd2cb9a37c4a2f8b7ae0aea Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:31:58 +0800 Subject: [PATCH 32/36] Update src/tangent_types/structural_tangent.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index a469d9f1b..04d93800f 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -94,7 +94,6 @@ struct MutableTangent{P,F} <: StructuralTangent{P} end end - #################################################################### # StructuralTangent Common From 501857dd7d50085d2aaf0e7dfe1287230b5255f5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:32:32 +0800 Subject: [PATCH 33/36] Update test/tangent_types/structural_tangent.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/tangent_types/structural_tangent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index b902e89e4..c177b05f4 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -400,8 +400,8 @@ end @test repr(Tangent{Foo}()) == "Tangent{Foo}()" @test ==( - repr(MutableTangent{MFoo}((;x=1.5, y=[1.0, 2.0]))), - "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])" + repr(MutableTangent{MFoo}((; x=1.5, y=[1.0, 2.0]))), + "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])", ) end From 2d61f416049ae5c79c79ecdfb430360f757818ae Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 16:37:11 +0800 Subject: [PATCH 34/36] Add broken tests for Aliasing and Cyclic references --- test/tangent_types/abstract_zero.jl | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index a4df83ebf..245d9a29d 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -275,4 +275,52 @@ end @test d.z == [2.0, 3.0] @test d.z isa SubArray end + + + @testset "aliasing" begin + a = Base.RefValue(1.5) + b = (a, 1.0, a) + db = zero_tangent(b) + @test iszero(db) + @test_broken db[1] === db[3] + @test db[2] == 0.0 + + x = [1.5] + y = [x, [1.0], x] + dy = zero_tangent(y) + @test iszero(dy) + @test_broken dy[1] === dy[3] + @test dy[2] == [0.0] + end + + @testset "cyclic references" begin + mutable struct Link + data::Float64 + next::Link + Link(data) = new(data) + end + + lk = Link(1.5) + lk.next = lk + + @test_broken d = zero_tangent(lk) + @test_broken d.data == 0.0 + @test_broken d.next === d + + struct CarryingArray + x::Vector + end + ca = CarryingArray(Any[1.5]) + push!(ca.x, ca) + @test_broken d_ca = zero_tangent(ca) + @test_broken d_ca[1] == 0.0 + @test_broken d_ca[2] === _ca + + # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing + xs = Any[1.5] + push!(xs, xs) + @test_broken d_xs = zero_tangent(xs) + @test_broken d_xs[1] == 0.0 + @test_broken d_xs[2] == d_xs + end end From 95e63d0dde918ac2d2662d0b574295ab71d51387 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 17:01:21 +0800 Subject: [PATCH 35/36] improve docs --- docs/src/rule_author/superpowers/mutation_support.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md index a4dec8ab8..b7a3b69ec 100644 --- a/docs/src/rule_author/superpowers/mutation_support.md +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -41,7 +41,7 @@ There are a few key points to follow: - There must be a mutable tangent input for every mutated primal input - When the primal value is changed, the corresponding change must be made to its tangent partner - When a value is returned, return its partnered tangent. - + - If two primals alias, then their tangents must also alias. ### Example For example, consider the primal function with: @@ -61,14 +61,14 @@ end The frule for this would be: ```julia -function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) +function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) @assert ȧ isa MutableTangent{typeof(a)} @assert ḃ isa MutableTangent{typeof(b)} a[] *= 2 ȧ.x *= 2 # `.x` is the field that lives behind RefValues - b[]=5.0 + b[] = 5.0 ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` return a, ȧ From 73b7508b4b4df59e3e0754386805f1cf78d8dcd3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 18:51:21 +0800 Subject: [PATCH 36/36] stronger statement about aliasing --- docs/src/rule_author/superpowers/mutation_support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md index b7a3b69ec..497e11575 100644 --- a/docs/src/rule_author/superpowers/mutation_support.md +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -41,7 +41,7 @@ There are a few key points to follow: - There must be a mutable tangent input for every mutated primal input - When the primal value is changed, the corresponding change must be made to its tangent partner - When a value is returned, return its partnered tangent. - - If two primals alias, then their tangents must also alias. + - If (and only if) primal values alias, then their tangents must also alias. ### Example For example, consider the primal function with: