Skip to content

Commit

Permalink
Improve error on inserted phi scev (#2185)
Browse files Browse the repository at this point in the history
* Improve error on inserted phi scev

* fix

* more

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Dec 4, 2024
1 parent 06367e4 commit 57b718b
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7137,7 +7137,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
{
SCEVExpander OrigExp(
*OrigSE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
"enzyme", /*PreserveLCSSA = */ false);

OrigExp.setInsertPoint(
isOriginal(l1.header)->getTerminator());
Expand All @@ -7160,22 +7160,45 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
return OrigDT->dominates(A, B);
});
for (auto a : InsertedInstructions) {
assert(!isa<PHINode>(a));
auto uw = cast<Instruction>(
if (isa<PHINode>(a)) {
std::string str;
raw_string_ostream ss(str);
ss << "oldFunc: " << *oldFunc << "\n";
ss << "newFunc: " << *newFunc << "\n";
ss << "li: " << *li << "\n";
ss << "start0: " << *start0 << "\n";
ss << "Inserted a phi node (" << *a
<< ") during unwrap of SCEV: " << *ar1->getStart()
<< "\n";
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(li),
ErrorType::InternalError, nullptr,
nullptr, nullptr);
} else {
EmitFailure("InsertedPHISCEV", li->getDebugLoc(), li,
ss.str());
}
}
auto uwV =
unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap,
/*scope*/ nullptr, /*cache*/ false));
assert(uw->getType() == a->getType());
/*scope*/ nullptr, /*cache*/ false);
auto uw = dyn_cast<Instruction>(uwV);
assert(uwV->getType() == a->getType());
#ifndef NDEBUG
for (size_t i = 0; i < uw->getNumOperands(); i++) {
auto op = uw->getOperand(i);
if (auto arg = dyn_cast<Argument>(op))
assert(arg->getParent() == newFunc);
else if (auto inst = dyn_cast<Instruction>(op))
assert(inst->getParent()->getParent() == newFunc);
if (uw) {
for (size_t i = 0; i < uw->getNumOperands(); i++) {
auto op = uw->getOperand(i);
if (auto arg = dyn_cast<Argument>(op))
assert(arg->getParent() == newFunc);
else if (auto inst = dyn_cast<Instruction>(op))
assert(inst->getParent()->getParent() == newFunc);
}
assert(uw->getParent()->getParent() == newFunc);
}
#endif
available[a] = uw;
unwrappedLoads.erase(cast<Instruction>(uw));
available[a] = uwV;
if (uw)
unwrappedLoads.erase(uw);
}

start =
Expand Down

0 comments on commit 57b718b

Please sign in to comment.