Skip to content

Commit

Permalink
More complex specialfunctions (#2190)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 6, 2024
1 parent 31efe08 commit 07e1826
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
31 changes: 19 additions & 12 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ class ForwardFromSummedReverseInternal<int unused_> {
}
def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>;

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}
class ConstantFP<string val> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
}


class Attribute<string name_> {
string name = name_;
Expand Down Expand Up @@ -46,11 +55,6 @@ class InstPattern<dag patternToMatch, string funcName, int minVer_, int maxVer_,
dag ArgDuals = forwardOps;
}

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}
class Inst<string mnemonic> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
}
Expand Down Expand Up @@ -175,6 +179,12 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z),
(FNeg $im)
)>;

def Complex : SubRoutine<(Op $x),
(ArrayRet
$x,
(ConstantFP<"0"> $x)
)>;

def Conj : SubRoutine<(Op (Op $re, $im):$z),
(ArrayRet
$re,
Expand Down Expand Up @@ -272,9 +282,6 @@ def MantissaMaskOfReturnForFrexp : GlobalExpr</*primal*/0, /*shadow*/0, [{
ConstantInt::get(ity, eval);
}]>;

class ConstantFP<string val> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
}
def Zero : Operation</*primal*/0, /*shadow*/0> {
}
class ConstantCFP<string rval, string ival> : Operation</*primal*/0, /*shadow*/0> {
Expand Down Expand Up @@ -671,14 +678,14 @@ def : CallPattern<(Op $n, $x),
[ReadNone, NoUnwind]
>;

def : CallPattern<(Op $n, $x),
def : CallPattern<(Op $n, $z),
["cmplx_jn","cmplx_yn"],
[
(InactiveArg),
(AssertingInactiveArg),
// Reverse mode needs to return the conjugate
(CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))))
(CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z)))))
],
(CFMul (Shadow $x), (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))),
(CFMul (Shadow $z), (CFMul (ConstantCFP<"0.5", "0"> $z), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $z), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $z)))),
[ReadNone, NoUnwind]
>;

Expand Down
5 changes: 4 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
<< rvalue->getValue()
<< "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \""
<< ivalue->getValue() << "\")});\n";
os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n";
os << curIndent << INDENT << "} else {\n";
os << curIndent << INDENT << " llvm::errs() << *ty << \"\\n\";\n";
os << curIndent << INDENT << " assert(0 && \"unhandled cfp\");\n";
os << curIndent << INDENT << "}\n";
os << curIndent << INDENT << "ret;\n";
os << curIndent << "})\n";
return false;
Expand Down

0 comments on commit 07e1826

Please sign in to comment.