Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk authored Nov 23, 2023
1 parent d2ab53b commit cbcc0f3
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 49 deletions.
4 changes: 4 additions & 0 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
if frule_call.rt !== Const(nothing)
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(primal_call.rt, primal_call.exct, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
else
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
end
else
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end
Expand Down
7 changes: 0 additions & 7 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ end
Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)

# Solves an error after https://github.com/JuliaLang/julia/pull/46961
# as does https://github.com/FluxML/IRTools.jl/pull/101
if isdefined(Core.Compiler, :CallInfo)
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
end


"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
8 changes: 4 additions & 4 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Core.IR
using Core.Compiler:
Argument, BasicBlock, CFG, CodeInfo, GotoIfNot, GotoNode, IRCode, IncrementalCompact,
Instruction, MethodInstance, NewInstruction, NewvarNode, OldSSAValue, PhiNode,
ReturnNode, SSAValue, SlotNumber, StmtRange,
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
NoCallInfo, OldSSAValue, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, effect_free_and_nothrow, non_dce_finish!, quoted, retrieve_code_info,
Expand Down Expand Up @@ -266,7 +266,7 @@ function optic_transform!(ci, mi, nargs, N)

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Any[nothing for i = 1:length(code)],
CallInfo[NoCallInfo() for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, sptypes(sparams))

Expand Down
171 changes: 133 additions & 38 deletions src/stage2/abstractinterpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using .CC: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
getfield_tfunc, _methods_by_ftype, VarTable, nfields_tfunc,
ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
StmtInfo
StmtInfo, NoCallInfo
using Core: PartialStruct
using Base.Meta

Expand Down Expand Up @@ -41,7 +41,11 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
else
rt2 = obtype
end
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt2, call.exct, call.effects, RecurseInfo(call.info))
else
return CallMeta(rt2, call.effects, RecurseInfo(call.info))
end
end

# Check if there is a rrule for this function
Expand All @@ -56,7 +60,12 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
end
call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1)
if call.rt != Const(nothing)
return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info))
newrt = getfield_tfunc(call.rt, Const(1))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
else
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
end
end
end
end
Expand All @@ -74,26 +83,39 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
return ret
end

function abstract_accum(interp::AbstractInterpreter, args::Vector{Any}, sv::InferenceState)
args = filter(x->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), args)
function abstract_accum(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
argtypes = filter(@nospecialize(x)->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), argtypes)

if length(args) == 0
return CallMeta(ZeroTangent, Effects(), nothing)
if length(argtypes) == 0
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(ZeroTangent, Any, Effects(), NoCallInfo())
else
return CallMeta(ZeroTangent, Effects(), NoCallInfo())
end
end

if length(args) == 1
return CallMeta(args[1], Effects(), nothing)
if length(argtypes) == 1
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(argtypes[1], Any, Effects(), NoCallInfo())
else
return CallMeta(argtypes[1], Effects(), NoCallInfo())
end
end

rtype = reduce(tmerge, args)
rtype = reduce(tmerge, argtypes)
if widenconst(rtype) <: Tuple
targs = Any[]
for i = 1:nfields_tfunc(rtype).val
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in args], sv).rt)
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in argtypes], sv).rt)
end
rt = tuple_tfunc(targs)
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), NoCallInfo())
else
return CallMeta(rt, Effects(), NoCallInfo())
end
return CallMeta(tuple_tfunc(targs), nothing)
end
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), args...],
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), argtypes...],
sv::InferenceState)
return call
end
Expand Down Expand Up @@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
ft = argextype(inst.args[1], primal, primal.sptypes)
f = singleton_type(ft)
if isa(f, Core.Builtin)
call = CallMeta(backwards_tfunc(f, primal, inst, Δ), nothing)
rt = backwards_tfunc(f, primal, inst, Δ)
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), NoCallInfo())
else
call = CallMeta(rt, Effects(), NoCallInfo())
end
else
bail!(inst)
continue
Expand All @@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
arg = getfield_tfunc(Δ, Const(1))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
rt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), ReifyInfo(call.info))
else
call = CallMeta(rt, Effects(), ReifyInfo(call.info))
end
else
(level, close) = derive_closure_type(call_info)
call = abstract_call(change_level(interp, level), ArgInfo(nothing, Any[close, Δ]), sv)
Expand All @@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp

