Skip to content

Commit

Permalink
adjustments to the latest master (#284)
Browse files Browse the repository at this point in the history
Also fixes a bunch of tests.

---------

Co-authored-by: Keno Fischer <[email protected]>
  • Loading branch information
aviatesk and Keno authored Mar 29, 2024
1 parent a444b7f commit 374f92b
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 201 deletions.
8 changes: 4 additions & 4 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Diffractor

export ∂⃖, gradient

using StructArrays
using PrecompileTools

export ∂⃖, gradient

const CC = Core.Compiler
using Core.IR

@static if VERSION v"1.11.0-DEV.1498"
import .CC: get_inference_world
Expand Down Expand Up @@ -33,7 +34,6 @@ end
include("stage2/tfuncs.jl")
include("stage2/forward.jl")

include("codegen/forward.jl")
include("analysis/forward.jl")
include("codegen/forward_demand.jl")
include("codegen/reverse.jl")
Expand All @@ -48,4 +48,4 @@ end
include("AbstractDifferentiation.jl")
end

end
end # module Diffractor
117 changes: 0 additions & 117 deletions src/codegen/forward.jl

This file was deleted.

51 changes: 37 additions & 14 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...)
if interp !== nothing
cis.inferred = true
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
@static if VERSION v"1.12.0-DEV.15"
rettype = Any # ci.rettype # TODO revisit
else
ci.inferred = true
rettype = ci.rettype
end
@static if VERSION v"1.12.0-DEV.15"
ocm = Core.OpaqueClosure(ci; rettype, nargs=meth_nargs, isva, sig=typ).source
else
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
end
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
end
end

Expand Down Expand Up @@ -107,8 +115,12 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...]
opaque_ci.slotflags = UInt8[0, ci.slotflags...]
end
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
@static if VERSION v"1.12.0-DEV.173"
opaque_ci.debuginfo = ci.debuginfo
else
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
end
opaque_ci
end

Expand Down Expand Up @@ -393,12 +405,17 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map)
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end


for nc = 2:2:n_closures
fwds = Any[nothing for i = 1:length(ir.stmts)]

Expand Down Expand Up @@ -475,9 +492,15 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end

# TODO: This is absolutely aweful, but the best we can do given the data structures we have
Expand Down
6 changes: 3 additions & 3 deletions src/higher_fwd_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ using Base.Iterators

function njet(::Val{N}, ::typeof(sin), x₀) where {N}
(s, c) = sincos(x₀)
Jet(x₀, s, tuple(take(cycle((c, -s, -c, s)), N)...))
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((c, -s, -c, s)), N)...))
end

function njet(::Val{N}, ::typeof(cos), x₀) where {N}
(s, c) = sincos(x₀)
Jet(x₀, s, tuple(take(cycle((-s, -c, s, c)), N)...))
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((-s, -c, s, c)), N)...))
end

function njet(::Val{N}, ::typeof(exp), x₀) where {N}
exped = exp(x₀)
Jet(x₀, exped, tuple(take(repeated(exped), N)...))
Jet(convert(typeof(exped), x₀), exped, tuple(take(repeated(exped), N)...))
end

jeval(j, x) = j(x)
Expand Down
17 changes: 13 additions & 4 deletions src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function Base.show(io::IO, j::Jet)
end

function domain_check(j::Jet, x)
if j.a !== x
if j.a !== convert(typeof(j.a), x)
throw(DomainError("Evaluation is only valid at a"))
end
end
Expand Down Expand Up @@ -153,11 +153,17 @@ function ChainRulesCore.rrule(j::Jet, x)
end

function ChainRulesCore.rrule(::typeof(map), ::typeof(*), a, b)
map(*, a, b), Δ->(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
map(*, a, b), Δ->let Δ=unthunk(Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
end
end

ChainRulesCore.rrule(::typeof(map), ::typeof(integrate), js::Array{<:Jet}) =
map(integrate, js), Δ->(NoTangent(), NoTangent(), map(deriv, Δ))
map(integrate, js), Δ->let Δ=unthunk(Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(deriv, Δ))
end

struct derivBack
js
Expand All @@ -177,7 +183,10 @@ end

function ChainRulesCore.rrule(::typeof(mapev), js::Array{<:Jet}, xs::AbstractArray)
mapev(js, xs), let djs=map(deriv, js)
Δ->(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
function (Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
end
end
end

Expand Down
32 changes: 17 additions & 15 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into Core.Compiler
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -8,38 +8,40 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
end

if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
Base.copy(ir::IRCode) = CC.copy(ir)

Core.Compiler.NewInstruction(@nospecialize node) =
CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)

Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
Core.Compiler.setindex!(x, v, f)
Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f)

Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
Core.Compiler.getindex(x, f)
Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f)

function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
stmt = ir.stmts[i]
stmt.inst = ni.stmt
stmt.type = ni.type
stmt.flag = something(ni.flag, 0) # fixes 1.9?
stmt.line = something(ni.line, 0)
@static if VERSION v"1.12.0-DEV.173"
stmt.line = something(ni.line, CC.NoLineUpdate)
else
stmt.line = something(ni.line, 0)
end
return ni
end

function Base.push!(ir::IRCode, ni::NewInstruction)
# TODO: This should be a check in insert_node!
@assert length(ir.new_nodes.stmts) == 0
@static if isdefined(Core.Compiler, :add!)
@static if isdefined(CC, :add!)
# Julia 1.7 & 1.8
ir[Core.Compiler.add!(ir.stmts)] = ni
ir[CC.add!(ir.stmts)] = ni
else
# Re-named in https://github.com/JuliaLang/julia/pull/47051
ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni
ir[CC.add_new_idx!(ir.stmts)] = ni
end
ir
end
Expand All @@ -54,8 +56,8 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter},
return (bb, idx - 1), (bb, idx - 1)
end

Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)
Base.lastindex(x::CC.InstructionStream) =
CC.length(x)

"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
end

function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
bundles = map(bundle, partials, args)
bundles = map(bundle, args, partials)
result = ∂☆internal{1}()(bundles...)
primal(result), first_partial(result)
end
Expand Down
Loading

0 comments on commit 374f92b

Please sign in to comment.