Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add Vector-dialect interleave-to-shuffle pattern, enable in VectorToSPIRV #91800

Merged
merged 3 commits into from
May 13, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented May 10, 2024

Context: iree-org/iree#17346.

Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131:

iree-org/iree#17359

This is added to VectorToSPIRV because SPIRV doesn't currently handle vector.interleave (see motivating context above).

This is limited to 1D, non-scalable vectors.

Copy link

github-actions bot commented May 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

Context: iree-org/iree#17346


Full diff: https://github.com/llvm/llvm-project/pull/91800.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+14)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+3)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp (+25)
  • (added) mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.interleave_to_shuffle",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that 1D vector interleave operations should be rewritten as
+    vector shuffle operations.
+
+    This is motivated by some current codegen backends not handling vector
+    interleave operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.rewrite_narrow_types",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
                                               int64_t targetRank = 1,
                                               PatternBenefit benefit = 1);
 
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+                                               PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 868a3521e7a0f..c2dd37f481466 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
+
+  // Need this until vector.interleave is handled.
+  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
   vector::populateVectorInterleaveLoweringPatterns(patterns);
 }
 
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..35557e05bb45e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
 
 #define DEBUG_TYPE "vector-interleave-lowering"
 
@@ -77,9 +78,33 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit){};
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = op.getSourceVectorType();
+    if (sourceType.getRank() != 1) {
+      return failure();
+    }
+    rewriter.replaceOpWithNewOp<ShuffleOp>(
+        op, op.getLhs(), op.getRhs(),
+        llvm::map_to_vector(llvm::seq<int64_t>(2 * sourceType.getNumElements()),
+                            [](int64_t i) { return i / 2; }));
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
   patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..0b039ba78289c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+  %0 = vector.interleave %a, %b : vector<7xi16>
+  return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.interleave_to_shuffle
+    } : !transform.any_op
+    transform.yield
+  }
+}

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you prefix the PR subject line with [mlir][vector]?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks! (I've left a couple of small suggestions)

Context: iree-org/iree#17346.

Note, for SVE we will need to use vector.interleave rather than vector.shuffle :) Just to make folks aware, I appreciate that the IREE issue is only affecting SPIR-V.

@@ -77,9 +78,33 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};

class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, vector.shuffle doesn't support scalable vectors:

You can just disable this pattern when sourceType.isScalable() returns true. It would be nice to have a test to verity that as well.

Also, would you mind adding a comment with IR BEFORE and AFTER (similar to what's been added to UnrollInterleaveOp)? I know that right now this is a small file and the pattern is self-explanatory, but these files tend to grow a lot at which point comments can be really helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, vector.shuffle doesn't support scalable vectors:

* [[vector][mlir] Restrict vector.shuffle to fixed-width vectors #88733](https://github.com/llvm/llvm-project/pull/88733)

You can just disable this pattern when sourceType.isScalable() returns true. It would be nice to have a test to verity that as well.

Thanks, done!

Also, would you mind adding a comment with IR BEFORE and AFTER (similar to what's been added to UnrollInterleaveOp)? I know that right now this is a small file and the pattern is self-explanatory, but these files tend to grow a lot at which point comments can be really helpful.

Good idea. Done!

@bjacob bjacob changed the title Add Vector-dialect interleave-to-shuffle pattern [mlir][vectorAdd Vector-dialect interleave-to-shuffle pattern May 13, 2024
@bjacob bjacob changed the title [mlir][vectorAdd Vector-dialect interleave-to-shuffle pattern [mlir][vector] Add Vector-dialect interleave-to-shuffle pattern, enable in VectorToSPIRV May 13, 2024
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking to double-check the semantics

mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir Outdated Show resolved Hide resolved
@bjacob bjacob requested a review from kuhar May 13, 2024 15:02
@bjacob bjacob merged commit cf40c93 into llvm:main May 13, 2024
3 of 4 checks passed
@bjacob
Copy link
Contributor Author

bjacob commented May 13, 2024

Shoot, buildbot failure, SPIRV linker error. Probably missing dep. Reverting.

bjacob added a commit that referenced this pull request May 13, 2024
…le in VectorToSPIRV (#92012)

This is the second attempt at merging #91800, which bounced due to a
linker error apparently caused by an undeclared dependency.
`MLIRVectorToSPIRV` needed to depend on `MLIRVectorTransforms`. In fact
that was a preexisting issue already flagged by the tool in
https://discourse.llvm.org/t/ninja-can-now-check-for-missing-cmake-dependencies-on-generated-files/74344.

Context: iree-org/iree#17346.

Test IREE integrate showing it's fixing the problem it's intended to
fix, i.e. it allows IREE to drop its local revert of
#89131:

iree-org/iree#17359

This is added to VectorToSPIRV because SPIRV doesn't currently handle
`vector.interleave` (see motivating context above).

This is limited to 1D, non-scalable vectors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants