Skip to content

Commit

Permalink
Merge pull request #181 from NeuralCoder3/ho_codegen
Browse files Browse the repository at this point in the history
Higher order codegen
  • Loading branch information
leissa authored Mar 13, 2023
2 parents 53e51ca + 2acef27 commit 1374e0c
Show file tree
Hide file tree
Showing 92 changed files with 2,671 additions and 290 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
*.s
.DS_Store
.cache
build*
build*/
vgcore.*
7 changes: 6 additions & 1 deletion dialects/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ add_thorin_dialect(clos
clos/pass/rw/clos2sjlj.h
clos/pass/rw/clos_conv_prep.cpp
clos/pass/rw/clos_conv_prep.h
clos/pass/rw/phase_wrapper.h
clos/phase/clos_conv.cpp
clos/phase/clos_conv.h
clos/phase/lower_typed_clos.cpp
clos/phase/lower_typed_clos.h
mem/passes/fp/copy_prop.cpp
mem/passes/rw/reshape.cpp
mem/phases/rw/add_mem.cpp
DEPENDS
mem
affine
Expand Down Expand Up @@ -137,6 +140,8 @@ add_thorin_dialect(mem
mem/passes/rw/alloc2malloc.h
mem/passes/rw/remem_elim.cpp
mem/passes/rw/remem_elim.h
mem/passes/rw/reshape.cpp
mem/passes/rw/reshape.h
mem/phases/rw/add_mem.cpp
mem/phases/rw/add_mem.h
DEPENDS
Expand All @@ -150,7 +155,6 @@ add_thorin_dialect(opt
SOURCES
opt/opt.cpp
opt/opt.h
opt/normalizers.cpp
DEPENDS
compile
mem
Expand All @@ -164,6 +168,7 @@ add_thorin_dialect(refly
refly/refly.cpp
refly/passes/remove_perm.h
refly/passes/remove_perm.cpp
refly/passes/debug_dump.h
refly/normalizers.cpp
DEPENDS
compile
Expand Down
10 changes: 5 additions & 5 deletions dialects/affine/passes/lower_for.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ const Def* LowerFor::rewrite(const Def* def) {

// reduce the body to remove the cn parameter
auto nom_body = body->as_nom<Lam>();
auto new_body = nom_body->stub(w, w.cn(w.sigma()))->set(body->dbg());
new_body->set(nom_body->reduce(w.tuple({iter, acc, yield_lam})));
auto new_body = nom_body->stub(w, w.cn(acc->type()))->set(body->dbg());
new_body->set(nom_body->reduce(w.tuple({iter, new_body->var(), yield_lam})));

// break
auto if_else_cn = w.cn(w.sigma());
auto if_else_cn = w.cn(acc->type());
auto if_else = w.nom_lam(if_else_cn);
if_else->app(false, brk, acc);
if_else->app(false, brk, if_else->var());

auto cmp = w.call(core::icmp::ul, Defs{iter, end});
for_lam->branch(false, cmp, new_body, if_else, w.tuple());
for_lam->branch(false, cmp, new_body, if_else, acc);
}

DefArray for_args{for_ax->num_args() - 2, [&](size_t i) { return for_ax->arg(i); }};
Expand Down
30 changes: 6 additions & 24 deletions dialects/clos/clos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
#include "dialects/clos/pass/rw/branch_clos_elim.h"
#include "dialects/clos/pass/rw/clos2sjlj.h"
#include "dialects/clos/pass/rw/clos_conv_prep.h"
#include "dialects/clos/phase/clos_conv.h"
#include "dialects/clos/phase/lower_typed_clos.h"
#include "dialects/clos/pass/rw/phase_wrapper.h"
#include "dialects/mem/mem.h"
#include "dialects/mem/passes/fp/copy_prop.h"
#include "dialects/mem/passes/rw/reshape.h"
#include "dialects/mem/phases/rw/add_mem.h"
#include "dialects/refly/passes/debug_dump.h"

namespace thorin::clos {

Expand Down Expand Up @@ -131,26 +133,6 @@ Ref ctype(World& w, Defs doms, Ref env_type) {
[&](auto i) { return clos_insert_env(i, env_type, [&](auto j) { return doms[j]; }); }));
}

/*
* Pass Wrappers
*/

class ClosConvWrapper : public RWPass<ClosConvWrapper, Lam> {
public:
ClosConvWrapper(PassMan& man)
: RWPass(man, "clos_conv") {}

void prepare() override { ClosConv(world()).run(); }
};

class LowerTypedClosWrapper : public RWPass<LowerTypedClosWrapper, Lam> {
public:
LowerTypedClosWrapper(PassMan& man)
: RWPass(man, "lower_typed_clos") {}

void prepare() override { LowerTypedClos(world()).run(); }
};

} // namespace thorin::clos

using namespace thorin;
Expand All @@ -159,11 +141,11 @@ extern "C" THORIN_EXPORT DialectInfo thorin_get_dialect_info() {
return {"clos",
[](Passes& passes) {
register_pass<clos::clos_conv_prep_pass, clos::ClosConvPrep>(passes, nullptr);
register_pass<clos::clos_conv_pass, clos::ClosConvWrapper>(passes);
register_pass<clos::clos_conv_pass, ClosConvWrapper>(passes);
register_pass<clos::branch_clos_pass, clos::BranchClosElim>(passes);
register_pass<clos::lower_typed_clos_prep_pass, clos::LowerTypedClosPrep>(passes);
register_pass<clos::clos2sjlj_pass, clos::Clos2SJLJ>(passes);
register_pass<clos::lower_typed_clos_pass, clos::LowerTypedClosWrapper>(passes);
register_pass<clos::lower_typed_clos_pass, LowerTypedClosWrapper>(passes);
// TODO:; remove after ho_codegen merge
passes[flags_t(Axiom::Base<clos::eta_red_bool_pass>)] = [&](World&, PipelineBuilder& builder, Ref app) {
auto bb = app->as<App>()->arg();
Expand Down
6 changes: 3 additions & 3 deletions dialects/clos/clos.thorin
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
%compile.combined_phase
(%compile.phase_list
(%compile.single_pass_phase nullptr)
// optimization_phase
// (%compile.single_pass_phase (%mem.reshape_pass %mem.reshape_flat))
// (%compile.single_pass_phase %mem.add_mem_pass)
optimization_phase
(%compile.single_pass_phase (%mem.reshape_pass %mem.reshape_flat))
(%compile.single_pass_phase %mem.add_mem_pass)
(%compile.single_pass_phase %clos.clos_conv_prep_pass)
(%compile.single_pass_phase (%compile.eta_exp_pass nullptr))
(%compile.single_pass_phase %clos.clos_conv_pass)
Expand Down
9 changes: 4 additions & 5 deletions dialects/clos/pass/rw/clos_conv_prep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ void ClosConvPrep::enter() {
}
}
}
if (auto body = curr_nom()->body()->isa<App>();
!wrapper_.contains(curr_nom()) && body && body->callee_type()->is_cn())
ignore_ = false;
else
ignore_ = true;

auto body = curr_nom()->body()->isa<App>();
// Skip if the nominal is already wrapped or the body is undefined/no continuation.
ignore_ = !(body && body->callee_type()->is_cn()) || wrapper_.contains(curr_nom());
}

const App* ClosConvPrep::rewrite_arg(const App* app) {
Expand Down
27 changes: 27 additions & 0 deletions dialects/clos/pass/rw/phase_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <thorin/config.h>
#include <thorin/pass/pass.h>

#include "thorin/dialects.h"

#include "dialects/clos/phase/clos_conv.h"
#include "dialects/clos/phase/lower_typed_clos.h"

using namespace thorin;

class ClosConvWrapper : public RWPass<ClosConvWrapper, Lam> {
public:
ClosConvWrapper(PassMan& man)
: RWPass(man, "clos_conv") {}

void prepare() override { clos::ClosConv(world()).run(); }
};

class LowerTypedClosWrapper : public RWPass<LowerTypedClosWrapper, Lam> {
public:
LowerTypedClosWrapper(PassMan& man)
: RWPass(man, "lower_typed_clos") {}

void prepare() override { clos::LowerTypedClos(world()).run(); }
};
40 changes: 40 additions & 0 deletions dialects/compile/compile.thorin
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,34 @@
.Pi PassList: *, %compile.Pass = PassList;
.Pi CombinedPhase: *, %compile.Phase = CombinedPhase;
///
/// (This is a forward declaration for opt.thorin.)
///
.ax %compile.Dialect: *;
///
/// ## Expressions
///
/// ### registered dialects
///
/// We expect the name in the tag before the `_` to be the name of the dialect (as given in `DialectInfo.plugin_name`).
/// (This is a forward declaration for opt.thorin.)
///
.ax %compile.core_dialect : %compile.Dialect;
.ax %compile.mem_dialect : %compile.Dialect;
.ax %compile.demo_dialect : %compile.Dialect;
.ax %compile.affine_dialect : %compile.Dialect;
.ax %compile.autodiff_dialect: %compile.Dialect;
.ax %compile.clos_dialect : %compile.Dialect;
.ax %compile.direct_dialect : %compile.Dialect;
.ax %compile.refly_dialect : %compile.Dialect;
///
/// ### %opt.is_loaded
///
/// Indicates whether a dialect is loaded.
/// The normalizer will statically evaluate this expression to a constant boolean.
/// TODO: find correct point (not at parsing but before compilation)
///
// .ax %opt.is_loaded: %opt.Dialect -> .Bool;
///
/// ### %compile.pipe
///
/// Given n phases, returns the representation of a pipeline.
Expand Down Expand Up @@ -119,4 +145,18 @@
.lam .extern _fallback_compile [] -> Pipeline = {
default_core_pipeline
};
///
/// ### Dependent Passes and Phases
///
.let empty_pass = %compile.nullptr_pass;
.let empty_phase = %compile.passes_to_phase 0 ();
.ax %compile.dialect_select: Π [T:*] -> %compile.Dialect -> T -> T -> T;
.let dialect_phase = %compile.dialect_select %compile.Phase;
.let dialect_pass = %compile.dialect_select %compile.Pass;
.lam dialect_cond_phase ![dialect: %compile.Dialect,phase: %compile.Phase] -> %compile.Phase = {
dialect_phase dialect phase empty_phase
};
.lam dialect_cond_pass ![dialect: %compile.Dialect,pass: %compile.Pass] -> %compile.Pass = {
dialect_pass dialect pass empty_pass
};

2 changes: 1 addition & 1 deletion dialects/compile/passes/internal_cleanup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace thorin::compile {

void InternalCleanup::enter() {
Lam* lam = curr_nom();
if (lam->sym()->starts_with("internal_")) {
if (lam->sym()->starts_with(prefix_)) {
lam->make_internal();
world().DLOG("internalized {}", lam);
}
Expand Down
8 changes: 6 additions & 2 deletions dialects/compile/passes/internal_cleanup.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ namespace thorin::compile {

class InternalCleanup : public RWPass<InternalCleanup, Lam> {
public:
InternalCleanup(PassMan& man)
: RWPass(man, "internal_cleanup") {}
InternalCleanup(PassMan& man, const char* prefix = "internal_")
: RWPass(man, "internal_cleanup")
, prefix_(prefix) {}

void enter() override;

private:
const char* prefix_;
};

} // namespace thorin::compile
49 changes: 28 additions & 21 deletions dialects/core/be/ll/ll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,9 @@ std::string Emitter::convert(const Def* type) {
}

std::string Emitter::convert_ret_pi(const Pi* pi) {
switch (pi->num_doms()) {
case 0: return "void";
case 1:
if (match<mem::M>(pi->dom())) return "void";
return convert(pi->dom());
case 2:
if (match<mem::M>(pi->dom(0))) return convert(pi->dom(1));
if (match<mem::M>(pi->dom(1))) return convert(pi->dom(0));
[[fallthrough]];
default: return convert(pi->dom());
}
auto dom = mem::strip_mem_ty(pi->dom());
if (dom == world().sigma()) { return "void"; }
return convert(dom);
}

/*
Expand Down Expand Up @@ -314,6 +306,7 @@ void Emitter::emit_epilogue(Lam* lam) {
}
} else if (auto ex = app->callee()->isa<Extract>(); ex && app->callee_type()->is_basicblock()) {
emit_unsafe(app->arg());

auto c = emit(ex->index());
if (ex->tuple()->num_projs() == 2) {
auto [f, t] = ex->tuple()->projs<2>([this](auto def) { return emit(def); });
Expand Down Expand Up @@ -448,12 +441,14 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {

std::string prev = "undef";
auto t = convert(tuple->type());
for (size_t i = 0, n = tuple->num_projs(); i != n; ++i) {
auto e = tuple->proj(n, i);
if (auto v_elem = emit_unsafe(e); !v_elem.empty()) {
auto t_elem = convert(e->type());
auto namei = name + "." + std::to_string(i);
prev = bb.assign(namei, "insertvalue {} {}, {} {}, {}", t, prev, t_elem, v_elem, i);
for (size_t src = 0, dst = 0, n = tuple->num_projs(); src != n; ++src) {
auto e = tuple->proj(n, src);
if (auto elem = emit_unsafe(e); !elem.empty()) {
auto elem_t = convert(e->type());
// TODO: check dst vs src
auto namei = name + "." + std::to_string(dst);
prev = bb.assign(namei, "insertvalue {} {}, {} {}, {}", t, prev, elem_t, elem, dst);
dst++;
}
}
return prev;
Expand Down Expand Up @@ -736,10 +731,19 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {
declare("i8* @malloc(i64)");

emit_unsafe(malloc->arg(0));
auto v_size = emit(malloc->arg(1));
auto t_ptr = convert(force<mem::Ptr>(def->proj(1)->type()));
bb.assign(name + ".i8", "call i8* @malloc(i64 {})", v_size);
return bb.assign(name, "bitcast i8* {} to {}", name + ".i8", t_ptr);
auto size = emit(malloc->arg(1));
auto ptr_t = convert(force<mem::Ptr>(def->proj(1)->type()));
bb.assign(name + ".i8", "call i8* @malloc(i64 {})", size);
return bb.assign(name, "bitcast i8* {} to {}", name + ".i8", ptr_t);
} else if (auto free = match<mem::free>(def)) {
declare("void @free(i8*)");
emit_unsafe(free->arg(0));
auto ptr = emit(free->arg(1));
auto ptr_t = convert(force<mem::Ptr>(free->arg(1)->type()));

bb.assign(name + ".i8", "bitcast {} {} to i8*", ptr_t, ptr);
bb.tail("call void @free(i8* {})", name + ".i8");
return {};
} else if (auto mslot = match<mem::mslot>(def)) {
emit_unsafe(mslot->arg(0));
// TODO array with size
Expand Down Expand Up @@ -944,6 +948,9 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {

return bb.assign(name, "{} {} {} to {}", op, t_src, v_src, t_dst);
}
auto& world = def->world();
world.DLOG("unhandled def: {} : {}", def, def->type());
def->dump();

def->dump(1);
err("unhandled def in LLVM backend: {}", def);
Expand Down
Loading

0 comments on commit 1374e0c

Please sign in to comment.