Skip to content

Commit

Permalink
Adjust to compiler excision (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno authored Nov 13, 2024
1 parent 06cb7dd commit 64fa61a
Show file tree
Hide file tree
Showing 16 changed files with 95 additions and 86 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "Diffractor"
uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
authors = ["Keno Fischer <[email protected]> and contributors"]
version = "0.2.10"
authors = ["Keno Fischer <[email protected]> and contributors"]

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand All @@ -20,6 +21,7 @@ AbstractDifferentiation = "0.5, 0.6"
ChainRules = "1.44.6"
ChainRulesCore = "1.20"
Combinatorics = "1"
Compiler = "0.0.1"
Cthulhu = "2.10.1"
OffsetArrays = "1"
PrecompileTools = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ export ∂⃖, gradient
using StructArrays
using PrecompileTools

if VERSION v"1.12.0-DEV.1581"
import Compiler
const CC = Compiler
else
const CC = Core.Compiler
end
using Core.IR

@static if VERSION v"1.11.0-DEV.1498"
Expand Down
4 changes: 2 additions & 2 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState
using .CC: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future
using .CC: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
Expand Down
12 changes: 6 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
using .CC: IRInterpretationState, construct_postdomtree, PiNode,
is_known_call, argextype, postdominates, userefs, PhiCNode, UpsilonNode

#=
Expand Down Expand Up @@ -55,7 +55,7 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
end

function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
ssa::SSAValue, inst::CC.Instruction, order::Int;
custom_diff!, diff_cache, eras_mode)
stmt = inst[:inst]
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode)
Expand Down Expand Up @@ -94,12 +94,12 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
Δbacking = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_typ))
newT = argextype(stmt.args[1], ir)
@assert isa(newT, Const)
tup_typ_typ = Core.Compiler.typeof_tfunc(tup_typ)
tup_typ_typ = CC.typeof_tfunc(tup_typ)
if !(newT.val <: Tuple)
tup_typ_typ = Core.Compiler.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ)
tup_typ_typ = CC.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ)
Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val))
end
tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val
tangentT = CC.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val
Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT))
return Δtangent
else # general frule handling
Expand All @@ -124,7 +124,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I

# Now do proper type inference with the known arguments
interp′ = disable_forward(interp)
new_frame = Core.Compiler.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true)
new_frame = CC.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true)

# Create :invoke expression for the newly inferred frule
frule_mi = CC.EscapeAnalysis.analyze_match(new_match, length(args)+2)
Expand Down
8 changes: 4 additions & 4 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
Expand Down Expand Up @@ -501,7 +501,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo = CC.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
Expand Down Expand Up @@ -533,7 +533,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end

if interp !== nothing
new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(Core.Compiler.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])]
new_argtypes = Any[Const(∂⃖recurse), tuple_tfunc(CC.optimizer_lattice(interp), ir.argtypes[1:nfixedargs])]
empty!(ir.argtypes)
append!(ir.argtypes, new_argtypes)
end
Expand Down Expand Up @@ -672,7 +672,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
ir = complete(compact)
#@show ir
ir = compact!(ir)
Core.Compiler.verify_ir(ir, true, true)
CC.verify_ir(ir, true, true)

return ir
end
34 changes: 15 additions & 19 deletions src/debugutils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter
using .CC: AbstractInterpreter, CodeInstance, MethodInstance, WorldView, NativeInterpreter
using InteractiveUtils

function infer_function(interp, tt)
Expand All @@ -13,23 +13,23 @@ function infer_function(interp, tt)
mtypes, msp, m = mthds[1]

# Grab the appropriate method instance for these types
mi = Core.Compiler.specialize_method(m, mtypes, msp)
mi = CC.specialize_method(m, mtypes, msp)

# Construct InferenceResult to hold the result,
result = Core.Compiler.InferenceResult(mi)
result = CC.InferenceResult(mi)

