Skip to content

Commit

Permalink
Use and_condition_over_domain to predicate the reducing definition in…
Browse files Browse the repository at this point in the history
… rfactor
  • Loading branch information
alexreinking committed Nov 24, 2024
1 parent 7ffdb66 commit 3fb172f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,8 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {

// Preserved update definition
{
auto [preserved_rdom, _] = project_rdom(preserved_rdims, definition);
auto [preserved_rdom, preserved_map] = project_rdom(preserved_rdims, definition);
preserved_rdom.set_predicate(simplify(substitute(preserved_map, preserved_rdom.predicate())));

// Replace the current definition with calls to the intermediate func.
vector<Expr> dim_exprs = copy_convert<Expr>(dim_vars);
Expand Down Expand Up @@ -957,8 +958,13 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
}
}

Scope<Interval> intm_rdom;
for (const auto &[var, min, extent] : intm.update(0).definition.schedule().rvars()) {
intm_rdom.push(var, Interval{min, min + extent - 1});
}

definition.args() = dim_exprs;
definition.predicate() = const_true(); // TODO: replace with strongest postcondition of the intermediate predicate with the eliminated rvars havoc'd
definition.predicate() = !and_condition_over_domain(simplify(!preserved_rdom.predicate()), intm_rdom);
definition.schedule().dims() = std::move(reducing_dims);
definition.schedule().rvars() = preserved_rdom.domain();
definition.values() = substitute(replacements, prover_result.pattern.ops);
Expand Down

0 comments on commit 3fb172f

Please sign in to comment.