Skip to content

Commit

Permalink
Add more conversion types for float8.
Browse files Browse the repository at this point in the history
These are needed for numpy dtype registration, allowing conversion
directly from float8 to integer and boolean types.  By default any
arithmetic type will go through float32.

PiperOrigin-RevId: 486242589
  • Loading branch information
cantonios authored and tensorflower-gardener committed Nov 4, 2022
1 parent e8cbb19 commit a7ff2f6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tensorflow/tsl/platform/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,21 @@ class float8_e4m3 : public float8_base<float8_e4m3> {

public:
float8_e4m3() : Base() {}

template <typename T,
typename EnableIf = std::enable_if<std::is_arithmetic<T>::value> >
explicit float8_e4m3(T f) : float8_e4m3(ConvertFrom(static_cast<float>(f))) {}
explicit float8_e4m3(double f64) : float8_e4m3(ConvertFrom(f64)) {}
explicit float8_e4m3(float f32) : float8_e4m3(ConvertFrom(f32)) {}
explicit float8_e4m3(Eigen::bfloat16 bf16) : float8_e4m3(ConvertFrom(bf16)) {}
explicit float8_e4m3(Eigen::half f16) : float8_e4m3(ConvertFrom(f16)) {}
explicit float8_e4m3(const float8_e5m2& f8) : float8_e4m3(ConvertFrom(f8)) {}

template <typename T,
typename EnableIf = std::enable_if<std::is_arithmetic<T>::value> >
explicit operator T() const {
return static_cast<T>(static_cast<float>(*this));
}
explicit operator double() const { return ConvertTo<double>(*this); }
explicit operator float() const { return ConvertTo<float>(*this); }
explicit operator Eigen::bfloat16() const {
Expand All @@ -171,6 +180,7 @@ class float8_e4m3 : public float8_base<float8_e4m3> {
explicit operator Eigen::half() const {
return ConvertTo<Eigen::half>(*this);
}
explicit operator bool() const { return (rep() & 0x7F) != 0; }
};

class float8_e5m2 : public float8_base<float8_e5m2> {
Expand All @@ -185,12 +195,21 @@ class float8_e5m2 : public float8_base<float8_e5m2> {

public:
float8_e5m2() : Base() {}

template <typename T,
typename EnableIf = std::enable_if<std::is_arithmetic<T>::value> >
explicit float8_e5m2(T f) : float8_e5m2(ConvertFrom(static_cast<float>(f))) {}
explicit float8_e5m2(double f64) : float8_e5m2(ConvertFrom(f64)) {}
explicit float8_e5m2(float f32) : float8_e5m2(ConvertFrom(f32)) {}
explicit float8_e5m2(Eigen::bfloat16 bf16) : float8_e5m2(ConvertFrom(bf16)) {}
explicit float8_e5m2(Eigen::half f16) : float8_e5m2(ConvertFrom(f16)) {}
explicit float8_e5m2(float8_e4m3 f8) : float8_e5m2(ConvertFrom(f8)) {}

template <typename T,
typename EnableIf = std::enable_if<std::is_arithmetic<T>::value> >
explicit operator T() const {
return static_cast<T>(static_cast<float>(*this));
}
explicit operator double() const { return ConvertTo<double>(*this); }
explicit operator float() const { return ConvertTo<float>(*this); }
explicit operator Eigen::bfloat16() const {
Expand All @@ -199,6 +218,7 @@ class float8_e5m2 : public float8_base<float8_e5m2> {
explicit operator Eigen::half() const {
return ConvertTo<Eigen::half>(*this);
}
explicit operator bool() const { return (rep() & 0x7F) != 0; }
};

// Structures for use in specializing std::numeric_limits.
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/tsl/platform/float8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cmath>
#include <limits>
#include <string>
#include <utility>

#include "tensorflow/tsl/platform/test.h"

Expand Down Expand Up @@ -381,5 +382,42 @@ TYPED_TEST(Float8Test, CallTheOperator) {
}
}

// Helper utility for prettier test names.
struct Float8CastTestParamNames {
template <typename TypeParam>
static std::string GetName(int idx) {
using first_type = typename TypeParam::first_type;
using second_type = typename TypeParam::second_type;
return absl::StrCat(::testing::internal::GetTypeName<first_type>(), "_",
::testing::internal::GetTypeName<second_type>());
}
};

using Float8CastTypePairs = ::testing::Types<
std::pair<float8_e4m3, long double>, std::pair<float8_e4m3, float>,
std::pair<float8_e4m3, Eigen::bfloat16>,
std::pair<float8_e4m3, Eigen::half>, std::pair<float8_e4m3, bool>,
std::pair<float8_e4m3, int32_t>, std::pair<float8_e4m3, int64_t>,
std::pair<float8_e5m2, long double>, std::pair<float8_e5m2, float>,
std::pair<float8_e5m2, Eigen::bfloat16>,
std::pair<float8_e5m2, Eigen::half>, std::pair<float8_e5m2, bool>,
std::pair<float8_e5m2, int32_t>, std::pair<float8_e5m2, int64_t> >;

template <typename CastPair>
class Float8CastTest : public ::testing::Test {};
TYPED_TEST_SUITE(Float8CastTest, Float8CastTypePairs, Float8CastTestParamNames);

TYPED_TEST(Float8CastTest, CastThroughFloat) {
using Float8 = typename TypeParam::first_type;
using DestType = typename TypeParam::second_type;

for (int i = 0x00; i <= 0xFF; ++i) {
Float8 f8 = Float8::FromRep(i);
DestType dest = static_cast<DestType>(f8);
DestType expected = static_cast<DestType>(static_cast<float>(f8));
EXPECT_THAT(dest, EqOrIsNaN(expected));
}
}

} // namespace
} // namespace tsl

0 comments on commit a7ff2f6

Please sign in to comment.