Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some large stack frames off recursive paths. #8507

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ ConstantInterval bounds_helper(const Expr &e,
ScopedBinding bind(scope, op->name, recurse(op->value));
return recurse(op->body);
} else if (const Call *op = e.as<Call>()) {
ConstantInterval result;
if (op->is_intrinsic(Call::abs)) {
return abs(recurse(op->args[0]));
} else if (op->is_intrinsic(Call::absd)) {
Expand Down
20 changes: 16 additions & 4 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,12 @@ class FindIntrinsics : public IRMutator {
return IRMutator::visit(op);
}

Expr value = mutate(op->value);
return visit_cast(op, mutate(op->value));
}

// Isolated in its own function to keep the (large) stack frame off the
// recursive path.
HALIDE_NEVER_INLINE Expr visit_cast(const Cast *op, Expr &&value) {
// This mutator can generate redundant casts. We can't use the simplifier because it
// undoes some of the intrinsic lowering here, and it causes some problems due to
// factoring (instead of distributing) constants.
Expand Down Expand Up @@ -550,6 +554,7 @@ class FindIntrinsics : public IRMutator {
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y));

if (
// Saturating patterns
rewrite(max(min(widening_add(x, y), upper), lower),
Expand All @@ -566,11 +571,11 @@ class FindIntrinsics : public IRMutator {

rewrite(min(widening_add(x, y), upper),
saturating_add(x, y),
op->type.is_uint() && is_x_same_uint) ||
is_x_same_uint) ||

rewrite(max(widening_sub(x, y), lower),
saturating_sub(x, y),
op->type.is_uint() && is_x_same_uint) ||
is_x_same_uint) ||

// Saturating narrow patterns.
rewrite(max(min(x, upper), lower),
Expand Down Expand Up @@ -721,10 +726,17 @@ class FindIntrinsics : public IRMutator {
op = mutated.as<Call>();
if (!op) {
return mutated;
} else {
return visit_call(op);
}
}

// Isolated in its own function to keep the (large) stack frame off the
// recursive path. The Call node has already been mutated by the base class
// visitor.
HALIDE_NEVER_INLINE Expr visit_call(const Call *op) {
auto rewrite = IRMatcher::rewriter(op, op->type);
if (rewrite(intrin(Call::abs, widening_sub(x, y)), cast(op->type, intrin(Call::absd, x, y))) ||
if (rewrite(abs(widening_sub(x, y)), cast(op->type, absd(x, y))) ||
false) {
return rewrite.result;
}
Expand Down
512 changes: 264 additions & 248 deletions src/HexagonOptimize.cpp

Large diffs are not rendered by default.

146 changes: 89 additions & 57 deletions src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1325,15 +1325,30 @@ constexpr int const_min(int a, int b) {
return a < b ? a : b;
}

template<typename... Args>
template<Call::IntrinsicOp intrin>
struct OptionalIntrinType {
bool check(const Type &) const {
return true;
}
};

template<>
struct OptionalIntrinType<Call::saturating_cast> {
halide_type_t type;
bool check(const Type &t) const {
return t == Type(type);
}
};

template<Call::IntrinsicOp intrin, typename... Args>
struct Intrin {
struct pattern_tag {};
Call::IntrinsicOp intrin;
std::tuple<Args...> args;
// The type of the output of the intrinsic node.
// Only necessary in cases where it can't be inferred
// from the input types (e.g. saturating_cast).
Type optional_type_hint;

OptionalIntrinType<intrin> optional_type_hint;

static constexpr uint32_t binds = bitwise_or_reduce((bindings<Args>::mask)...);

Expand Down Expand Up @@ -1362,7 +1377,7 @@ struct Intrin {
}
const Call &c = (const Call &)e;
return (c.is_intrinsic(intrin) &&
((optional_type_hint == Type()) || optional_type_hint == e.type) &&
optional_type_hint.check(e.type) &&
match_args<0, bound>(0, c, state));
}

Expand Down Expand Up @@ -1394,8 +1409,8 @@ struct Intrin {
return likely_if_innermost(std::move(arg0));
} else if (intrin == Call::abs) {
return abs(std::move(arg0));
} else if (intrin == Call::saturating_cast) {
return saturating_cast(optional_type_hint, std::move(arg0));
} else if constexpr (intrin == Call::saturating_cast) {
return saturating_cast(optional_type_hint.type, std::move(arg0));
}

Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
Expand Down Expand Up @@ -1489,98 +1504,113 @@ struct Intrin {
}

HALIDE_ALWAYS_INLINE
Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
: intrin(intrin), args(args...) {
Intrin(Args... args) noexcept
: args(args...) {
}
};

template<typename... Args>
std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
s << op.intrin << "(";
template<Call::IntrinsicOp intrin, typename... Args>
std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
s << intrin << "(";
op.print_args(s);
s << ")";
return s;
}

template<typename... Args>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
return {intrinsic_op, pattern_arg(args)...};
}

template<typename A, typename B>
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_add, pattern_arg(a), pattern_arg(b)};
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_mul, pattern_arg(a), pattern_arg(b)};
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widen_right_sub, pattern_arg(a), pattern_arg(b)};
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}

