Skip to content

Commit

Permalink
Hook up Eras to forward code gen
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Feb 12, 2024
1 parent 4bd8b46 commit c1e4047
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/codegen/forward.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function fwd_transform(ci, args...)
function fwd_transform(ci, mi, nargs, N, E)
newci = copy(ci)
fwd_transform!(newci, args...)
fwd_transform!(newci, mi, nargs, N, E)
return newci
end

Expand Down
1 change: 1 addition & 0 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ end
#E eras mode, this controls if we should Error if it isn't Taylor. This should be a Bool
struct ∂☆internal{N, E}; end
struct ∂☆recurse{N, E}; end
∂☆recurse{N}() where N = ∂☆recurse{N,false}
struct ∂☆shuffle{N}; end

function shuffle_base(r)
Expand Down
4 changes: 2 additions & 2 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function ∂☆builtin((f_bundle, args...))
end

function perform_fwd_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
@nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E}
if all(x->x <: ZeroBundle, args)
return generate_lambda_ex(world, source,
Core.svec(:ff, :args), Core.svec(), :(∂☆passthrough(args)))
Expand All @@ -96,7 +96,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)

return fwd_transform(ci, mi, length(args)-1, N)
return fwd_transform(ci, mi, length(args)-1, N, E)
end

@eval function (ff::∂☆recurse)(args...)
Expand Down
5 changes: 4 additions & 1 deletion src/stage1/termination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = f
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, get_world_counter())
for (;method) in [
Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, get_world_counter());
Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N, E}, Vararg{Any}} where {N, E}, nothing, -1, get_world_counter());
]
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
Expand Down

0 comments on commit c1e4047

Please sign in to comment.