Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LAA: don't version single-iteration loops #96927

Closed
wants to merge 2 commits into from
Closed

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Jun 27, 2024

Single-iteration loops are a special case that are likely to be unprofitable to version.

Fixes #96656.

-- 8< --
Based on #97075.

@artagnon artagnon requested review from nikic, fhahn and preames June 27, 2024 16:31
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculating the stride currently inserts a Stride == 1 predicate, which is equivalent to asserting that the that the loop executes atleast once.

Uh, how are these equivalent (or even related at all)?

However, when the backedge-taken-count is known-non-negative, speculating the stride unnecessarily versions the loop. Avoid this.

The BE count is an unsigned quantity, so I don't get how "known-non-negative" can be meaningful for it.

@artagnon
Copy link
Contributor Author

artagnon commented Jun 27, 2024

Speculating the stride currently inserts a Stride == 1 predicate, which is equivalent to asserting that the that the loop executes atleast once.

Uh, how are these equivalent (or even related at all)?

We currently speculate on the stride and insert Stride == 1 when Stride < TC. My understanding is that this is to guard against loops with unknown TC that never execute. Inserting Stride == 1 is a way to version the loop, so that one version executes with this predicate.

However, when the backedge-taken-count is known-non-negative, speculating the stride unnecessarily versions the loop. Avoid this.

The BE count is an unsigned quantity, so I don't get how "known-non-negative" can be meaningful for it.

Wait, what happens for TC = 0 or unknown? Isn't the BE SCEV an expression that's not known-non-negative?

@fhahn
Copy link
Contributor

fhahn commented Jun 27, 2024

hmm that's interesting, I was looking at the same code a few weeks ago. I'll try to refresh my memory tomorrow

@artagnon artagnon changed the title LAA: don't speculate stride when BTC >= 0 LAA: don't speculate stride when loop is known to execute Jun 28, 2024
The current stride versioning code in collectStridedAccess is quite
fragile, and has implicit effects. Make it more robust by making it
clear that Stride - 1 == BTC is a special case, and operate on
ConstantRanges directly. Query the exact backedge-taken count instead of
the symbolic maximum of it. This patch has the side effect of making it
possible to directly return the SCEVUnknown under a cast in
getStrideFromPointer, eliminating a second cast-stripping in
collectStridedAccess. It also has the side-effect of a positive test
update in symbolic-stride.
@artagnon
Copy link
Contributor Author

Okay, so I've investigated this problem thoroughly (producing #97075 as a side-effect), and concluded that this entire stride-versioning thing done by LAA is a rough heuristic that can't change, since too many callers depend on it. To fix the regression, I propose that we patch LoopVersioning directly, and not touch LAA. Any thoughts?

@fhahn
Copy link
Contributor

fhahn commented Jun 28, 2024

Okay, so I've investigated this problem thoroughly (producing #97075 as a side-effect), and concluded that this entire stride-versioning thing done by LAA is a rough heuristic that can't change, since too many callers depend on it. To fix the regression, I propose that we patch LoopVersioning directly, and not touch LAA. Any thoughts?

Yeah, versioning loops with a single iteration is likely to not be profitable. Would be interesting to know who is calling LoopVersioning on such loops

@artagnon artagnon changed the title LAA: don't speculate stride when loop is known to execute LAA: don't version single-iteration loops Jun 28, 2024
@artagnon artagnon marked this pull request as ready for review June 28, 2024 17:00
@llvmbot
Copy link
Member

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Ramkumar Ramachandra (artagnon)

Changes

Single-iteration loops are a special case that are likely to be unprofitable to version.

Fixes #96656.


Full diff: https://github.com/llvm/llvm-project/pull/96927.diff

