Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust to compiler excision #297

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "Diffractor"
uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
authors = ["Keno Fischer <[email protected]> and contributors"]
version = "0.2.10"
authors = ["Keno Fischer <[email protected]> and contributors"]

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand All @@ -20,6 +21,7 @@ AbstractDifferentiation = "0.5, 0.6"
ChainRules = "1.44.6"
ChainRulesCore = "1.20"
Combinatorics = "1"
Compiler = "0.0.1"
Cthulhu = "2.10.1"
OffsetArrays = "1"
PrecompileTools = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ export ∂⃖, gradient
using StructArrays
using PrecompileTools

if VERSION ≥ v"1.12.0-DEV.1581"
import Compiler
const CC = Compiler
else
const CC = Core.Compiler
end
using Core.IR

@static if VERSION ≥ v"1.11.0-DEV.1498"
Expand Down
4 changes: 2 additions & 2 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState
using .CC: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future
using .CC: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
Expand Down
12 changes: 6 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
using .CC: IRInterpretationState, construct_postdomtree, PiNode,
is_known_call, argextype, postdominates, userefs, PhiCNode, UpsilonNode

#=
Expand Down Expand Up @@ -55,7 +55,7 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
end

function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
ssa::SSAValue, inst::CC.Instruction, order::Int;
custom_diff!, diff_cache, eras_mode)
stmt = inst[:inst]
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode)
Expand Down Expand Up @@ -94,12 +94,12 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
Δbacking = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_typ))
newT = argextype(stmt.args[1], ir)
@assert isa(newT, Const)
tup_typ_typ = Core.Compiler.typeof_tfunc(tup_typ)
tup_typ_typ = CC.typeof_tfunc(tup_typ)
if !(newT.val <: Tuple)
tup_typ_typ = Core.Compiler.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ)
tup_typ_typ = CC.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ)
Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val))
end
tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val
tangentT = CC.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val
Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT))
return Δtangent
else # general frule handling
Expand All @@ -124,7 +124,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I

# Now do proper type inference with the known arguments
interp′ = disable_forward(interp)
new_frame = Core.Compiler.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true)
new_frame = CC.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true)

# Create :invoke expression for the newly inferred frule
frule_mi = CC.EscapeAnalysis.analyze_match(new_match, length(args)+2)
Expand Down
8 changes: 4 additions & 4 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

@static if VERSION ≥ v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
Expand Down Expand Up @@ -501,7 +501,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

@static if VERSION ≥ v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
Expand Down Expand Up @@ -533,7 +533,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

if interp !== nothing
new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(Core.Compiler.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])]
new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(CC.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])]
empty!(ir.argtypes)
append!(ir.argtypes, new_argtypes)
end
Expand Down Expand Up @@ -672,7 +672,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
ir = complete(compact)
#@show ir
ir = compact!(ir)
Core.Compiler.verify_ir(ir, true, true)
CC.verify_ir(ir, true, true)

return ir
end
34 changes: 15 additions & 19 deletions src/debugutils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter
using .CC: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter
using InteractiveUtils

function infer_function(interp, tt)
Expand All @@ -13,23 +13,23 @@ function infer_function(interp, tt)
mtypes, msp, m = mthds[1]

# Grab the appropriate method instance for these types
mi = Core.Compiler.specialize_method(m, mtypes, msp)
mi = CC.specialize_method(m, mtypes, msp)

# Construct InferenceResult to hold the result,
result = Core.Compiler.InferenceResult(mi)
result = CC.InferenceResult(mi)

