diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 3e609e95..10351583 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -1,5 +1,65 @@ using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState +if VERSION >= v"1.12.0-DEV.1268" + +using Core.Compiler: Future + +function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta}) + if f === ChainRulesCore.frule + # TODO: Currently, we don't have any termination analysis for the non-stratified + # forward analysis, so bail out here. + return Future{Union{CallMeta, Nothing}}(nothing) + end + + nargs = length(arginfo.argtypes)-1 + frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}] + frule_argtypes = append!(frule_preargtypes, arginfo.argtypes) + frule_atype = CC.argtypes_to_type(frule_argtypes) + + local frule_call::Future{CallMeta} + local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}() + function make_progress(_, sv) + if isa(primal_call[].info, UnionSplitApplyCallInfo) + result[] = nothing + return true + end + + ready = false + if !@isdefined(frule_call) + # Here we simply check for the frule existance - we don't want to do a full + # inference with specialized argtypes and everything since the problem is + # likely sparse and we only need to do a full inference on a few calls. + # Thus, here we pick `Any` for the tangent types rather than trying to + # discover what they are. frules should be written in such a way that + # whether or not they return `nothing`, only depends on the non-tangent arguments + frule_arginfo = ArgInfo(nothing, frule_argtypes) + frule_si = StmtInfo(true) + # turn off frule analysis in the frule to avoid cycling + interp′ = disable_forward(interp) + frule_call = CC.abstract_call_gf_by_type(interp′, + ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future + isready(frule_call) || return false + end + + frc = frule_call[] + pc = primal_call[] + + if frc.rt !== Const(nothing) + result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc)) + else + result[] = nothing + CC.add_mt_backedge!(sv, frule_mt, frule_atype) + end + + return true + end + (!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress) + return result +end + +else + function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta) if f === ChainRulesCore.frule @@ -38,4 +98,8 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize return nothing end + + +end + const frule_mt = methods(ChainRulesCore.frule).mt diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 52c8de3f..4f80051e 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258" Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa) end +if isdefined(CC, :Future) + Base.isready(future::CC.Future) = CC.isready(future) + Base.getindex(future::CC.Future) = CC.getindex(future) + Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value) +end + Base.copy(ir::IRCode) = CC.copy(ir) CC.NewInstruction(@nospecialize node) = diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index efb31611..ff724ac5 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E) ci.ssaflags = UInt8[0 for i=1:length(new_code)] ci.method_for_inference_limit_heuristics = meth ci.edges = MethodInstance[mi] + if hasfield(CodeInstance, :nargs) + ci.nargs = 2 + ci.isva = true + end return ci end diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 663b5ffe..ade95ed4 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -71,6 +71,9 @@ end CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info) CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx) CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx) +if isdefined(CC, :add_uncovered_edges_impl) + CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype) +end function Base.show(io::IO, info::FRuleCallInfo) print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")