-
Notifications
You must be signed in to change notification settings - Fork 12.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add Vector-dialect interleave-to-shuffle pattern, enable in VectorToSPIRV #91800
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
2c8e80a
to
b55f155
Compare
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesContext: iree-org/iree#17346 Full diff: https://github.com/llvm/llvm-project/pull/91800.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.interleave_to_shuffle",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that 1D vector interleave operations should be rewritten as
+ vector shuffle operations.
+
+ This is motivated by some current codegen backends not handling vector
+ interleave operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 868a3521e7a0f..c2dd37f481466 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
+
+ // Need this until vector.interleave is handled.
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
vector::populateVectorInterleaveLoweringPatterns(patterns);
}
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..35557e05bb45e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-interleave-lowering"
@@ -77,9 +78,33 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit){};
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ if (sourceType.getRank() != 1) {
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<ShuffleOp>(
+ op, op.getLhs(), op.getRhs(),
+ llvm::map_to_vector(llvm::seq<int64_t>(2 * sourceType.getNumElements()),
+ [](int64_t i) { return i / 2; }));
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..0b039ba78289c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+ %0 = vector.interleave %a, %b : vector<7xi16>
+ return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.interleave_to_shuffle
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Could you prefix the PR subject line with [mlir][vector]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks! (I've left a couple of small suggestions)
Context: iree-org/iree#17346.
Note, for SVE we will need to use vector.interleave
rather than vector.shuffle
:) Just to make folks aware, I appreciate that the IREE issue is only affecting SPIR-V.
@@ -77,9 +78,33 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> { | |||
int64_t targetRank = 1; | |||
}; | |||
|
|||
class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, vector.shuffle
doesn't support scalable vectors:
You can just disable this pattern when sourceType.isScalable()
returns true
. It would be nice to have a test to verity that as well.
Also, would you mind adding a comment with IR BEFORE
and AFTER
(similar to what's been added to UnrollInterleaveOp
)? I know that right now this is a small file and the pattern is self-explanatory, but these files tend to grow a lot at which point comments can be really helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note,
vector.shuffle
doesn't support scalable vectors:* [[vector][mlir] Restrict vector.shuffle to fixed-width vectors #88733](https://github.com/llvm/llvm-project/pull/88733)
You can just disable this pattern when
sourceType.isScalable()
returnstrue
. It would be nice to have a test to verity that as well.
Thanks, done!
Also, would you mind adding a comment with IR
BEFORE
andAFTER
(similar to what's been added toUnrollInterleaveOp
)? I know that right now this is a small file and the pattern is self-explanatory, but these files tend to grow a lot at which point comments can be really helpful.
Good idea. Done!
b55f155
to
991ec83
Compare
991ec83
to
2366844
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking to double-check the semantics
Shoot, buildbot failure, SPIRV linker error. Probably missing dep. Reverting. |
…rn, enable in VectorToSPIRV" (#92006) Reverts #91800 Reason: https://lab.llvm.org/buildbot/#/builders/268/builds/13935
…le in VectorToSPIRV (#92012) This is the second attempt at merging #91800, which bounced due to a linker error apparently caused by an undeclared dependency. `MLIRVectorToSPIRV` needed to depend on `MLIRVectorTransforms`. In fact that was a preexisting issue already flagged by the tool in https://discourse.llvm.org/t/ninja-can-now-check-for-missing-cmake-dependencies-on-generated-files/74344. Context: iree-org/iree#17346. Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131: iree-org/iree#17359 This is added to VectorToSPIRV because SPIRV doesn't currently handle `vector.interleave` (see motivating context above). This is limited to 1D, non-scalable vectors.
Context: iree-org/iree#17346.
Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131:
iree-org/iree#17359
This is added to VectorToSPIRV because SPIRV doesn't currently handle
vector.interleave
(see motivating context above).This is limited to 1D, non-scalable vectors.