Skip to content

Commit

Permalink
Add bfloat16 (#2033)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 6, 2024
1 parent 5cf1bd2 commit 87eed6b
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) {
return ConcreteType(llvm::Type::getDoubleTy(ctx));
case DT_X86_FP80:
return ConcreteType(llvm::Type::getX86_FP80Ty(ctx));
case DT_BFloat16:
return ConcreteType(llvm::Type::getBFloatTy(ctx));
case DT_Unknown:
return BaseType::Unknown;
}
Expand Down Expand Up @@ -131,6 +133,8 @@ CConcreteType ewrap(const ConcreteType &CT) {
return DT_Double;
if (flt->isX86_FP80Ty())
return DT_X86_FP80;
if (flt->isBFloatTy())
return DT_BFloat16;
} else {
switch (CT.SubTypeEnum) {
case BaseType::Integer:
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ typedef enum {
DT_Double = 5,
DT_Unknown = 6,
DT_X86_FP80 = 7,
DT_BFloat16 = 8,
} CConcreteType;

struct CDataPair {
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def MantissaMaskOfReturnForFrexp : GlobalExpr</*primal*/0, /*shadow*/0, [{
// x86_fp80 has only 15 exponent bits, but we must also
// retain the most-significant bit of the mantissa as
// there is no implicit leading
} else if (ty->isBFloatTy()) {
tsize = 16;
high = tsize - 1;
low = high - 8;
} else if (ty->isFP128Ty()) {
tsize = 128;
high = tsize - 1;
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/ConcreteType.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class ConcreteType {
SubType = llvm::Type::getDoubleTy(C);
} else if (SubName == "fp80") {
SubType = llvm::Type::getX86_FP80Ty(C);
} else if (SubName == "bf16") {
SubType = llvm::Type::getBFloatTy(C);
} else if (SubName == "fp128") {
SubType = llvm::Type::getFP128Ty(C);
} else if (SubName == "ppc128") {
Expand All @@ -104,6 +106,8 @@ class ConcreteType {
Result += "@double";
} else if (SubType->isX86_FP80Ty()) {
Result += "@fp80";
} else if (SubType->isBFloatTy()) {
Result += "@bf16";
} else if (SubType->isFP128Ty()) {
Result += "@fp128";
} else if (SubType->isPPC_FP128Ty()) {
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ static inline std::string tofltstr(Type *T) {
return "double";
case Type::X86_FP80TyID:
return "x87d";
case Type::BFloatTyID:
return "bf16";
case Type::FP128TyID:
return "quad";
case Type::PPC_FP128TyID:
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ static inline llvm::Type *FloatToIntTy(llvm::Type *T) {
}
if (T->isHalfTy())
return llvm::IntegerType::get(T->getContext(), 16);
if (T->isBFloatTy())
return llvm::IntegerType::get(T->getContext(), 16);
if (T->isFloatTy())
return llvm::IntegerType::get(T->getContext(), 32);
if (T->isDoubleTy())
Expand All @@ -605,6 +607,7 @@ static inline llvm::Type *IntToFloatTy(llvm::Type *T) {
switch (ty->getBitWidth()) {
case 16:
return llvm::Type::getHalfTy(T->getContext());
// return llvm::Type::getBFloat16Ty(T->getContext());
case 32:
return llvm::Type::getFloatTy(T->getContext());
case 64:
Expand Down

0 comments on commit 87eed6b

Please sign in to comment.