From 87e52428a9e2fd459e57d345c0f34bd08491b05c Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Wed, 11 Dec 2024 14:26:10 -0500 Subject: [PATCH] Partition splits lists while validating to drop clean-up step. --- src/Func.cpp | 58 ++++++++++++++-------------------------------------- src/Func.h | 3 ++- 2 files changed, 17 insertions(+), 44 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 556ca19bab0f..5d33ab5a0d6d 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -713,7 +713,7 @@ pair project_rdom(const vector &dims, con } // namespace -void Stage::rfactor_validate_args(const std::vector> &preserved, const AssociativeOp &prover_result) { +pair, vector> Stage::rfactor_validate_args(const std::vector> &preserved, const AssociativeOp &prover_result) { const vector &dims = definition.schedule().dims(); user_assert(prover_result.associative()) @@ -760,6 +760,7 @@ void Stage::rfactor_validate_args(const std::vector> &prese } // Check that no Vars were fused into RVars + vector var_splits, rvar_splits; Scope<> rdims; for (const ReductionVariable &rv : definition.schedule().rvars()) { rdims.push(rv.var); @@ -771,6 +772,9 @@ void Stage::rfactor_validate_args(const std::vector> &prese rdims.pop(split.old_var); rdims.push(split.outer); rdims.push(split.inner); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); } break; case Split::FuseVars: @@ -784,6 +788,9 @@ void Stage::rfactor_validate_args(const std::vector> &prese rdims.pop(split.outer); rdims.pop(split.inner); rdims.push(split.old_var); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); } break; case Split::PurifyRVar: @@ -791,10 +798,14 @@ void Stage::rfactor_validate_args(const std::vector> &prese if (rdims.contains(split.old_var)) { rdims.pop(split.old_var); rdims.push(split.outer); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); } break; } } + return std::make_pair(std::move(var_splits), std::move(rvar_splits)); } Func Stage::rfactor(const vector> &preserved) { @@ -806,7 +817,7 @@ Func Stage::rfactor(const vector> &preserved) { // its identity for each value in the definition if it is a Tuple const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values()); - rfactor_validate_args(preserved, prover_result); + const auto &[var_splits, rvar_splits] = rfactor_validate_args(preserved, prover_result); const vector dim_vars_exprs = [&] { vector result; @@ -928,6 +939,7 @@ Func Stage::rfactor(const vector> &preserved) { intm.function().update(0).schedule() = definition.schedule().get_copy(); intm.function().update(0).schedule().dims() = std::move(intm_dims); intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); + intm.function().update(0).schedule().splits() = var_splits; } // Preserved update definition @@ -979,47 +991,7 @@ Func Stage::rfactor(const vector> &preserved) { definition.predicate() = preserved_rdom.predicate(); definition.schedule().dims() = std::move(reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); - } - - // Clean up the splits lists - for (Stage st : {*this, intm.update(0)}) { - Scope<> dims; - for (const Var &v : st.dim_vars) { - dims.push(v.name()); - } - for (const ReductionVariable &rv : st.definition.schedule().rvars()) { - dims.push(rv.var); - } - vector new_splits; - for (const Split &split : st.definition.schedule().splits()) { - switch (split.split_type) { - case Split::SplitVar: - if (dims.contains(split.old_var)) { - dims.pop(split.old_var); - dims.push(split.outer); - dims.push(split.inner); - new_splits.push_back(split); - } - break; - case Split::FuseVars: - if (dims.contains(split.outer) && dims.contains(split.inner)) { - dims.pop(split.outer); - dims.pop(split.inner); - dims.push(split.old_var); - new_splits.push_back(split); - } - break; - case Split::PurifyRVar: - case Split::RenameVar: - if (dims.contains(split.old_var)) { - dims.pop(split.old_var); - dims.push(split.outer); - new_splits.push_back(split); - } - break; - } - } - st.definition.schedule().splits().swap(new_splits); + definition.schedule().splits() = var_splits; } return intm; diff --git a/src/Func.h b/src/Func.h index d502d2618db0..5a34cd803500 100644 --- a/src/Func.h +++ b/src/Func.h @@ -90,7 +90,8 @@ class Stage { Stage &compute_with(LoopLevel loop_level, const std::map &align); - void rfactor_validate_args(const std::vector> &preserved, const Internal::AssociativeOp &prover_result); + std::pair, std::vector> + rfactor_validate_args(const std::vector> &preserved, const Internal::AssociativeOp &prover_result); public: Stage(Internal::Function f, Internal::Definition d, size_t stage_index)