if isa(info, UnionSplitApplyCallInfo)
argts = Any[argextype(inst.args[i], primal, primal.sptypes) for i = 4:length(inst.args)]
call = CallMeta(repackage_apply_rt(info, call.rt, argts),
UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
rt = repackage_apply_rt(info, call.rt, argts)
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), newinfo)
else
call = CallMeta(rt, Effects(), newinfo)
end
end

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[derive_closure_type(call.info)[2]; call.rt])
call = CallMeta(new_rt, RecurseInfo(call.info))
newinfo = RecurseInfo(call.info)
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), newinfo)
else
call = CallMeta(new_rt, Effects(), newinfo)
end
end

if call.rt === Union{}
Expand Down Expand Up @@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
accum_call = abstract_accum(interp, this_arg_typs, sv)
if accum_call.rt == Union{}
@show accum_call.rt
return CallMeta(Union{}, false)
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(Union{}, Any, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Effects(), NoCallInfo())
end
end
push!(arg_accums, accum_call)
tup_push!(tup_elemns, accum_call.rt)
end
end

rt = tuple_tfunc(Any[tup_elemns...])
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
end
end

function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospecialize(cc_Δ), sv::InferenceState)
Expand Down Expand Up @@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(inst, ReturnNode)
rt = accum_arg(inst.val)
return CallMeta(rt, CompClosInfo(cc, ssa_infos))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
end
end

args = Any[]
Expand Down Expand Up @@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
arg = getfield_tfunc(Δ, Const(2))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
newrt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(newrt, Any, Effects(), ReifyInfo(call.info))
else
call = CallMeta(newrt, Effects(), ReifyInfo(call.info))
end
#error()
else
(level, clos) = derive_closure_type(call_info)
Expand All @@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[call.rt; derive_closure_type(call.info)[2]])
call = CallMeta(new_rt, RecurseInfo())
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), RecurseInfo())
else
call = CallMeta(new_rt, Effects(), RecurseInfo())
end
end

if isa(info, UnionSplitApplyCallInfo)
call = CallMeta(call.rt, UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(call.rt, call.exct, Effects(), newinfo)
else
call = CallMeta(call.rt, Effects(), newinfo)
end
end

accums[i] = call.rt
Expand All @@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
end

function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecialize(Δ), sv::InferenceState)
@show ("enter", pc)

if pc.seq == 1
call = abstract_call(change_level(interp, pc.order), nothing, Any[pc.dual, Δ], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(call.rt, call.exct, Effects(), newinfo)
else
return CallMeta(call.rt, Effects(), newinfo)
end
elseif pc.seq == 2
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_below.results.matches[1], true)
Expand All @@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
call = infer_comp_closure(ni, cc, Δ, sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif pc.seq == 3
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_carried.info.results.matches[1], true)
Expand All @@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 0
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)

# Add back gradient w.r.t. rrule
Δ = tuple_tfunc(Any[NoTangent, tuple_type_fields(Δ)...])
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(1))
@show (pc, Δ, rt)
return CallMeta(rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 1
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[pc.dual, Δ])], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 2
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 3
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
end
error()
end
Expand All @@ -556,8 +652,7 @@ function CC.abstract_call_opaque_closure(interp::ADInterpreter,
if isa(closure.source, AbstractCompClosure)
(;argtypes) = arginfo
if length(argtypes) !== 2
error()
return CallMeta(Union{}, false)
error("bad argtypes")
end
return infer_comp_closure(interp, closure.source, argtypes[2], sv)
elseif isa(closure.source, PrimClosure)
Expand Down

0 comments on commit cbcc0f3

Please sign in to comment.