Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for deopts in promises #1177

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rir/src/compiler/analysis/context_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct ContextStackState {
if (auto mk = MkEnv::Cast((*i)->env()))
it(mk);
}
size_t context() const { return contextStack.size(); }
size_t numContexts() const { return contextStack.size(); }
AbstractResult merge(const ContextStackState& other) {
assert(contextStack.size() == other.contextStack.size() &&
"stack imbalance");
Expand Down
23 changes: 5 additions & 18 deletions rir/src/compiler/analysis/verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TheVerifier {
std::unordered_set<BB*> seenPreds;

void operator()() {
Visitor::run(f->entry, [&](BB* bb) { return verify(bb, false); });
Visitor::run(f->entry, [&](BB* bb) { return verify(bb); });
Visitor::run(f->entry, [&](BB* bb) { seenPreds.erase(bb); });
if (!seenPreds.empty()) {
std::cerr << "The following preds are not reachable from entry: ";
Expand Down Expand Up @@ -98,7 +98,7 @@ class TheVerifier {
return doms.at(c);
}

void verify(BB* bb, bool inPromise) {
void verify(BB* bb) {
if (bb->id >= bb->owner->nextBBId) {
std::cout << "BB" << bb->id << " id is bigger than max ("
<< bb->owner->nextBBId << ")\n";
Expand All @@ -121,7 +121,7 @@ class TheVerifier {
}

for (auto i : *bb) {
verify(i, bb, inPromise);
verify(i, bb);
}
/* This check verifies that our graph is in edge-split format.
Currently we do not rely on this property, however we should
Expand Down Expand Up @@ -196,10 +196,10 @@ class TheVerifier {
}

void verify(Promise* p) {
Visitor::run(p->entry, [&](BB* bb) { verify(bb, true); });
Visitor::run(p->entry, [&](BB* bb) { verify(bb); });
}

void verify(Instruction* i, BB* bb, bool inPromise) {
void verify(Instruction* i, BB* bb) {
if (i->bb() != bb) {
std::cerr << "Error: instruction '";
i->print(std::cerr);
Expand Down Expand Up @@ -269,19 +269,6 @@ class TheVerifier {
});
}

if (i->frameState()) {
if (!inPromise) {
auto fs = i->frameState();
while (fs->next())
fs = fs->next();
if (fs->inPromise) {
std::cerr << "Error: instruction '";
i->print(std::cerr);
std::cerr << "' outermost fs inPromis in body code\n";
ok = false;
}
}
}
if (auto assume = Assume::Cast(i)) {
if (IsType::Cast(assume->arg(0).val())) {
if (!assume->reason.pc()) {
Expand Down
40 changes: 29 additions & 11 deletions rir/src/compiler/backend.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "backend.h"
#include "R/BuiltinIds.h"
#include "analysis/context_stack.h"
#include "analysis/dead.h"
#include "bc/CodeStream.h"
#include "bc/CodeVerifier.h"
Expand Down Expand Up @@ -58,7 +59,7 @@ static void approximateNeedsLdVarForUpdate(
}
}
if (auto mk = MkEnv::Cast(ld->env()))
if (mk->stub && mk->arg(mk->indexOf(ld->varName)).val() !=
if (mk->stub && mk->argNamed(ld->varName).val() !=
UnboundValue::instance())
return;

Expand Down Expand Up @@ -106,11 +107,28 @@ static void approximateNeedsLdVarForUpdate(
});
}

static void lower(Module* module, Code* code) {
static void lower(Module* module, ClosureVersion* cls, Code* code,
AbstractLog& log) {
DeadInstructions representAsReal(
code, 1, Effects::Any(),
DeadInstructions::IgnoreUsesThatDontObserveIntVsReal);

// If we take a deopt that's in between a PushContext/PopContext pair but
// whose Checkpoint is not, we have to remove the extra context(s). We emit
// DropContext instructions for this while lowering the Assume to a branch.
ContextStack cs(cls, code, log);
std::unordered_map<Assume*, size_t> nDropContexts;
Visitor::run(code->entry, [&](Instruction* i) {
if (auto a = Assume::Cast(i)) {
auto beforeA = cs.before(a).numContexts();
auto beforeCp = cs.before(a->checkpoint()).numContexts();

assert(nDropContexts.count(a) == 0);
assert(beforeCp <= beforeA);
nDropContexts[a] = beforeA - beforeCp;
}
});

Visitor::runPostChange(code->entry, [&](BB* bb) {
auto it = bb->begin();
while (it != bb->end()) {
Expand Down Expand Up @@ -184,31 +202,31 @@ static void lower(Module* module, Code* code) {
Tombstone::framestate(), 0));
next = it + 2;
}
} else if (auto expect = Assume::Cast(*it)) {
if (expect->triviallyHolds()) {
} else if (auto assume = Assume::Cast(*it)) {
if (assume->triviallyHolds()) {
next = bb->remove(it);
} else {
auto expectation = expect->assumeTrue;
auto expectation = assume->assumeTrue;
std::string debugMessage;
if (Parameter::DEBUG_DEOPTS) {
debugMessage = "DEOPT, assumption ";
{
std::stringstream dump;
if (auto i =
Instruction::Cast(expect->condition())) {
Instruction::Cast(assume->condition())) {
dump << "\n";
i->printRecursive(dump, 4);
dump << "\n";
} else {
expect->condition()->printRef(dump);
assume->condition()->printRef(dump);
}
debugMessage += dump.str();
}
debugMessage += " failed\n";
}
BBTransform::lowerExpect(
module, code, bb, it, expect, expectation,
expect->checkpoint()->bb()->falseBranch(),
BBTransform::lowerAssume(
module, code, bb, it, assume, nDropContexts.at(assume),
expectation, assume->checkpoint()->deoptBranch(),
debugMessage);
// lowerExpect splits the bb from current position. There
// remains nothing to process. Breaking seems more robust
Expand Down Expand Up @@ -324,7 +342,7 @@ rir::Function* Backend::doCompile(ClosureVersion* cls, ClosureLog& log) {
std::function<void(Code*)> lowerAndScanForPromises = [&](Code* c) {
if (promMap.count(c))
return;
lower(module, c);
lower(module, cls, c, log);
toCSSA(module, c);
log.CSSA(c);
#ifdef FULLVERIFIER
Expand Down
2 changes: 1 addition & 1 deletion rir/src/compiler/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction,
auto arg = closure->formals().defaultArgs()[idx];
assert(rir::Code::check(arg) && "Default arg not compiled");
auto code = rir::Code::unpack(arg);
auto res = rir2pir.tryCreateArg(code, builder, false);
auto res = rir2pir.tryCreateArg(code, builder);
if (!res) {
failedToCompileDefaultArgs = true;
return;
Expand Down
58 changes: 44 additions & 14 deletions rir/src/compiler/native/builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,8 @@ static SEXP deoptSentinelContainer = []() {
return store;
}();

void deoptImpl(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,
bool leakedEnv, DeoptReason* deoptReason, SEXP deoptTrigger) {
SEXP deopt(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,
bool leakedEnv, DeoptReason* deoptReason, SEXP deoptTrigger) {
deoptReason->record(deoptTrigger);

assert(m->numFrames >= 1);
Expand All @@ -854,8 +854,9 @@ void deoptImpl(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,
auto le = LazyEnvironment::check(env);
if (deoptless && m->numFrames == 1 && cls != deoptlessRecursion &&
((le && !le->materialized()) ||
(!le && (!leakedEnv || !deoptlessNoLeakedEnvs)))) {
assert(m->frames[0].inPromise == false);
(!le && (!leakedEnv || !deoptlessNoLeakedEnvs))) &&
/* TODO: support deoptless when outermost frame is a promise */
!m->frames[0].inPromise) {

size_t envSize = le ? le->nargs : Rf_length(FRAME(env));
if (envSize <= DeoptContext::MAX_ENV &&
Expand Down Expand Up @@ -932,8 +933,8 @@ void deoptImpl(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,

Rf_findcontext(CTXT_BROWSER | CTXT_FUNCTION,
originalCntxt->cloenv, res);
assert(false);
return;
assert(false && "unreachable after deoptless");
return nullptr;
}
}
}
Expand All @@ -945,13 +946,35 @@ void deoptImpl(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,
if (f->body() == c)
Pool::patch(idx, deoptSentinelContainer);

CallContext call(ArglistOrder::NOT_REORDERED, c, cls,
/* nargs */ -1, src_pool_at(c->src), args,
(Immediate*)nullptr, env, R_NilValue, Context());
if (cls) {
CallContext call(ArglistOrder::NOT_REORDERED, c, cls,
/* nargs */ -1, src_pool_at(c->src), args,
(Immediate*)nullptr, env, R_NilValue, Context());

deoptFramesWithContext(&call, m, R_NilValue, m->numFrames - 1, stackHeight,
(RCNTXT*)R_GlobalContext);
assert(false);
// Deopt in a function longjumps to its context
deoptFramesWithContext(&call, m, R_NilValue, m->numFrames - 1,
stackHeight, (RCNTXT*)R_GlobalContext);
assert(false && "unreachable after deopt");
return nullptr;
} else {
// Deopt in a promise has nowhere to longjump, so it leaves the result
// on the TOS and returns here, this is immediately returned as the
// result of the promise
deoptFramesWithContext(nullptr, m, R_NilValue, m->numFrames - 1,
stackHeight, (RCNTXT*)R_GlobalContext);
return ostack_pop();
}
}

void deoptImpl(rir::Code* c, SEXP cls, DeoptMetadata* m, R_bcstack_t* args,
bool leakedEnv, DeoptReason* deoptReason, SEXP deoptTrigger) {
deopt(c, cls, m, args, leakedEnv, deoptReason, deoptTrigger);
}

SEXP deoptPromImpl(rir::Code* c, DeoptMetadata* m, R_bcstack_t* args,
bool leakedEnv, DeoptReason* deoptReason,
SEXP deoptTrigger) {
return deopt(c, nullptr, m, args, leakedEnv, deoptReason, deoptTrigger);
}

void recordTypefeedbackImpl(Opcode* pos, rir::Code* code, SEXP value) {
Expand Down Expand Up @@ -2163,8 +2186,7 @@ void initClosureContextImpl(ArglistOrder::CallId callId, rir::Code* c, SEXP ast,
}

static void endClosureContextImpl(RCNTXT* cntxt, SEXP result) {
cntxt->returnValue = result;
Rf_endcontext(cntxt);
endClosureContext(cntxt, result);
}

int ncolsImpl(SEXP v) { return getMatrixDim(v).col; }
Expand Down Expand Up @@ -2437,6 +2459,14 @@ void NativeBuiltins::initializeBuiltins() {
t::DeoptReasonPtr, t::SEXP},
false),
{llvm::Attribute::NoReturn}};
get_(Id::deoptProm) = {
"deoptProm",
(void*)&deoptPromImpl,
llvm::FunctionType::get(t::SEXP,
{t::voidPtr, t::voidPtr, t::stackCellPtr, t::i1,
t::DeoptReasonPtr, t::SEXP},
false),
{}};
get_(Id::assertFail) = {"assertFail",
(void*)&assertFailImpl,
t::void_voidPtr,
Expand Down
1 change: 1 addition & 0 deletions rir/src/compiler/native/builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct NativeBuiltins {
length,
recordTypefeedback,
deopt,
deoptProm,
assertFail,
printValue,
extract11,
Expand Down
62 changes: 53 additions & 9 deletions rir/src/compiler/native/lower_function_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,7 @@ void LowerFunctionLLVM::compile() {
stack({container(paramCode())});
additionalStackSlots++;
}

{
SmallSet<std::pair<Value*, SEXP>> bindings;
Visitor::run(code->entry, [&](Instruction* i) {
Expand Down Expand Up @@ -2352,6 +2353,27 @@ void LowerFunctionLLVM::compile() {
break;
}

case Tag::DropContext: {
auto globalContextPtrAddr =
convertToPointer(&R_GlobalContext, t::RCNTXT_ptr);
auto globalContextPtr =
builder.CreateLoad(globalContextPtrAddr);
auto callflagAddr =
builder.CreateGEP(globalContextPtr, {c(0), c(1)});
auto callflag = builder.CreateLoad(callflagAddr);
insn_assert(
builder.CreateICmpNE(
builder.CreateAnd(callflag, c(CTXT_FUNCTION)), c(0)),
"Expected R_GlobalContext to be a closure context");
// R_GlobalContext = R_GlobalContext->nextcontext
auto nextcontextAddr =
builder.CreateGEP(globalContextPtr, {c(0), c(0)});
auto nextcontext = builder.CreateLoad(nextcontextAddr);
builder.CreateStore(nextcontext, globalContextPtrAddr);
inPushContext--;
break;
}

case Tag::CastType: {
auto in = i->arg(0).val();
if (!variables_.count(i))
Expand Down Expand Up @@ -3561,15 +3583,37 @@ void LowerFunctionLLVM::compile() {
target->addExtraPoolEntry(store);
}

withCallFrame(args, [&]() {
return call(NativeBuiltins::get(NativeBuiltins::Id::deopt),
{paramCode(), paramClosure(),
convertToPointer(m, t::i8, true), paramArgs(),
c(deopt->escapedEnv, 1),
load(deopt->deoptReason()),
loadSxp(deopt->deoptTrigger())});
});
builder.CreateUnreachable();
// Deopt only returns if the outermost frame is a promise.
// In that case, the result is the returned value and we simply
// return it from here.
if (code->isPromise()) {
auto res = withCallFrame(
args,
[&]() {
return call(NativeBuiltins::get(
NativeBuiltins::Id::deoptProm),
{paramCode(),
convertToPointer(m, t::i8, true),
paramArgs(), c(deopt->escapedEnv, 1),
load(deopt->deoptReason()),
loadSxp(deopt->deoptTrigger())});
},
false);
exitBlocks.push_back(builder.GetInsertBlock());
builder.CreateRet(res);
} else {
withCallFrame(args, [&]() {
return call(
NativeBuiltins::get(NativeBuiltins::Id::deopt),
{paramCode(), paramClosure(),
convertToPointer(m, t::i8, true), paramArgs(),
c(deopt->escapedEnv, 1),
load(deopt->deoptReason()),
loadSxp(deopt->deoptTrigger())});
});
insn_assert(builder.getFalse(), "unreachable after deopt");
builder.CreateUnreachable();
}
break;
}

Expand Down
9 changes: 8 additions & 1 deletion rir/src/compiler/opt/force_dominance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,14 @@ bool ForceDominance::apply(Compiler&, ClosureVersion* cls, Code* code,
Value* eager = mkarg->eagerArg();
f->replaceUsesWith(eager);
next = bb->remove(ip);
} else if (toInline.count(f)) {
} else if (toInline.count(f) &&
// Can't inline a promise with Assumes if the
// dominating Force doesn't have a Framestate
(f->frameState() ||
Visitor::check(mkarg->prom()->entry,
[](Instruction* i) {
return !Assume::Cast(i);
}))) {
anyChange = true;
Promise* prom = mkarg->prom();
BB* split =
Expand Down
3 changes: 2 additions & 1 deletion rir/src/compiler/opt/hoist_instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ bool HoistInstruction::apply(Compiler& cmp, ClosureVersion* cls, Code* code,
while (b->isEmpty())
b = *b->predecessors().begin();

if (cs.after(b->last()).context() > cs.before(i).context()) {
if (cs.after(b->last()).numContexts() >
cs.before(i).numContexts()) {
ip = next;
continue;
}
Expand Down
Loading