Skip to content

Commit

Permalink
Merge pull request #72 from kroma-network/chore/add-math-base-comments
Browse files Browse the repository at this point in the history
chore(math): add comments to math base files
  • Loading branch information
Merlyn authored Oct 5, 2023
2 parents 14276f7 + 9af55f4 commit 517b24a
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 48 deletions.
2 changes: 1 addition & 1 deletion benchmark/msm/msm_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bool MSMConfig::Parse(int argc, char** argv,
parser.AddFlag<base::Flag<std::vector<uint64_t>>>(&degrees_)
.set_short_name("-n")
.set_required()
.set_help("Specify the exponent 'n' where the number of points to test is 2^n.");
.set_help("Specify the exponent 'n' where the number of points to test is 2ⁿ.");
// clang-format on
parser.AddFlag<base::BoolFlag>(&check_results_)
.set_long_name("--check_results")
Expand Down
4 changes: 2 additions & 2 deletions tachyon/base/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ ALWAYS_INLINE constexpr
: bits;
}

// Returns the integer i such as 2^i <= n < 2^(i+1).
// Returns the integer i such as 2ⁱ <= n < 2ⁱ⁺¹.
//
// There is a common `BitLength` function, which returns the number of bits
// required to represent a value. Rather than implement that function,
Expand All @@ -116,7 +116,7 @@ ALWAYS_INLINE constexpr
// TODO(pkasting): When C++20 is available, replace with std::bit_xxx().
constexpr int Log2Floor(uint32_t n) { return 31 - CountLeadingZeroBits(n); }

