diff --git a/.vscode/settings.json b/.vscode/settings.json index 3c989f58d..7a261015a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -112,6 +112,8 @@ "cfenv": "cpp", "csignal": "cpp", "__functional_base_03": "cpp", - "__memory": "cpp" + "__memory": "cpp", + "__bits": "cpp", + "__availability": "cpp" } } diff --git a/rir/src/compiler/compiler.cpp b/rir/src/compiler/compiler.cpp index 8d5bd04f1..c4fee994b 100644 --- a/rir/src/compiler/compiler.cpp +++ b/rir/src/compiler/compiler.cpp @@ -239,6 +239,43 @@ void Compiler::compileClosure(Closure* closure, rir::Function* optFunction, bool MEASURE_COMPILER_PERF = getenv("PIR_MEASURE_COMPILER") ? true : false; +static void resetVersionsRefCountState(Module* m) { + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + v->staticCallRefCount = 0; + v->isClone = false; + }); + }); + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + auto check = [&](Instruction* i) { + if (auto call = StaticCall::Cast(i)) { + call->updateVersionRefCount(); + } else if (auto call = CallInstruction::CastCall(i)) { + if (auto cls = call->tryGetCls()) { + + if (auto dispatchedVersion = call->tryDispatch(cls)) { + dispatchedVersion->staticCallRefCount++; + } + } + } + }; + + Visitor::run(v->entry, check); + v->eachPromise([&](Promise* p) { Visitor::run(p->entry, check); }); + }); + }); + + m->eachPirClosure([&](Closure* c) { + c->eachVersion([&](ClosureVersion* v) { + if (v->staticCallRefCount == 0) + v->staticCallRefCount = 1; + }); + }); +} + static void findUnreachable(Module* m, Log& log, const std::string& where) { std::unordered_map> reachable; bool changed = true; @@ -283,6 +320,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { i->printRecursive(msg, 2); log.warn(msg.str()); } + found(call->tryDispatch()); found(call->tryOptimisticDispatch()); found(call->hint); @@ -310,6 +348,7 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { m->eachPirClosure([&](Closure* c) { const auto& reachableVersions = reachable[c]; c->eachVersion([&](ClosureVersion* v) { + // assert(c->getVersion(v->context()) == v); if (!reachableVersions.count(v->context())) { toErase.push_back({v->owner(), v->context()}); log.close(v); @@ -320,6 +359,8 @@ static void findUnreachable(Module* m, Log& log, const std::string& where) { for (auto e : toErase) e.first->erase(e.second); + + resetVersionsRefCountState(m); }; void Compiler::optimizeClosureVersion(ClosureVersion* v) { diff --git a/rir/src/compiler/native/lower_function_llvm.cpp b/rir/src/compiler/native/lower_function_llvm.cpp index 622779a91..aa5ca17b6 100644 --- a/rir/src/compiler/native/lower_function_llvm.cpp +++ b/rir/src/compiler/native/lower_function_llvm.cpp @@ -3387,6 +3387,9 @@ void LowerFunctionLLVM::compile() { if (calli->isReordered()) callId = pushArgReordering(calli->getArgOrderOrig()); + if (!target) + assert(target && "target is null!"); + if (!target->owner()->hasOriginClosure()) { setVal( i, withCallFrame(args, [&]() -> llvm::Value* { diff --git a/rir/src/compiler/opt/eager_calls.cpp b/rir/src/compiler/opt/eager_calls.cpp index 63e2d8051..bee4727ef 100644 --- a/rir/src/compiler/opt/eager_calls.cpp +++ b/rir/src/compiler/opt/eager_calls.cpp @@ -15,6 +15,46 @@ namespace rir { namespace pir { +static ClosureVersion* +cloneOrReplaceVersion(Closure* target, ClosureVersion* version, + StaticCall* call, Context assumptions, + const std::function& + updateVersionWithNewAssumptions) { + if (version != call->lastSeen) { + version->staticCallRefCount++; + } + + ClosureVersion* newVersion; + + call->lastSeen = nullptr; + if (version->isClone || version->staticCallRefCount > 1) { + + newVersion = target->cloneWithAssumptions( + version, assumptions, [&](ClosureVersion* newCls) { + updateVersionWithNewAssumptions(newCls); + }); + + if (newVersion != version) { + newVersion->isClone = true; + + call->lastSeen = newVersion; + version->staticCallRefCount--; + } + + } else { + + newVersion = target->replaceWithAssumptions( + version, assumptions, updateVersionWithNewAssumptions); + + call->lastSeen = newVersion; + if (newVersion->staticCallRefCount == 0) { + newVersion->staticCallRefCount = 1; + } + } + + return newVersion; +} + bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, AbstractLog& log, size_t) const { AvailableCheckpoints checkpoint(cls, code, log); @@ -287,8 +327,9 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, // version. Maybe we should limit this at some point, to avoid // version explosion. if (availableAssumptions.isImproving(version)) { - auto newVersion = target->cloneWithAssumptions( - version, availableAssumptions, + + + auto updateVersionWithNewAssumptions = [&](ClosureVersion* newCls) { Visitor::run(newCls->entry, [&](Instruction* i) { if (auto f = Force::Cast(i)) { @@ -308,7 +349,12 @@ bool EagerCalls::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } } }); - }); + }; + + auto newVersion = cloneOrReplaceVersion( + target, version, call, availableAssumptions, + updateVersionWithNewAssumptions); + call->hint = newVersion; assert(call->tryDispatch() == newVersion); ip = next; diff --git a/rir/src/compiler/opt/match_call_args.cpp b/rir/src/compiler/opt/match_call_args.cpp index 561d4a67c..f70075be3 100644 --- a/rir/src/compiler/opt/match_call_args.cpp +++ b/rir/src/compiler/opt/match_call_args.cpp @@ -239,9 +239,11 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, if (staticallyArgmatched && target) { anyChange = true; + if (auto c = call) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); + auto nc = new StaticCall( c->env(), target, asmpt, matchedArgs, argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); @@ -249,6 +251,7 @@ bool MatchCallArgs::apply(Compiler& cmp, ClosureVersion* cls, Code* code, } else if (auto c = namedCall) { assert(!usemethodTarget); auto cls = c->cls()->followCastsAndForce(); + auto nc = new StaticCall( c->env(), target, asmpt, matchedArgs, argOrderOrig, c->frameStateOrTs(), c->srcIdx, cls); diff --git a/rir/src/compiler/pir/closure.cpp b/rir/src/compiler/pir/closure.cpp index cb9e3df6f..988531403 100644 --- a/rir/src/compiler/pir/closure.cpp +++ b/rir/src/compiler/pir/closure.cpp @@ -54,6 +54,25 @@ ClosureVersion* Closure::cloneWithAssumptions(ClosureVersion* version, return copy; } +ClosureVersion* Closure::replaceWithAssumptions(ClosureVersion* version, + const Context& asmpt, + const MaybeClsVersion& change) { + assert(versions.count(version->context()) > 0); + + auto newCtx = version->context() | asmpt; + if (versions.count(newCtx)) { + return versions.at(newCtx); + } + + versions.erase(version->context()); + + version->setContext(newCtx); + + versions[newCtx] = version; + change(version); + return version; +} + ClosureVersion* Closure::findCompatibleVersion(const Context& ctx) const { // ordered by number of assumptions for (auto& candidate : versions) { diff --git a/rir/src/compiler/pir/closure.h b/rir/src/compiler/pir/closure.h index e07df5061..52cc56a40 100644 --- a/rir/src/compiler/pir/closure.h +++ b/rir/src/compiler/pir/closure.h @@ -51,6 +51,7 @@ class Closure { Context userContext_; public: + bool matchesUserContext(Context c) const { return c.smaller(this->userContext_); } @@ -90,6 +91,10 @@ class Closure { const Context& asmpt, const MaybeClsVersion& change); + ClosureVersion* replaceWithAssumptions(ClosureVersion* cls, + const Context& asmpt, + const MaybeClsVersion& change); + typedef std::function ClosureVersionIterator; void eachVersion(ClosureVersionIterator it) const; diff --git a/rir/src/compiler/pir/closure_version.h b/rir/src/compiler/pir/closure_version.h index 7421dc226..5be9049b9 100644 --- a/rir/src/compiler/pir/closure_version.h +++ b/rir/src/compiler/pir/closure_version.h @@ -42,11 +42,12 @@ class ClosureVersion : public Code { const bool root; rir::Function* optFunction; + size_t staticCallRefCount = 0; private: Closure* owner_; std::vector promises_; - const Context optimizationContext_; + Context optimizationContext_; std::string name_; std::string nameSuffix_; @@ -59,9 +60,14 @@ class ClosureVersion : public Code { friend class Closure; public: + bool isClone = false; + ClosureVersion* clone(const Context& newContext); const Context& context() const { return optimizationContext_; } + void setContext(const Context& newContext) { + optimizationContext_ = newContext; + } Properties properties; diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index 8735fd107..8c7eb8a06 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -1255,6 +1255,7 @@ StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion, Context givenContext, const std::vector& args, const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, unsigned srcIdx, Value* runtimeClosure) + : VarLenInstructionWithEnvSlot(PirType::val(), callerEnv, srcIdx), cls_(clsVersion->owner()), argOrderOrig(argOrderOrig), givenContext(givenContext) { @@ -1276,7 +1277,27 @@ StaticCall::StaticCall(Value* callerEnv, ClosureVersion* clsVersion, } } + auto dispatched = tryDispatch(); assert(tryDispatch() == clsVersion); + + // Update version-refCount fields + dispatched->staticCallRefCount++; + lastSeen = dispatched; +} + +void StaticCall::updateVersionRefCount() { + lastSeen = nullptr; + auto target = tryDispatch(); + if (target) { + lastSeen = target; + target->staticCallRefCount++; + } +} + +Instruction* StaticCall::clone() const { + auto sc = StaticCall::Cast(InstructionImplementation::clone()); + sc->updateVersionRefCount(); + return sc; } PirType StaticCall::inferType(const GetType& getType) const { diff --git a/rir/src/compiler/pir/instruction.h b/rir/src/compiler/pir/instruction.h index 63739a3c0..ba098b2eb 100644 --- a/rir/src/compiler/pir/instruction.h +++ b/rir/src/compiler/pir/instruction.h @@ -62,6 +62,7 @@ namespace pir { class BB; class Closure; + class Phi; struct InstrArg { @@ -2268,6 +2269,7 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { const ArglistOrder::CallArglistOrder& argOrderOrig, Value* fs, unsigned srcIdx, Value* runtimeClosure = Tombstone::closure()); + ClosureVersion* lastSeen = nullptr; Context givenContext; ClosureVersion* hint = nullptr; @@ -2279,6 +2281,9 @@ class VLIE(StaticCall, Effects::Any()), public CallInstruction { size_t nCallArgs() const override { return nargs() - 3; } + Instruction* clone() const override; + void updateVersionRefCount(); + void eachNamedCallArg(const NamedArgumentValueIterator& it) const override { for (size_t i = 0; i < nCallArgs(); ++i) it(R_NilValue, arg(i + 2).val()); diff --git a/rir/src/compiler/rir2pir/rir2pir.cpp b/rir/src/compiler/rir2pir/rir2pir.cpp index 49309d7c1..daba908f1 100644 --- a/rir/src/compiler/rir2pir/rir2pir.cpp +++ b/rir/src/compiler/rir2pir/rir2pir.cpp @@ -928,13 +928,14 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos, assert(!inlining()); auto fs = insert.registerFrameState(srcCode, nextPos, stack, inPromise()); + auto cl = insert( new StaticCall(insert.env, f, given, matchedArgs, std::move(argOrderOrig), fs, ast, f->owner()->closureEnv() == Env::notClosed() ? guardedCallee : Tombstone::closure())); - cl->effects.set(Effect::DependsOnAssume); + push(cl); auto innerc = MkCls::Cast(guardedCallee->followCastsAndForce());