Skip to content

Commit

Permalink
perf(math): add dedicated pow 3, 5, 7 operations for Packed(Baby|Koal…
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed Oct 31, 2024
1 parent aaf5896 commit 2e0f7f0
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 18 deletions.
45 changes: 34 additions & 11 deletions tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@
std::declval<T>().Name##InPlace()))> \
: std::true_type {}

#define SUPPORTS_DEDICATED_EXP_OPERATOR(Pow) \
template <typename T, typename = void> \
struct SupportsExp##Pow : std::false_type {}; \
\
template <typename T> \
struct SupportsExp##Pow<T, decltype(void(std::declval<T>().Exp##Pow()))> \
: std::true_type {};

namespace tachyon::math {
namespace internal {

SUPPORTS_BINARY_OPERATOR(Mul);
SUPPORTS_UNARY_OPERATOR(SquareImpl);
SUPPORTS_BINARY_OPERATOR(Add);
SUPPORTS_UNARY_OPERATOR(DoubleImpl);
SUPPORTS_DEDICATED_EXP_OPERATOR(3);
SUPPORTS_DEDICATED_EXP_OPERATOR(5);
SUPPORTS_DEDICATED_EXP_OPERATOR(7);

template <typename T, typename = void>
struct SupportsSize : std::false_type {};
Expand Down Expand Up @@ -157,24 +168,36 @@ class MultiplicativeSemigroup {
return g;
else if constexpr (Power == 2)
return Square();
else if constexpr (Power == 3)
return Square() * g;
else if constexpr (Power == 4)
else if constexpr (Power == 3) {
if constexpr (internal::SupportsExp3<G>::value) {
return g.Exp3();
} else {
return Square() * g;
}
} else if constexpr (Power == 4) {
return Square().Square();
else if constexpr (Power == 5) {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
} else if constexpr (Power == 5) {
if constexpr (internal::SupportsExp5<G>::value) {
return g.Exp5();
} else {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
}
} else if constexpr (Power == 6) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2;
} else if constexpr (Power == 7) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
if constexpr (internal::SupportsExp7<G>::value) {
return g.Exp7();
} else {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
}
} else {
return DoPow(BigInt<1>(Power));
}
Expand Down
5 changes: 4 additions & 1 deletion tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ tachyon_cc_library(
tachyon_cc_library(
name = "packed_prime_field32_avx2",
hdrs = ["packed_prime_field32_avx2.h"],
deps = ["//tachyon/base:compiler_specific"],
deps = [
"//tachyon/base:compiler_specific",
"//tachyon/base/functional:callback",
],
)

tachyon_cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,16 @@ PackedBabyBearAVX2 PackedBabyBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedBabyBearAVX2 final

// MultiplicativeSemigroup methods
PackedBabyBearAVX2 Mul(const PackedBabyBearAVX2& other) const;

PackedBabyBearAVX2 Exp3() const;
PackedBabyBearAVX2 Exp5() const;
PackedBabyBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,16 @@ PackedKoalaBearAVX2 PackedKoalaBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedKoalaBearAVX2 final

// MultiplicativeSemigroup methods
PackedKoalaBearAVX2 Mul(const PackedKoalaBearAVX2& other) const;

PackedKoalaBearAVX2 Exp3() const;
PackedKoalaBearAVX2 Exp5() const;
PackedKoalaBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
138 changes: 132 additions & 6 deletions tachyon/math/finite_fields/packed_prime_field32_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <immintrin.h>

#include "tachyon/base/compiler_specific.h"
#include "tachyon/base/functional/callback.h"

namespace tachyon::math {

Expand Down Expand Up @@ -171,11 +172,59 @@ ALWAYS_INLINE __m256i NegateMod32(__m256i val, __m256i p) {
//
// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann,
// Cambridge University Press, 2010, algorithm 2.7.
ALWAYS_INLINE __m256i MontyD(__m256i lhs, __m256i rhs, __m256i p, __m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
__m256i q = _mm256_mul_epu32(prod, inv);

// We provide 2 variants of Montgomery reduction depending on if the inputs are
// unsigned or signed. The unsigned variant follows steps 1 and 2 in the above
// protocol to produce D in (-P, ..., P). For the signed variant we assume -PB/2
// < C < PB/2 and let Q := μ C mod B be the unique representative in [-B/2, ...,
// B/2 - 1]. The division in step 2 is clearly still exact and |C - Q P| <= |C|
// + |Q||P| < PB so D still lies in (-P, ..., P).

// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedUnsignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epu32(input, inv);
__m256i q_p = _mm256_mul_epu32(q, p);
return _mm256_sub_epi64(prod, q_p);
// By construction, the bottom 32 bits of input and q_p are equal.
// Thus |_mm256_sub_epi32| and |_mm256_sub_epi64| should act identically.
// However for some reason, the compiler gets confused if we use
// |_mm256_sub_epi64| and outputs a load of nonsense, see:
// https://godbolt.org/z/3W8M7Tv84.
return _mm256_sub_epi32(input, q_p);
}
// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedSignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epi32(input, inv);
__m256i q_p = _mm256_mul_epi32(q, p);
// Unlike the previous case the compiler output is essentially identical
// between |_mm256_sub_epi32| and |_mm256_sub_epi64|. We use
// |_mm256_sub_epi32| again just for consistency.
return _mm256_sub_epi32(input, q_p);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be unsigned 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMul(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be signed 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMulSigned(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epi32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

ALWAYS_INLINE __m256i movehdup_epi32(__m256i x) {
Expand Down Expand Up @@ -210,8 +259,8 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
__m256i lhs_odd = movehdup_epi32(lhs);
__m256i rhs_odd = movehdup_epi32(rhs);

__m256i d_evn = MontyD(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyD(lhs_odd, rhs_odd, p, inv);
__m256i d_evn = MontyMul(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyMul(lhs_odd, rhs_odd, p, inv);

__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
Expand All @@ -220,6 +269,83 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
return _mm256_min_epu32(t, u);
}

// Square the field elements in the even index entries.
// Inputs must be signed 32-bit integers.
// Outputs will be a signed integer in (-P, ..., P) copied into both the even
// and odd indices.
ALWAYS_INLINE __m256i ShiftedSquare(__m256i input, __m256i p, __m256i inv) {
// Note that we do not need a restriction on the size of |input[i]²| as
// 2³⁰ < P and |i32| <= 2³¹ and so => |input[i]²| <= 2⁶² < 2³²P.
__m256i square = _mm256_mul_epi32(input, input);
__m256i square_red = PartialMontyRedSignedToSigned(square, p, inv);
return movehdup_epi32(square_red);
}

// Apply callback to the even and odd indices of the input vector.
// callback should only depend in the 32 bit entries in the even indices.
// The output of callback must lie in (-P, ..., P) and be stored in the odd
// indices. The even indices of the output of callback will not be read. The
// input should conform to the requirements of |callback|.
// NOTE(chokobole): This is to suppress the error below.
// clang-format off
// error: ignoring attributes on template argument '__m256i(__m256i, __m256i, __m256i)' [-Werror=ignored-attributes]
// clang-format on
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
ALWAYS_INLINE __m256i ApplyFuncToEvenOdd(
__m256i input, __m256i p, __m256i inv,
base::RepeatingCallback<__m256i(__m256i, __m256i, __m256i)> callback) {
__m256i input_evn = input;
__m256i input_odd = movehdup_epi32(input);
__m256i d_evn = callback.Run(input_evn, p, inv);
__m256i d_odd = callback.Run(input_odd, p, inv);
__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
__m256i u = _mm256_add_epi32(t, p);
return _mm256_min_epu32(t, u);
}
#pragma GCC diagnostic pop

// Cube the field elements in the even index entries.
// Inputs must be signed 32-bit integers in [-P, ..., P].
// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp3(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
return MontyMulSigned(square, input, p, inv);
}

ALWAYS_INLINE __m256i Exp3(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp3);
}

// Take the fifth power of the field elements in the even index
// entries. Inputs must be signed 32-bit integers in [-P, ..., P]. Outputs will
// be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp5(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, input, p, inv);
}

ALWAYS_INLINE __m256i Exp5(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp5);
}

/// Take the seventh power of the field elements in the even index
/// entries. Inputs must lie in [-P, ..., P]. Outputs will also lie in (-P, ...,
/// P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp7(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i cube = MontyMulSigned(square, input, p, inv);
__m256i cube_shifted = movehdup_epi32(cube);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, cube_shifted, p, inv);
}

ALWAYS_INLINE __m256i Exp7(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp7);
}

} // namespace tachyon::math

#endif // TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_

0 comments on commit 2e0f7f0

Please sign in to comment.