Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM integrate fixes #2160

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 4 additions & 39 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1442,20 +1442,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {

if (auto LI = dyn_cast<LoadInst>(TmpOrig))
return isConstantValue(TR, LI->getPointerOperand());
if (isa<IntrinsicInst>(TmpOrig) &&
(cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f))
if (isNVLoad(TmpOrig)) {
return isConstantValue(TR, cast<Instruction>(TmpOrig)->getOperand(0));
}

if (TmpOrig == Val)
return false;
Expand Down Expand Up @@ -1547,19 +1536,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
TmpOrig);
}
}
} else if (isa<IntrinsicInst>(TmpOrig) &&
(cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f)) {
} else if (isNVLoad(TmpOrig)) {
auto II = cast<IntrinsicInst>(TmpOrig);
if (directions == UP) {
if (isConstantValue(TR, II->getOperand(0))) {
Expand Down Expand Up @@ -1950,19 +1927,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
isRefSet(AARes)) {
if (EnzymePrintActivity)
llvm::errs() << "potential active load: " << *I << "\n";
if (isa<LoadInst>(I) || (isa<IntrinsicInst>(I) &&
(cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f))) {
if (isa<LoadInst>(I) || isNVLoad(I)) {
// If the ref'ing value is a load check if the loaded value is
// active
if (!Hypothesis->isConstantValue(TR, I)) {
Expand Down
10 changes: 6 additions & 4 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3725,12 +3725,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
Module *M = I.getParent()->getParent()->getParent();

switch (ID) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
auto CI = cast<ConstantInt>(I.getOperand(1));
visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()),
/*constantval*/ false);
Expand Down
11 changes: 1 addition & 10 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,17 +949,8 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
if (isAllocationCall(V, TLI) || isa<AllocaInst>(V)) {
auto next = (*mp.begin()).V;
bool noncapture = false;
if (isa<LoadInst>(next)) {
if (isa<LoadInst>(next) || isNVLoad(next)) {
noncapture = true;
} else if (auto II = dyn_cast<IntrinsicInst>(next)) {
if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load)
noncapture = true;
} else if (auto CI = dyn_cast<CallInst>(next)) {
bool captures = false;
for (size_t i = 0; i < CI->arg_size(); i++) {
Expand Down
12 changes: 11 additions & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,13 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM)));
};

auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) {
#if LLVM_VERSION_MAJOR >= 20
auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level,
ThinOrFullLTOPhase)
#else
auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level)
#endif
{
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));

if (!EnzymeEnable)
Expand Down Expand Up @@ -3643,7 +3649,11 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
LPM.addPass(LoopDeletionPass());
// FIXME: Add loop interchange.

#if LLVM_VERSION_MAJOR >= 20
loadPass(MPM, Level, ThinOrFullLTOPhase::None);
#else
loadPass(MPM, Level);
#endif
};
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
}
Expand Down
30 changes: 7 additions & 23 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,14 @@ struct CacheAnalysis {
continue;
for (auto &inst : B) {
// For each load instruction, determine if it is uncacheable.
if (auto op = dyn_cast<LoadInst>(&inst)) {
can_modref_map[op] = is_load_uncacheable(*op);
}
if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
if (isa<LoadInst>(&inst)) {
can_modref_map[&inst] = is_load_uncacheable(inst);
} else if (isNVLoad(&inst)) {
can_modref_map[&inst] = false;
} else if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
can_modref_map[II] = false;
break;
case Intrinsic::masked_load:
can_modref_map[II] = is_load_uncacheable(*II);
can_modref_map[&inst] = is_load_uncacheable(inst);
break;
default:
break;
Expand Down Expand Up @@ -5364,20 +5357,11 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
llvm::SmallVectorImpl<llvm::Value *> &orig_ops) {
using namespace llvm;

switch (ID) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
if (isNVLoad(&I)) {
auto CI = cast<ConstantInt>(I.getOperand(1));
visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()));
return true;
}
default:
break;
}