# Create an InferenceState to begin inference, give it a world that is always newest
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)
frame = CC.InferenceState(result, #=cached=# true, interp)

# Run type inference on this frame. Because the interpreter is embedded
# within this InferenceResult, we don't need to pass the interpreter in.
Core.Compiler.typeinf(interp, frame)
CC.typeinf(interp, frame)

# Give the result back
return (mi, result)
end

struct ExtractingInterpreter <: Core.Compiler.AbstractInterpreter
struct ExtractingInterpreter <: CC.AbstractInterpreter
code::Dict{MethodInstance, CodeInstance}
native_interpreter::NativeInterpreter
msgs::Vector{Tuple{MethodInstance, Int, String}}
Expand All @@ -43,7 +43,7 @@ ExtractingInterpreter(;optimize=false) = ExtractingInterpreter(
optimize
)

import Core.Compiler: InferenceParams, OptimizationParams, #=get_inference_world,=#
import .CC: InferenceParams, OptimizationParams, #=get_inference_world,=#
get_inference_cache, code_cache,
WorldView, lock_mi_inference, unlock_mi_inference, InferenceState
InferenceParams(ei::ExtractingInterpreter) = InferenceParams(ei.native_interpreter)
Expand All @@ -56,18 +56,14 @@ lock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing
unlock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing

code_cache(ei::ExtractingInterpreter) = ei.code
Core.Compiler.get(a::Dict, b, c) = Base.get(a,b,c)
Core.Compiler.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c)
Core.Compiler.haskey(a::Dict, b) = Base.haskey(a, b)
Core.Compiler.haskey(a::WorldView{<:Dict}, b) =
Core.Compiler.haskey(a.cache, b)
Core.Compiler.setindex!(a::Dict, b, c) = setindex!(a, b, c)
Core.Compiler.may_optimize(ei::ExtractingInterpreter) = ei.optimize
Core.Compiler.may_compress(ei::ExtractingInterpreter) = false
Core.Compiler.may_discard_trees(ei::ExtractingInterpreter) = false

function Core.Compiler.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg)
@show msg
CC.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c)
CC.haskey(a::WorldView{<:Dict}, b) =
CC.haskey(a.cache, b)
CC.may_optimize(ei::ExtractingInterpreter) = ei.optimize
CC.may_compress(ei::ExtractingInterpreter) = false
CC.may_discard_trees(ei::ExtractingInterpreter) = false

function CC.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg)
push!(ei.msgs, (sv.linfo, sv.currpc, msg))
end

Expand Down
4 changes: 2 additions & 2 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ end

@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
@ChainRules.non_differentiable Base.throw(err)
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
@ChainRules.non_differentiable CC.return_type(args...)
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()

# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
Expand Down Expand Up @@ -294,7 +294,7 @@ Base.:(==)(::ZeroTangent, x::Number) = iszero(x)
Base.hash(x::ZeroTangent, h::UInt64) = hash(0, h)

# should this be in ChainRules/ChainRulesCore?
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
# or a valid second order Tangent for the primal
function ChainRulesCore.frule((_, ẋ), T::Type{<:Tangent}, x)
ẋ::Tangent
Expand Down
38 changes: 31 additions & 7 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter
using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -11,13 +11,37 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end
if VERSION < v"1.12.0-DEV.1268"
if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)
Base.iterate(c::IncrementalCompact, args...) = CC.iterate(c, args...)
Base.iterate(p::CC.Pair, args...) = CC.iterate(p, args...)
Base.iterate(urs::CC.UseRefIterator, args...) = CC.iterate(urs, args...)
Base.iterate(x::CC.BBIdxIter, args...) = CC.iterate(x, args...)
Base.getindex(urs::CC.UseRefIterator, args...) = CC.getindex(urs, args...)
Base.getindex(urs::CC.UseRef, args...) = CC.getindex(urs, args...)
Base.getindex(c::CC.IncrementalCompact, args...) = CC.getindex(c, args...)
Base.setindex!(c::CC.IncrementalCompact, args...) = CC.setindex!(c, args...)
Base.setindex!(urs::CC.UseRef, args...) = CC.setindex!(urs, args...)

Base.copy(ir::IRCode) = CC.copy(ir)

CC.BasicBlock(x::UnitRange) =
BasicBlock(StmtRange(first(x), last(x)))
CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
Base.length(c::CC.NewNodeStream) = CC.length(c)
Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...)
Base.size(x::CC.UnitRange) = CC.size(x)

CC.get(a::Dict, b, c) = Base.get(a,b,c)
CC.haskey(a::Dict, b) = Base.haskey(a, b)
CC.setindex!(a::Dict, b, c) = setindex!(a, b, c)
end

CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)
Expand Down
4 changes: 2 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function perform_optic_transform(world::UInt, source::LineNumberNode,
end
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)
mi = CC.specialize_method(match)
ci = CC.retrieve_code_info(mi, world)
if ci === nothing
# Failed to retrieve source - likely a generated function that errors.
# To aid the user in debugging, run the original call in the forward pass and if that
Expand Down
12 changes: 2 additions & 10 deletions src/stage1/hacks.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
# Updated copy of the same code in Base, but with bugs fixed
using Core.Compiler:
using .CC:
NewSSAValue, OldSSAValue, StmtRange, BasicBlock,
count_added_node!, add_pending!

