Skip to content

Commit

Permalink
fix(compiler): Convert scf.for to scf.parallel only if parallel attri…
Browse files Browse the repository at this point in the history
…bute is true

The pattern converting `scf.for` operations to `scf.parallel`
operations from `lib/Transforms/ForLoopToParallel.cpp` contains an
assertion that ensures that the source operation does not have any
iteration arguments in order to keep the conversion as simple as
possible.

However, if the attribute `parallel` of the source operation is
`false`, the operation is replaced with an identical clone and the
conversion could be treated as a no-op.

This change modifies the pattern, such that it simply fails if
`parallel` is `false`, making the check for the absence of iteration
arguments unnecessary and avoiding unnecessary bailouts by the
compiler.
  • Loading branch information
andidr committed Jul 29, 2024
1 parent 0379f9a commit 3cd3dff
Showing 1 changed file with 16 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,26 @@ class ForOpPattern : public mlir::OpRewritePattern<mlir::scf::ForOp> {
matchAndRewrite(mlir::scf::ForOp forOp,
mlir::PatternRewriter &rewriter) const override {
auto attr = forOp->getAttrOfType<mlir::BoolAttr>("parallel");
if (attr == nullptr) {

if (!attr || !attr.getValue()) {
return mlir::failure();
}

assert(forOp.getRegionIterArgs().size() == 0 &&
"unexpecting iter args when loops are bufferized");
if (attr.getValue()) {
rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
forOp, mlir::ValueRange{forOp.getLowerBound()},
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(),
std::nullopt,
[&](mlir::OpBuilder &builder, mlir::Location location,
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
mlir::IRMapping map;
map.map(forOp.getInductionVar(), indVar.front());
for (auto &op : forOp.getRegion().front()) {
auto newOp = builder.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
});
} else {
rewriter.replaceOpWithNewOp<mlir::scf::ForOp>(
forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
std::nullopt,
[&](mlir::OpBuilder &builder, mlir::Location location,
mlir::Value indVar, mlir::ValueRange iterArgs) {
mlir::IRMapping map;
map.map(forOp.getInductionVar(), indVar);
for (auto &op : forOp.getRegion().front()) {
auto newOp = builder.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
});
}

rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
forOp, mlir::ValueRange{forOp.getLowerBound()},
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), std::nullopt,
[&](mlir::OpBuilder &builder, mlir::Location location,
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
mlir::IRMapping map;
map.map(forOp.getInductionVar(), indVar.front());
for (auto &op : forOp.getRegion().front()) {
auto newOp = builder.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
});

return mlir::success();
}
Expand Down

0 comments on commit 3cd3dff

Please sign in to comment.