diff --git a/dialects/matrix/passes/lower_matrix_mediumlevel.cpp b/dialects/matrix/passes/lower_matrix_mediumlevel.cpp index 66e4fa7f63..32bd930fef 100644 --- a/dialects/matrix/passes/lower_matrix_mediumlevel.cpp +++ b/dialects/matrix/passes/lower_matrix_mediumlevel.cpp @@ -250,7 +250,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) { iterator[idx] = world.call(world.type_idx(dim_nat_def), iter); auto [new_mem, new_mat] = new_acc->projs<2>(); acc = {new_mem, new_mat}; - current_mut->set(dim_nat_def, for_call); + current_mut->set(false, for_call); // TODO correct filter? current_mut = body; } @@ -292,6 +292,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) { acc = {current_mem, element_acc}; cont = write_back; + // TODO this is copy&paste code from above for (auto idx : in_indices) { char for_name[32]; sprintf(for_name, "forIn_%lu", idx); @@ -306,7 +307,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) { iterator[idx] = world.call(world.type_idx(dim_nat_def), iter); auto [new_mem, new_element] = new_acc->projs<2>(); acc = {new_mem, new_element}; - current_mut->set(dim_nat_def, for_call); + current_mut->set(false, for_call); // TODO current_mut = body; } diff --git a/thorin/check.cpp b/thorin/check.cpp index c6e241e4e6..d7a878bd91 100644 --- a/thorin/check.cpp +++ b/thorin/check.cpp @@ -93,8 +93,19 @@ bool Checker::equiv_internal(Ref d1, Ref d2) { if (!equiv(d1->type(), d2->type())) return false; if (d1->isa() || d2->isa()) return equiv(d1->type(), d2->type()); + struct Pop { + ~Pop() { + if (vars) vars->pop_back(); + } + + Vars* vars = nullptr; + } pop; + if (auto n1 = d1->isa_mut()) { - if (auto n2 = d2->isa_mut()) vars_.emplace_back(n1, n2); + if (auto n2 = d2->isa_mut()) { + vars_.emplace_back(n1, n2); + pop.vars = &vars_; // make sure vars_ is popped again + } } if (d1->isa()) { @@ -109,10 +120,18 @@ bool Checker::equiv_internal(Ref d1, Ref d2) { if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return false; - if (auto var = d1->isa()) { // vars are equal if they appeared under the same binder - for (auto [n1, n2] : vars_) - if (var->mut() == n1) return d2->as()->mut() == n2; - // TODO what if Var is free? + if (auto var1 = d1->isa()) { // vars are equal if they appeared under the same binder + auto var2 = d2->as(); + bool bound1 = false, bound2 = false; + for (auto [n1, n2] : vars_) { + if (var1->mut() == n1) { + bound1 = true; + return d2->as()->mut() == n2; + } + assert(var1->mut() != n2); + if (var2->mut() == n1 || var2->mut() == n2) bound2 = true; + } + if (!bound1 && !bound2) return true; // both var1 and var2 are free return false; } @@ -185,19 +204,17 @@ void Sigma::check() { void Lam::check() { auto& w = world(); - return; // TODO if (!w.checker().equiv(filter()->type(), w.type_bool())) error(filter(), "filter '{}' of lambda is of type '{}' but must be of type '.Bool'", filter(), filter()->type()); if (!w.checker().equiv(body()->type(), codom())) - error(body(), "body '{}' of lambda is of type '{}' but its codomain is of type '{}'", body(), body()->type(), - codom()); + error(body(), "body '{}' of lambda is of type \n'{}' but its codomain is of type \n'{}'", body(), + body()->type(), codom()); } void Pi::check() { auto& w = world(); auto t = infer(dom(), codom()); - if (!w.checker().equiv(t, type())) error(type(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t); } diff --git a/thorin/check.h b/thorin/check.h index f0de50b799..c7b7bf3134 100644 --- a/thorin/check.h +++ b/thorin/check.h @@ -71,7 +71,8 @@ class Checker { World* world_; DefDefMap equiv_; - std::deque> vars_; + using Vars = std::deque>; + Vars vars_; }; } // namespace thorin