Skip to content

Commit

Permalink
Fix Xtensa code for as_const API change (#8502)
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson authored Dec 9, 2024
1 parent 25ebbb3 commit e995c89
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
20 changes: 9 additions & 11 deletions src/CodeGen_Xtensa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,8 @@ void CodeGen_Xtensa::visit(const IntImm *op) {
}
}
void CodeGen_Xtensa::visit(const Mul *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
print_expr(Call::make(op->type, Call::shift_left, {op->a, Expr(bits)}, Call::PureIntrinsic));
if (auto bits = is_const_power_of_two_integer(op->b)) {
print_expr(Call::make(op->type, Call::shift_left, {op->a, Expr(*bits)}, Call::PureIntrinsic));
} else {
if (is_native_xtensa_vector<int16_t>(op->type)) {
string sa = print_expr(op->a);
Expand Down Expand Up @@ -462,9 +461,8 @@ string CodeGen_Xtensa::print_xtensa_call(const Call *op) {
}

void CodeGen_Xtensa::visit(const Div *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
print_expr(Call::make(op->type, Call::shift_right, {op->a, Expr(bits)}, Call::PureIntrinsic));
if (auto bits = is_const_power_of_two_integer(op->b)) {
print_expr(Call::make(op->type, Call::shift_right, {op->a, Expr(*bits)}, Call::PureIntrinsic));
} else if (is_native_xtensa_vector<float16_t>(op->type)) {
ostringstream rhs;
rhs << "IVP_DIVNXF16(" << print_expr(op->a) << ", " << print_expr(op->b) << ")";
Expand Down Expand Up @@ -495,11 +493,11 @@ void CodeGen_Xtensa::visit(const Div *op) {
}

void CodeGen_Xtensa::visit(const Mod *op) {
int bits;
if (is_native_vector_type(op->type) && is_const_power_of_two_integer(op->b, &bits)) {
std::optional<int> bits;
if (is_native_vector_type(op->type) && (bits = is_const_power_of_two_integer(op->b))) {
print_expr(op->a &
Broadcast::make(
Cast::make(op->type.with_lanes(1), Expr((1 << bits) - 1)), op->type.lanes()));
Cast::make(op->type.with_lanes(1), Expr((1 << *bits) - 1)), op->type.lanes()));
} else if (is_native_xtensa_vector<int32_t>(op->type)) {
string sa = print_expr(op->a);
string sb = print_expr(op->b);
Expand Down Expand Up @@ -1069,7 +1067,7 @@ void CodeGen_Xtensa::visit(const Call *op) {
if (op->is_intrinsic(Call::shift_left)) {
internal_assert(op->args.size() == 2);
string a0 = print_expr(op->args[0]);
const int64_t *bits = as_const_int(op->args[1]);
auto bits = as_const_int(op->args[1]);
if (is_native_xtensa_vector<uint8_t>(op->type) && bits) {
rhs << "IVP_SLLI2NX8U(" << a0 << ", " << std::to_string(*bits) << ")";
} else if (is_native_xtensa_vector<int8_t>(op->type) && bits) {
Expand Down Expand Up @@ -1115,7 +1113,7 @@ void CodeGen_Xtensa::visit(const Call *op) {
} else if (op->is_intrinsic(Call::shift_right)) {
internal_assert(op->args.size() == 2);
string a0 = print_expr(op->args[0]);
const int64_t *bits = as_const_int(op->args[1]);
auto bits = as_const_int(op->args[1]);
if (is_native_xtensa_vector<uint8_t>(op->type) && bits) {
rhs << "IVP_SRLI2NX8U(" << a0 << ", " << std::to_string(*bits) << ")";
} else if (is_native_xtensa_vector<int8_t>(op->type) && bits) {
Expand Down
7 changes: 3 additions & 4 deletions src/XtensaOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ bool process_match_flags(vector<Expr> &matches, int flags) {
// This flag is mainly to capture shifts. When the operand of a div or
// mul is a power of 2, we can use a shift instead.
if (flags & (Pattern::ExactLog2Op1 << (i - Pattern::BeginExactLog2Op))) {
int pow;
if (is_const_power_of_two_integer(matches[i], &pow)) {
matches[i] = cast(matches[i].type().with_lanes(1), pow);
if (auto pow = is_const_power_of_two_integer(matches[i])) {
matches[i] = cast(matches[i].type().with_lanes(1), *pow);
} else {
return false;
}
Expand Down Expand Up @@ -1010,7 +1009,7 @@ class MatchXtensaPatterns : public IRGraphMutator {
}
} else if (op->is_intrinsic(Call::widening_shift_left)) {
// Replace widening left shift with multiplication.
const uint64_t *c = as_const_uint(op->args[1]);
auto c = as_const_uint(op->args[1]);
if (c && op->args[1].type().can_represent((uint64_t)1 << *c)) {
if (op->args[0].type().is_int() && (*c < (uint64_t)op->args[0].type().bits() - 1)) {
return mutate(widening_mul(op->args[0], bc(IntImm::make(op->args[1].type().with_code(halide_type_int).with_lanes(1), (int64_t)1 << *c), op->args[1].type().lanes())));
Expand Down

0 comments on commit e995c89

Please sign in to comment.