Skip to content

Commit

Permalink
Simplify anyFloat usage in ActivityAnalysis (#2155)
Browse files Browse the repository at this point in the history
* Simplify anyFloat usage in ActivityAnalysis

* Update DiffeGradientUtils.cpp
  • Loading branch information
wsmoses authored Nov 5, 2024
1 parent 7b27adb commit 9c154c5
Showing 1 changed file with 6 additions and 34 deletions.
40 changes: 6 additions & 34 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,30 +708,6 @@ static inline void propagateArgumentInformation(
}
}

bool isPossibleFloat(const TypeResults &TR, Value *I, const DataLayout &DL) {
bool possibleFloat = false;
if (!I->getType()->isVoidTy()) {
auto Size = (DL.getTypeSizeInBits(I->getType()) + 7) / 8;
auto vd = TR.query(I);
auto ct0 = vd[{-1}];
if (ct0.isPossibleFloat() && ct0 != BaseType::Anything) {
for (unsigned i = 0; i < Size;) {
auto ct = vd[{(int)i}];
if (ct.isPossibleFloat() && ct != BaseType::Anything) {
possibleFloat = true;
break;
}
size_t chunk = 1;
// Implicit pointer
if (ct == BaseType::Pointer)
chunk = DL.getPointerSizeInBits() / 8;
i += chunk;
}
}
}
return possibleFloat;
}

/// Return whether this instruction is known not to propagate adjoints
/// Note that instructions could return an active pointer, but
/// do not propagate adjoints themselves
Expand Down Expand Up @@ -921,8 +897,7 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR,
}
}
if (noActiveWrite) {
auto &DL = I->getParent()->getParent()->getParent()->getDataLayout();
bool possibleFloat = isPossibleFloat(TR, I, DL);
bool possibleFloat = TR.anyFloat(I);
// Even if returning a pointer, this instruction is considered inactive
// since the instruction doesn't prop gradients. Thus, so long as we don't
// return an object containing a float, this instruction is inactive
Expand Down Expand Up @@ -2438,10 +2413,9 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,

if (!considerValue) {
if (auto IEI = dyn_cast<InsertElementInst>(inst)) {
auto &DL = IEI->getParent()->getParent()->getParent()->getDataLayout();
if ((!isPossibleFloat(TR, IEI->getOperand(0), DL) ||
if ((!TR.anyFloat(IEI->getOperand(0)) ||
isConstantValue(TR, IEI->getOperand(0))) &&
(!isPossibleFloat(TR, IEI->getOperand(1), DL) ||
(!TR.anyFloat(IEI->getOperand(1)) ||
isConstantValue(TR, IEI->getOperand(1)))) {
if (EnzymePrintActivity)
llvm::errs()
Expand All @@ -2451,10 +2425,9 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
}
}
if (auto IEI = dyn_cast<InsertValueInst>(inst)) {
auto &DL = IEI->getParent()->getParent()->getParent()->getDataLayout();
if ((!isPossibleFloat(TR, IEI->getAggregateOperand(), DL) ||
if ((!TR.anyFloat(IEI->getAggregateOperand()) ||
isConstantValue(TR, IEI->getAggregateOperand())) &&
(!isPossibleFloat(TR, IEI->getInsertedValueOperand(), DL) ||
(!TR.anyFloat(IEI->getInsertedValueOperand()) ||
isConstantValue(TR, IEI->getInsertedValueOperand()))) {
if (EnzymePrintActivity)
llvm::errs()
Expand Down Expand Up @@ -2482,9 +2455,8 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
}
}
bool legal = true;
auto &DL = PN->getParent()->getParent()->getParent()->getDataLayout();
for (auto V : incoming) {
if (isPossibleFloat(TR, V, DL) && !isConstantValue(TR, V)) {
if (TR.anyFloat(V) && !isConstantValue(TR, V)) {
legal = false;
break;
}
Expand Down

0 comments on commit 9c154c5

Please sign in to comment.