From 5cf1bd26019a438f7d8b0aba84ae78d2da914e3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 5 Aug 2024 21:10:18 -0400 Subject: [PATCH] Don't remove unreachable reverse (#2032) --- enzyme/Enzyme/EnzymeLogic.cpp | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 155e8bc4632..feb39de42c7 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4319,13 +4319,15 @@ Function *EnzymeLogic::CreatePrimalAndGradient( if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { auto newBB = cast(gutils->getNewFromOriginal(&oBB)); SmallVector toRemove; - if (auto II = dyn_cast(oBB.getTerminator())) { - toRemove.push_back( - cast(gutils->getNewFromOriginal(II->getNormalDest()))); - } else { - for (auto next : successors(&oBB)) { - auto sucBB = cast(gutils->getNewFromOriginal(next)); - toRemove.push_back(sucBB); + if (key.mode != DerivativeMode::ReverseModeCombined) { + if (auto II = dyn_cast(oBB.getTerminator())) { + toRemove.push_back(cast( + gutils->getNewFromOriginal(II->getNormalDest()))); + } else { + for (auto next : successors(&oBB)) { + auto sucBB = cast(gutils->getNewFromOriginal(next)); + toRemove.push_back(sucBB); + } } } @@ -4354,11 +4356,13 @@ Function *EnzymeLogic::CreatePrimalAndGradient( /*check*/ key.mode == DerivativeMode::ReverseModeCombined); } - if (newBB->getTerminator()) - gutils->erase(newBB->getTerminator()); - IRBuilder<> builder(newBB); - builder.CreateUnreachable(); + if (key.mode != DerivativeMode::ReverseModeCombined) { + if (newBB->getTerminator()) + gutils->erase(newBB->getTerminator()); + IRBuilder<> builder(newBB); + builder.CreateUnreachable(); + } continue; }