Skip to content

Commit

Permalink
Fix memtransfer of anything (#2152)
Browse files Browse the repository at this point in the history
* Fix memtransfer of anything

* fix

* fmt

* Fix

---------

Co-authored-by: Paul Berg <[email protected]>
  • Loading branch information
wsmoses and Pangoraw authored Nov 5, 2024
1 parent f1f4d8e commit 4025abd
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 33 deletions.
3 changes: 2 additions & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3317,7 +3317,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
} else {
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto vd = TR.query(orig_dst).Data0().ShiftIndices(DL, 0, size, 0);
vd |= TR.query(orig_src).Data0().ShiftIndices(DL, 0, size, 0);
vd |= TR.query(orig_src).Data0().PurgeAnything().ShiftIndices(DL, 0, size,
0);
for (size_t i = 0; i < MTI.getNumOperands(); i++)
if (MTI.getOperand(i) == orig_dst)
if (MTI.getAttributes().hasParamAttr(i, "enzyme_type")) {
Expand Down
42 changes: 22 additions & 20 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
std::map<Node, Node> parent;
bfs(G, Recomputes, parent);

std::deque<Value *> todo;
SetVector<Value *> todo;

// Print all edges that are from a reachable vertex to
// non-reachable vertex in the original graph
Expand All @@ -885,20 +885,36 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
assert(pair.first.V == N.V);
MinReq.insert(N.V);
if (Orig.find(Node(N.V, true)) != Orig.end()) {
todo.push_back(N.V);
todo.insert(N.V);
}
}
}
}

// When ambiguous, push to cache the last value in a computation chain
// This should be considered in a cost for the max flow
while (todo.size()) {
auto V = todo.front();
todo.pop_front();
todo.remove(V);
auto found = Orig.find(Node(V, true));
assert(found != Orig.end());
const auto &mp = found->second;

assert(MinReq.count(V));

// Fix up non-cacheable calls to use their operand(s) instead
if (hasNoCache(V)) {
assert(!Required.count(V));
MinReq.remove(V);
for (auto &pair : Orig) {
if (pair.second.count(Node(V, false))) {
MinReq.insert(pair.first.V);
todo.insert(pair.first.V);
}
}
continue;
}

// When ambiguous, push to cache the last value in a computation chain
// This should be considered in a cost for the max flow
if (mp.size() == 1 && !Required.count(V)) {
bool potentiallyRecursive =
isa<PHINode>((*mp.begin()).V) &&
Expand Down Expand Up @@ -967,21 +983,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
auto nnode = (*mp.begin()).V;
MinReq.insert(nnode);
if (Orig.find(Node(nnode, true)) != Orig.end())
todo.push_back(nnode);
}
}
}

// Fix up non-cacheable calls to use their operand(s) instead
for (auto V : Intermediates) {
if (!hasNoCache(V))
continue;
if (!MinReq.count(V))
continue;
MinReq.remove(V);
for (auto &pair : Orig) {
if (pair.second.count(Node(V, false))) {
MinReq.insert(pair.first.V);
todo.insert(nnode);
}
}
}
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@

extern "C" {
extern llvm::cl::opt<bool> EnzymePrint;
extern llvm::cl::opt<bool> EnzymeJuliaAddrLoad;
}

constexpr char EnzymeFPRTPrefix[] = "__enzyme_fprt_";
Expand Down
9 changes: 9 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8177,6 +8177,9 @@ void GradientUtils::computeMinCache() {
notForAnalysis);
if (oneneed) {
knownRecomputeHeuristic[&I] = false;

CountTrackedPointers T(I.getType());
assert(!T.derived);
} else
Recomputes.insert(&I);
}
Expand Down Expand Up @@ -8267,8 +8270,14 @@ void GradientUtils::computeMinCache() {
assert(legalRecompute(V, Available2, nullptr));
}
if (!NeedGraph.count(V)) {
assert(!MinReq.count(V));
unnecessaryIntermediates.insert(cast<Instruction>(V));
}

