Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust forward stage2 to Core.Compiler changes #295

Merged
merged 1 commit into from
Oct 17, 2024
Merged

Conversation

Keno
Copy link
Collaborator

@Keno Keno commented Oct 2, 2024

Only what is necessary for Cedar right now. Ordinary stage 2 reverse mode will need similar changes at a later point.

@Keno Keno requested a review from aviatesk October 2, 2024 22:54

local frule_call::Future{CallMeta}
local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vtjnash, please confirm that this is the intended way to use this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to have closure captured interp instead of using the argument? The interp struct is commonly quite large, so that can increase memory usage quite a bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is the wrong interp

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this way of defining a make_progress seems fine. There isn't really one right answer about how to code this, so base itself already uses probably 3 or 4 different patterns, depending on what kept the original code control flow seemed least distorted. I hadn't used the @isdefined trick, but it is essentially equivalent to the nextstate pattern I'd used for manual stackless state machine conversion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think capturing the interp will give you the behavior you want. I think you might need to mutate sv instead of re-using sv with different interp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not re-using sv with a different interp. sv here is an IRInterpretationState, which doesn't have an interp argument, so when the callback later gets scheduled, there's just some random interp in there.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sv claims to be a AbsIntState here? For IRInterpretationState currently it passes the interp here that was originally used to construct the IRInterpretationState, since everything is on the stack there and doesn't handle recursion

I think the behavior here is also probably fine, but that no other callback will be using the right interp, since none other are expecting the interp to be different from the one used to allocate the state object

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and doesn't handle recursion

Doesn't handle recursion in Base ;).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I know almost nothing about this code, so I am reviewing without really knowing how this integrates. The current implementation in Base would potentially break here: https://github.com/JuliaLang/julia/blob/be401635fe02b28ce994e2e3cae0733d101f8927/base/compiler/ssair/irinterp.jl#L154
since it was not tracking if the return type changed to reschedule this instruction if it became part of cycle (I believe it should detect and @assert though if that attempts to happen)

Keno added a commit to CedarEDA/DAECompiler.jl that referenced this pull request Oct 3, 2024
Keno added a commit to CedarEDA/DAECompiler.jl that referenced this pull request Oct 3, 2024
@Keno Keno force-pushed the kf/compileradjust branch from 43a914c to ef83e29 Compare October 3, 2024 01:12
Copy link
Member

@aviatesk aviatesk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
frule_atype = CC.argtypes_to_type(frule_argtypes)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful not to closure capture any types, as your performance may suffer quite badly, but still just fast enough you won't notice (e.g. the sysimage could build still when I missed one of these cases, it just took several times longer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So put it in a Ref{Any}?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or a Core.Box equivalently. I'd found several places where we had a MethodMatch object that was needed anyways, so that also happened to work sometimes

local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = nothing
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type appears to be wrong here. The intended behavior appears to be returning result[] = primal_call[] in this case (

r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
if r !== nothing
return r
end
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the function that calls this. Potentially it should be refactored to just do that, but I just wanted to only make the refactoring change.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is required to be Future{CallMeta} though, or the caller's caller will be unhappy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call-site there isn't updated yet. This function is called directly from DAECompiler and I adjusted the call-site there to work with Future{Union{Nothing, CallMeta}}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine if it still branches on r!==nothing there, it will just be dead code now, as it appears you you must handle that case here now, instead of being able to handle it there

Only what is necessary for Cedar right now. Ordinary stage 2 reverse
mode will need similar changes at a later point.
@Keno Keno force-pushed the kf/compileradjust branch from ef83e29 to 1fb7102 Compare October 11, 2024 10:37
Keno added a commit to CedarEDA/DAECompiler.jl that referenced this pull request Oct 17, 2024
* Adjust to stackless compiler changes

Depends on:
- JuliaDiff/Diffractor.jl#295
- JuliaLang/julia#55972

* More compiler adjust
@Keno Keno merged commit 1cbde03 into main Oct 17, 2024
3 of 8 checks passed
@aviatesk aviatesk deleted the kf/compileradjust branch October 17, 2024 10:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants