diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index eba8ac19193..1a02e39535b 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -101,6 +101,7 @@ def is_diag_int : IntMatchers< def is_left : MagicInst; def is_lower : MagicInst; +def is_zero : MagicInst; def is_nonunit : MagicInst; def First : MagicInst; @@ -757,7 +758,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info), (BlasCall<"trsm"> $layout, (uplo_to_side $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"ztri">, $n) , - (For<"i", 0> (Sub $n, ConstantInt<1>), + (For<"i", 0> (ISelect (is_zero $n), ConstantInt<0>, (Sub $n, ConstantInt<1>)), (BlasCall<"axpy"> (Sub (Sub $n, ConstantInt<1>), $i), Constant<"1.0">, diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 9c28664a73f..ad2cc72bb11 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -928,6 +928,20 @@ void rev_call_arg(bool forward, const DagInit *ruleDag, << ", cache_" << matName << ", byRef, cublas)}"; return; } + if (Def->getName() == "is_zero") { + if (Dag->getNumArgs() != 1) + PrintFatalError(pattern.getLoc(), "only 1-arg ld operands supported"); + const auto name = Dag->getArgNameStr(0); + os << " ({ auto V = load_if_ref(Builder2, intType, arg_" << name + << ", byRef);\n"; + os << " SmallVector vs = {to_blas_callconv(Builder2, " + "Builder2.CreateICmpEQ(V, ConstantInt::get(V->getType(), 0)), " + "byRef, cublas, julia_decl_type, allocationBuilder, " + "\"is_zero\")};\n"; + os << " vs; })"; + return; + } + if (Def->getName() == "is_left") { if (Dag->getNumArgs() != 1) PrintFatalError(pattern.getLoc(), "only 1-arg ld operands supported");