Skip to content

Commit

Permalink
Use std::optional to clean up some code and prevent use-after-free bu…
Browse files Browse the repository at this point in the history
…gs (#8484)

* Make as_const_* return a std::optional instead of a pointer

To prevent use-after-free bugs

* Also use std::optional for get_md_string and get_md_bool
  • Loading branch information
abadams authored Dec 3, 2024
1 parent e5e2510 commit 435fb23
Show file tree
Hide file tree
Showing 49 changed files with 408 additions and 495 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Checks: >
bugprone-terminating-continue,
bugprone-throw-keyword-missing,
bugprone-too-small-loop-variable,
bugprone-unchecked-optional-access,
-bugprone-unchecked-optional-access, # Too many false-positives
bugprone-undefined-memory-manipulation,
bugprone-undelegated-constructor,
bugprone-unhandled-exception-at-new,
Expand Down
2 changes: 1 addition & 1 deletion apps/onnx/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void prepare_random_input(
const Tensor &t = pipeline.model->tensors.at(input_name);
std::vector<int> input_shape;
for (int i = 0; i < t.shape.size(); ++i) {
const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]);
auto dim = Halide::Internal::as_const_int(t.shape[i]);
if (!dim) {
// The dimension isn't fixed: use the estimated typical value instead if
// one was provided.
Expand Down
46 changes: 23 additions & 23 deletions apps/onnx/onnx_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,8 @@ Halide::Func generate_padding_expr(
Halide::Expr pad_before = pads[i];
Halide::Expr pad_after = input_shape[i + skip_dims] + pad_before - 1;
padding_extents.emplace_back(pad_before, pad_after);
const int64_t *p1 = Halide::Internal::as_const_int(pad_before);
const int64_t *p2 =
auto p1 = Halide::Internal::as_const_int(pad_before);
auto p2 =
Halide::Internal::as_const_int(pads[rank - skip_dims - i]);
if (!p1 || *p1 != 0 || !p2 || *p2 != 0) {
maybe_has_padding = true;
Expand Down Expand Up @@ -1089,7 +1089,7 @@ Node convert_conv_node(
bool supported_shape = true;
for (int i = 2; i < rank; ++i) {
const Halide::Expr w_shape_expr = Halide::Internal::simplify(W.shape[i]);
const int64_t *dim = Halide::Internal::as_const_int(w_shape_expr);
auto dim = Halide::Internal::as_const_int(w_shape_expr);
if (!dim || *dim != 3) {
supported_shape = false;
break;
Expand Down Expand Up @@ -1912,7 +1912,7 @@ Node convert_split_node(
axis += inputs[0].shape.size();
}
Halide::Expr axis_dim = inputs[0].shape.at(axis);
const int64_t *axis_dim_size = Halide::Internal::as_const_int(axis_dim);
auto axis_dim_size = Halide::Internal::as_const_int(axis_dim);

if (user_splits.size() == 0) {
if (axis_dim_size && (*axis_dim_size % num_outputs != 0)) {
Expand Down Expand Up @@ -2041,19 +2041,19 @@ Node convert_slice_node(
Halide::Internal::simplify(starts_tensor.shape[0]);
const Halide::Expr ends_shape_expr =
Halide::Internal::simplify(ends_tensor.shape[0]);
const int64_t *starts_shape_dim_0 =
auto starts_shape_dim_0 =
Halide::Internal::as_const_int(starts_shape_expr);
const int64_t *ends_shape_dim_0 =
auto ends_shape_dim_0 =
Halide::Internal::as_const_int(ends_shape_expr);
if (starts_shape_dim_0 == nullptr && ends_shape_dim_0 == nullptr) {
if (!starts_shape_dim_0 && !ends_shape_dim_0) {
throw std::invalid_argument(
"Can't statisticaly infer slice dim size for slice node " +
node.name());
} else {
result.requirements.push_back(starts_shape_expr == ends_shape_expr);
}
num_slice_dims =
starts_shape_dim_0 != nullptr ? *starts_shape_dim_0 : *ends_shape_dim_0;
starts_shape_dim_0 ? *starts_shape_dim_0 : *ends_shape_dim_0;
if (num_slice_dims != *ends_shape_dim_0) {
throw std::invalid_argument(
"Starts and ends input tensor must have the same shape for "
Expand All @@ -2074,9 +2074,9 @@ Node convert_slice_node(
const Tensor &axes_tensor = inputs[3];
const Halide::Expr axes_shape_expr =
Halide::Internal::simplify(axes_tensor.shape[0]);
const int64_t *axes_shape_dim_0 =
auto axes_shape_dim_0 =
Halide::Internal::as_const_int(axes_shape_expr);
if (axes_shape_dim_0 != nullptr && *axes_shape_dim_0 != num_slice_dims) {
if (axes_shape_dim_0 && *axes_shape_dim_0 != num_slice_dims) {
throw std::invalid_argument(
"Axes tensor must have the same shape as starts and ends for slice "
"node " +
Expand All @@ -2099,9 +2099,9 @@ Node convert_slice_node(
const Tensor &steps_tensor = inputs[4];
const Halide::Expr steps_shape_expr =
Halide::Internal::simplify(steps_tensor.shape[0]);
const int64_t *steps_shape_dim_0 =
auto steps_shape_dim_0 =
Halide::Internal::as_const_int(steps_shape_expr);
if (steps_shape_dim_0 != nullptr && *steps_shape_dim_0 != num_slice_dims) {
if (steps_shape_dim_0 && *steps_shape_dim_0 != num_slice_dims) {
throw std::invalid_argument(
"Steps tensor must have the same shape as starts and ends for slice "
"node " +
Expand Down Expand Up @@ -2414,7 +2414,7 @@ Node convert_squeeze_node(
if (implicit) {
for (int i = 0; i < rank; ++i) {
const Halide::Expr dim_expr = Halide::Internal::simplify(input.shape[i]);
const int64_t *dim = Halide::Internal::as_const_int(dim_expr);
auto dim = Halide::Internal::as_const_int(dim_expr);
if (!dim) {
throw std::invalid_argument(
"Unknown dimension for input dim " + std::to_string(i) +
Expand Down Expand Up @@ -2471,7 +2471,7 @@ Node convert_constant_of_shape(
Tensor &out = result.outputs[0];
const Halide::Expr shape_expr =
Halide::Internal::simplify(inputs[0].shape[0]);
const int64_t *shape_dim_0 = Halide::Internal::as_const_int(shape_expr);
auto shape_dim_0 = Halide::Internal::as_const_int(shape_expr);
if (!shape_dim_0) {
throw std::invalid_argument(
"Can't infer rank statically for ConstantOfShape node " + node.name());
Expand Down Expand Up @@ -2744,7 +2744,7 @@ Node convert_expand_node(
const int in_rank = input.shape.size();
const Halide::Expr shape_expr =
Halide::Internal::simplify(expand_shape.shape[0]);
const int64_t *shape_dim_0 = Halide::Internal::as_const_int(shape_expr);
auto shape_dim_0 = Halide::Internal::as_const_int(shape_expr);
if (!shape_dim_0) {
throw std::invalid_argument(
"Can't infer rank statically for expand node " + node.name());
Expand Down Expand Up @@ -3098,7 +3098,7 @@ Node convert_reshape_node(
}
const Halide::Expr shape_expr =
Halide::Internal::simplify(new_shape.shape[0]);
const int64_t *num_dims = Halide::Internal::as_const_int(shape_expr);
auto num_dims = Halide::Internal::as_const_int(shape_expr);
if (!num_dims) {
throw std::domain_error(
"Couldn't statically infer the rank of the output of " + node.name());
Expand Down Expand Up @@ -3285,7 +3285,7 @@ Node convert_gru_node(
}

const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]);
const int64_t *dim = Halide::Internal::as_const_int(dim_expr);
auto dim = Halide::Internal::as_const_int(dim_expr);
if (!dim) {
throw std::domain_error("Unknown number of timesteps");
}
Expand Down Expand Up @@ -3683,7 +3683,7 @@ Node convert_rnn_node(
}

const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]);
const int64_t *dim = Halide::Internal::as_const_int(dim_expr);
auto dim = Halide::Internal::as_const_int(dim_expr);
if (!dim) {
throw std::domain_error("Unknown number of timesteps");
}
Expand Down Expand Up @@ -3925,7 +3925,7 @@ Node convert_lstm_node(
throw std::domain_error("Invalid rank");
}
const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]);
const int64_t *dim = Halide::Internal::as_const_int(dim_expr);
auto dim = Halide::Internal::as_const_int(dim_expr);
if (!dim) {
throw std::domain_error("Unknown number of timesteps");
}
Expand Down Expand Up @@ -4722,7 +4722,7 @@ Model convert_model(
throw std::domain_error("Invalid dimensions for output " + output.name());
}
for (int i = 0; i < args.size(); ++i) {
const int64_t *dim_value = Halide::Internal::as_const_int(dims[i]);
auto dim_value = Halide::Internal::as_const_int(dims[i]);
if (dim_value) {
int dim = static_cast<int>(*dim_value);
f.set_estimate(args[i], 0, dim);
Expand Down Expand Up @@ -4777,7 +4777,7 @@ static int64_t infer_dim_from_inputs(
replacement.min, replacement.extent, result);
}
result = Halide::Internal::simplify(result);
const int64_t *actual_dim = Halide::Internal::as_const_int(result);
auto actual_dim = Halide::Internal::as_const_int(result);
if (!actual_dim) {
throw std::invalid_argument(
"Couldn't statically infer one of the dimensions of output " + name);
Expand Down Expand Up @@ -4812,7 +4812,7 @@ void compute_output_shapes(
std::vector<int> &output_shape = (*output_shapes)[name];
const int rank = t.shape.size();
for (int i = 0; i < rank; ++i) {
const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]);
auto dim = Halide::Internal::as_const_int(t.shape[i]);
if (!dim) {
output_shape.push_back(
infer_dim_from_inputs(t.shape[i], replacements, name));
Expand All @@ -4833,7 +4833,7 @@ void extract_expected_input_shapes(
const Tensor &t = model.tensors.at(input_name);
std::vector<int> input_shape;
for (int i = 0; i < t.shape.size(); ++i) {
const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]);
auto dim = Halide::Internal::as_const_int(t.shape[i]);
if (!dim) {
// The dimension isn't fixed: use the estimated typical value instead if
// one was provided.
Expand Down
2 changes: 1 addition & 1 deletion src/AlignLoads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AlignLoads : public IRMutator {

Expr index = mutate(op->index);
const Ramp *ramp = index.as<Ramp>();
const int64_t *const_stride = ramp ? as_const_int(ramp->stride) : nullptr;
auto const_stride = ramp ? as_const_int(ramp->stride) : std::nullopt;
if (!ramp || !const_stride) {
// We can't handle indirect loads, or loads with
// non-constant strides.
Expand Down
2 changes: 1 addition & 1 deletion src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class BoundSmallAllocations : public IRMutator {
<< "Try storing on the heap or stack instead.";
}

const int64_t *size_ptr = bound.defined() ? as_const_int(bound) : nullptr;
auto size_ptr = bound.defined() ? as_const_int(bound) : std::nullopt;
int64_t size = size_ptr ? *size_ptr : 0;

if (size_ptr && size == 0 && !op->new_expr.defined()) {
Expand Down
12 changes: 6 additions & 6 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,18 @@ class Bounds : public IRVisitor {
// each other; however, if the bounds can be simplified to
// constants, they might fit regardless of types.
a = simplify(a);
const auto *umin = as_const_uint(a.min);
const auto *umax = as_const_uint(a.max);
auto umin = as_const_uint(a.min);
auto umax = as_const_uint(a.max);
if (umin && umax && to.can_represent(*umin) && to.can_represent(*umax)) {
could_overflow = false;
} else {
const auto *imin = as_const_int(a.min);
const auto *imax = as_const_int(a.max);
auto imin = as_const_int(a.min);
auto imax = as_const_int(a.max);
if (imin && imax && to.can_represent(*imin) && to.can_represent(*imax)) {
could_overflow = false;
} else {
const auto *fmin = as_const_float(a.min);
const auto *fmax = as_const_float(a.max);
auto fmin = as_const_float(a.min);
auto fmax = as_const_float(a.max);
if (fmin && fmax && to.can_represent(*fmin) && to.can_represent(*fmax)) {
could_overflow = false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ void CodeGen_ARM::visit(const Cast *op) {
if (expr_match(pattern.pattern, op, matches)) {
if (pattern.intrin.find("shift_right_narrow") != string::npos) {
// The shift_right_narrow patterns need the shift to be constant in [1, output_bits].
const uint64_t *const_b = as_const_uint(matches[1]);
auto const_b = as_const_uint(matches[1]);
if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) {
continue;
}
Expand Down Expand Up @@ -2015,7 +2015,7 @@ void CodeGen_ARM::visit(const Call *op) {
if (expr_match(pattern.pattern, op, matches)) {
if (pattern.intrin.find("shift_right_narrow") != string::npos) {
// The shift_right_narrow patterns need the shift to be constant in [1, output_bits].
const uint64_t *const_b = as_const_uint(matches[1]);
auto const_b = as_const_uint(matches[1]);
if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) {
continue;
}
Expand Down
16 changes: 7 additions & 9 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,9 +1325,8 @@ void CodeGen_C::visit(const Mul *op) {
}

void CodeGen_C::visit(const Div *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
visit_binop(op->type, op->a, make_const(op->a.type(), bits), ">>");
if (auto bits = is_const_power_of_two_integer(op->b)) {
visit_binop(op->type, op->a, make_const(op->a.type(), *bits), ">>");
} else if (op->type.is_int()) {
print_expr(lower_euclidean_div(op->a, op->b));
} else {
Expand All @@ -1336,9 +1335,8 @@ void CodeGen_C::visit(const Div *op) {
}

void CodeGen_C::visit(const Mod *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
visit_binop(op->type, op->a, make_const(op->a.type(), (1 << bits) - 1), "&");
if (auto bits = is_const_power_of_two_integer(op->b)) {
visit_binop(op->type, op->a, make_const(op->a.type(), ((uint64_t)1 << *bits) - 1), "&");
} else if (op->type.is_int()) {
print_expr(lower_euclidean_mod(op->a, op->b));
} else if (op->type.is_float()) {
Expand Down Expand Up @@ -1613,7 +1611,7 @@ void CodeGen_C::visit(const Call *op) {
} else if (op->is_intrinsic(Call::alloca)) {
internal_assert(op->args.size() == 1);
internal_assert(op->type.is_handle());
const int64_t *sz = as_const_int(op->args[0]);
auto sz = as_const_int(op->args[0]);
if (op->type == type_of<struct halide_buffer_t *>() &&
Call::as_intrinsic(op->args[0], {Call::size_of_halide_buffer_t})) {
stream << get_indent();
Expand Down Expand Up @@ -1752,8 +1750,8 @@ void CodeGen_C::visit(const Call *op) {
internal_assert(op->args.size() == 3);
std::string struct_instance = print_expr(op->args[0]);
std::string struct_prototype = print_expr(op->args[1]);
const int64_t *index = as_const_int(op->args[2]);
internal_assert(index != nullptr);
auto index = as_const_int(op->args[2]);
internal_assert(index);
rhs << "((decltype(" << struct_prototype << "))"
<< struct_instance << ")->f_" << *index;
} else if (op->is_intrinsic(Call::get_user_context)) {
Expand Down
12 changes: 5 additions & 7 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Min *op) {
}

void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Div *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
if (auto bits = is_const_power_of_two_integer(op->b)) {
ostringstream oss;
oss << print_expr(op->a) << " >> " << bits;
oss << print_expr(op->a) << " >> " << *bits;
print_assignment(op->type, oss.str());
} else if (op->type.is_int()) {
print_expr(lower_euclidean_div(op->a, op->b));
Expand All @@ -281,10 +280,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Div *op) {
}

void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Mod *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
if (auto bits = is_const_power_of_two_integer(op->b)) {
ostringstream oss;
oss << print_expr(op->a) << " & " << ((1 << bits) - 1);
oss << print_expr(op->a) << " & " << (((uint64_t)1 << *bits) - 1);
print_assignment(op->type, oss.str());
} else if (op->type.is_int()) {
print_expr(lower_euclidean_mod(op->a, op->b));
Expand Down Expand Up @@ -349,7 +347,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
if (op->is_intrinsic(Call::gpu_thread_barrier)) {
internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n";

const auto *fence_type_ptr = as_const_int(op->args[0]);
auto fence_type_ptr = as_const_int(op->args[0]);
internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n";
auto fence_type = *fence_type_ptr;

Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,8 +1932,8 @@ void CodeGen_Hexagon::visit(const Call *op) {
return;
} else if (op->is_intrinsic(Call::dynamic_shuffle)) {
internal_assert(op->args.size() == 4);
const int64_t *min_index = as_const_int(op->args[2]);
const int64_t *max_index = as_const_int(op->args[3]);
auto min_index = as_const_int(op->args[2]);
auto max_index = as_const_int(op->args[3]);
internal_assert(min_index && max_index);
Value *lut = codegen(op->args[0]);
Value *idx = codegen(op->args[1]);
Expand Down
Loading

0 comments on commit 435fb23

Please sign in to comment.