if (ID == Intrinsic::masked_store) {
auto align0 = cast<ConstantInt>(I.getOperand(2))->getZExtValue();
Expand Down
53 changes: 13 additions & 40 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3904,15 +3904,9 @@ bool GradientUtils::legalRecompute(const Value *val,
if (auto li = dyn_cast<Instruction>(val)) {

const IntrinsicInst *II;
if (isa<LoadInst>(li) ||
if (isa<LoadInst>(li) || isNVLoad(li) ||
((II = dyn_cast<IntrinsicInst>(li)) &&
(II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load))) {
(II->getIntrinsicID() == Intrinsic::masked_load))) {
// If this is an already unwrapped value, legal to recompute again.
if (unwrappedLoads.find(li) != unwrappedLoads.end())
return legalRecompute(unwrappedLoads.find(li)->second, available,
Expand Down Expand Up @@ -4174,7 +4168,7 @@ bool GradientUtils::shouldRecompute(const Value *val,
}

if (auto op = dyn_cast<IntrinsicInst>(val)) {
if (!op->mayReadOrWriteMemory() || isReadNone(op))
if (!op->mayReadOrWriteMemory() || isReadNone(op) || isNVLoad(op))
return true;
switch (op->getIntrinsicID()) {
case Intrinsic::sin:
Expand All @@ -4186,12 +4180,6 @@ bool GradientUtils::shouldRecompute(const Value *val,
case Intrinsic::sinh:
#endif
case Intrinsic::log:
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
return true;
default:
return false;
Expand Down Expand Up @@ -6109,12 +6097,14 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
switch (II->getIntrinsicID()) {
default:
goto end;
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
return applyChainRule(
II->getType(), bb,
[&](Value *ptr) {
Expand Down Expand Up @@ -6388,19 +6378,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
bool reduceRegister = false;

if (EnzymeRegisterReduce) {
if (auto II = dyn_cast<IntrinsicInst>(inst)) {
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
reduceRegister = true;
break;
default:
break;
}
if (isNVLoad(inst)) {
reduceRegister = true;
}
if (auto LI = dyn_cast<LoadInst>(inst)) {
auto Arch =
Expand Down Expand Up @@ -9526,17 +9505,11 @@ bool GradientUtils::needsCacheWholeAllocation(
continue;
seen.insert(pair);
// Loads are always fine
if (isa<LoadInst>(cur))
if (isa<LoadInst>(cur) || isNVLoad(cur))
continue;

if (auto II = dyn_cast<IntrinsicInst>(cur))
if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load)
if (II->getIntrinsicID() == Intrinsic::masked_load)
continue;

bool returnedSameValue = false;
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(LIBS
MLIREnzymeTransforms
MLIREnzyme
MLIROptLib
MLIRFuncInlinerExtension
)
add_llvm_executable(enzymemlir-opt enzymemlir-opt.cpp)

Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ struct AddToOpToIndexAndLoadPass
cacheBuilder.setInsertionPoint(terminator);

// Is it a fine assumption that all indexing maps are the same?
for (int i = 0; i < map[0].getNumDims(); i++) {
for (size_t i = 0; i < map[0].getNumDims(); i++) {
indices.push_back(cacheBuilder.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> rets;
for (int i = 0; i < retargs.size(); i++) {
for (size_t i = 0; i < retargs.size(); i++) {
// auto load = cacheBuilder.create<AffineLoadOp>(loc, inputs[i], map[i],
// indices); auto store = cacheBuilder.create<AffineStoreOp>(loc, load,
// inputs[i], map[i], indices);
Expand All @@ -95,7 +95,7 @@ struct AddToOpToIndexAndLoadPass
mapAppliedIndices);
}

for (int i = 0; i < retargs.size(); i++) {
for (size_t i = 0; i < retargs.size(); i++) {
SmallVector<Value> mapAppliedIndices =
applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
auto load = cacheBuilder.create<memref::LoadOp>(loc, outs[i],
Expand Down
10 changes: 6 additions & 4 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3898,12 +3898,14 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
return;
}

case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;

Expand Down
20 changes: 20 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,3 +3790,23 @@ void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk << "\n"; }
void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; }

void dumpTypeResults(TypeResults &TR) { TR.dump(); }

bool isNVLoad(const llvm::Value *V) {
auto II = dyn_cast<IntrinsicInst>(V);
if (!II)
return false;
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
#endif
return true;
default:
return false;
}
return false;
}
3 changes: 3 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2095,4 +2095,7 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B,
llvm::ArrayRef<llvm::Value *> Args,
llvm::Instruction *FMFSource = nullptr,
const llvm::Twine &Name = "");

bool isNVLoad(const llvm::Value *V);

#endif // ENZYME_UTILS_H
Loading