Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 4, 2024
1 parent d13fca4 commit a69f47c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
41 changes: 22 additions & 19 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
std::map<Node, Node> parent;
bfs(G, Recomputes, parent);

std::deque<Value *> todo;
SetVector<Value *> todo;

// Print all edges that are from a reachable vertex to
// non-reachable vertex in the original graph
Expand All @@ -885,20 +885,36 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
assert(pair.first.V == N.V);
MinReq.insert(N.V);
if (Orig.find(Node(N.V, true)) != Orig.end()) {
todo.push_back(N.V);
todo.insert(N.V);
}
}
}
}

// When ambiguous, push to cache the last value in a computation chain
// This should be considered in a cost for the max flow
while (todo.size()) {
auto V = todo.front();
todo.pop_front();
todo.remove(V);
auto found = Orig.find(Node(V, true));
assert(found != Orig.end());
const auto &mp = found->second;

assert(MinReq.count(V));

// Fix up non-cacheable calls to use their operand(s) instead
if (hasNoCache(V)) {
assert(!Required.count(V));
MinReq.remove(V);
for (auto &pair : Orig) {
if (pair.second.count(Node(V, false))) {
MinReq.insert(pair.first.V);
todo.insert(pair.first.V);
}
}
continue;
}

// When ambiguous, push to cache the last value in a computation chain
// This should be considered in a cost for the max flow
if (mp.size() == 1 && !Required.count(V)) {
bool potentiallyRecursive =
isa<PHINode>((*mp.begin()).V) &&
Expand Down Expand Up @@ -967,24 +983,11 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
auto nnode = (*mp.begin()).V;
MinReq.insert(nnode);
if (Orig.find(Node(nnode, true)) != Orig.end())
todo.push_back(nnode);
todo.insert(nnode);
}
}
}

// Fix up non-cacheable calls to use their operand(s) instead
for (auto V : Intermediates) {
if (!hasNoCache(V))
continue;
if (!MinReq.count(V))
continue;
MinReq.remove(V);
for (auto &pair : Orig) {
if (pair.second.count(Node(V, false))) {
MinReq.insert(pair.first.V);
}
}
}

// Fix up non-repeatable writing calls that chain within rematerialized
// allocations. We could iterate from the keys of the valuemap, but that would
Expand Down
9 changes: 9 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8177,6 +8177,9 @@ void GradientUtils::computeMinCache() {
notForAnalysis);
if (oneneed) {
knownRecomputeHeuristic[&I] = false;

CountTrackedPointers T(I.getType());
assert(!T.derived);
} else
Recomputes.insert(&I);
}
Expand Down Expand Up @@ -8267,8 +8270,14 @@ void GradientUtils::computeMinCache() {
assert(legalRecompute(V, Available2, nullptr));
}
if (!NeedGraph.count(V)) {
assert(!MinReq.count(V));
unnecessaryIntermediates.insert(cast<Instruction>(V));
}

if (NeedGraph.count(V) && MinReq.count(V)) {
CountTrackedPointers T(V->getType());
assert(!T.derived);
}
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ attributes #0 = { argmemonly nounwind }

; CHECK: define internal void @fwddiffememcpy_ptr(i8* nocapture %dst, i8* nocapture %"dst'", i8* nocapture readonly %src, i8* nocapture %"src'", i64 %num)
; CHECK-NEXT: entry:
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"dst'", i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 8, i1 false)
; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"dst'", i64 8
; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %0, i8 0, i64 8, i1 true)
; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %"dst'", i8 0, i64 16, i1 true)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %dst, i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 16, i1 false)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
6 changes: 4 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/nullcp.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -enzyme-runtime-error -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -enzyme-runtime-error -S | FileCheck %s

source_filename = "nullcp.c"
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
Expand Down Expand Up @@ -130,6 +130,8 @@ attributes #6 = { nounwind }
; CHECK-NEXT: %0 = bitcast double* %dst to i8*
; CHECK-NEXT: %1 = bitcast double* %src to i8*
; CHECK-NEXT: %mul = shl i64 %n, 3
; CHECK-NEXT: %2 = call i32 @puts(i8* getelementptr inbounds ([126 x i8], [126 x i8]* @.str.1, i32 0, i32 0))
; CHECK-NEXT: call void @exit(i32 1)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"'ipc", i8* align 8 %1, i64 %mul, i1 false)
; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %0, i8* align 8 %1, i64 %mul, i1 false)
; CHECK-NEXT: br label %invertif.end
Expand Down

0 comments on commit a69f47c

Please sign in to comment.