diff --git a/Project.toml b/Project.toml index 9b15e368..02d21d33 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "Diffractor" uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c" -authors = ["Keno Fischer and contributors"] version = "0.2.10" +authors = ["Keno Fischer and contributors"] [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1" Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" @@ -20,6 +21,7 @@ AbstractDifferentiation = "0.5, 0.6" ChainRules = "1.44.6" ChainRulesCore = "1.20" Combinatorics = "1" +Compiler = "0.0.1" Cthulhu = "2.10.1" OffsetArrays = "1" PrecompileTools = "1" diff --git a/src/Diffractor.jl b/src/Diffractor.jl index 5cb37459..7fb7205c 100644 --- a/src/Diffractor.jl +++ b/src/Diffractor.jl @@ -5,7 +5,12 @@ export ∂⃖, gradient using StructArrays using PrecompileTools +if VERSION ≥ v"1.12.0-DEV.1581" +import Compiler +const CC = Compiler +else const CC = Core.Compiler +end using Core.IR @static if VERSION ≥ v"1.11.0-DEV.1498" diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 66feaadd..2afe827c 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -1,8 +1,8 @@ -using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState +using .CC: StmtInfo, ArgInfo, CallMeta, AbsIntState if VERSION >= v"1.12.0-DEV.1268" -using Core.Compiler: Future +using .CC: Future function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta}) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 77710d9d..9681c591 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -1,4 +1,4 @@ -using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode, +using .CC: IRInterpretationState, construct_postdomtree, PiNode, is_known_call, argextype, postdominates, userefs, PhiCNode, UpsilonNode #= @@ -55,7 +55,7 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre end function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState, - ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int; + ssa::SSAValue, inst::CC.Instruction, order::Int; custom_diff!, diff_cache, eras_mode) stmt = inst[:inst] recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode) @@ -94,12 +94,12 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I Δbacking = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_typ)) newT = argextype(stmt.args[1], ir) @assert isa(newT, Const) - tup_typ_typ = Core.Compiler.typeof_tfunc(tup_typ) + tup_typ_typ = CC.typeof_tfunc(tup_typ) if !(newT.val <: Tuple) - tup_typ_typ = Core.Compiler.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ) + tup_typ_typ = CC.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ) Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val)) end - tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val + tangentT = CC.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT)) return Δtangent else # general frule handling @@ -124,7 +124,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I # Now do proper type inference with the known arguments interp′ = disable_forward(interp) - new_frame = Core.Compiler.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true) + new_frame = CC.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true) # Create :invoke expression for the newly inferred frule frule_mi = CC.EscapeAnalysis.analyze_match(new_match, length(args)+2) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index d00d5802..276405c0 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -414,7 +414,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end @static if VERSION ≥ v"1.12.0-DEV.173" - debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) + debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) debuginfo.def = :var"N/A" opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code)) else @@ -501,7 +501,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end @static if VERSION ≥ v"1.12.0-DEV.173" - debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) + debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) debuginfo.def = :var"N/A" opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code)) else @@ -533,7 +533,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end if interp !== nothing - new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(Core.Compiler.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])] + new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(CC.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])] empty!(ir.argtypes) append!(ir.argtypes, new_argtypes) end @@ -672,7 +672,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I ir = complete(compact) #@show ir ir = compact!(ir) - Core.Compiler.verify_ir(ir, true, true) + CC.verify_ir(ir, true, true) return ir end diff --git a/src/debugutils.jl b/src/debugutils.jl index cd0ca769..74b2ad74 100644 --- a/src/debugutils.jl +++ b/src/debugutils.jl @@ -1,4 +1,4 @@ -using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter +using .CC: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter using InteractiveUtils function infer_function(interp, tt) @@ -13,23 +13,23 @@ function infer_function(interp, tt) mtypes, msp, m = mthds[1] # Grab the appropriate method instance for these types - mi = Core.Compiler.specialize_method(m, mtypes, msp) + mi = CC.specialize_method(m, mtypes, msp) # Construct InferenceResult to hold the result, - result = Core.Compiler.InferenceResult(mi) + result = CC.InferenceResult(mi) # Create an InferenceState to begin inference, give it a world that is always newest - frame = Core.Compiler.InferenceState(result, #=cached=# true, interp) + frame = CC.InferenceState(result, #=cached=# true, interp) # Run type inference on this frame. Because the interpreter is embedded # within this InferenceResult, we don't need to pass the interpreter in. - Core.Compiler.typeinf(interp, frame) + CC.typeinf(interp, frame) # Give the result back return (mi, result) end -struct ExtractingInterpreter <: Core.Compiler.AbstractInterpreter +struct ExtractingInterpreter <: CC.AbstractInterpreter code::Dict{MethodInstance, CodeInstance} native_interpreter::NativeInterpreter msgs::Vector{Tuple{MethodInstance, Int, String}} @@ -43,7 +43,7 @@ ExtractingInterpreter(;optimize=false) = ExtractingInterpreter( optimize ) -import Core.Compiler: InferenceParams, OptimizationParams, #=get_inference_world,=# +import .CC: InferenceParams, OptimizationParams, #=get_inference_world,=# get_inference_cache, code_cache, WorldView, lock_mi_inference, unlock_mi_inference, InferenceState InferenceParams(ei::ExtractingInterpreter) = InferenceParams(ei.native_interpreter) @@ -56,18 +56,14 @@ lock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing unlock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing code_cache(ei::ExtractingInterpreter) = ei.code -Core.Compiler.get(a::Dict, b, c) = Base.get(a,b,c) -Core.Compiler.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c) -Core.Compiler.haskey(a::Dict, b) = Base.haskey(a, b) -Core.Compiler.haskey(a::WorldView{<:Dict}, b) = - Core.Compiler.haskey(a.cache, b) -Core.Compiler.setindex!(a::Dict, b, c) = setindex!(a, b, c) -Core.Compiler.may_optimize(ei::ExtractingInterpreter) = ei.optimize -Core.Compiler.may_compress(ei::ExtractingInterpreter) = false -Core.Compiler.may_discard_trees(ei::ExtractingInterpreter) = false - -function Core.Compiler.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg) - @show msg +CC.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c) +CC.haskey(a::WorldView{<:Dict}, b) = + CC.haskey(a.cache, b) +CC.may_optimize(ei::ExtractingInterpreter) = ei.optimize +CC.may_compress(ei::ExtractingInterpreter) = false +CC.may_discard_trees(ei::ExtractingInterpreter) = false + +function CC.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg) push!(ei.msgs, (sv.linfo, sv.currpc, msg)) end diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 0ff826f6..5dab4dd1 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -255,7 +255,7 @@ end @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) @ChainRules.non_differentiable Base.throw(err) -@ChainRules.non_differentiable Core.Compiler.return_type(args...) +@ChainRules.non_differentiable CC.return_type(args...) ChainRulesCore.canonicalize(::NoTangent) = NoTangent() # Disable thunking at higher order (TODO: These should go into ChainRulesCore) @@ -294,7 +294,7 @@ Base.:(==)(::ZeroTangent, x::Number) = iszero(x) Base.hash(x::ZeroTangent, h::UInt64) = hash(0, h) # should this be in ChainRules/ChainRulesCore? -# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent, +# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent, # or a valid second order Tangent for the primal function ChainRulesCore.frule((_, ẋ), T::Type{<:Tangent}, x) ẋ::Tangent diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 4f80051e..93fb30c1 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -1,5 +1,5 @@ # Utilities that should probably go into CC -using .CC: IRCode, CFG, BasicBlock, BBIdxIter +using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter function Base.push!(cfg::CFG, bb::BasicBlock) @assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start @@ -11,13 +11,37 @@ if VERSION < v"1.11.0-DEV.258" Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa) end -if isdefined(CC, :Future) - Base.isready(future::CC.Future) = CC.isready(future) - Base.getindex(future::CC.Future) = CC.getindex(future) - Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value) -end +if VERSION < v"1.12.0-DEV.1268" + if isdefined(CC, :Future) + Base.isready(future::CC.Future) = CC.isready(future) + Base.getindex(future::CC.Future) = CC.getindex(future) + Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value) + end -Base.copy(ir::IRCode) = CC.copy(ir) + Base.iterate(c::IncrementalCompact, args...) = CC.iterate(c, args...) + Base.iterate(p::CC.Pair, args...) = CC.iterate(p, args...) + Base.iterate(urs::CC.UseRefIterator, args...) = CC.iterate(urs, args...) + Base.iterate(x::CC.BBIdxIter, args...) = CC.iterate(x, args...) + Base.getindex(urs::CC.UseRefIterator, args...) = CC.getindex(urs, args...) + Base.getindex(urs::CC.UseRef, args...) = CC.getindex(urs, args...) + Base.getindex(c::CC.IncrementalCompact, args...) = CC.getindex(c, args...) + Base.setindex!(c::CC.IncrementalCompact, args...) = CC.setindex!(c, args...) + Base.setindex!(urs::CC.UseRef, args...) = CC.setindex!(urs, args...) + + Base.copy(ir::IRCode) = CC.copy(ir) + + CC.BasicBlock(x::UnitRange) = + BasicBlock(StmtRange(first(x), last(x))) + CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) = + BasicBlock(StmtRange(first(x), last(x)), preds, succs) + Base.length(c::CC.NewNodeStream) = CC.length(c) + Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...) + Base.size(x::CC.UnitRange) = CC.size(x) + + CC.get(a::Dict, b, c) = Base.get(a,b,c) + CC.haskey(a::Dict, b) = Base.haskey(a, b) + CC.setindex!(a::Dict, b, c) = setindex!(a, b, c) +end CC.NewInstruction(@nospecialize node) = NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 07f5d59b..a1b861fa 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -30,8 +30,8 @@ function perform_optic_transform(world::UInt, source::LineNumberNode, end match = only(mthds)::Core.MethodMatch - mi = Core.Compiler.specialize_method(match) - ci = Core.Compiler.retrieve_code_info(mi, world) + mi = CC.specialize_method(match) + ci = CC.retrieve_code_info(mi, world) if ci === nothing # Failed to retrieve source - likely a generated function that errors. # To aid the user in debugging, run the original call in the forward pass and if that diff --git a/src/stage1/hacks.jl b/src/stage1/hacks.jl index 18fba64c..9425672f 100644 --- a/src/stage1/hacks.jl +++ b/src/stage1/hacks.jl @@ -1,18 +1,10 @@ # Updated copy of the same code in Base, but with bugs fixed -using Core.Compiler: +using .CC: NewSSAValue, OldSSAValue, StmtRange, BasicBlock, count_added_node!, add_pending! # Re-named in https://github.com/JuliaLang/julia/pull/47051 -const add! = Core.Compiler.add_inst! - -Base.length(c::Core.Compiler.NewNodeStream) = Core.Compiler.length(c) -Base.setindex!(i::Instruction, args...) = Core.Compiler.setindex!(i, args...) -Core.Compiler.BasicBlock(x::UnitRange) = - BasicBlock(StmtRange(first(x), last(x))) -Core.Compiler.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) = - BasicBlock(StmtRange(first(x), last(x)), preds, succs) -Base.size(x::Core.Compiler.UnitRange) = Core.Compiler.size(x) +const add! = CC.add_inst! function my_insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, attach_after::Bool=false) @assert inst.effect_free_computed diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index a1240418..1e09e881 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -1,5 +1,5 @@ using Core.IR -using Core.Compiler: +using .CC: BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange, bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete, construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!, @@ -7,9 +7,9 @@ using Core.Compiler: scan_slot_def_use, userefs, SimpleInferenceLattice if VERSION < v"1.11.0-DEV.1351" - using Core.Compiler: effect_free_and_nothrow as removable_if_unused + using .CC: effect_free_and_nothrow as removable_if_unused else - using Core.Compiler: removable_if_unused + using .CC: removable_if_unused end using Base.Meta @@ -93,7 +93,7 @@ function expand_switch(code::Vector{Any}, bb_ranges::Vector{UnitRange{Int}}, slo # Now rewrite branch targets back to statement indexing for i = 1:length(new_code) stmt = new_code[i] - stmt = Core.Compiler.renumber_ssa!(stmt, renumber) + stmt = CC.renumber_ssa!(stmt, renumber) stmt = new_to_regular(stmt) if isa(stmt, GotoNode) stmt = GotoNode(renumber[first(bb_ranges[stmt.label])].id) @@ -239,19 +239,9 @@ function split_critical_edges!(ir) return ir′ end -Base.iterate(c::IncrementalCompact, args...) = Core.Compiler.iterate(c, args...) -Base.iterate(p::Core.Compiler.Pair, args...) = Core.Compiler.iterate(p, args...) -Base.iterate(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.iterate(urs, args...) -Base.iterate(x::Core.Compiler.BBIdxIter, args...) = Core.Compiler.iterate(x, args...) -Base.getindex(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.getindex(urs, args...) -Base.getindex(urs::Core.Compiler.UseRef, args...) = Core.Compiler.getindex(urs, args...) -Base.getindex(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.getindex(c, args...) -Base.setindex!(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.setindex!(c, args...) -Base.setindex!(urs::Core.Compiler.UseRef, args...) = Core.Compiler.setindex!(urs, args...) - -import Core.Compiler: VarState +import .CC: VarState function sptypes(sparams) - VarState[Core.Compiler.VarState.(sparams, false)...] + VarState[CC.VarState.(sparams, false)...] end function optic_transform(ci::CodeInfo, args...) @@ -277,12 +267,12 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) argtypes = Any[Any for i = 1:2] meta = Expr[] @static if VERSION ≥ v"1.12.0-DEV.173" - debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code)) - stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag) + debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(code)) + stmts = CC.InstructionStream(code, type, info, debuginfo.codelocs, flag) ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams)) else linetable = Core.LineInfoNode[ci.linetable...] - stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag) + stmts = CC.InstructionStream(code, type, info, ci.codelocs, flag) ir = IRCode(stmts, cfg, linetable, argtypes, meta, sptypes(sparams)) end @@ -307,7 +297,7 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) ir = diffract_ir!(ir, ci, meth, sparams, nargs, N) - Core.Compiler.replace_code_newstyle!(ci, ir) + CC.replace_code_newstyle!(ci, ir) ci.ssavaluetypes = length(ci.code) ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)] diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index b73280da..954a98eb 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -81,7 +81,7 @@ end function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E) new_code = Any[] @static if VERSION ≥ v"1.12.0-DEV.173" - debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code)) + debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(ci.code)) new_codelocs = Int32[] else new_codelocs = Any[] @@ -243,8 +243,8 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode, end match = only(mthds)::Core.MethodMatch - mi = Core.Compiler.specialize_method(match) - ci = Core.Compiler.retrieve_code_info(mi, world) + mi = CC.specialize_method(match) + ci = CC.retrieve_code_info(mi, world) return fwd_transform(ci, mi, length(args)-1, N, E) end diff --git a/src/stage1/termination.jl b/src/stage1/termination.jl index ac870f1c..4714a721 100644 --- a/src/stage1/termination.jl +++ b/src/stage1/termination.jl @@ -14,7 +14,7 @@ first(methods(Diffractor.∂⃖recurse{1}())).recursion_relation = function(meth # TODO: What if method2 is itself a generated function. return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig) end - return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1) + return CC.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1) end first(methods(PrimeDerivativeBack(sin))).recursion_relation = function(method1, method2, parent_sig, new_sig) @@ -27,7 +27,7 @@ first(methods(PrimeDerivativeBack(sin))).recursion_relation = function(method1, end wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...} wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...} - return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1) + return CC.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1) end which(Tuple{∂⃖{N}, T, Vararg{Any}} where {T,N}).recursion_relation = function(_, _, parent_sig, new_sig) diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index 2e825602..dd0bfdb1 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -22,7 +22,7 @@ end function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false) interp = ADInterpreter(; forward=true, backward=false) match = Base._which(tt) - frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) + frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) mi = frame.linfo src = CC.copy(interp.unopt[0][mi].src) diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 34f193ba..0381ff26 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -1,5 +1,5 @@ -using Core.Compiler: CallInfo, CallMeta -import Core.Compiler: widenconst +using .CC: CallInfo, CallMeta +import .CC: widenconst struct CompClosure; opaque; end # TODO: Is this a YAKC? (::CompClosure)(x) = error("Hello") diff --git a/test/forward_diff_no_inf.jl b/test/forward_diff_no_inf.jl index 4ff4b0f3..a3f62b82 100644 --- a/test/forward_diff_no_inf.jl +++ b/test/forward_diff_no_inf.jl @@ -1,9 +1,9 @@ module forward_diff_no_inf using Core: SSAValue - const CC = Core.Compiler using Diffractor, Test + const CC = Diffractor.CC ##################### Helpers: @@ -30,7 +30,7 @@ module forward_diff_no_inf # For testing purposes we are going to refine everything else ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED end - + method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) min_world = world = (interp).world max_world = Diffractor.get_world_counter() @@ -61,7 +61,7 @@ module forward_diff_no_inf if predicate(inst) return SSAValue(ii) end - catch + catch # ignore errors so predicate can be simple end end @@ -69,7 +69,7 @@ module forward_diff_no_inf end ############################### Actual tests: - + @testset "Constructors in forward_diff_no_inf!" begin struct Bar148 v @@ -142,7 +142,7 @@ module forward_diff_no_inf @test isfully_inferred(ir) # passes with and without eras mode add_ssa = findfirst_ssa(x->x.args[1].name==:+, ir) - Diffractor.forward_diff_no_inf!(ir, [add_ssa] .=> 1; transform! = identity_transform!, eras_mode) + Diffractor.forward_diff_no_inf!(ir, [add_ssa] .=> 1; transform! = identity_transform!, eras_mode) ir = CC.compact!(ir) infer_ir!(ir) CC.verify_ir(ir)