Skip to content

Commit

Permalink
Adjust forward stage2 to Core.Compiler changes (#295)
Browse files Browse the repository at this point in the history
Only what is necessary for Cedar right now. Ordinary stage 2 reverse
mode will need similar changes at a later point.
  • Loading branch information
Keno authored Oct 17, 2024
1 parent 778af00 commit 1cbde03
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 6 deletions.
68 changes: 66 additions & 2 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,71 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return primal_call
end

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
local frule_atype::Any = CC.argtypes_to_type(frule_argtypes)

local frule_call::Future{CallMeta}
local result::Future{CallMeta} = Future{CallMeta}()
function make_progress(_, sv)
if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = primal_call[]
return true
end

ready = false
if !@isdefined(frule_call)
# Here we simply check for the frule existance - we don't want to do a full
# inference with specialized argtypes and everything since the problem is
# likely sparse and we only need to do a full inference on a few calls.
# Thus, here we pick `Any` for the tangent types rather than trying to
# discover what they are. frules should be written in such a way that
# whether or not they return `nothing`, only depends on the non-tangent arguments
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future
isready(frule_call) || return false
end

frc = frule_call[]
pc = primal_call[]

if frc.rt !== Const(nothing)
result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc))
else
result[] = pc
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return true
end
(!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress)
return result
end

else

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return nothing
return primal_call
end

nargs = length(arginfo.argtypes)-1
Expand Down Expand Up @@ -35,7 +95,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return nothing
return primal_call
end



end

const frule_mt = methods(ChainRulesCore.frule).mt
6 changes: 6 additions & 0 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)

CC.NewInstruction(@nospecialize node) =
Expand Down
4 changes: 4 additions & 0 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]
if hasfield(CodeInfo, :nargs)
ci.nargs = 2
ci.isva = true
end

if isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs)
ci.nargs = 2
Expand Down
5 changes: 1 addition & 4 deletions src/stage2/abstractinterpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::InferenceState, max_methods::Int)

if interp.forward
r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
if r !== nothing
return r
end
return fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
end

return ret
Expand Down
3 changes: 3 additions & 0 deletions src/stage2/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ end
CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx)
if isdefined(CC, :add_uncovered_edges_impl)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
end

function Base.show(io::IO, info::FRuleCallInfo)
print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")
Expand Down

0 comments on commit 1cbde03

Please sign in to comment.