7 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRange.h (+3)
  • (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+30-36)
  • (modified) llvm/lib/IR/ConstantRange.cpp (+5)
  • (modified) llvm/test/Analysis/LoopAccessAnalysis/pr96656.ll (-4)
  • (modified) llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll (+4-11)
  • (modified) llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll (+16-9)
  • (modified) llvm/test/Transforms/LoopVersioning/pr96656.ll (+4-22)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 7b94b9c6c6d11..86d0a6b35d748 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -277,6 +277,9 @@ class [[nodiscard]] ConstantRange {
   /// Return true if all values in this range are non-negative.
   bool isAllNonNegative() const;
 
+  /// Return true if all values in this range are positive.
+  bool isAllPositive() const;
+
   /// Return the largest unsigned value contained in the ConstantRange.
   APInt getUnsignedMax() const;
 
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 38bf6d8160aa9..9c21fb1c28eb4 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
@@ -61,6 +62,8 @@
 #include <cassert>
 #include <cstdint>
 #include <iterator>
+#include <optional>
+#include <sys/types.h>
 #include <utility>
 #include <variant>
 #include <vector>
@@ -2914,7 +2917,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
 
   if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
     if (isa<SCEVUnknown>(C->getOperand()))
-      return V;
+      return C->getOperand();
 
   return nullptr;
 }
@@ -2930,7 +2933,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // computation of an interesting IV - but we chose not to as we
   // don't have a cost model here, and broadening the scope exposes
   // far too many unprofitable cases.
-  const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+  ScalarEvolution *SE = PSE->getSE();
+  const SCEV *StrideExpr = getStrideFromPointer(Ptr, SE, TheLoop);
   if (!StrideExpr)
     return;
 
@@ -2943,10 +2947,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
     return;
   }
 
-  // Avoid adding the "Stride == 1" predicate when we know that
-  // Stride >= Trip-Count. Such a predicate will effectively optimize a single
-  // or zero iteration loop, as Trip-Count <= Stride == 1.
-  //
   // TODO: We are currently not making a very informed decision on when it is
   // beneficial to apply stride versioning. It might make more sense that the
   // users of this analysis (such as the vectorizer) will trigger it, based on
@@ -2956,40 +2956,34 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // of various possible stride specializations, considering the alternatives
   // of using gather/scatters (if available).
 
-  const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
-
-  // Match the types so we can compare the stride and the MaxBTC.
-  // The Stride can be positive/negative, so we sign extend Stride;
-  // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
-  const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
-  uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
-  uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
-  const SCEV *CastedStride = StrideExpr;
-  const SCEV *CastedBECount = MaxBTC;
-  ScalarEvolution *SE = PSE->getSE();
-  if (BETypeSizeBits >= StrideTypeSizeBits)
-    CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
-  else
-    CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
-  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
-  // Since TripCount == BackEdgeTakenCount + 1, checking:
-  // "Stride >= TripCount" is equivalent to checking:
-  // Stride - MaxBTC> 0
-  if (SE->isKnownPositive(StrideMinusBETaken)) {
-    LLVM_DEBUG(
-        dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
-                  "Stride==1 predicate will imply that the loop executes "
-                  "at most once.\n");
+  // Get two signed ranges and compare them, after adjusting for bitwidth. BTC
+  // range could extend into -1.
+  const SCEV *BTC = PSE->getBackedgeTakenCount();
+  ConstantRange BTCRange = SE->getSignedRange(BTC);
+  ConstantRange StrideRange =
+      SE->getSignedRange(StrideExpr).sextOrTrunc(BTCRange.getBitWidth());
+
+  // Stride is zero-extended to compare with BTC.
+  const SCEV *CastedStride =
+      SE->getTruncateOrZeroExtend(StrideExpr, BTC->getType());
+  const SCEV *StrideMinusOne =
+      SE->getMinusSCEV(CastedStride, SE->getOne(CastedStride->getType()));
+
+  bool IsSingleIterationLoop =
+      BTCRange.isSingleElement() && BTCRange.getSingleElement()->isZero();
+
+  // Stride - 1 exactly equal to BTC is a special case for which the loop should
+  // not be versioned. Single-iteration loops are likely unprofitable to
+  // version. Otherwise, the loop should not be versioned if the range
+  // difference is all positive.
+  if (StrideMinusOne == BTC || IsSingleIterationLoop ||
+      StrideRange.difference(BTCRange).isAllPositive()) {
+    LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
     return;
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  // Strip back off the integer cast, and check that our result is a
-  // SCEVUnknown as we expect.
-  const SCEV *StrideBase = StrideExpr;
-  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
-    StrideBase = C->getOperand();
-  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideExpr);
 }
 
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 19041704a40be..b942894d34467 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -440,6 +440,11 @@ bool ConstantRange::isAllNonNegative() const {
   return !isSignWrappedSet() && Lower.isNonNegative();
 }
 
