Skip to content

Commit

Permalink
New scheduling directive to disallow partitioning.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcourteaux committed Oct 6, 2023
1 parent 120e5fd commit aeb3d37
Show file tree
Hide file tree
Showing 33 changed files with 101 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator {
if (is_no_op(body)) {
return body;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3725,7 +3725,8 @@ void bounds_test() {
{Add::make(Call::make(in, input_site_1),
Call::make(in, input_site_2))},
output_site,
const_true()));
const_true()),
true);

map<string, Box> r;
r = boxes_required(loop);
Expand Down
4 changes: 2 additions & 2 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ class BoundsInference : public IRMutator {
}
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

Scope<> let_vars_in_scope;
Expand Down Expand Up @@ -1389,7 +1389,7 @@ Stmt bounds_inference(Stmt s,
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s);
s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s, false);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
Expand Down
2 changes: 1 addition & 1 deletion src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CanonicalizeGPUVars : public IRMutator {
body.same_as(op->body)) {
return op;
} else {
return For::make(name, min, extent, op->for_type, op->device_api, body);
return For::make(name, min, extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class InjectHVXLocks : public IRMutator {
body = acquire_hvx_context(body, target);
body = substitute("uses_hvx", true, body);
Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->device_api, body, op->allow_partitioning);
Stmt prolog =
IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock());
Stmt epilog =
Expand Down Expand Up @@ -407,7 +407,7 @@ class InjectHVXLocks : public IRMutator {
// halide_qurt_unlock
// }
s = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->device_api, body, op->allow_partitioning);
}

uses_hvx = old_uses_hvx;
Expand Down
18 changes: 18 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,24 @@ Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &
return gpu_tile(x, y, z, x, y, z, tx, ty, tz, x_size, y_size, z_size, tail, device_api);
}

Stage &Stage::disallow_partitioning(const VarOrRVar &var) {
definition.schedule().touched() = true;
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
found = true;
dim.allow_partitioning = false;
}
}
user_assert(found)
<< "In schedule for " << name()
<< ", could not find var " << var.name()
<< " to mark as disallow partitioning.\n"
<< dump_argument_list();
return *this;
}

Stage &Stage::hexagon(const VarOrRVar &x) {
set_dim_device_api(x, DeviceAPI::Hexagon);
return *this;
Expand Down
1 change: 1 addition & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class Stage {
const std::vector<Expr> &factors,
TailStrategy tail = TailStrategy::Auto);
Stage &reorder(const std::vector<VarOrRVar> &vars);
Stage &disallow_partitioning(const VarOrRVar &var);

template<typename... Args>
HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<VarOrRVar, Args...>::value, Stage &>::type
Expand Down
20 changes: 10 additions & 10 deletions src/FuseGPUThreadLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class NormalizeDimensionality : public IRMutator {
}
while (max_depth < block_size.threads_dimensions()) {
string name = thread_names[max_depth];
s = For::make("." + name, 0, 1, ForType::GPUThread, device_api, s);
s = For::make("." + name, 0, 1, ForType::GPUThread, device_api, s, false);
max_depth++;
}
return s;
Expand Down Expand Up @@ -398,7 +398,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
Expr v = Variable::make(Int(32), loop_name);
host_side_preamble = substitute(op->name, v, host_side_preamble);
host_side_preamble = For::make(loop_name, new_min, new_extent,
ForType::Serial, DeviceAPI::None, host_side_preamble);
ForType::Serial, DeviceAPI::None, host_side_preamble, op->allow_partitioning);
if (old_preamble.defined()) {
host_side_preamble = Block::make(old_preamble, host_side_preamble);
}
Expand All @@ -407,7 +407,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
}

return For::make(op->name, new_min, new_extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
}

Stmt visit(const Block *op) override {
Expand Down Expand Up @@ -1093,7 +1093,7 @@ class ExtractRegisterAllocations : public IRMutator {
allocations.swap(old);
}

return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->device_api, body);
return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down Expand Up @@ -1254,7 +1254,7 @@ class InjectThreadBarriers : public IRMutator {
body = Block::make(body, make_barrier(0));
}
return For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
} else {
return IRMutator::visit(op);
}
Expand Down Expand Up @@ -1405,14 +1405,14 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
string thread_id = "." + thread_names[0];
// Add back in any register-level allocations
body = register_allocs.rewrap(body, thread_id);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->device_api, body);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->device_api, body, op->allow_partitioning);

