Skip to content

Commit

Permalink
Handle unreachable code in bounds inference (#7866)
Browse files Browse the repository at this point in the history
* Handle unreachable code in bounds inference

* Avoid ambiguous constructor

* IRVisitor -> IRGraphVisitor

* Add success print
  • Loading branch information
abadams authored Sep 27, 2023
1 parent 9f96b25 commit 76ac233
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 13 deletions.
29 changes: 26 additions & 3 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,7 @@ class BoxesTouched : public IRGraphVisitor {
// Map variable name to all other vars which values depend on that variable.
map<VarInstance, set<VarInstance>> children;

bool in_producer{false};
bool in_producer{false}, in_unreachable{false};
map<std::string, Expr> buffer_lets;

using IRGraphVisitor::visit;
Expand Down Expand Up @@ -2810,16 +2810,23 @@ class BoxesTouched : public IRGraphVisitor {

// Fork the boxes touched and go down each path
map<string, Box> then_boxes, else_boxes;
bool then_unreachable = false, else_unreachable = false;
then_boxes.swap(boxes);
std::swap(then_unreachable, in_unreachable);
op->then_case.accept(this);
then_boxes.swap(boxes);
std::swap(then_unreachable, in_unreachable);

if (op->else_case.defined()) {
else_boxes.swap(boxes);
std::swap(else_unreachable, in_unreachable);
op->else_case.accept(this);
else_boxes.swap(boxes);
std::swap(else_unreachable, in_unreachable);
}

in_unreachable = then_unreachable && else_unreachable;

// Make sure all the then boxes have an entry on the else
// side so that the merge doesn't skip them.
for (pair<const string, Box> &i : then_boxes) {
Expand All @@ -2832,13 +2839,22 @@ class BoxesTouched : public IRGraphVisitor {
Box &then_box = then_boxes[i.first];
Box &orig_box = boxes[i.first];

if (then_box.maybe_unused()) {
if (else_unreachable) {
// Don't incorporate the condition into
// then.used. boxes_touched assumes that asserts pass, so if
// the else case contains an assert(false), conservatively
// assume the then case will unconditionally run. This
// provides more useful bounds for bounds queries on
// pipelines that use specialize_fail.
} else if (then_box.maybe_unused()) {
then_box.used = then_box.used && op->condition;
} else {
then_box.used = op->condition;
}

if (else_box.maybe_unused()) {
if (then_unreachable) {
// Conservatively assume the else case will run.
} else if (else_box.maybe_unused()) {
else_box.used = else_box.used && !op->condition;
} else {
else_box.used = !op->condition;
Expand All @@ -2850,6 +2866,13 @@ class BoxesTouched : public IRGraphVisitor {
}
}

void visit(const AssertStmt *op) override {
if (is_const_zero(op->condition)) {
in_unreachable = true;
}
IRGraphVisitor::visit(op);
}

void visit(const For *op) override {
TRACK_BOXES_TOUCHED;
TRACK_BOXES_TOUCHED_INFO("var:", op->name);
Expand Down
22 changes: 12 additions & 10 deletions src/Bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ Box box_intersection(const Box &a, const Box &b);
/** Test if box a provably contains box b */
bool box_contains(const Box &a, const Box &b);

/** Compute rectangular domains large enough to cover all the 'Call's
* to each function that occurs within a given statement or
* expression. This is useful for figuring out what regions of things
* to evaluate. */
/** Compute rectangular domains large enough to cover all the 'Call's to each
* function that occurs within a given statement or expression. This is useful
* for figuring out what regions of things to evaluate. Respects control flow
* (e.g. encodes if statement conditions), but assumes all encountered asserts
* pass. If it encounters an assert(false) in one if branch, assumes the
* opposite if branch runs unconditionally. */
// @{
std::map<std::string, Box> boxes_required(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand All @@ -118,9 +120,9 @@ std::map<std::string, Box> boxes_required(Stmt s,
const FuncValueBounds &func_bounds = empty_func_value_bounds());
// @}

/** Compute rectangular domains large enough to cover all the
* 'Provides's to each function that occurs within a given statement
* or expression. */
/** Compute rectangular domains large enough to cover all the 'Provides's to
* each function that occurs within a given statement or expression. Handles
* asserts in the same way as boxes_required. */
// @{
std::map<std::string, Box> boxes_provided(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand All @@ -130,9 +132,9 @@ std::map<std::string, Box> boxes_provided(Stmt s,
const FuncValueBounds &func_bounds = empty_func_value_bounds());
// @}

/** Compute rectangular domains large enough to cover all the 'Call's
* and 'Provides's to each function that occurs within a given
* statement or expression. */
/** Compute rectangular domains large enough to cover all the 'Call's and
* 'Provides's to each function that occurs within a given statement or
* expression. Handles asserts in the same way as boxes_required. */
// @{
std::map<std::string, Box> boxes_touched(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ tests(GROUPS correctness
bounds_of_multiply.cpp
bounds_of_split.cpp
bounds_query.cpp
bounds_query_respects_specialize_fail.cpp
buffer_t.cpp
c_function.cpp
callable.cpp
Expand Down
39 changes: 39 additions & 0 deletions test/correctness/bounds_query_respects_specialize_fail.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;
using namespace Halide::ConciseCasts;

int main(int argc, char **argv) {

ImageParam im(UInt(8), 1);
Func f;
Var x;

f(x) = im(x);

im.dim(0).set_stride(Expr());
f.specialize(im.dim(0).stride() == 1);
f.specialize(im.dim(0).stride() == 2);
f.specialize_fail("unreachable");

Callable c = f.compile_to_callable({im});

Halide::Runtime::Buffer<uint8_t> in_buf(nullptr, {halide_dimension_t{0, 0, 0}});
Halide::Runtime::Buffer<uint8_t> out_buf(32);

c(in_buf, out_buf);

if (in_buf.dim(0).stride() != 1 ||
in_buf.dim(0).extent() != 32) {
printf("Unexpected bounds query result. stride = %d, extent = %d\n",
in_buf.dim(0).stride(),
in_buf.dim(0).extent());
return 1;
}

printf("Success!\n");

return 0;
}

0 comments on commit 76ac233

Please sign in to comment.