Skip to content

Commit

Permalink
Add float8 operators.
Browse files Browse the repository at this point in the history
These are generally useful, and are necessary for basic numeric type
registration in Python.  Most operators go through float32.

PiperOrigin-RevId: 486225906
  • Loading branch information
cantonios authored and tensorflower-gardener committed Nov 4, 2022
1 parent c503bc5 commit 5e418c3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
78 changes: 71 additions & 7 deletions tensorflow/tsl/platform/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ class float8_base {
constexpr uint8_t rep() const { return rep_; }

constexpr Derived operator-() const {
return Derived(static_cast<uint8_t>(rep_ ^ 0x80), ConstructFromRepTag{});
return Derived(static_cast<uint8_t>(rep() ^ 0x80), ConstructFromRepTag{});
}

constexpr bool operator==(const Derived& other) const {
if (Eigen::numext::isnan(derived())) {
return false;
} else if ((rep() & 0x7F) == 0) {
return (other.rep() & 0x7F) == 0;
}
return rep() == other.rep();
}

constexpr bool operator!=(const Derived& other) const {
return !(derived() == other);
}

constexpr const Derived& derived() const {
return *static_cast<const Derived*>(this);
}
Expand All @@ -71,6 +77,70 @@ class float8_base {
template <typename To, bool kSaturate = false, bool kTruncate = false>
static EIGEN_DEVICE_FUNC To ConvertTo(const Derived& from);

// Operators via float32.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
operator+(const Derived& other) const {
return Derived{float{derived()} + float{other}};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
operator-(const Derived& other) const {
return Derived{float{derived()} - float{other}};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
operator*(const Derived& other) const {
return Derived{float{derived()} * float{other}};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
operator/(const Derived& other) const {
return Derived{float{derived()} / float{other}};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(
const Derived& other) const {
return float{derived()} < float{other};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(
const Derived& other) const {
return float{derived()} <= float{other};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const Derived& other) {
return float{derived()} > float{other};
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const Derived& other) {
return float{derived()} >= float{other};
}

// Compound assignment.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator+=(
const Derived& other) {
derived() = derived() + other;
return derived();
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator-=(
const Derived& other) {
derived() = derived() - other;
return derived();
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator*=(
const Derived& other) {
derived() = derived() * other;
return derived();
}

EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator/=(
const Derived& other) {
derived() = derived() / other;
return derived();
}

private:
uint8_t rep_;
};
Expand Down Expand Up @@ -101,9 +171,6 @@ class float8_e4m3 : public float8_base<float8_e4m3> {
explicit operator Eigen::half() const {
return ConvertTo<Eigen::half>(*this);
}

using Base::operator==;
using Base::operator-;
};

class float8_e5m2 : public float8_base<float8_e5m2> {
Expand Down Expand Up @@ -132,9 +199,6 @@ class float8_e5m2 : public float8_base<float8_e5m2> {
explicit operator Eigen::half() const {
return ConvertTo<Eigen::half>(*this);
}

using Base::operator==;
using Base::operator-;
};

// Structures for use in specializing std::numeric_limits.
Expand Down
39 changes: 39 additions & 0 deletions tensorflow/tsl/platform/float8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,44 @@ TEST(Float8Test, Half_To_Float8E5m2) {
0xBF);
}

using ::testing::Eq;
using ::testing::IsTrue;
MATCHER_P(EqOrIsNaN, other, "") {
if (Eigen::numext::isnan(other)) {
return ExplainMatchResult(IsTrue(), Eigen::numext::isnan(arg),
result_listener);
}
return ExplainMatchResult(Eq(other), arg, result_listener);
}

TYPED_TEST(Float8Test, CallTheOperator) {
using Float8 = TypeParam;

for (int i = 0x00; i <= 0xFF; ++i) {
Float8 a = Float8::FromRep(i);
for (int j = 0x00; j <= 0xFF; ++j) {
Float8 b = Float8::FromRep(j);

EXPECT_THAT(a + b, EqOrIsNaN(Float8{float{a} + float{b}}));
EXPECT_THAT(a - b, EqOrIsNaN(Float8{float{a} - float{b}}));
EXPECT_THAT(a * b, EqOrIsNaN(Float8{float{a} * float{b}}));
EXPECT_THAT(a / b, EqOrIsNaN(Float8{float{a} / float{b}}));

Float8 c;
EXPECT_THAT((c = a, c += b), EqOrIsNaN(Float8{float{a} + float{b}}));
EXPECT_THAT((c = a, c -= b), EqOrIsNaN(Float8{float{a} - float{b}}));
EXPECT_THAT((c = a, c *= b), EqOrIsNaN(Float8{float{a} * float{b}}));
EXPECT_THAT((c = a, c /= b), EqOrIsNaN(Float8{float{a} / float{b}}));

EXPECT_EQ(a == b, float{a} == float{b}) << float{a} << " vs " << float{b};
EXPECT_EQ(a != b, float{a} != float{b});
EXPECT_EQ(a < b, float{a} < float{b});
EXPECT_EQ(a <= b, float{a} <= float{b});
EXPECT_EQ(a > b, float{a} > float{b});
EXPECT_EQ(a >= b, float{a} >= float{b});
}
}
}

} // namespace
} // namespace tsl

0 comments on commit 5e418c3

Please sign in to comment.