Skip to content

Commit

Permalink
Use a consistent idiom for visit_let (#8540)
Browse files Browse the repository at this point in the history
visit_let in the codebase uses a wide variety of template names,
argument names, and ways of getting the body type. This just picks one
and uses it consistently. No functional changes.
  • Loading branch information
abadams authored Dec 27, 2024
1 parent c2d5ea3 commit 097aee9
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 120 deletions.
18 changes: 9 additions & 9 deletions src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,40 @@ class BoundSmallAllocations : public IRMutator {
// Track constant bounds
Scope<Interval> scope;

template<typename T, typename Body>
Body visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// Visit an entire chain of lets in a single method to conserve stack space.
struct Frame {
const T *op;
const LetOrLetStmt *op;
ScopedBinding<Interval> binding;
Frame(const T *op, Scope<Interval> &scope)
Frame(const LetOrLetStmt *op, Scope<Interval> &scope)
: op(op),
binding(scope, op->name, find_constant_bounds(op->value, scope)) {
}
};
std::vector<Frame> frames;
Body result;
decltype(op->body) result;

do {
result = op->body;
frames.emplace_back(op, scope);
} while ((op = result.template as<T>()));
} while ((op = result.template as<LetOrLetStmt>()));

result = mutate(result);

for (const auto &frame : reverse_view(frames)) {
result = T::make(frame.op->name, frame.op->value, result);
result = LetOrLetStmt::make(frame.op->name, frame.op->value, result);
}

return result;
}

Stmt visit(const LetStmt *op) override {
return visit_let<LetStmt, Stmt>(op);
return visit_let(op);
}

Expr visit(const Let *op) override {
return visit_let<Let, Expr>(op);
return visit_let(op);
}

bool in_thread_loop = false;
Expand Down
18 changes: 9 additions & 9 deletions src/ClampUnsafeAccesses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ struct ClampUnsafeAccesses : IRMutator {
}

Expr visit(const Let *let) override {
return visit_let<Let, Expr>(let);
return visit_let(let);
}

Stmt visit(const LetStmt *let) override {
return visit_let<LetStmt, Stmt>(let);
return visit_let(let);
}

Expr visit(const Variable *var) override {
Expand Down Expand Up @@ -80,15 +80,15 @@ struct ClampUnsafeAccesses : IRMutator {
}

private:
template<typename L, typename Body>
Body visit_let(const L *let) {
ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false);
Body body = mutate(let->body);
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
ScopedBinding<bool> binding(let_var_inside_indexing, op->name, false);
auto body = mutate(op->body);

ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name));
Expr value = mutate(let->value);
ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(op->name));
Expr value = mutate(op->value);

return L::make(let->name, std::move(value), std::move(body));
return LetOrLetStmt::make(op->name, std::move(value), std::move(body));
}

bool bounds_smaller_than_type(const Interval &bounds, Type type) {
Expand Down
28 changes: 14 additions & 14 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,44 +465,44 @@ class Interleaver : public IRMutator {
return Shuffle::make_interleave(exprs);
}

template<typename T, typename Body>
Body visit_lets(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// Visit an entire chain of lets in a single method to conserve stack space.
struct Frame {
const T *op;
const LetOrLetStmt *op;
Expr new_value;
ScopedBinding<> binding;
Frame(const T *op, Expr v, Scope<void> &scope)
Frame(const LetOrLetStmt *op, Expr v, Scope<void> &scope)
: op(op),
new_value(std::move(v)),
binding(new_value.type().is_vector(), scope, op->name) {
}
};
std::vector<Frame> frames;
Body result;
decltype(op->body) result;

do {
result = op->body;
frames.emplace_back(op, mutate(op->value), vector_lets);
} while ((op = result.template as<T>()));
} while ((op = result.template as<LetOrLetStmt>()));

result = mutate(result);

for (const auto &frame : reverse_view(frames)) {
Expr value = std::move(frame.new_value);

result = T::make(frame.op->name, value, result);
result = LetOrLetStmt::make(frame.op->name, value, result);

// For vector lets, we may additionally need a let defining the even and odd lanes only
if (value.type().is_vector()) {
if (value.type().lanes() % 2 == 0) {
result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
}
if (value.type().lanes() % 3 == 0) {
result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
}
}
}
Expand All @@ -511,11 +511,11 @@ class Interleaver : public IRMutator {
}

Expr visit(const Let *op) override {
return visit_lets<Let, Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_lets<LetStmt, Stmt>(op);
return visit_let(op);
}

Expr visit(const Ramp *op) override {
Expand Down
10 changes: 5 additions & 5 deletions src/EliminateBoolVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ class EliminateBoolVectors : public IRMutator {
return expr;
}

template<typename NodeType, typename LetType>
NodeType visit_let(const LetType *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
Expr value = mutate(op->value);

// We changed the type of the let, we need to replace the
Expand All @@ -305,17 +305,17 @@ class EliminateBoolVectors : public IRMutator {
}

if (!value.same_as(op->value) || !body.same_as(op->body)) {
return LetType::make(op->name, value, body);
return LetOrLetStmt::make(op->name, value, body);
} else {
return op;
}
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}
Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}
};