+bool ConstantRange::isAllPositive() const {
+  // Empty and full set are automatically treated correctly.
+  return !isSignWrappedSet() && Lower.isStrictlyPositive();
+}
+
 APInt ConstantRange::getUnsignedMax() const {
   if (isFullSet() || isUpperWrapped())
     return APInt::getMaxValue(getBitWidth());
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/pr96656.ll b/llvm/test/Analysis/LoopAccessAnalysis/pr96656.ll
index 5b9833553fa02..cc7aca6f4f8f6 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/pr96656.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/pr96656.ll
@@ -11,12 +11,8 @@ define void @false.equal.predicate(ptr %arg, ptr %arg1, i1 %arg2) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Non vectorizable stores to invariant address were not found in loop.
 ; CHECK-NEXT:      SCEV assumptions:
-; CHECK-NEXT:      Equal predicate: %load == 1
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Expressions re-written:
-; CHECK-NEXT:      [PSE] %gep10 = getelementptr double, ptr %gep8, i64 %mul:
-; CHECK-NEXT:        {(8 + %arg1),+,(8 * (sext i32 %load to i64))<nsw>}<%loop.body>
-; CHECK-NEXT:        --> {(8 + %arg1),+,8}<%loop.body>
 ;
 entry:
   %load = load i32, ptr %arg, align 4
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
index 7c1b11e22aef2..e9aeac7ac2bc5 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
@@ -170,23 +170,16 @@ define void @single_stride_castexpr_multiuse(i32 %offset, ptr %src, ptr %dst, i1
 ; CHECK-NEXT:          %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3
 ; CHECK-NEXT:      Grouped accesses:
 ; CHECK-NEXT:        Group [[GRP3]]:
-; CHECK-NEXT:          (Low: ((4 * %iv.1) + %dst) High: (804 + (4 * %iv.1) + (-4 * (zext i32 %offset to i64))<nsw> + %dst))
-; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: (((4 * %iv.1) + %dst) umin ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst)) High: (4 + (((4 * %iv.1) + %dst) umax ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst))))
+; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
 ; CHECK-NEXT:        Group [[GRP4]]:
-; CHECK-NEXT:          (Low: (4 + %src) High: (808 + (-4 * (zext i32 %offset to i64))<nsw> + %src))
-; CHECK-NEXT:            Member: {(4 + %src),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: ((4 * (zext i32 %offset to i64))<nuw><nsw> + %src) High: (804 + %src))
+; CHECK-NEXT:            Member: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Non vectorizable stores to invariant address were not found in loop.
 ; CHECK-NEXT:      SCEV assumptions:
