Skip to content

Commit

Permalink
Adjust forward stage2 to Core.Compiler changes
Browse files Browse the repository at this point in the history
Only what is necessary for Cedar right now. Ordinary stage 2 reverse
mode will need similar changes at a later point.
  • Loading branch information
Keno committed Oct 11, 2024
1 parent d0b3e3e commit 1fb7102
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 6 deletions.
68 changes: 66 additions & 2 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,71 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return primal_call
end

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
local frule_atype::Any = CC.argtypes_to_type(frule_argtypes)

local frule_call::Future{CallMeta}
local result::Future{CallMeta} = Future{CallMeta}()
function make_progress(_, sv)
if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = primal_call[]
return true
end

ready = false
if !@isdefined(frule_call)
# Here we simply check for the frule existance - we don't want to do a full
# inference with specialized argtypes and everything since the problem is
# likely sparse and we only need to do a full inference on a few calls.
# Thus, here we pick `Any` for the tangent types rather than trying to
# discover what they are. frules should be written in such a way that
# whether or not they return `nothing`, only depends on the non-tangent arguments
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future
isready(frule_call) || return false
end

frc = frule_call[]
pc = primal_call[]

if frc.rt !== Const(nothing)
result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc))
else
result[] = pc
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return true
end
(!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress)
return result
end

else

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return nothing
return primal_call
end

nargs = length(arginfo.argtypes)-1
Expand Down Expand Up @@ -35,7 +95,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return nothing
return primal_call
end



end

const frule_mt = methods(ChainRulesCore.frule).mt
6 changes: 6 additions & 0 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)

CC.NewInstruction(@nospecialize node) =
Expand Down
4 changes: 4 additions & 0 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]
if hasfield(CodeInfo, :nargs)
ci.nargs = 2
ci.isva = true
end

return ci
end
Expand Down
5 changes: 1 addition & 4 deletions src/stage2/abstractinterpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::InferenceState, max_methods::Int)

if interp.forward
r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
if r !== nothing
return r
end
return fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
end

return ret
Expand Down
3 changes: 3 additions & 0 deletions src/stage2/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ end
CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx)
if isdefined(CC, :add_uncovered_edges_impl)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
end

function Base.show(io::IO, info::FRuleCallInfo)
print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")
Expand Down

0 comments on commit 1fb7102

Please sign in to comment.