From 20d8738291b50c98faa916af4a9f33bc17773d05 Mon Sep 17 00:00:00 2001 From: Danial Klimkin Date: Wed, 2 Oct 2024 21:05:42 +0200 Subject: [PATCH] Update MustExitScalarEvolution.cpp for upstream LLVM changes (#2098) * Update MustExitScalarEvolution.cpp for upstream LLVM changes Fixes issue #2097 * Fix formatting as prescribed. --- enzyme/Enzyme/MustExitScalarEvolution.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MustExitScalarEvolution.cpp b/enzyme/Enzyme/MustExitScalarEvolution.cpp index 2ba7319c131f..dff36303f9ca 100644 --- a/enzyme/Enzyme/MustExitScalarEvolution.cpp +++ b/enzyme/Enzyme/MustExitScalarEvolution.cpp @@ -27,6 +27,7 @@ #include "MustExitScalarEvolution.h" #include "FunctionUtils.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -229,12 +230,15 @@ MustExitScalarEvolution::computeExitLimitFromCondImpl( !isa(BECount)) MaxBECount = getConstant(getUnsignedRangeMax(BECount)); -#if LLVM_VERSION_MAJOR >= 16 +#if LLVM_VERSION_MAJOR >= 20 + return ExitLimit(BECount, MaxBECount, MaxBECount, false, + {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); +#elif LLVM_VERSION_MAJOR >= 16 return ExitLimit(BECount, MaxBECount, MaxBECount, false, {&EL0.Predicates, &EL1.Predicates}); #else - return ExitLimit(BECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); + return ExitLimit(BECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); #endif } if (BO->getOpcode() == Instruction::Or) { @@ -273,8 +277,13 @@ MustExitScalarEvolution::computeExitLimitFromCondImpl( if (EL0.ExactNotTaken == EL1.ExactNotTaken) BECount = EL0.ExactNotTaken; } +#if LLVM_VERSION_MAJOR >= 20 + return ExitLimit(BECount, MaxBECount, MaxBECount, false, + {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); +#else return ExitLimit(BECount, MaxBECount, MaxBECount, false, {&EL0.Predicates, &EL1.Predicates}); +#endif #else if (EL0.MaxNotTaken == getCouldNotCompute()) MaxBECount = EL1.MaxNotTaken; @@ -741,7 +750,11 @@ static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, ScalarEvolution::ExitLimit MustExitScalarEvolution::howManyLessThans( const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { +#if LLVM_VERSION_MAJOR >= 20 + SmallVector Predicates; +#else SmallPtrSet Predicates; +#endif const SCEVAddRecExpr *IV = dyn_cast(LHS); bool PredicatedIV = false;