Expand Down
10 changes: 5 additions & 5 deletions src/FuseGPUThreadLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,9 @@ class ExtractRegisterAllocations : public IRMutator {
op->param, mutate(op->predicate), op->alignment);
}

template<typename ExprOrStmt, typename LetOrLetStmt>
ExprOrStmt visit_let(const LetOrLetStmt *op) {
ExprOrStmt body = op->body;
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
auto body = op->body;

body = mutate(op->body);
Expr value = mutate(op->value);
Expand All @@ -1178,11 +1178,11 @@ class ExtractRegisterAllocations : public IRMutator {
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Scope<int> register_allocations;
Expand Down
42 changes: 21 additions & 21 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator {
}
}

template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
NodeType node = IRMutator::visit(op);
auto node = IRMutator::visit(op);
bounds.pop(op->name);
return node;
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Expr visit(const Div *op) override {
Expand Down Expand Up @@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator {
}
}

template<typename NodeType, typename LetType>
NodeType visit_let(const LetType *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {

Expr value = mutate(op->value);
string deinterleaved_name;
NodeType body;
decltype(op->body) body;
// Other code in this mutator needs to be able to tell the
// difference between a Let that yields a deinterleave, and a
// let that has a removable deinterleave. Lets that can
Expand Down Expand Up @@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator {
return op;
} else if (body.same_as(op->body)) {
// If the body didn't change, we must not have used the deinterleaved value.
return LetType::make(op->name, value, body);
return LetOrLetStmt::make(op->name, value, body);
} else {
// We need to rewrap the body with new lets.
NodeType result = body;
auto result = body;
bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name);
bool interleaved_used = stmt_or_expr_uses_var(result, op->name);
if (deinterleaved_used && interleaved_used) {
Expand All @@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator {
interleaved = native_interleave(interleaved);
}

result = LetType::make(op->name, interleaved, result);
return LetType::make(deinterleaved_name, deinterleaved, result);
result = LetOrLetStmt::make(op->name, interleaved, result);
return LetOrLetStmt::make(deinterleaved_name, deinterleaved, result);
} else if (deinterleaved_used) {
// Only the deinterleaved value is used, we can eliminate the interleave.
return LetType::make(deinterleaved_name, remove_interleave(value), result);
return LetOrLetStmt::make(deinterleaved_name, remove_interleave(value), result);
} else if (interleaved_used) {
// Only the original value is used, regenerate the let.
return LetType::make(op->name, value, result);
return LetOrLetStmt::make(op->name, value, result);
} else {
// The let must have been dead.
internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
Expand All @@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator {
}

Expr visit(const Let *op) override {
Expr expr = visit_let<Expr>(op);
Expr expr = visit_let(op);

// Lift interleaves out of Let expression bodies.
const Let *let = expr.as<Let>();
Expand All @@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator {
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Expr visit(const Cast *op) override {
Expand Down Expand Up @@ -2047,25 +2047,25 @@ class ScatterGatherGenerator : public IRMutator {
return IRMutator::visit(op);
}

template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
// We only care about vector lets.
if (op->value.type().is_vector()) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
}
NodeType node = IRMutator::visit(op);
auto node = IRMutator::visit(op);
if (op->value.type().is_vector()) {
bounds.pop(op->name);
}
return node;
}

Expr visit(const Let *op) override {
return visit_let<Expr>(op);
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
return visit_let(op);
}

Stmt visit(const Allocate *op) override {
Expand Down
Loading

0 comments on commit 097aee9

Please sign in to comment.