Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce MutableTangent #626

Merged
merged 36 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
92f6a03
rename files
oxinabox Aug 4, 2023
f9b5a24
move functionality up to StructuralTangent
oxinabox Aug 4, 2023
35aff30
Formatting
oxinabox Aug 4, 2023
c7932f1
WIP mutable Tangent (squash me)
oxinabox Aug 21, 2023
4b50fd2
wip
oxinabox Sep 15, 2023
87ceddf
First pass at something that maybe works
oxinabox Sep 15, 2023
e75a364
accept int index
oxinabox Sep 18, 2023
b566581
add == and hash for MutableTangent
oxinabox Sep 26, 2023
f8900c4
add and test zero_tangent
oxinabox Sep 26, 2023
bcb5587
export StructuralTangent
oxinabox Sep 28, 2023
63c450b
Style
oxinabox Oct 2, 2023
53e8f0d
handle unassigned a bit more
oxinabox Oct 4, 2023
17064c2
add some more test cases to zero_tangent
oxinabox Oct 4, 2023
5a19913
style
oxinabox Oct 4, 2023
dac92bd
Handle Structs with undef fields
oxinabox Oct 6, 2023
98a7e39
overhaul zero_tangent and MutableTangent for type stability
oxinabox Dec 22, 2023
7d88acd
set MutableTangent setproperty! on index
oxinabox Dec 27, 2023
b3562c6
formatting
oxinabox Dec 27, 2023
dd3f1ab
handle abstract fields right in mutable tangents outside of zero tangent
oxinabox Dec 28, 2023
db45626
formatting
oxinabox Dec 28, 2023
ad29971
Add docs for forward mutation support
oxinabox Dec 28, 2023
8471f39
use ismutabletype from Compat
oxinabox Dec 29, 2023
9b6d6e5
wrap structural tangent tests in a common testset
oxinabox Dec 29, 2023
f5efd7d
Support types that have no tangent space in zero_tangent
oxinabox Dec 29, 2023
8a54fae
define zero_tangent for Tangent
oxinabox Jan 16, 2024
b67686d
Add structural zero tangent code for higher order
oxinabox Jan 17, 2024
b3a4d57
Formatting
oxinabox Jan 17, 2024
f481d05
overload show for mutable tangent
oxinabox Jan 17, 2024
da8c204
formatting
oxinabox Jan 19, 2024
26138a9
move show code to `Common` area
oxinabox Jan 23, 2024
e886589
docs more consistent
oxinabox Jan 23, 2024
d3380bc
Update src/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
501857d
Update test/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
2d61f41
Add broken tests for Aliasing and Cyclic references
oxinabox Jan 25, 2024
95e63d0
improve docs
oxinabox Jan 25, 2024
73b7508
stronger statement about aliasing
oxinabox Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[compat]
BenchmarkTools = "0.5"
Compat = "2, 3, 4"
Compat = "3.40, 4"
FiniteDifferences = "0.10"
OffsetArrays = "1"
StaticArrays = "0.11, 0.12, 1"
julia = "1.6"

[extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand All @@ -31,3 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ makedocs(;
"`@opt_out`" => "rule_author/superpowers/opt_out.md",
"`RuleConfig`" => "rule_author/superpowers/ruleconfig.md",
"Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md",
"Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md",
],
"Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md",
"Tips for making your package work with AD" => "rule_author/tips_for_packages.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Modules = [ChainRulesCore]
Pages = [
"tangent_types/abstract_zero.jl",
"tangent_types/one.jl",
"tangent_types/tangent.jl",
"tangent_types/structural_tangent.jl",
"tangent_types/thunks.jl",
"tangent_types/abstract_tangent.jl",
"tangent_types/notimplemented.jl",
Expand Down
82 changes: 82 additions & 0 deletions docs/src/rule_author/superpowers/mutation_support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Mutation Support

ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD.
(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface)

!!! warning "Experimental"
This page documents an experimental feature.
Expect breaking changes in minor versions while this remains.
It is not suitable for general use unless you are prepared to modify how you are using it each minor release.
It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions.


## `MutableTangent`
The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place.
It is required to be a structural tangent, having one tangent for each field of the primal object.

Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents.
Just like not all `struct`s need to use `Tangent`s.
Common examples away from this are natural tangent types like for arrays.
However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance.

## `zero_tangent`

The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value.
The [`ZeroTangent`](@ref) type also does this.
The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal.
To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation.
However, in practice[^1] this is achieved through in a structural tangent
For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes.
To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref).



It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type.
For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not,
and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s.

## Writing a frule for a mutating function
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
It is relatively straight forward to write a frule for a mutating function.
There are a few key points to follow:
- There must be a mutable tangent input for every mutated primal input
- When the primal value is changed, the corresponding change must be made to its tangent partner
- When a value is returned, return its partnered tangent.
- If (and only if) primal values alias, then their tangents must also alias.

### Example
For example, consider the primal function with:
1. takes two `Ref`s
2. doubles the first one in place
3. overwrites the second one's value with the literal 5.0
4. returns the first one


```julia
function foo!(a::Base.RefValue, b::Base.RefValue)
a[] *= 2
b[] = 5.0
return a
end
```

The frule for this would be:
```julia
function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue)
@assert ȧ isa MutableTangent{typeof(a)}
@assert ḃ isa MutableTangent{typeof(b)}

a[] *= 2
ȧ.x *= 2 # `.x` is the field that lives behind RefValues

b[] = 5.0
ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0`

return a, ȧ
end
```

Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct.

[^1]:
Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent.
Except in the special case of where the struct is not mutable and has no nested fields that are mutable.
8 changes: 4 additions & 4 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@ module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using Compat: hasfield, hasproperty
using Compat: hasfield, hasproperty, ismutabletype

export frule, rrule # core function
# rule configurations
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # tangent operations
export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

include("debug_mode.jl")

include("tangent_types/abstract_tangent.jl")
include("tangent_types/structural_tangent.jl")
include("tangent_types/abstract_zero.jl")
include("tangent_types/thunks.jl")
include("tangent_types/tangent.jl")
include("tangent_types/notimplemented.jl")

include("tangent_arithmetic.jl")
Expand Down
22 changes: 11 additions & 11 deletions src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x
Base.:*(x::NotImplemented, ::NotImplemented) = x
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
# `NotImplemented` always "wins" +
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any)
@eval Base.:+(x::NotImplemented, ::$T) = x
@eval Base.:+(::$T, x::NotImplemented) = x
end
Expand All @@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent)
@eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T()
end
# `NotImplemented` "wins" * and dot for other types
for T in (:AbstractThunk, :Tangent, :Any)
for T in (:AbstractThunk, :StructuralTangent, :Any)
@eval Base.:*(x::NotImplemented, ::$T) = x
@eval Base.:*(::$T, x::NotImplemented) = x
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
Expand All @@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent()
Base.:-(::NoTangent) = NoTangent()
Base.:*(::NoTangent, ::NoTangent) = NoTangent()
LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent()
for T in (:AbstractThunk, :Tangent, :Any)
for T in (:AbstractThunk, :StructuralTangent, :Any)
@eval Base.:+(::NoTangent, b::$T) = b
@eval Base.:+(a::$T, ::NoTangent) = a
@eval Base.:-(::NoTangent, b::$T) = -b
Expand Down Expand Up @@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
Base.:-(::ZeroTangent) = ZeroTangent()
Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
for T in (:AbstractThunk, :Tangent, :Any)
for T in (:AbstractThunk, :StructuralTangent, :Any)
@eval Base.:+(::ZeroTangent, b::$T) = b
@eval Base.:+(a::$T, ::ZeroTangent) = a
@eval Base.:-(::ZeroTangent, b::$T) = -b
Expand Down Expand Up @@ -126,11 +126,11 @@ for T in (:Tangent, :Any)
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
end

function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P}
function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P}
data = elementwise_add(backing(a), backing(b))
return Tangent{P,typeof(data)}(data)
return StructuralTangent{P}(data)
end
function Base.:+(a::P, d::Tangent{P}) where {P}
function Base.:+(a::P, d::StructuralTangent{P}) where {P}
net_backing = elementwise_add(backing(a), backing(d))
if debug_mode()
try
Expand All @@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P}
end
end
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
Base.:+(a::Tangent{P}, b::P) where {P} = b + a
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a

Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent)
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)

# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 tangents
# Only of a tangent and a scaling factor (generally `Real`)
for T in (:Number,)
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
@eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent)
@eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent)
end
87 changes: 87 additions & 0 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,90 @@ arguments.
```
"""
struct NoTangent <: AbstractZero end

"""
zero_tangent(primal)

This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
For mutable composites types this is a structural [`MutableTangent`](@ref)
For `Array`s, it is applied recursively for each element.
For other types, in particular immutable types, we do not make promises beyond that it will be `iszero`
and suitable for accumulating against.
For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`.
In general, it is more likely to produce a structural tangent.

!!! warning Exprimental
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
Exactly how it should be used (e.g. is it forward-mode only?)
"""
function zero_tangent end

zero_tangent(x::Number) = zero(x)

zero_tangent(::Type) = NoTangent()

function zero_tangent(x::MutableTangent{P}) where {P}
zb = backing(zero_tangent(backing(x)))
return MutableTangent{P}(zb)
end

function zero_tangent(x::Tangent{P}) where {P}
zb = backing(zero_tangent(backing(x)))
return Tangent{P,typeof(zb)}(zb)
end

@generated function zero_tangent(primal)
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
zfield_exprs = map(fieldnames(primal)) do fname
fval = :(
if isdefined(primal, $(QuoteNode(fname)))
zero_tangent(getfield(primal, $(QuoteNode(fname))))
else
# This is going to be potentially bad, but that's what they get for not giving us a primal
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
ZeroTangent()
end
)
Expr(:kw, fname, fval)
end
return if has_mutable_tangent(primal)
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
Expr(:kw, fname, fdef)
end
:($MutableTangent{$primal}(
$(Expr(:tuple, Expr(:parameters, any_mask...))),
$(Expr(:tuple, Expr(:parameters, zfield_exprs...))),
))
else
:($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
end
end

zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...)

function zero_tangent(x::Array{P,N}) where {P,N}
if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x)))
return map(zero_tangent, x)
end

# Now we need to handle nonfully assigned arrays
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...)
@inbounds for n in eachindex(y)
if isassigned(x, n)
y[n] = zero_tangent(x[n])
end
end
return y
end

# Sad heauristic methods we need because of unassigned values
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N}
return Array{guess_zero_tangent_type(T),N}
end
guess_zero_tangent_type(T::Type) = Any
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading