Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type check lambdas when last op (body) is set #217

Merged
merged 6 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dialects/matrix/passes/lower_matrix_mediumlevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
auto [new_mem, new_mat] = new_acc->projs<2>();
acc = {new_mem, new_mat};
current_mut->set(dim_nat_def, for_call);
current_mut->set(true, for_call);
current_mut = body;
}

Expand Down Expand Up @@ -292,6 +292,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
acc = {current_mem, element_acc};
cont = write_back;

// TODO this is copy&paste code from above
for (auto idx : in_indices) {
char for_name[32];
sprintf(for_name, "forIn_%lu", idx);
Expand All @@ -306,7 +307,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
auto [new_mem, new_element] = new_acc->projs<2>();
acc = {new_mem, new_element};
current_mut->set(dim_nat_def, for_call);
current_mut->set(true, for_call);
current_mut = body;
}

Expand Down
2 changes: 1 addition & 1 deletion docs/langref.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ What is more, since they are bound by a *let declaration*, they have the exact s

The following expressions for applying `f` are also equivalent:
```
f .Nat ((23, 42),.cn res: .Nat = use(res))
f .Nat ((23, 42), .cn res: .Nat = use(res))
.ret res = f .Nat $ (23, 42); use(res)
```

Expand Down
18 changes: 11 additions & 7 deletions lit/fun.thorin
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

.lam Ptr(T: *) -> * = %mem.Ptr (T, 0);

.fun foo(mem: %mem.M, x: %core.I32)@(%core.icmp.e (x, 23:%core.I32)) -> [%mem.M, %core.I32] = return (mem, %core.wrap.add 0 (x, 1:%core.I32));
.fun foo(mem: %mem.M, x: %core.I32)@(%core.icmp.e (x, 23:%core.I32)) -> [%mem.M, %core.I32] =
return (mem, %core.wrap.add 0 (x, 1:%core.I32));

.fun .extern main(mem: %mem.M, argc: %core.I32, argv: Ptr (Ptr %core.I8)) -> [%mem.M, %core.I32] = {
.fun .extern main(mem: %mem.M, argc: %core.I32, argv: Ptr (Ptr %core.I8)) -> [%mem.M, %core.I32] =
.ret (`mem, x) = foo $ (mem, 23:%core.I32);
.ret (`mem, y) = foo $ (mem, 23:%core.I32);
return (mem, %core.wrap.add 0 (x, y))
};
return (mem, %core.wrap.add 0 (x, y));

.lam f1(T: *)((x y: T), return: T -> ⊥) -> ⊥ = return x;
.con f2(T: *)((x y: T), return: .Cn T) = return x;
Expand All @@ -25,8 +25,12 @@
.let F2 =.Cn[T:*][T, T][.Cn T];
.let F3 =.Fn[T:*][T, T] -> T;


.let _ = .ret res = f1 .Nat $ (23, 42); res;
.let _ = f1 .Nat ((23, 42), .cn res: .Nat = res);
.fun bar(cond: .Bool) -> .Nat =
.con t() =
.ret res = f1 .Nat $ (23, 42);
return res;
.con f() =
f1 .Nat ((23, 42), .cn res: .Nat = return res);
(f, t)#cond ();

// CHECK-DAG: return{{.*}}48
46 changes: 36 additions & 10 deletions thorin/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ bool Checker::equiv(Ref r1, Ref r2) {
}

assert(!i1 && !i2);
if (d1->gid() > d2->gid()) std::swap(d1, d2); // normalize
// normalize: Lit to right; then sort by gid
if ((d1->isa<Lit>() && !d2->isa<Lit>()) || (d1->gid() > d2->gid())) std::swap(d1, d2);

if (auto [it, ins] = equiv_.emplace(std::pair(d1, d2), Equiv::Unknown); !ins) {
switch (it->second) {
Expand All @@ -93,8 +94,19 @@ bool Checker::equiv_internal(Ref d1, Ref d2) {
if (!equiv(d1->type(), d2->type())) return false;
if (d1->isa<Top>() || d2->isa<Top>()) return equiv(d1->type(), d2->type());

struct Pop {
~Pop() {
if (vars) vars->pop_back();
}

Vars* vars = nullptr;
} pop;

if (auto n1 = d1->isa_mut()) {
if (auto n2 = d2->isa_mut()) vars_.emplace_back(n1, n2);
if (auto n2 = d2->isa_mut()) {
vars_.emplace_back(n1, n2);
pop.vars = &vars_; // make sure vars_ is popped again
}
}

if (d1->isa<Sigma, Arr>()) {
Expand All @@ -107,12 +119,28 @@ bool Checker::equiv_internal(Ref d1, Ref d2) {
}
}

if (auto umax = d1->isa<UMax>(); umax && umax->has_dep(Dep::Infer)) {
if (auto l = d2->isa<Lit>()) {
for (auto op : umax->ops())
if (auto infer = op->isa_mut<Infer>(); infer && !infer->is_set()) infer->set(l);
}
d1 = umax->rebuild(world(), umax->type(), umax->ops());
}

if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return false;

if (auto var = d1->isa<Var>()) { // vars are equal if they appeared under the same binder
for (auto [n1, n2] : vars_)
if (var->mut() == n1) return d2->as<Var>()->mut() == n2;
// TODO what if Var is free?
if (auto var1 = d1->isa<Var>()) { // vars are equal if they appeared under the same binder
auto var2 = d2->as<Var>();
bool bound1 = false, bound2 = false;
for (auto [n1, n2] : vars_) {
if (var1->mut() == n1) {
bound1 = true;
return d2->as<Var>()->mut() == n2;
}
assert(var1->mut() != n2);
if (var2->mut() == n1 || var2->mut() == n2) bound2 = true;
}
if (!bound1 && !bound2) return true; // both var1 and var2 are free
return false;
}

Expand Down Expand Up @@ -185,19 +213,17 @@ void Sigma::check() {

void Lam::check() {
auto& w = world();
return; // TODO
if (!w.checker().equiv(filter()->type(), w.type_bool()))
error(filter(), "filter '{}' of lambda is of type '{}' but must be of type '.Bool'", filter(),
filter()->type());
if (!w.checker().equiv(body()->type(), codom()))
error(body(), "body '{}' of lambda is of type '{}' but its codomain is of type '{}'", body(), body()->type(),
codom());
error(body(), "body '{}' of lambda is of type \n'{}' but its codomain is of type \n'{}'", body(),
body()->type(), codom());
}

void Pi::check() {
auto& w = world();
auto t = infer(dom(), codom());

if (!w.checker().equiv(t, type()))
error(type(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t);
}
Expand Down
3 changes: 2 additions & 1 deletion thorin/check.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class Checker {

World* world_;
DefDefMap<Equiv> equiv_;
std::deque<std::pair<Def*, Def*>> vars_;
using Vars = std::deque<std::pair<Def*, Def*>>;
Vars vars_;
};

} // namespace thorin