Skip to content

Commit

Permalink
[Util][NFC] OptimizeIntArithmetic: reduce calls to eraseState (#19130)
Browse files Browse the repository at this point in the history
This pass is causing long compilation times for llama3 405b (even when
cherry-picking llvm/llvm-project#115399). The
majority of the time is spent in this one pass. The compilation times
improve when calling `eraseState` only when ops are deleted. This is
similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and
`IntRangeOptimizations.cpp`. It appears this function loops over all
`LatticeAnchors` on each invocation to find the one to delete, causing
it to be slow. My (nonrigorous) experiment showed a decrease from 18 min
to 3 min compile time. My main concern here would be this affecting
correctness, as I don't know if this has unaccounted for side effects.

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Nov 14, 2024
1 parent d497571 commit 81dd4e6
Showing 1 changed file with 5 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-util-optimize-arithmetic"
#define DEBUG_TYPE "iree-util-optimize-int-arithmetic"
using llvm::dbgs;

using namespace mlir::dataflow;
Expand Down Expand Up @@ -289,43 +289,7 @@ class DataFlowListener : public RewriterBase::Listener {
void notifyOperationErased(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
flushValue(res);
}
void notifyOperationModified(Operation *op) override {
for (Value res : op->getResults())
flushValue(res);
}
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
for (Value res : op->getResults())
flushValue(res);
}

void flushValue(Value value) {
SmallVector<Value> worklist;
SmallVector<Value> process;
worklist.push_back(value);

while (!worklist.empty()) {
process.clear();
process.swap(worklist);
for (Value childValue : process) {
auto *state = s.lookupState<IntegerValueRangeLattice>(childValue);
if (!state) {
continue;
}
s.eraseState(childValue);
for (auto user : childValue.getUsers()) {
for (Value result : user->getResults()) {
worklist.push_back(result);
}
}
}
}
s.eraseState(res);
}

DataFlowSolver &s;
Expand Down Expand Up @@ -386,11 +350,14 @@ class OptimizeIntArithmeticPass

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (int i = 0;; ++i) {
LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n");
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc()) << "failed to perform int range analysis";
return signalPassFailure();
}

LLVM_DEBUG(
dbgs() << " * Finished Running Solver -- Applying Patterns\n");
bool changed = false;
if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config,
&changed))) {
Expand Down

0 comments on commit 81dd4e6

Please sign in to comment.