Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle unreachable code in bounds inference #7866

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,7 @@ class BoxesTouched : public IRGraphVisitor {
// Map variable name to all other vars which values depend on that variable.
map<VarInstance, set<VarInstance>> children;

bool in_producer{false};
bool in_producer{false}, in_unreachable{false};
map<std::string, Expr> buffer_lets;

using IRGraphVisitor::visit;
Expand Down Expand Up @@ -2810,16 +2810,23 @@ class BoxesTouched : public IRGraphVisitor {

// Fork the boxes touched and go down each path
map<string, Box> then_boxes, else_boxes;
bool then_unreachable = false, else_unreachable = false;
then_boxes.swap(boxes);
std::swap(then_unreachable, in_unreachable);
op->then_case.accept(this);
then_boxes.swap(boxes);
std::swap(then_unreachable, in_unreachable);

if (op->else_case.defined()) {
else_boxes.swap(boxes);
std::swap(else_unreachable, in_unreachable);
op->else_case.accept(this);
else_boxes.swap(boxes);
std::swap(else_unreachable, in_unreachable);
}

in_unreachable = then_unreachable && else_unreachable;

// Make sure all the then boxes have an entry on the else
// side so that the merge doesn't skip them.
for (pair<const string, Box> &i : then_boxes) {
Expand All @@ -2832,13 +2839,22 @@ class BoxesTouched : public IRGraphVisitor {
Box &then_box = then_boxes[i.first];
Box &orig_box = boxes[i.first];

if (then_box.maybe_unused()) {
if (else_unreachable) {
// Don't incorporate the condition into
// then.used. boxes_touched assumes that asserts pass, so if
// the else case contains an assert(false), conservatively
// assume the then case will unconditionally run. This
// provides more useful bounds for bounds queries on
// pipelines that use specialize_fail.
} else if (then_box.maybe_unused()) {
then_box.used = then_box.used && op->condition;
} else {
then_box.used = op->condition;
}

if (else_box.maybe_unused()) {
if (then_unreachable) {
// Conservatively assume the else case will run.
} else if (else_box.maybe_unused()) {
else_box.used = else_box.used && !op->condition;
} else {
else_box.used = !op->condition;
Expand All @@ -2850,6 +2866,13 @@ class BoxesTouched : public IRGraphVisitor {
}
}

void visit(const AssertStmt *op) override {
if (is_const_zero(op->condition)) {
in_unreachable = true;
}
IRVisitor::visit(op);
}

void visit(const For *op) override {
TRACK_BOXES_TOUCHED;
TRACK_BOXES_TOUCHED_INFO("var:", op->name);
Expand Down
22 changes: 12 additions & 10 deletions src/Bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ Box box_intersection(const Box &a, const Box &b);
/** Test if box a provably contains box b */
bool box_contains(const Box &a, const Box &b);

/** Compute rectangular domains large enough to cover all the 'Call's
* to each function that occurs within a given statement or
* expression. This is useful for figuring out what regions of things
* to evaluate. */
/** Compute rectangular domains large enough to cover all the 'Call's to each
* function that occurs within a given statement or expression. This is useful
* for figuring out what regions of things to evaluate. Respects control flow
* (e.g. encodes if statement conditions), but assumes all encountered asserts
* pass. If it encounters an assert(false) in one if branch, assumes the
* opposite if branch runs unconditionally. */
// @{
std::map<std::string, Box> boxes_required(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand All @@ -118,9 +120,9 @@ std::map<std::string, Box> boxes_required(Stmt s,
const FuncValueBounds &func_bounds = empty_func_value_bounds());
// @}

/** Compute rectangular domains large enough to cover all the
* 'Provides's to each function that occurs within a given statement
* or expression. */
/** Compute rectangular domains large enough to cover all the 'Provides's to
* each function that occurs within a given statement or expression. Handles
* asserts in the same way as boxes_required. */
// @{
std::map<std::string, Box> boxes_provided(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand All @@ -130,9 +132,9 @@ std::map<std::string, Box> boxes_provided(Stmt s,
const FuncValueBounds &func_bounds = empty_func_value_bounds());
// @}

/** Compute rectangular domains large enough to cover all the 'Call's
* and 'Provides's to each function that occurs within a given
* statement or expression. */
/** Compute rectangular domains large enough to cover all the 'Call's and
* 'Provides's to each function that occurs within a given statement or
* expression. Handles asserts in the same way as boxes_required. */
// @{
std::map<std::string, Box> boxes_touched(const Expr &e,
const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ tests(GROUPS correctness
bounds_of_multiply.cpp
bounds_of_split.cpp
bounds_query.cpp
bounds_query_respects_specialize_fail.cpp
buffer_t.cpp
c_function.cpp
callable.cpp
Expand Down
37 changes: 37 additions & 0 deletions test/correctness/bounds_query_respects_specialize_fail.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;
using namespace Halide::ConciseCasts;

int main(int argc, char **argv) {

ImageParam im(UInt(8), 1);
Func f;
Var x;

f(x) = im(x);

im.dim(0).set_stride(Expr());
f.specialize(im.dim(0).stride() == 1);
f.specialize(im.dim(0).stride() == 2);
f.specialize_fail("unreachable");

Callable c = f.compile_to_callable({im});

Halide::Runtime::Buffer<uint8_t> in_buf(nullptr, {halide_dimension_t{0, 0, 0}});
Halide::Runtime::Buffer<uint8_t> out_buf({32});

c(in_buf, out_buf);

if (in_buf.dim(0).stride() != 1 ||
in_buf.dim(0).extent() != 32) {
printf("Unexpected bounds query result. stride = %d, extent = %d\n",
in_buf.dim(0).stride(),
in_buf.dim(0).extent());
return 1;
}

return 0;
}