# Create an InferenceState to begin inference, give it a world that is always newest
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)
frame = CC.InferenceState(result, #=cached=# true, interp)

# Run type inference on this frame. Because the interpreter is embedded
# within this InferenceResult, we don't need to pass the interpreter in.
Core.Compiler.typeinf(interp, frame)
CC.typeinf(interp, frame)

# Give the result back
return (mi, result)
end

struct ExtractingInterpreter <: Core.Compiler.AbstractInterpreter
struct ExtractingInterpreter <: CC.AbstractInterpreter
code::Dict{MethodInstance, CodeInstance}
native_interpreter::NativeInterpreter
msgs::Vector{Tuple{MethodInstance, Int, String}}
Expand All @@ -43,7 +43,7 @@ ExtractingInterpreter(;optimize=false) = ExtractingInterpreter(
optimize
)

import Core.Compiler: InferenceParams, OptimizationParams, #=get_inference_world,=#
import .CC: InferenceParams, OptimizationParams, #=get_inference_world,=#
get_inference_cache, code_cache,
WorldView, lock_mi_inference, unlock_mi_inference, InferenceState
InferenceParams(ei::ExtractingInterpreter) = InferenceParams(ei.native_interpreter)
Expand All @@ -56,18 +56,14 @@ lock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing
unlock_mi_inference(ei::ExtractingInterpreter, mi::MethodInstance) = nothing

code_cache(ei::ExtractingInterpreter) = ei.code
Core.Compiler.get(a::Dict, b, c) = Base.get(a,b,c)
Core.Compiler.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c)
Core.Compiler.haskey(a::Dict, b) = Base.haskey(a, b)
Core.Compiler.haskey(a::WorldView{<:Dict}, b) =
Core.Compiler.haskey(a.cache, b)
Core.Compiler.setindex!(a::Dict, b, c) = setindex!(a, b, c)
Core.Compiler.may_optimize(ei::ExtractingInterpreter) = ei.optimize
Core.Compiler.may_compress(ei::ExtractingInterpreter) = false
Core.Compiler.may_discard_trees(ei::ExtractingInterpreter) = false

function Core.Compiler.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg)
@show msg
CC.get(a::WorldView{<:Dict}, b, c) = Base.get(a.cache,b,c)
CC.haskey(a::WorldView{<:Dict}, b) =
CC.haskey(a.cache, b)
CC.may_optimize(ei::ExtractingInterpreter) = ei.optimize
CC.may_compress(ei::ExtractingInterpreter) = false
CC.may_discard_trees(ei::ExtractingInterpreter) = false

function CC.add_remark!(ei::ExtractingInterpreter, sv::InferenceState, msg)
push!(ei.msgs, (sv.linfo, sv.currpc, msg))
end

Expand Down
4 changes: 2 additions & 2 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ end

@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
@ChainRules.non_differentiable Base.throw(err)
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
@ChainRules.non_differentiable CC.return_type(args...)
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()

# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
Expand Down Expand Up @@ -294,7 +294,7 @@ Base.:(==)(::ZeroTangent, x::Number) = iszero(x)
Base.hash(x::ZeroTangent, h::UInt64) = hash(0, h)

# should this be in ChainRules/ChainRulesCore?
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
# or a valid second order Tangent for the primal
function ChainRulesCore.frule((_, ẋ), T::Type{<:Tangent}, x)
::Tangent
Expand Down
38 changes: 31 additions & 7 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter
using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -11,13 +11,37 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end
if VERSION < v"1.12.0-DEV.1268"
if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)
Base.iterate(c::IncrementalCompact, args...) = CC.iterate(c, args...)
Base.iterate(p::CC.Pair, args...) = CC.iterate(p, args...)
Base.iterate(urs::CC.UseRefIterator, args...) = CC.iterate(urs, args...)
Base.iterate(x::CC.BBIdxIter, args...) = CC.iterate(x, args...)
Base.getindex(urs::CC.UseRefIterator, args...) = CC.getindex(urs, args...)
Base.getindex(urs::CC.UseRef, args...) = CC.getindex(urs, args...)
Base.getindex(c::CC.IncrementalCompact, args...) = CC.getindex(c, args...)
Base.setindex!(c::CC.IncrementalCompact, args...) = CC.setindex!(c, args...)
Base.setindex!(urs::CC.UseRef, args...) = CC.setindex!(urs, args...)

Base.copy(ir::IRCode) = CC.copy(ir)

CC.BasicBlock(x::UnitRange) =
BasicBlock(StmtRange(first(x), last(x)))
CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
Base.length(c::CC.NewNodeStream) = CC.length(c)
Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...)
Base.size(x::CC.UnitRange) = CC.size(x)

CC.get(a::Dict, b, c) = Base.get(a,b,c)
CC.haskey(a::Dict, b) = Base.haskey(a, b)
CC.setindex!(a::Dict, b, c) = setindex!(a, b, c)
end

CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)
Expand Down
4 changes: 2 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function perform_optic_transform(world::UInt, source::LineNumberNode,
end
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)
mi = CC.specialize_method(match)
ci = CC.retrieve_code_info(mi, world)
if ci === nothing
# Failed to retrieve source - likely a generated function that errors.
# To aid the user in debugging, run the original call in the forward pass and if that
Expand Down
12 changes: 2 additions & 10 deletions src/stage1/hacks.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
# Updated copy of the same code in Base, but with bugs fixed
using Core.Compiler:
using .CC:
NewSSAValue, OldSSAValue, StmtRange, BasicBlock,
count_added_node!, add_pending!

# Re-named in https://github.com/JuliaLang/julia/pull/47051
const add! = Core.Compiler.add_inst!

Base.length(c::Core.Compiler.NewNodeStream) = Core.Compiler.length(c)
Base.setindex!(i::Instruction, args...) = Core.Compiler.setindex!(i, args...)
Core.Compiler.BasicBlock(x::UnitRange) =
BasicBlock(StmtRange(first(x), last(x)))
Core.Compiler.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
Base.size(x::Core.Compiler.UnitRange) = Core.Compiler.size(x)
const add! = CC.add_inst!

function my_insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, attach_after::Bool=false)
@assert inst.effect_free_computed
Expand Down
30 changes: 10 additions & 20 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using Core.IR
using Core.Compiler:
using .CC:
BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, non_dce_finish!, quoted, retrieve_code_info,
scan_slot_def_use, userefs, SimpleInferenceLattice

if VERSION < v"1.11.0-DEV.1351"
using Core.Compiler: effect_free_and_nothrow as removable_if_unused
using .CC: effect_free_and_nothrow as removable_if_unused
else
using Core.Compiler: removable_if_unused
using .CC: removable_if_unused
end

using Base.Meta
Expand Down Expand Up @@ -93,7 +93,7 @@ function expand_switch(code::Vector{Any}, bb_ranges::Vector{UnitRange{Int}}, slo
# Now rewrite branch targets back to statement indexing
for i = 1:length(new_code)
stmt = new_code[i]
stmt = Core.Compiler.renumber_ssa!(stmt, renumber)
stmt = CC.renumber_ssa!(stmt, renumber)
stmt = new_to_regular(stmt)
if isa(stmt, GotoNode)
stmt = GotoNode(renumber[first(bb_ranges[stmt.label])].id)
Expand Down Expand Up @@ -239,19 +239,9 @@ function split_critical_edges!(ir)
return ir′
end

Base.iterate(c::IncrementalCompact, args...) = Core.Compiler.iterate(c, args...)
Base.iterate(p::Core.Compiler.Pair, args...) = Core.Compiler.iterate(p, args...)
Base.iterate(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.iterate(urs, args...)
Base.iterate(x::Core.Compiler.BBIdxIter, args...) = Core.Compiler.iterate(x, args...)
Base.getindex(urs::Core.Compiler.UseRefIterator, args...) = Core.Compiler.getindex(urs, args...)
Base.getindex(urs::Core.Compiler.UseRef, args...) = Core.Compiler.getindex(urs, args...)
Base.getindex(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.getindex(c, args...)
Base.setindex!(c::Core.Compiler.IncrementalCompact, args...) = Core.Compiler.setindex!(c, args...)
Base.setindex!(urs::Core.Compiler.UseRef, args...) = Core.Compiler.setindex!(urs, args...)

import Core.Compiler: VarState
import .CC: VarState
function sptypes(sparams)
VarState[Core.Compiler.VarState.(sparams, false)...]
VarState[CC.VarState.(sparams, false)...]
end

function optic_transform(ci::CodeInfo, args...)
Expand All @@ -277,12 +267,12 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
argtypes = Any[Any for i = 1:2]
meta = Expr[]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag)
debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = CC.InstructionStream(code, type, info, debuginfo.codelocs, flag)
ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams))
else
linetable = Core.LineInfoNode[ci.linetable...]
stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag)
stmts = CC.InstructionStream(code, type, info, ci.codelocs, flag)
ir = IRCode(stmts, cfg, linetable, argtypes, meta, sptypes(sparams))
end

Expand All @@ -307,7 +297,7 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)

Core.Compiler.replace_code_newstyle!(ci, ir)
CC.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)]
Expand Down
6 changes: 3 additions & 3 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
new_code = Any[]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
debuginfo = CC.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
new_codelocs = Int32[]
else
new_codelocs = Any[]
Expand Down Expand Up @@ -243,8 +243,8 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
end
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)
mi = CC.specialize_method(match)
ci = CC.retrieve_code_info(mi, world)

return fwd_transform(ci, mi, length(args)-1, N, E)
end
Expand Down
Loading

0 comments on commit 64fa61a

Please sign in to comment.