From 1fb71028a86aad3b91bcf6b604506abb4ea9b7a8 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 5 Jun 2024 03:57:38 +0000 Subject: [PATCH] Adjust forward stage2 to Core.Compiler changes Only what is necessary for Cedar right now. Ordinary stage 2 reverse mode will need similar changes at a later point. --- src/analysis/forward.jl | 68 ++++++++++++++++++++++++++++++++- src/stage1/compiler_utils.jl | 6 +++ src/stage1/recurse_fwd.jl | 4 ++ src/stage2/abstractinterpret.jl | 5 +-- src/stage2/lattice.jl | 3 ++ 5 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 3e609e95..1078cbaa 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -1,11 +1,71 @@ using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState +if VERSION >= v"1.12.0-DEV.1268" + +using Core.Compiler: Future + +function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta}) + if f === ChainRulesCore.frule + # TODO: Currently, we don't have any termination analysis for the non-stratified + # forward analysis, so bail out here. + return primal_call + end + + nargs = length(arginfo.argtypes)-1 + frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}] + frule_argtypes = append!(frule_preargtypes, arginfo.argtypes) + local frule_atype::Any = CC.argtypes_to_type(frule_argtypes) + + local frule_call::Future{CallMeta} + local result::Future{CallMeta} = Future{CallMeta}() + function make_progress(_, sv) + if isa(primal_call[].info, UnionSplitApplyCallInfo) + result[] = primal_call[] + return true + end + + ready = false + if !@isdefined(frule_call) + # Here we simply check for the frule existance - we don't want to do a full + # inference with specialized argtypes and everything since the problem is + # likely sparse and we only need to do a full inference on a few calls. + # Thus, here we pick `Any` for the tangent types rather than trying to + # discover what they are. frules should be written in such a way that + # whether or not they return `nothing`, only depends on the non-tangent arguments + frule_arginfo = ArgInfo(nothing, frule_argtypes) + frule_si = StmtInfo(true) + # turn off frule analysis in the frule to avoid cycling + interp′ = disable_forward(interp) + frule_call = CC.abstract_call_gf_by_type(interp′, + ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future + isready(frule_call) || return false + end + + frc = frule_call[] + pc = primal_call[] + + if frc.rt !== Const(nothing) + result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc)) + else + result[] = pc + CC.add_mt_backedge!(sv, frule_mt, frule_atype) + end + + return true + end + (!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress) + return result +end + +else + function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta) if f === ChainRulesCore.frule # TODO: Currently, we don't have any termination analysis for the non-stratified # forward analysis, so bail out here. - return nothing + return primal_call end nargs = length(arginfo.argtypes)-1 @@ -35,7 +95,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize CC.add_mt_backedge!(sv, frule_mt, frule_atype) end - return nothing + return primal_call +end + + + end const frule_mt = methods(ChainRulesCore.frule).mt diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 52c8de3f..4f80051e 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258" Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa) end +if isdefined(CC, :Future) + Base.isready(future::CC.Future) = CC.isready(future) + Base.getindex(future::CC.Future) = CC.getindex(future) + Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value) +end + Base.copy(ir::IRCode) = CC.copy(ir) CC.NewInstruction(@nospecialize node) = diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index efb31611..7a2b22dd 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E) ci.ssaflags = UInt8[0 for i=1:length(new_code)] ci.method_for_inference_limit_heuristics = meth ci.edges = MethodInstance[mi] + if hasfield(CodeInfo, :nargs) + ci.nargs = 2 + ci.isva = true + end return ci end diff --git a/src/stage2/abstractinterpret.jl b/src/stage2/abstractinterpret.jl index bb634d63..537c7308 100644 --- a/src/stage2/abstractinterpret.jl +++ b/src/stage2/abstractinterpret.jl @@ -74,10 +74,7 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::InferenceState, max_methods::Int) if interp.forward - r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) - if r !== nothing - return r - end + return fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) end return ret diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 663b5ffe..ade95ed4 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -71,6 +71,9 @@ end CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info) CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx) CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx) +if isdefined(CC, :add_uncovered_edges_impl) + CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype) +end function Base.show(io::IO, info::FRuleCallInfo) print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")