// Returns the integer i such as 2^(i-1) < n <= 2^i.
// Returns the integer i such as 2ⁱ⁻¹ < n <= 2ⁱ.
constexpr int Log2Ceiling(uint32_t n) {
// When n == 0, we want the function to return -1.
// When n == 0, (n - 1) will underflow to 0xFFFFFFFF, which is
Expand Down
2 changes: 1 addition & 1 deletion tachyon/base/memory/aligned_memory_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ TEST(AlignedMemoryTest, IsAligned) {

// Walk back down all lower powers of two checking alignment.
for (int j = i - 1; j >= 0; --j) {
// n is aligned on all powers of two less than or equal to 2^i.
// n is aligned on all powers of two less than or equal to 2ⁱ.
EXPECT_TRUE(IsAligned(n, n >> j))
<< "Expected " << n << " to be " << (n >> j) << " aligned";

Expand Down
2 changes: 1 addition & 1 deletion tachyon/base/numerics/math_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ constexpr float kRadToDegFloat = 57.295779513082320876798f;
constexpr double kSqrtHalfDouble = 0.70710678118654752440;
constexpr float kSqrtHalfFloat = 0.70710678118654752440f;

// The mean acceleration due to gravity on Earth in m/s^2.
// The mean acceleration due to gravity on Earth in m/s².
constexpr double kMeanGravityDouble = 9.80665;
constexpr float kMeanGravityFloat = 9.80665f;

Expand Down
24 changes: 12 additions & 12 deletions tachyon/base/ranges/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ constexpr bool equal(Range1&& range1,
// `last1 - first1` applications of the corresponding predicate and projections
// if `ranges::equal(first1, last1, first2, last2, pred, proj, proj)` would
// return true;
// otherwise, at worst `O(N^2)`, where `N` has the value `last1 - first1`.
// otherwise, at worst `O(N²)`, where `N` has the value `last1 - first1`.
//
// Reference:
// https://wg21.link/alg.is.permutation#:~:text=ranges::is_permutation(I1
Expand Down Expand Up @@ -1038,7 +1038,7 @@ constexpr bool is_permutation(ForwardIterator1 first1,
// `size(range1) != size(range2)`. Otherwise, exactly `size(range1)`
// applications of the corresponding predicate and projections if
// `ranges::equal(range1, range2, pred, proj, proj)` would return true;
// otherwise, at worst `O(N^2)`, where `N` has the value `size(range1)`.
// otherwise, at worst `O(N²)`, where `N` has the value `size(range1)`.
//
// Reference:
// https://wg21.link/alg.is.permutation#:~:text=ranges::is_permutation(R1
Expand Down Expand Up @@ -2726,7 +2726,7 @@ constexpr auto sort(Range&& range, Comp comp = {}, Proj proj = {}) {
// Returns: `last`.
//
// Complexity: Let `N` be `last - first`. If enough extra memory is available,
// `N log (N)` comparisons. Otherwise, at most `N log^2 (N)` comparisons. In
// `N log (N)` comparisons. Otherwise, at most `N log² (N)` comparisons. In
// either case, twice as many projections as the number of comparisons.
//
// Remarks: Stable.
Expand All @@ -2753,7 +2753,7 @@ constexpr auto stable_sort(RandomAccessIterator first,
// Returns: `end(rang)`.
//
// Complexity: Let `N` be `size(range)`. If enough extra memory is available,
// `N log (N)` comparisons. Otherwise, at most `N log^2 (N)` comparisons. In
// `N log (N)` comparisons. Otherwise, at most `N log² (N)` comparisons. In
// either case, twice as many projections as the number of comparisons.
//
// Remarks: Stable.
Expand Down Expand Up @@ -3097,7 +3097,7 @@ constexpr auto nth_element(Range&& range,
// for every iterator `j` in the range `[first, i)`,
// `bool(invoke(comp, invoke(proj, *j), value))` is true.
//
// Complexity: At most `log_2(last - first) + O(1)` comparisons and projections.
// Complexity: At most `log₂(last - first) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/lower.bound#:~:text=ranges::lower_bound(I
template <typename ForwardIterator,
Expand Down Expand Up @@ -3125,7 +3125,7 @@ constexpr auto lower_bound(ForwardIterator first,
// `[begin(range), end(range)]` such that for every iterator `j` in the range
// `[begin(range), i)`, `bool(invoke(comp, invoke(proj, *j), value))` is true.
//
// Complexity: At most `log_2(size(range)) + O(1)` comparisons and projections.
// Complexity: At most `log₂(size(range)) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/lower.bound#:~:text=ranges::lower_bound(R
template <typename Range,
Expand All @@ -3151,7 +3151,7 @@ constexpr auto lower_bound(Range&& range,
// for every iterator `j` in the range `[first, i)`,
// `!bool(invoke(comp, value, invoke(proj, *j)))` is true.
//
// Complexity: At most `log_2(last - first) + O(1)` comparisons and projections.
// Complexity: At most `log₂(last - first) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/upper.bound#:~:text=ranges::upper_bound(I
template <typename ForwardIterator,
Expand Down Expand Up @@ -3179,7 +3179,7 @@ constexpr auto upper_bound(ForwardIterator first,
// `[begin(range), end(range)]` such that for every iterator `j` in the range
// `[begin(range), i)`, `!bool(invoke(comp, value, invoke(proj, *j)))` is true.
//
// Complexity: At most `log_2(size(range)) + O(1)` comparisons and projections.
// Complexity: At most `log₂(size(range)) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/upper.bound#:~:text=ranges::upper_bound(R
template <typename Range,
Expand All @@ -3205,7 +3205,7 @@ constexpr auto upper_bound(Range&& range,
// Returns: `{ranges::lower_bound(first, last, value, comp, proj),
// ranges::upper_bound(first, last, value, comp, proj)}`.
//
// Complexity: At most 2 ∗ log_2(last - first) + O(1) comparisons and
// Complexity: At most 2 ∗ log₂(last - first) + O(1) comparisons and
// projections.
//
// Reference: https://wg21.link/equal.range#:~:text=ranges::equal_range(I
Expand Down Expand Up @@ -3233,7 +3233,7 @@ constexpr auto equal_range(ForwardIterator first,
// Returns: `{ranges::lower_bound(range, value, comp, proj),
// ranges::upper_bound(range, value, comp, proj)}`.
//
// Complexity: At most 2 ∗ log_2(size(range)) + O(1) comparisons and
// Complexity: At most 2 ∗ log₂(size(range)) + O(1) comparisons and
// projections.
//
// Reference: https://wg21.link/equal.range#:~:text=ranges::equal_range(R
Expand Down Expand Up @@ -3261,7 +3261,7 @@ constexpr auto equal_range(Range&& range,
// `[first, last)`, `!bool(invoke(comp, invoke(proj, *i), value)) &&
// !bool(invoke(comp, value, invoke(proj, *i)))` is true.
//
// Complexity: At most `log_2(last - first) + O(1)` comparisons and projections.
// Complexity: At most `log₂(last - first) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/binary.search#:~:text=ranges::binary_search(I
template <typename ForwardIterator,
Expand All @@ -3287,7 +3287,7 @@ constexpr auto binary_search(ForwardIterator first,
// `!bool(invoke(comp, invoke(proj, *i), value)) &&
// !bool(invoke(comp, value, invoke(proj, *i)))` is true.
//
// Complexity: At most `log_2(size(range)) + O(1)` comparisons and projections.
// Complexity: At most `log₂(size(range)) + O(1)` comparisons and projections.
//
// Reference: https://wg21.link/binary.search#:~:text=ranges::binary_search(R
template <typename Range,
Expand Down
2 changes: 1 addition & 1 deletion tachyon/base/time/time_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ TEST(TimeDelta, MaxConversions) {

static_assert(
Microseconds(max_d).is_max(),
"Make sure that 2^63 correctly gets clamped to `max` (crbug.com/612601)");
"Make sure that 2⁶³ correctly gets clamped to `max` (crbug.com/612601)");

static_assert(Milliseconds(std::numeric_limits<double>::infinity()).is_max());

Expand Down
4 changes: 2 additions & 2 deletions tachyon/crypto/hashes/sponge/poseidon/poseidon.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ struct PoseidonSponge

void ApplySBox(bool is_full_round) {
if (is_full_round) {
// Full rounds apply the S-Box (x^alpha) to every element of |state|.
// Full rounds apply the S-Box (xᵅ) to every element of |state|.
for (F& elem : state.elements) {
elem = elem.Pow(math::BigInt<1>(config.alpha));
}
} else {
// Partial rounds apply the S-Box (x^alpha) to just the first element of
// Partial rounds apply the S-Box (xᵅ) to just the first element of
// |state|.
state[0] = state[0].Pow(math::BigInt<1>(config.alpha));
}
Expand Down
58 changes: 58 additions & 0 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ constexpr size_t LimbsAlignment(size_t x) {

} // namespace internal

// BigInt is a fixed size array of uint64_t, capable of holding up to |N| limbs,
// designed to support a wide range of big integer arithmetic operations.
template <size_t N>
struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
uint64_t limbs[N] = {
Expand Down Expand Up @@ -78,6 +80,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {

constexpr static BigInt One() { return BigInt(1); }

// Returns the maximum representable value for BigInt.
constexpr static BigInt Max() {
BigInt ret;
for (uint64_t& limb : ret.limbs) {
Expand All @@ -86,6 +89,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Generate a random BigInt between [0, |max|).
constexpr static BigInt Random(const BigInt& max = Max()) {
BigInt ret;
for (size_t i = 0; i < N; ++i) {
Expand All @@ -98,18 +102,22 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Convert a decimal string to a BigInt.
constexpr static BigInt FromDecString(std::string_view str) {
BigInt ret;
CHECK(internal::StringToLimbs(str, ret.limbs, N));
return ret;
}

// Convert a hexadecimal string to a BigInt.
constexpr static BigInt FromHexString(std::string_view str) {
BigInt ret;
CHECK(internal::HexStringToLimbs(str, ret.limbs, N));
return ret;
}

// Constructs a BigInt value from a given array of bits in little-endian
// order.
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsLE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
Expand All @@ -136,6 +144,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Constructs a BigInt value from a given array of bits in big-endian order.
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsBE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
Expand All @@ -162,6 +171,8 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Constructs a BigInt value from a given array of bytes in little-endian
// order.
constexpr static BigInt FromBytesLE(const std::vector<uint8_t>& bytes) {
BigInt ret;
size_t byte_idx = 0;
Expand All @@ -184,6 +195,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Constructs a BigInt value from a given array of bytes in big-endian order.
constexpr static BigInt FromBytesBE(const std::vector<uint8_t>& bytes) {
BigInt ret;
size_t byte_idx = 0;
Expand Down Expand Up @@ -218,6 +230,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return FromMontgomery(value, modulus, inverse);
}

// Extend the current |N| size BigInt to a larger |N2| size.
template <size_t N2>
constexpr BigInt<N2> Extend() const {
static_assert(N2 > N);
Expand All @@ -228,6 +241,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Shrink the current |N| size BigInt to a smaller |N2| size.
template <size_t N2>
constexpr BigInt<N2> Shrink() const {
static_assert(N2 < N);
Expand All @@ -238,6 +252,10 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Clamp the BigInt value with respect to a modulus.
// If the value is larger than or equal to the modulus, then the modulus is
// subtracted from the value. The function considers a spare bit in the
// modulus based on the template parameter.
template <bool ModulusHasSpareBit>
constexpr static void Clamp(const BigInt& modulus, BigInt* value,
[[maybe_unused]] bool carry = false) {
Expand Down Expand Up @@ -283,20 +301,26 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
constexpr bool IsEven() const { return limbs[kSmallestLimbIdx] % 2 == 0; }
constexpr bool IsOdd() const { return limbs[kSmallestLimbIdx] % 2 == 1; }

// Return the largest (most significant) limb of the BigInt.
constexpr uint64_t& biggest_limb() { return limbs[kBiggestLimbIdx]; }
constexpr const uint64_t& biggest_limb() const {
return limbs[kBiggestLimbIdx];
}

// Return the smallest (least significant) limb of the BigInt.
constexpr uint64_t& smallest_limb() { return limbs[kSmallestLimbIdx]; }
constexpr const uint64_t& smallest_limb() const {
return limbs[kSmallestLimbIdx];
}

// Extracts a specified number of bits starting from a given bit offset and
// returns them as a uint64_t.
constexpr uint64_t ExtractBits64(size_t bit_offset, size_t bit_count) const {
return ExtractBits<uint64_t>(bit_offset, bit_count);
}

// Extracts a specified number of bits starting from a given bit offset and
// returns them as a uint32_t.
constexpr uint32_t ExtractBits32(size_t bit_offset, size_t bit_count) const {
return ExtractBits<uint32_t>(bit_offset, bit_count);
}
Expand Down Expand Up @@ -627,6 +651,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return internal::LimbsToHexString(limbs, N);
}

// Converts the BigInt to a bit array in little-endian.
template <size_t BitNums = kBitNums>
std::bitset<BitNums> ToBitsLE() const {
std::bitset<BitNums> ret;
Expand All @@ -641,6 +666,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Converts the BigInt to a bit array in big-endian.
template <size_t BitNums = kBitNums>
std::bitset<BitNums> ToBitsBE() const {
std::bitset<BitNums> ret;
Expand All @@ -655,6 +681,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Converts the BigInt to a byte array in little-endian.
std::vector<uint8_t> ToBytesLE() const {
std::vector<uint8_t> ret;
ret.reserve(kByteNums);
Expand All @@ -667,6 +694,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret;
}

// Converts the BigInt to a byte array in big-endian.
std::vector<uint8_t> ToBytesBE() const {
std::vector<uint8_t> ret;
ret.reserve(kByteNums);
Expand Down Expand Up @@ -781,6 +809,14 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return ret & mask;
}

// Montgomery arithmetic is a technique that allows modular arithmetic to be
// done more efficiently, by avoiding the need for explicit divisions.
// See https://en.wikipedia.org/wiki/Montgomery_modular_multiplication

// Converts a BigInt value from the Montgomery domain back to the standard
// domain. |FromMontgomery()| performs the Montgomery reduction algorithm to
// transform a value from the Montgomery domain back to its standard
// representation.
template <typename T>
constexpr static BigInt FromMontgomery(const BigInt<N>& value,
const BigInt<N>& modulus, T inverse) {
Expand Down Expand Up @@ -810,6 +846,28 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return r;
}

// Performs Montgomery reduction on a doubled-sized BigInt, and populates
// |out| with the result.

// Inputs:
// - r: A BigInt representing a value (typically A x B) in Montgomery form.
// - modulus: The modulus M against which we're performing arithmetic.
// - inverse: The multiplicative inverse of the radix w.r.t. the modulus.

// Operation:
// 1. For each limb of r:
// - Compute a tmp = r(current limb) * inverse.
// This value aids in eliminating the lowest limb of r when multiplied by
// the modulus.
// - Incrementally add tmp * (modulus to r), effectively canceling out its
// current lowest limb.
//
// 2. After iterating over all limbs, the higher half of r is the
// Montgomery-reduced result of the original operation (like A x B). This
// result remains in the Montgomery domain.
//
// 3. Apply a final correction (if necessary) to ensure the result is less
// than |modulus|.
template <bool ModulusHasSpareBit, typename T>
constexpr static void MontgomeryReduce(BigInt<2 * N>& r,
const BigInt& modulus, T inverse,
Expand Down
9 changes: 9 additions & 0 deletions tachyon/math/base/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@

namespace tachyon::math {

// Field is any set of elements that satisfies the field axioms for both
// addition and multiplication and is commutative division algebra
// Simply put, a field is a ring in which multiplicative commutativity exists,
// and every non-zero element has a multiplicative inverse.
// See https://mathworld.wolfram.com/Field.html

// The Field supports SumOfProducts, inheriting the properties of both
// AdditiveGroup and MultiplicativeGroup.
template <typename F>
class Field : public AdditiveGroup<F>, public MultiplicativeGroup<F> {
public:
// Sum of products: a₁ * b₁ + a₂ * b₂ + ... + aₙ * bₙ
template <
typename InputIterator,
std::enable_if_t<std::is_same_v<F, base::iter_value_t<InputIterator>>>* =
Expand Down
Loading

0 comments on commit 517b24a

Please sign in to comment.