// Rewrap the whole thing in other loops over threads
for (int i = 1; i < block_size.threads_dimensions(); i++) {
thread_id = "." + thread_names[i];
body = register_allocs.rewrap(body, thread_id);
body = For::make("." + thread_names[i], 0, block_size.num_threads(i),
ForType::GPUThread, op->device_api, body);
ForType::GPUThread, op->device_api, body, op->allow_partitioning);
}
thread_id.clear();
body = register_allocs.rewrap(body, thread_id);
Expand All @@ -1428,7 +1428,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
if (body.same_as(op->body)) {
return op;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
} else {
return IRMutator::visit(op);
Expand Down Expand Up @@ -1497,7 +1497,7 @@ class ZeroGPULoopMins : public IRMutator {
internal_assert(op);
Expr adjusted = Variable::make(Int(32), op->name) + op->min;
Stmt body = substitute(op->name, adjusted, op->body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->device_api, body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
return stmt;
}
Expand Down Expand Up @@ -1580,7 +1580,7 @@ class AddConditionToALoop : public IRMutator {
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api,
IfThenElse::make(condition, op->body, Stmt()));
IfThenElse::make(condition, op->body, Stmt()), op->allow_partitioning);
}

public:
Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOffload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ class InjectHexagonRpc : public IRMutator {
body = LetStmt::make(loop->name, loop->min, loop->body);
} else {
body = For::make(loop->name, loop->min, loop->extent, loop->for_type,
DeviceAPI::None, loop->body);
DeviceAPI::None, loop->body, loop->allow_partitioning);
}

// Build a closure for the device code.
Expand Down
3 changes: 2 additions & 1 deletion src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ Stmt ProducerConsumer::make_consume(const std::string &name, Stmt body) {
return ProducerConsumer::make(name, false, std::move(body));
}

Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body) {
Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body, bool allow_partitioning) {
internal_assert(min.defined()) << "For of undefined\n";
internal_assert(extent.defined()) << "For of undefined\n";
internal_assert(min.type() == Int(32)) << "For with non-integer min\n";
Expand All @@ -354,6 +354,7 @@ Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type,
node->min = std::move(min);
node->extent = std::move(extent);
node->for_type = for_type;
node->allow_partitioning = allow_partitioning;
node->device_api = device_api;
node->body = std::move(body);
return node;
Expand Down
3 changes: 2 additions & 1 deletion src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,9 @@ struct For : public StmtNode<For> {
ForType for_type;
DeviceAPI device_api;
Stmt body;
bool allow_partitioning;

static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body, bool allow_partitioning);

bool is_unordered_parallel() const {
return Halide::Internal::is_unordered_parallel(for_type);
Expand Down
2 changes: 1 addition & 1 deletion src/IRMutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Stmt IRMutator::visit(const For *op) {
return op;
}
return For::make(op->name, std::move(min), std::move(extent),
op->for_type, op->device_api, std::move(body));
op->for_type, op->device_api, std::move(body), op->allow_partitioning);
}

Stmt IRMutator::visit(const Store *op) {
Expand Down
4 changes: 2 additions & 2 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ void IRPrinter::test() {
internal_assert(expr_source.str() == "((x + 3)*((y/2) + 17))");

Stmt store = Store::make("buf", (x * 17) / (x - 3), y - 1, Parameter(), const_true(), ModulusRemainder());
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, DeviceAPI::Host, store);
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, DeviceAPI::Host, store, true);
vector<Expr> args(1);
args[0] = x % 3;
Expr call = Call::make(i32, "buf", args, Call::Extern);
Stmt store2 = Store::make("out", call + 1, x, Parameter(), const_true(), ModulusRemainder(3, 5));
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, DeviceAPI::Host, store2);
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, DeviceAPI::Host, store2, true);

Stmt producer = ProducerConsumer::make_produce("buf", for_loop);
Stmt consumer = ProducerConsumer::make_consume("buf", for_loop2);
Expand Down
6 changes: 3 additions & 3 deletions src/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class LICM : public IRMutator {
internal_assert(loop);

new_stmt = For::make(loop->name, loop->min, loop->extent,
loop->for_type, loop->device_api, mutate(loop->body));
loop->for_type, loop->device_api, mutate(loop->body), loop->allow_partitioning);