if (NeedGraph.count(V) && MinReq.count(V)) {
CountTrackedPointers T(V->getType());
assert(!T.derived);
}
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
{"__nv_drcp_rn", Intrinsic::not_intrinsic},
{"__nv_drcp_ru", Intrinsic::not_intrinsic},
{"__nv_drcp_rz", Intrinsic::not_intrinsic},
{"__nv_isnand", Intrinsic::not_intrinsic},
{"__nv_isnanf", Intrinsic::not_intrinsic},
{"__nv_isinfd", Intrinsic::not_intrinsic},
{"__nv_isinff", Intrinsic::not_intrinsic},
{"__nv_acos", Intrinsic::not_intrinsic},
{"asin", Intrinsic::not_intrinsic},
{"__nv_asin", Intrinsic::not_intrinsic},
{"atan", Intrinsic::not_intrinsic},
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ extern const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS;

static inline bool isMemFreeLibMFunction(llvm::StringRef str,
llvm::Intrinsic::ID *ID = nullptr) {
llvm::StringRef ogstr = str;
if (startsWith(str, "__") && endsWith(str, "_finite")) {
str = str.substr(2, str.size() - 2 - 7);
} else if (startsWith(str, "__fd_") && endsWith(str, "_1")) {
Expand All @@ -72,7 +73,8 @@ static inline bool isMemFreeLibMFunction(llvm::StringRef str,
*ID = LIBM_FUNCTIONS.find(str.str())->second;
return true;
}
if (endsWith(str, "f") || endsWith(str, "l")) {
if (endsWith(str, "f") || endsWith(str, "l") ||
(startsWith(ogstr, "__nv_") && endsWith(str, "d"))) {
if (LIBM_FUNCTIONS.find(str.substr(0, str.size() - 1).str()) !=
LIBM_FUNCTIONS.end()) {
if (ID)
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ extern llvm::cl::opt<bool> EnzymePrintPerf;
extern llvm::cl::opt<bool> EnzymeStrongZero;
extern llvm::cl::opt<bool> EnzymeBlasCopy;
extern llvm::cl::opt<bool> EnzymeLapackCopy;
extern llvm::cl::opt<bool> EnzymeJuliaAddrLoad;
extern LLVMValueRef (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
const void *, LLVMValueRef,
LLVMBuilderRef);
Expand Down Expand Up @@ -1184,6 +1185,15 @@ static inline bool hasNoCache(llvm::Value *op) {
if (auto I = dyn_cast<Instruction>(op))
if (hasMetadata(I, "enzyme_nocache"))
return true;

if (EnzymeJuliaAddrLoad) {
if (auto PT = dyn_cast<PointerType>(op->getType())) {
if (PT->getAddressSpace() == 11 || PT->getAddressSpace() == 13) {
if (isa<CastInst>(op) || isa<GetElementPtrInst>(op))
return true;
}
}
}
return false;
}

Expand Down
4 changes: 1 addition & 3 deletions enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ attributes #0 = { argmemonly nounwind }

; CHECK: define internal void @fwddiffememcpy_ptr(i8* nocapture %dst, i8* nocapture %"dst'", i8* nocapture readonly %src, i8* nocapture %"src'", i64 %num)
; CHECK-NEXT: entry:
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"dst'", i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 8, i1 false)
; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"dst'", i64 8
; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %0, i8 0, i64 8, i1 true)
; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %"dst'", i8 0, i64 16, i1 true)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %dst, i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 16, i1 false)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
6 changes: 4 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/nullcp.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -enzyme-runtime-error -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -enzyme-runtime-error -S | FileCheck %s

source_filename = "nullcp.c"
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
Expand Down Expand Up @@ -130,6 +130,8 @@ attributes #6 = { nounwind }
; CHECK-NEXT: %0 = bitcast double* %dst to i8*
; CHECK-NEXT: %1 = bitcast double* %src to i8*
; CHECK-NEXT: %mul = shl i64 %n, 3
; CHECK-NEXT: %2 = call i32 @puts(i8* getelementptr inbounds ([126 x i8], [126 x i8]* @.str.1, i32 0, i32 0))
; CHECK-NEXT: call void @exit(i32 1)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"'ipc", i8* align 8 %1, i64 %mul, i1 false)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %0, i8* align 8 %1, i64 %mul, i1 false)
; CHECK-NEXT: br label %invertif.end
Expand Down

0 comments on commit 4025abd

Please sign in to comment.