-; CHECK-NEXT:      Equal predicate: %offset == 1
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Expressions re-written:
-; CHECK-NEXT:      [PSE] %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3:
-; CHECK-NEXT:        {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
-; CHECK-NEXT:        --> {(4 + %src),+,4}<%inner.loop>
-; CHECK-NEXT:      [PSE] %gep.dst = getelementptr i32, ptr %dst, i64 %iv.2:
-; CHECK-NEXT:        {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
-; CHECK-NEXT:        --> {((4 * %iv.1) + %dst),+,4}<%inner.loop>
 ; CHECK-NEXT:    outer.header:
 ; CHECK-NEXT:      Report: loop is not the innermost loop
 ; CHECK-NEXT:      Dependences:
diff --git a/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll b/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
index 45596169da3cc..f50b0e8be2cb3 100644
--- a/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
+++ b/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
@@ -490,10 +490,7 @@ define void @sext_of_i1_stride(i1 %g, ptr %dst) mustprogress {
 ; CHECK-NEXT:    [[TMP1:%.*]] = udiv i64 [[TMP0]], [[G_64]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP2]], 4
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
-; CHECK:       vector.scevcheck:
-; CHECK-NEXT:    [[IDENT_CHECK:%.*]] = icmp ne i1 [[G]], true
-; CHECK-NEXT:    br i1 [[IDENT_CHECK]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP2]], 4
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP2]], [[N_MOD_VF]]
@@ -504,17 +501,27 @@ define void @sext_of_i1_stride(i1 %g, ptr %dst) mustprogress {
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = mul i64 [[INDEX]], [[G_64]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 0, [[G_64]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[OFFSET_IDX]], [[TMP3]]
+; CHECK-NEXT:    [[TMP11:%.*]] = mul i64 1, [[G_64]]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[OFFSET_IDX]], [[TMP11]]
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i64 2, [[G_64]]
+; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[OFFSET_IDX]], [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = mul i64 3, [[G_64]]
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[OFFSET_IDX]], [[TMP9]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[DST]], i64 [[TMP4]]
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[TMP6]], i32 -3
-; CHECK-NEXT:    store <4 x i16> <i16 -1, i16 -1, i16 -1, i16 -1>, ptr [[TMP7]], align 2
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i16, ptr [[DST]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds i16, ptr [[DST]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds i16, ptr [[DST]], i64 [[TMP10]]
+; CHECK-NEXT:    store i16 [[G_16]], ptr [[TMP5]], align 2
+; CHECK-NEXT:    store i16 [[G_16]], ptr [[TMP12]], align 2
+; CHECK-NEXT:    store i16 [[G_16]], ptr [[TMP13]], align 2
+; CHECK-NEXT:    store i16 [[G_16]], ptr [[TMP14]], align 2
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-NEXT:    br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ], [ 0, [[VECTOR_SCEVCHECK]] ]
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
 ; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
@@ -560,5 +567,5 @@ exit:
 ; CHECK: [[LOOP12]] = distinct !{[[LOOP12]], [[META1]], [[META2]]}
 ; CHECK: [[LOOP13]] = distinct !{[[LOOP13]], [[META1]]}
 ; CHECK: [[LOOP14]] = distinct !{[[LOOP14]], [[META1]], [[META2]]}
-; CHECK: [[LOOP15]] = distinct !{[[LOOP15]], [[META1]]}
+; CHECK: [[LOOP15]] = distinct !{[[LOOP15]], [[META2]], [[META1]]}
 ;.
diff --git a/llvm/test/Transforms/LoopVersioning/pr96656.ll b/llvm/test/Transforms/LoopVersioning/pr96656.ll
index 0264fe40a9430..2ef8ccbb8f9d1 100644
--- a/llvm/test/Transforms/LoopVersioning/pr96656.ll
+++ b/llvm/test/Transforms/LoopVersioning/pr96656.ll
@@ -6,44 +6,26 @@ define void @lver.check.unnecessary(ptr %arg, ptr %arg1, i1 %arg2) {
 ; CHECK-SAME: ptr [[ARG:%.*]], ptr [[ARG1:%.*]], i1 [[ARG2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i32, ptr [[ARG]], align 4
-; CHECK-NEXT:    br i1 [[ARG2]], label %[[NOLOOP_EXIT:.*]], label %[[LOOP_BODY_LVER_CHECK:.*]]
-; CHECK:       [[LOOP_BODY_LVER_CHECK]]:
+; CHECK-NEXT:    br i1 [[ARG2]], label %[[NOLOOP_EXIT:.*]], label %[[LOOP_PH:.*]]
+; CHECK:       [[LOOP_PH]]:
 ; CHECK-NEXT:    [[SEXT7:%.*]] = sext i32 [[LOAD]] to i64
 ; CHECK-NEXT:    [[GEP8:%.*]] = getelementptr i8, ptr [[ARG1]], i64 8
-; CHECK-NEXT:    [[IDENT_CHECK:%.*]] = icmp ne i32 [[LOAD]], 1
-; CHECK-NEXT:    br i1 [[IDENT_CHECK]], label %[[LOOP_BODY_PH_LVER_ORIG:.*]], label %[[LOOP_BODY_PH:.*]]
-; CHECK:       [[LOOP_BODY_PH_LVER_ORIG]]:
-; CHECK-NEXT:    br label %[[LOOP_BODY_LVER_ORIG:.*]]
-; CHECK:       [[LOOP_BODY_LVER_ORIG]]:
-; CHECK-NEXT:    [[PHI_LVER_ORIG:%.*]] = phi i64 [ 0, %[[LOOP_BODY_PH_LVER_ORIG]] ], [ [[ADD_LVER_ORIG:%.*]], %[[LOOP_BODY_LVER_ORIG]] ]
-; CHECK-NEXT:    [[MUL_LVER_ORIG:%.*]] = mul i64 [[PHI_LVER_ORIG]], [[SEXT7]]
-; CHECK-NEXT:    [[GEP10_LVER_ORIG:%.*]] = getelementptr double, ptr [[GEP8]], i64 [[MUL_LVER_ORIG]]
-; CHECK-NEXT:    [[LOAD11_LVER_ORIG:%.*]] = load double, ptr [[GEP10_LVER_ORIG]], align 8
-; CHECK-NEXT:    store double [[LOAD11_LVER_ORIG]], ptr [[ARG1]], align 8
-; CHECK-NEXT:    [[ADD_LVER_ORIG]] = add i64 [[PHI_LVER_ORIG]], 1
-; CHECK-NEXT:    [[ICMP_LVER_ORIG:%.*]] = icmp eq i64 [[PHI_LVER_ORIG]], 0
-; CHECK-NEXT:    br i1 [[ICMP_LVER_ORIG]], label %[[LOOP_EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY_LVER_ORIG]]
-; CHECK:       [[LOOP_BODY_PH]]:
 ; CHECK-NEXT:    br label %[[LOOP_BODY:.*]]
 ; CHECK:       [[LOOP_BODY]]:
-; CHECK-NEXT:    [[PHI:%.*]] = phi i64 [ 0, %[[LOOP_BODY_PH]] ], [ [[ADD:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT:    [[PHI:%.*]] = phi i64 [ 0, %[[LOOP_PH]] ], [ [[ADD:%.*]], %[[LOOP_BODY]] ]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul i64 [[PHI]], [[SEXT7]]
 ; CHECK-NEXT:    [[GEP10:%.*]] = getelementptr double, ptr [[GEP8]], i64 [[MUL]]
 ; CHECK-NEXT:    [[LOAD11:%.*]] = load double, ptr [[GEP10]], align 8
 ; CHECK-NEXT:    store double [[LOAD11]], ptr [[ARG1]], align 8
 ; CHECK-NEXT:    [[ADD]] = add i64 [[PHI]], 1
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp eq i64 [[PHI]], 0
-; CHECK-NEXT:    br i1 [[ICMP]], label %[[LOOP_EXIT_LOOPEXIT1:.*]], label %[[LOOP_BODY]]
+; CHECK-NEXT:    br i1 [[ICMP]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY]]
 ; CHECK:       [[NOLOOP_EXIT]]:
 ; CHECK-NEXT:    [[SEXT:%.*]] = sext i32 [[LOAD]] to i64
 ; CHECK-NEXT:    [[GEP:%.*]] = getelementptr double, ptr [[ARG1]], i64 [[SEXT]]
 ; CHECK-NEXT:    [[LOAD5:%.*]] = load double, ptr [[GEP]], align 8
 ; CHECK-NEXT:    store double [[LOAD5]], ptr [[ARG]], align 8
 ; CHECK-NEXT:    ret void
-; CHECK:       [[LOOP_EXIT_LOOPEXIT]]:
-; CHECK-NEXT:    br label %[[LOOP_EXIT:.*]]
-; CHECK:       [[LOOP_EXIT_LOOPEXIT1]]:
-; CHECK-NEXT:    br label %[[LOOP_EXIT]]
 ; CHECK:       [[LOOP_EXIT]]:
 ; CHECK-NEXT:    ret void
 ;

@artagnon
Copy link
Contributor Author

Thanks for the guidance. I realized that this was the case, but wasn't sure about adding a special-case for single-iteration loops.

@fhahn
Copy link
Contributor

fhahn commented Jun 28, 2024

Thanks for the guidance. I realized that this was the case, but wasn't sure about adding a special-case for single-iteration loops.

Yeah, as I mentioned earlier, it would also be good to understand how we get to the point that someone asks loop-versioning to version such a loop. It might need to be fixed in the passes that requesting versioning, based on the benefit the passes get from versioning (e.g. depending on the pass, other small trip counts may also not be profitable, but ideally that would be decided by the pass that requests versioning, as there more info is available)

@artagnon
Copy link
Contributor Author

Right, got it. I won't re-purpose this PR further, and will open a fresh PR fixing the callers of LoopVersioning.

@artagnon artagnon closed this Jun 28, 2024
@artagnon artagnon deleted the laa-96656 branch June 28, 2024 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LoopLoadElim: calling LoopVersioning with single-iteration loop
4 participants