Skip to content

Commit

Permalink
Merge pull request #263 from JuliaDiff/ox/eras2
Browse files Browse the repository at this point in the history
Eras mode
  • Loading branch information
oxinabox authored Apr 9, 2024
2 parents a68e3f3 + 4b1f94a commit 80660ad
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 160 deletions.
16 changes: 9 additions & 7 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
end
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
val, order::Int;
custom_diff!, diff_cache)
custom_diff!, diff_cache, eras_mode)
return ChainRulesCore.zero_tangent(val)
end
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
arg::Argument, order::Int;
custom_diff!, diff_cache)
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
custom_diff!, diff_cache, eras_mode)
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache, eras_mode)
val = custom_diff!(ir, SSAValue(0), arg, recurse)
if val !== nothing
return val
Expand All @@ -56,9 +56,9 @@ end

function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
custom_diff!, diff_cache)
custom_diff!, diff_cache, eras_mode)
stmt = inst[:inst]
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode)
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
return val
elseif isa(stmt, PiNode)
Expand Down Expand Up @@ -212,8 +212,10 @@ Internal method which generates the code for forward mode diffentiation
decides if the custom `transform!` should be applied to a `stmt` or not
Default: `false` for all statements
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
- `eras_mode`: determines if to error if not all derivatives are taylor
"""
function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
eras_mode = false,
visit_custom! = (@nospecialize args...)->false,
transform! = (@nospecialize args...)->error())
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
Expand Down Expand Up @@ -286,12 +288,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
newargs = map(stmt.args[2:end]) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order, eras_mode}(), newargs...))
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
newargs = map(stmt.args) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
f = isexpr(stmt, :call) ? ∂☆{order, eras_mode}() : ∂☆new{order}()
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
elseif isa(stmt, PiNode)
# TODO: New PiNode that discriminates based on primal?
Expand Down
13 changes: 13 additions & 0 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,16 @@ function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::M
= setproperty!(ȯbj, field, ẋ)
return y, ẏ
end

# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/607
Base.:(==)(x::Number, ::ZeroTangent) = iszero(x)
Base.:(==)(::ZeroTangent, x::Number) = iszero(x)
Base.hash(x::ZeroTangent, h::UInt64) = hash(0, h)

# should this be in ChainRules/ChainRulesCore?
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
# or a valid second order Tangent for the primal
function ChainRulesCore.frule((_, ẋ), T::Type{<:Tangent}, x)
::Tangent
return T(x), ẋ
end
21 changes: 18 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,28 @@ const ∂⃖¹ = ∂⃖{1}()
(::Type{∂⃖})(args...) = ∂⃖¹(args...)

"""
∂☆{N}
∂☆{N,E}
∂☆{N} is the forward-mode AD functor of order `N`. A call
∂☆{N} is the forward-mode AD functor of order `N` (An integer). A call
`(::∂☆{N})(f, args...)` evaluating a function `f: A -> B` is lifted to its
pushforward on the N-th order tangent bundle `f⋆: Tⁿ A -> Tⁿ B`.
!!!advanced "Eras Mode"
E (a bool, default false) is for Eras mode. In Eras mode, we are Taylor or bust.
Normally if a particular derivative can not be represented as a `TaylorBundle`
we fall back and represent it as a `ExplictTangentBundle`.
However, in Eras mode we error if it can't be represented as a TaylorBundle.
In general, this is not wanted since it often will break nested AD.
But in the cases it doesn't its really fast, since it means we can rewrite nested AD
as Taylor-mode AD (plus its more type stable).
To be safe in Eras mode, it is sufficient, but not necessary, to be doing nested AD with
respect to the same variable. It also works in other cases where (likely by problem construction)
ADing with respect to a second variable happens to result in something that can be represented
with a `TaylorBundle` also. (You need your different partials to happen to be exactly equal).
"""
struct ∂☆{N}; end
struct ∂☆{N, E}; end
∂☆{N}() where N = ∂☆{N,false}() # default to not using Era mode
const ∂☆¹ = ∂☆{1}()

"""
Expand Down
12 changes: 7 additions & 5 deletions src/stage1/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ using Base.Broadcast
using Base.Broadcast: broadcasted, Broadcasted

# Forward mode broadcast rule
struct FwdBroadcast{N, T<:AbstractTangentBundle{N}}
struct FwdBroadcast{N, E, T<:AbstractTangentBundle{N}}
f::T
end
(f::FwdBroadcast{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
FwdBroadcast{E}(f::T) where {N, E, T<:AbstractTangentBundle{N}} = FwdBroadcast{N,E,T}(f)

(f::FwdBroadcast{N,E})(args::AbstractTangentBundle{N}...) where {N,E} = ∂☆{N,E}()(f.f, args...)

n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))

function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
bc::ATB{N, <:Broadcasted}) where {N}
function (∂ₙ::∂☆{N,E})(zc::AbstractZeroBundle{N, typeof(copy)},
bc::ATB{N, <:Broadcasted}) where {N,E}
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
args = n_getfield(∂ₙ, bc, :args)
r = copy(Broadcasted(
FwdMap(n_getfield(∂ₙ, bc, :f)),
FwdMap{E}(n_getfield(∂ₙ, bc, :f)),
ntuple(length(primal(args))) do i
val = n_getfield(∂ₙ, args, i)
if ndims(primal(val)) == 0
Expand Down
Loading

0 comments on commit 80660ad

Please sign in to comment.