Skip to content

Commit

Permalink
Partition splits lists while validating to drop clean-up step.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinking committed Dec 11, 2024
1 parent b0e9b29 commit 87e5242
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 44 deletions.
58 changes: 15 additions & 43 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, con

} // namespace

void Stage::rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &preserved, const AssociativeOp &prover_result) {
pair<vector<Split>, vector<Split>> Stage::rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &preserved, const AssociativeOp &prover_result) {
const vector<Dim> &dims = definition.schedule().dims();

user_assert(prover_result.associative())
Expand Down Expand Up @@ -760,6 +760,7 @@ void Stage::rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &prese
}

// Check that no Vars were fused into RVars
vector<Split> var_splits, rvar_splits;
Scope<> rdims;
for (const ReductionVariable &rv : definition.schedule().rvars()) {
rdims.push(rv.var);
Expand All @@ -771,6 +772,9 @@ void Stage::rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &prese
rdims.pop(split.old_var);
rdims.push(split.outer);
rdims.push(split.inner);
rvar_splits.emplace_back(split);
} else {
var_splits.emplace_back(split);
}
break;
case Split::FuseVars:
Expand All @@ -784,17 +788,24 @@ void Stage::rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &prese
rdims.pop(split.outer);
rdims.pop(split.inner);
rdims.push(split.old_var);
rvar_splits.emplace_back(split);
} else {
var_splits.emplace_back(split);
}
break;
case Split::PurifyRVar:
case Split::RenameVar:
if (rdims.contains(split.old_var)) {
rdims.pop(split.old_var);
rdims.push(split.outer);
rvar_splits.emplace_back(split);
} else {
var_splits.emplace_back(split);
}
break;
}
}
return std::make_pair(std::move(var_splits), std::move(rvar_splits));
}

Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
Expand All @@ -806,7 +817,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
// its identity for each value in the definition if it is a Tuple
const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values());

rfactor_validate_args(preserved, prover_result);
const auto &[var_splits, rvar_splits] = rfactor_validate_args(preserved, prover_result);

const vector<Expr> dim_vars_exprs = [&] {
vector<Expr> result;
Expand Down Expand Up @@ -928,6 +939,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
intm.function().update(0).schedule() = definition.schedule().get_copy();
intm.function().update(0).schedule().dims() = std::move(intm_dims);
intm.function().update(0).schedule().rvars() = intermediate_rdom.domain();
intm.function().update(0).schedule().splits() = var_splits;
}

// Preserved update definition
Expand Down Expand Up @@ -979,47 +991,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
definition.predicate() = preserved_rdom.predicate();
definition.schedule().dims() = std::move(reducing_dims);
definition.schedule().rvars() = preserved_rdom.domain();
}

// Clean up the splits lists
for (Stage st : {*this, intm.update(0)}) {
Scope<> dims;
for (const Var &v : st.dim_vars) {
dims.push(v.name());
}
for (const ReductionVariable &rv : st.definition.schedule().rvars()) {
dims.push(rv.var);
}
vector<Split> new_splits;
for (const Split &split : st.definition.schedule().splits()) {
switch (split.split_type) {
case Split::SplitVar:
if (dims.contains(split.old_var)) {
dims.pop(split.old_var);
dims.push(split.outer);
dims.push(split.inner);
new_splits.push_back(split);
}
break;
case Split::FuseVars:
if (dims.contains(split.outer) && dims.contains(split.inner)) {
dims.pop(split.outer);
dims.pop(split.inner);
dims.push(split.old_var);
new_splits.push_back(split);
}
break;
case Split::PurifyRVar:
case Split::RenameVar:
if (dims.contains(split.old_var)) {
dims.pop(split.old_var);
dims.push(split.outer);
new_splits.push_back(split);
}
break;
}
}
st.definition.schedule().splits().swap(new_splits);
definition.schedule().splits() = var_splits;
}

return intm;
Expand Down
3 changes: 2 additions & 1 deletion src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class Stage {

Stage &compute_with(LoopLevel loop_level, const std::map<std::string, LoopAlignStrategy> &align);

void rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &preserved, const Internal::AssociativeOp &prover_result);
std::pair<std::vector<Internal::Split>, std::vector<Internal::Split>>
rfactor_validate_args(const std::vector<std::pair<RVar, Var>> &preserved, const Internal::AssociativeOp &prover_result);

public:
Stage(Internal::Function f, Internal::Definition d, size_t stage_index)
Expand Down

0 comments on commit 87e5242

Please sign in to comment.