// Wrap lets for the lifted invariants
for (size_t i = 0; i < exprs.size(); i++) {
Expand Down Expand Up @@ -564,15 +564,15 @@ class HoistIfStatements : public IRMutator {
is_pure(i->condition) &&
!expr_uses_var(i->condition, op->name)) {
Stmt s = For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, i->then_case);
op->for_type, op->device_api, i->then_case, op->allow_partitioning);
return IfThenElse::make(i->condition, s);
}
}
if (body.same_as(op->body)) {
return op;
} else {
return For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/LoopCarry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class LoopCarry : public IRMutator {
if (body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

// Inject the scratch buffer allocations.
Expand Down
3 changes: 2 additions & 1 deletion src/LowerParallelTasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ struct LowerParallelTasks : public IRMutator {
Variable::make(Int(32), loop_extent_name),
ForType::Serial,
DeviceAPI::None,
t.body);
t.body,
true);
} else {
internal_assert(is_const_one(t.extent));
}
Expand Down
4 changes: 2 additions & 2 deletions src/LowerWarpShuffles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class LowerWarpShuffles : public IRMutator {
allocations.clear();

return For::make(op->name, op->min, warp_size,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
} else {
return IRMutator::visit(op);
}
Expand Down Expand Up @@ -731,7 +731,7 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator {
} else {
debug(3) << "Successfully hoisted shuffle out of for loop\n";
}
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

Stmt visit(const Store *op) override {
Expand Down
16 changes: 10 additions & 6 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ class PartitionLoops : public IRMutator {
bool in_gpu_loop = false;

Stmt visit(const For *op) override {
if (!op->allow_partitioning) {
return IRMutator::visit(op);
}

Stmt body = op->body;

ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop ||
Expand Down Expand Up @@ -706,16 +710,16 @@ class PartitionLoops : public IRMutator {
// Bust simple serial for loops up into three.
if (op->for_type == ForType::Serial && !op->body.as<Acquire>()) {
stmt = For::make(op->name, min_steady, max_steady - min_steady,
op->for_type, op->device_api, simpler_body);
op->for_type, op->device_api, simpler_body, true);

if (make_prologue) {
prologue = For::make(op->name, op->min, min_steady - op->min,
op->for_type, op->device_api, prologue);
op->for_type, op->device_api, prologue, true);
stmt = Block::make(prologue, stmt);
}
if (make_epilogue) {
epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
op->for_type, op->device_api, epilogue);
op->for_type, op->device_api, epilogue, true);
stmt = Block::make(stmt, epilogue);
}
} else {
Expand Down Expand Up @@ -743,7 +747,7 @@ class PartitionLoops : public IRMutator {
stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
}
}
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt, true);
}

if (make_epilogue) {
Expand Down Expand Up @@ -866,7 +870,7 @@ class RenormalizeGPULoops : public IRMutator {
internal_assert(!expr_uses_var(f->min, op->name) &&
!expr_uses_var(f->extent, op->name));
Stmt inner = LetStmt::make(op->name, op->value, f->body);
inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner, f->allow_partitioning);
return mutate(inner);
} else if (a && in_gpu_loop && !in_thread_loop) {
internal_assert(a->extents.size() == 1);
Expand Down Expand Up @@ -944,7 +948,7 @@ class RenormalizeGPULoops : public IRMutator {
for_a->min.same_as(for_b->min) &&
for_a->extent.same_as(for_b->extent)) {
Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner, for_a->allow_partitioning);
return mutate(inner);
} else {
internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
Expand Down
6 changes: 3 additions & 3 deletions src/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class InjectPlaceholderPrefetch : public IRMutator {

Stmt stmt;
if (!body.same_as(op->body)) {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, std::move(body));
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, std::move(body), op->allow_partitioning);
} else {
stmt = op;
}
Expand Down Expand Up @@ -303,7 +303,7 @@ class ReducePrefetchDimension : public IRMutator {
stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2],
ForType::Serial, DeviceAPI::None, stmt);
ForType::Serial, DeviceAPI::None, stmt, true);
}
debug(5) << "\nReduce prefetch to " << max_dim << " dim:\n"
<< "Before:\n"
Expand Down Expand Up @@ -374,7 +374,7 @@ class SplitPrefetch : public IRMutator {
stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, extents[i],
ForType::Serial, DeviceAPI::None, stmt);
ForType::Serial, DeviceAPI::None, stmt, true);
}
debug(5) << "\nSplit prefetch to max of " << max_byte_size << " bytes:\n"
<< "Before:\n"
Expand Down
Loading

0 comments on commit aeb3d37

Please sign in to comment.