From 07e1826f2bc1090c3d8bc5bd30b9e2797e65a6dc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 23:22:48 -0600 Subject: [PATCH] More complex specialfunctions (#2190) --- enzyme/Enzyme/InstructionDerivatives.td | 31 ++++++++++++-------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 5 +++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 2973c4b84cd..c778e63aa25 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -4,6 +4,15 @@ class ForwardFromSummedReverseInternal { } def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} +class ConstantFP : Operation { + string value = val; +} + class Attribute { string name = name_; @@ -46,11 +55,6 @@ class InstPattern { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} class Inst : Operation { string name = mnemonic; } @@ -175,6 +179,12 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $im) )>; +def Complex : SubRoutine<(Op $x), + (ArrayRet + $x, + (ConstantFP<"0"> $x) + )>; + def Conj : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet $re, @@ -272,9 +282,6 @@ def MantissaMaskOfReturnForFrexp : GlobalExpr; -class ConstantFP : Operation { - string value = val; -} def Zero : Operation { } class ConstantCFP : Operation { @@ -671,14 +678,14 @@ def : CallPattern<(Op $n, $x), [ReadNone, NoUnwind] >; -def : CallPattern<(Op $n, $x), +def : CallPattern<(Op $n, $z), ["cmplx_jn","cmplx_yn"], [ - (InactiveArg), + (AssertingInactiveArg), // Reverse mode needs to return the conjugate - (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x))))) + (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z))))) ], - (CFMul (Shadow $x), (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))), + (CFMul (Shadow $z), (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z)))), [ReadNone, NoUnwind] >; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 2da86c18e5e..900c5c813cd 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -695,7 +695,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << rvalue->getValue() << "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" << ivalue->getValue() << "\")});\n"; - os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n"; + os << curIndent << INDENT << "} else {\n"; + os << curIndent << INDENT << " llvm::errs() << *ty << \"\\n\";\n"; + os << curIndent << INDENT << " assert(0 && \"unhandled cfp\");\n"; + os << curIndent << INDENT << "}\n"; os << curIndent << INDENT << "ret;\n"; os << curIndent << "})\n"; return false;