diff --git a/dialects/clos/clos.cpp b/dialects/clos/clos.cpp index f72f734cb7..a64789aee7 100644 --- a/dialects/clos/clos.cpp +++ b/dialects/clos/clos.cpp @@ -20,6 +20,23 @@ namespace thorin::clos { +size_t env_idx(Defs defs) { + if (defs.empty()) return 0; + if (match(defs.front()) || match(defs.front()->type())) return 1; + return 0; +} + +Sigma* doms2clos(World& world, Defs doms) { + auto sigma = world.nom_sigma(world.type<1>(), 3, world.dbg("Clos"))->set(0, world.type()); + auto env_t = sigma->var(0_s); + auto new_dom = env_insert(world, doms, env_t); + sigma->set(1, world.cn(new_dom)); + sigma->set(2, env_t); + return sigma; +} + +Sigma* pi2clos(World& world, const Pi* pi) { return doms2clos(world, pi->doms()); } + /* * ClosLit */ @@ -99,7 +116,7 @@ const Sigma* isa_clos_type(const Def* def) { return (pi && pi->is_cn() && pi->num_ops() > 1_u64 && pi->dom(Clos_Env_Param) == var) ? sig : nullptr; } -Sigma* clos_type(const Pi* pi) { return ctype(pi->world(), pi->doms(), nullptr)->as_nom(); } +Sigma* clos_type(const Pi* pi) { return pi2clos(pi->world(), pi); } const Pi* clos_type_to_pi(const Def* ct, const Def* new_env_type) { assert(isa_clos_type(ct)); diff --git a/dialects/clos/clos.h b/dialects/clos/clos.h index 7cb6047f55..40d3e00026 100644 --- a/dialects/clos/clos.h +++ b/dialects/clos/clos.h @@ -48,9 +48,7 @@ class ClosLit { ///@} operator bool() const { return def_ != nullptr; } - operator const Tuple*() { return def_; } - const Tuple* operator->() { assert(def_); return def_; @@ -163,6 +161,22 @@ inline const Def* clos_remove_env(const Def* tup_or_sig) { inline const Def* clos_sub_env(const Def* tup_or_sig, const Def* new_env) { return tup_or_sig->refine(Clos_Env_Param, new_env); } + +size_t env_idx(Defs); + +template +const Def* env_insert(World& world, Defs defs, const Def* env) { + size_t n = defs.size(); + size_t x = env_idx(defs); + + DefArray new_ops(n + 1); + for (size_t i = 0, j = 0; i != n + 1; ++i) new_ops[i] = i == x ? env : defs[j++]; + + return type ? world.sigma(new_ops) : world.tuple(new_ops); +} + +Sigma* doms2clos(World& world, Defs doms); + ///@} } // namespace thorin::clos diff --git a/dialects/clos/phase/clos_conv.cpp b/dialects/clos/phase/clos_conv.cpp index f13c839bac..f0289e6265 100644 --- a/dialects/clos/phase/clos_conv.cpp +++ b/dialects/clos/phase/clos_conv.cpp @@ -99,6 +99,30 @@ DefSet& FreeDefAna::run(Lam* lam) { return node->fvs; } +/* + * Closure Conversion - reimpl + */ + +const Def* ClosConv_::rewrite_structural(const Def* def) { + if (auto pi = def->isa(); pi && pi->is_cn()) { + if (auto ret_pi = pi->ret_pi()) { + size_t n = pi->num_doms(); + DefArray new_doms( + n, [&](auto i) { return i != n - 1 ? rewrite(pi->dom(i)) : world().cn(rewrite(ret_pi->dom())); }); + return doms2clos(world(), new_doms); + } + auto new_dom = rewrite(pi->dom()); + return doms2clos(world(), new_dom->projs()); + } else if (auto a = match(def)) { + switch (a.id()) { + case attr::esc: break; + } + } + return Rewriter::rewrite_structural(def); +} + +const Def* ClosConv_::rewrite_nom(Def* nom) { return Rewriter::rewrite_nom(nom); } + /* * Closure Conversion */ diff --git a/dialects/clos/phase/clos_conv.h b/dialects/clos/phase/clos_conv.h index c536883ab7..d97fcaea63 100644 --- a/dialects/clos/phase/clos_conv.h +++ b/dialects/clos/phase/clos_conv.h @@ -19,7 +19,7 @@ class FreeDefAna { FreeDefAna(World& world) : world_(world) , cur_pass_id(1) - , lam2nodes_(){}; + , lam2nodes_() {} /// FreeDefAna::run will compute free defs (FD) that appear in @p lam%s body. /// Nominal Def%s are only considered free if they are annotated with Clos::freeBB or @@ -73,6 +73,19 @@ class FreeDefAna { DefMap> lam2nodes_; }; +class ClosConv_ : public RWPhase { +public: + ClosConv_(World& world) + : RWPhase(world, "clos_conv") + , fva_(world) {} + + const Def* rewrite_structural(const Def*) override; + const Def* rewrite_nom(Def*) override; + +private: + FreeDefAna fva_; +}; + /// Performs *typed closure conversion*. /// This is based on the [Simply Typed Closure Conversion](https://dl.acm.org/doi/abs/10.1145/237721.237791). /// Closures are represented using tuples: `[Env: *, .Cn [Env, Args..], Env]`. diff --git a/thorin/def.cpp b/thorin/def.cpp index 471c5b001b..55d3df925b 100644 --- a/thorin/def.cpp +++ b/thorin/def.cpp @@ -380,7 +380,7 @@ const Def* Def::proj(nat_t a, nat_t i, const Def* dbg) const { World& w = world(); - if (isa() || isa()) { + if (isa()) { return op(i); } else if (auto arr = isa()) { if (arr->arity()->isa()) return arr->body();