From a69f47c59aa05abb1818668e6f60492cabf4ea40 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 4 Nov 2024 14:59:32 -0500 Subject: [PATCH] fix --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 41 ++++++++++--------- enzyme/Enzyme/GradientUtils.cpp | 9 ++++ .../test/Enzyme/ForwardMode/memcpyanyflt.ll | 4 +- enzyme/test/Enzyme/ReverseMode/nullcp.ll | 6 ++- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 6a4573b736a3..8cb47284e505 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -873,7 +873,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, std::map parent; bfs(G, Recomputes, parent); - std::deque todo; + SetVector todo; // Print all edges that are from a reachable vertex to // non-reachable vertex in the original graph @@ -885,20 +885,36 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, assert(pair.first.V == N.V); MinReq.insert(N.V); if (Orig.find(Node(N.V, true)) != Orig.end()) { - todo.push_back(N.V); + todo.insert(N.V); } } } } - // When ambiguous, push to cache the last value in a computation chain - // This should be considered in a cost for the max flow while (todo.size()) { auto V = todo.front(); - todo.pop_front(); + todo.remove(V); auto found = Orig.find(Node(V, true)); assert(found != Orig.end()); const auto &mp = found->second; + + assert(MinReq.count(V)); + + // Fix up non-cacheable calls to use their operand(s) instead + if (hasNoCache(V)) { + assert(!Required.count(V)); + MinReq.remove(V); + for (auto &pair : Orig) { + if (pair.second.count(Node(V, false))) { + MinReq.insert(pair.first.V); + todo.insert(pair.first.V); + } + } + continue; + } + + // When ambiguous, push to cache the last value in a computation chain + // This should be considered in a cost for the max flow if (mp.size() == 1 && !Required.count(V)) { bool potentiallyRecursive = isa((*mp.begin()).V) && @@ -967,24 +983,11 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, auto nnode = (*mp.begin()).V; MinReq.insert(nnode); if (Orig.find(Node(nnode, true)) != Orig.end()) - todo.push_back(nnode); + todo.insert(nnode); } } } - // Fix up non-cacheable calls to use their operand(s) instead - for (auto V : Intermediates) { - if (!hasNoCache(V)) - continue; - if (!MinReq.count(V)) - continue; - MinReq.remove(V); - for (auto &pair : Orig) { - if (pair.second.count(Node(V, false))) { - MinReq.insert(pair.first.V); - } - } - } // Fix up non-repeatable writing calls that chain within rematerialized // allocations. We could iterate from the keys of the valuemap, but that would diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index b8cb66ef9d4d..1b2ed12781cd 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -8177,6 +8177,9 @@ void GradientUtils::computeMinCache() { notForAnalysis); if (oneneed) { knownRecomputeHeuristic[&I] = false; + + CountTrackedPointers T(I.getType()); + assert(!T.derived); } else Recomputes.insert(&I); } @@ -8267,8 +8270,14 @@ void GradientUtils::computeMinCache() { assert(legalRecompute(V, Available2, nullptr)); } if (!NeedGraph.count(V)) { + assert(!MinReq.count(V)); unnecessaryIntermediates.insert(cast(V)); } + + if (NeedGraph.count(V) && MinReq.count(V)) { + CountTrackedPointers T(V->getType()); + assert(!T.derived); + } } } } diff --git a/enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll b/enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll index 5a6a197be793..8e748e9c5907 100644 --- a/enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll +++ b/enzyme/test/Enzyme/ForwardMode/memcpyanyflt.ll @@ -26,9 +26,7 @@ attributes #0 = { argmemonly nounwind } ; CHECK: define internal void @fwddiffememcpy_ptr(i8* nocapture %dst, i8* nocapture %"dst'", i8* nocapture readonly %src, i8* nocapture %"src'", i64 %num) ; CHECK-NEXT: entry: -; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"dst'", i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 8, i1 false) -; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"dst'", i64 8 -; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %0, i8 0, i64 8, i1 true) +; CHECK-NEXT: tail call void @llvm.memset.p0i8.i64(i8* align 1 %"dst'", i8 0, i64 16, i1 true) ; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %dst, i8* bitcast ({ i64, double }* @_j_const2 to i8*), i64 16, i1 false) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/nullcp.ll b/enzyme/test/Enzyme/ReverseMode/nullcp.ll index 9fee889e9a25..1e222534ebb1 100644 --- a/enzyme/test/Enzyme/ReverseMode/nullcp.ll +++ b/enzyme/test/Enzyme/ReverseMode/nullcp.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -enzyme-runtime-error -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -enzyme-runtime-error -S | FileCheck %s source_filename = "nullcp.c" target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" @@ -130,6 +130,8 @@ attributes #6 = { nounwind } ; CHECK-NEXT: %0 = bitcast double* %dst to i8* ; CHECK-NEXT: %1 = bitcast double* %src to i8* ; CHECK-NEXT: %mul = shl i64 %n, 3 +; CHECK-NEXT: %2 = call i32 @puts(i8* getelementptr inbounds ([126 x i8], [126 x i8]* @.str.1, i32 0, i32 0)) +; CHECK-NEXT: call void @exit(i32 1) ; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %"'ipc", i8* align 8 %1, i64 %mul, i1 false) ; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %0, i8* align 8 %1, i64 %mul, i1 false) ; CHECK-NEXT: br label %invertif.end