# Re-named in https://github.com/JuliaLang/julia/pull/47051
const add! = Core.Compiler.add_inst!

Base.length(c::Core.Compiler.NewNodeStream) = Core.Compiler.length(c)
Base.setindex!(i::Instruction, args...) = Core.Compiler.setindex!(i, args...)
Core.Compiler.BasicBlock(x::UnitRange) =
BasicBlock(StmtRange(first(x), last(x)))
Core.Compiler.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
Base.size(x::Core.Compiler.UnitRange) = Core.Compiler.size(x)
const add! = CC.add_inst!

function my_insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, attach_after::Bool=false)
@assert inst.effect_free_computed
Expand Down
30 changes: 10 additions & 20 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using Core.IR
using Core.Compiler:
using .CC:
BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, non_dce_finish!, quoted, retrieve_code_info,
scan_slot_def_use, userefs, SimpleInferenceLattice

if VERSION < v"1.11.0-DEV.1351"
using Core.Compiler: effect_free_and_nothrow as removable_if_unused
using .CC: effect_free_and_nothrow as removable_if_unused
else
using Core.Compiler: removable_if_unused
using .CC: removable_if_unused
end

using Base.Meta
Expand Down Expand Up @@ -93,7 +93,7 @@ function expand_switch(code::Vector{Any}, bb_ranges::Vector{UnitRange{Int}}, slo
# Now rewrite branch targets back to statement indexing
for i = 1:length(new_code)
stmt = new_code[i]
stmt = Core.Compiler.renumber_ssa!(stmt, renumber)
stmt = CC.renumber_ssa!(stmt, renumber)
stmt = new_to_regular(stmt)
if isa(stmt, GotoNode)
stmt = GotoNode(renumber[first(bb_ranges[stmt.label])].id)
Expand Down Expand Up @@ -239,19 +239,9 @@ function split_critical_edges!(ir)
return ir′
end

Base.iterate(c::IncrementalCompact, args...) = Core.Compiler.iterate(c, args...)
Base.iterate(p::Core.Compiler.Pair, args...) = Core.Compiler.iterate(p, args...)
Base.iterate(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.iterate(urs, args...)
Base.iterate(x::Core.Compiler.BBIdxIter, args...) = Core.Compiler.iterate(x, args...)
Base.getindex(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.getindex(urs, args...)
Base.getindex(urs::Core.Compiler.UseRef, args...) = Core.Compiler.getindex(urs, args...)
Base.getindex(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.getindex(c, args...)
Base.setindex!(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.setindex!(c, args...)
Base.setindex!(urs::Core.Compiler.UseRef, args...) = Core.Compiler.setindex!(urs, args...)

import Core.Compiler: VarState
import .CC: VarState
function sptypes(sparams)
VarState[Core.Compiler.VarState.(sparams, false)...]
VarState[CC.VarState.(sparams, false)...]
end

function optic_transform(ci::CodeInfo, args...)
Expand All @@ -277,12 +267,12 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
argtypes = Any[Any for i = 1:2]
meta = Expr[]
@static if VERSION ≥ v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag)
debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = CC.InstructionStream(code, type, info, debuginfo.codelocs, flag)
ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams))
else
linetable = Core.LineInfoNode[ci.linetable...]
stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag)
stmts = CC.InstructionStream(code, type, info, ci.codelocs, flag)
ir = IRCode(stmts, cfg, linetable, argtypes, meta, sptypes(sparams))
end

Expand All @@ -307,7 +297,7 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)

Core.Compiler.replace_code_newstyle!(ci, ir)
CC.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)]
Expand Down
6 changes: 3 additions & 3 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
new_code = Any[]
@static if VERSION ≥ v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
new_codelocs = Int32[]
else
new_codelocs = Any[]
Expand Down Expand Up @@ -243,8 +243,8 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
end
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)
mi = CC.specialize_method(match)
ci = CC.retrieve_code_info(mi, world)

return fwd_transform(ci, mi, length(args)-1, N, E)
end
Expand Down
Loading
Loading