template<typename A, typename B>
auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A>
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
Intrin<decltype(pattern_arg(a))> p = {Call::saturating_cast, pattern_arg(a)};
p.optional_type_hint = t;
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
Intrin<Call::saturating_cast, decltype(pattern_arg(a))> p = {pattern_arg(a)};
p.optional_type_hint.type = t;
return p;
}
template<typename A, typename B>
auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)};
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)};
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)};
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B, typename C>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}
template<typename A, typename B, typename C>
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}

template<typename A>
auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A, typename B>
auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}

template<typename A>
auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A>
auto likely_if_innermost(A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
return {pattern_arg(a)};
}

template<typename A>
Expand Down Expand Up @@ -2425,7 +2455,8 @@ template<typename A>
struct IsInt {
struct pattern_tag {};
A a;
int bits, lanes;
uint8_t bits;
uint16_t lanes;

constexpr static uint32_t binds = bindings<A>::mask;

Expand All @@ -2448,7 +2479,7 @@ struct IsInt {
};

template<typename A>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits, lanes};
}
Expand All @@ -2470,7 +2501,8 @@ template<typename A>
struct IsUInt {
struct pattern_tag {};
A a;
int bits, lanes;
uint8_t bits;
uint16_t lanes;

constexpr static uint32_t binds = bindings<A>::mask;

Expand All @@ -2493,7 +2525,7 @@ struct IsUInt {
};

template<typename A>
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits, lanes};
}
Expand Down
8 changes: 4 additions & 4 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
rewrite(max(select(x, w, max(z, y)), z), max(select(x, w, y), z)) ||
rewrite(max(select(x, w, max(z, y)), y), max(select(x, w, z), y)) ||

rewrite(max(intrin(Call::likely, x), x), b) ||
rewrite(max(x, intrin(Call::likely, x)), a) ||
rewrite(max(intrin(Call::likely_if_innermost, x), x), b) ||
rewrite(max(x, intrin(Call::likely_if_innermost, x)), a) ||
rewrite(max(likely(x), x), b) ||
rewrite(max(x, likely(x)), a) ||
rewrite(max(likely_if_innermost(x), x), b) ||
rewrite(max(x, likely_if_innermost(x)), a) ||

(no_overflow(op->type) &&
(rewrite(max(ramp(x, y, lanes), broadcast(z, lanes)), a,
Expand Down
8 changes: 4 additions & 4 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
rewrite(min(select(x, w, min(z, y)), y), min(select(x, w, z), y)) ||
rewrite(min(select(x, w, min(z, y)), z), min(select(x, w, y), z)) ||

rewrite(min(intrin(Call::likely, x), x), b) ||
rewrite(min(x, intrin(Call::likely, x)), a) ||
rewrite(min(intrin(Call::likely_if_innermost, x), x), b) ||
rewrite(min(x, intrin(Call::likely_if_innermost, x)), a) ||
rewrite(min(likely(x), x), b) ||
rewrite(min(x, likely(x)), a) ||
rewrite(min(likely_if_innermost(x), x), b) ||
rewrite(min(x, likely_if_innermost(x)), a) ||

(no_overflow(op->type) &&
(rewrite(min(ramp(x, y, lanes), broadcast(z, lanes)), a,
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_Not.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Expr Simplify::visit(const Not *op, ExprInfo *info) {
}

if (rewrite(!broadcast(x, c0), broadcast(!x, c0)) ||
rewrite(!intrin(Call::likely, x), intrin(Call::likely, !x)) ||
rewrite(!intrin(Call::likely_if_innermost, x), intrin(Call::likely_if_innermost, !x)) ||
rewrite(!likely(x), likely(!x)) ||
rewrite(!likely_if_innermost(x), likely_if_innermost(!x)) ||
rewrite(!(!x && y), x || !y) ||
rewrite(!(!x || y), x && !y) ||
rewrite(!(x && !y), !x || y) ||
Expand Down
16 changes: 8 additions & 8 deletions src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(select(IRMatcher::intrin(Call::likely, true), x, y), x) ||
rewrite(select(IRMatcher::intrin(Call::likely, false), x, y), y) ||
rewrite(select(IRMatcher::intrin(Call::likely_if_innermost, true), x, y), x) ||
rewrite(select(IRMatcher::intrin(Call::likely_if_innermost, false), x, y), y) ||
(rewrite(select(IRMatcher::likely(true), x, y), x) ||
rewrite(select(IRMatcher::likely(false), x, y), y) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) ||
rewrite(select(1, x, y), x) ||
rewrite(select(0, x, y), y) ||
rewrite(select(x, y, y), y) ||
rewrite(select(x, intrin(Call::likely, y), y), false_value) ||
rewrite(select(x, y, intrin(Call::likely, y)), true_value) ||
rewrite(select(x, intrin(Call::likely_if_innermost, y), y), false_value) ||
rewrite(select(x, y, intrin(Call::likely_if_innermost, y)), true_value) ||
rewrite(select(x, likely(y), y), false_value) ||
rewrite(select(x, y, likely(y)), true_value) ||
rewrite(select(x, likely_if_innermost(y), y), false_value) ||
rewrite(select(x, y, likely_if_innermost(y)), true_value) ||
false)) {
return rewrite.result;
}
Expand Down