From d22654cfdb2f7092ea95f4bdf7c2c0caa49ceb83 Mon Sep 17 00:00:00 2001 From: "Wonyong Kim(Ryan Kim)" Date: Fri, 22 Sep 2023 21:38:35 +0900 Subject: [PATCH] feat: implement poseidon Implement Poseidon, ZK friendly Sponge construction hash. It consists of three steps: - Absorbing - Permutation - Squeezing See: https://eprint.iacr.org/2019/458.pdf --- tachyon/crypto/hashes/sponge/BUILD.bazel | 13 + .../crypto/hashes/sponge/poseidon/BUILD.bazel | 12 + .../crypto/hashes/sponge/poseidon/poseidon.h | 266 ++++++++++++++++++ .../sponge/poseidon/poseidon_unittest.cc | 43 +++ tachyon/crypto/hashes/sponge/sponge.h | 172 +++++++++++ tachyon/math/base/groups.h | 14 +- tachyon/math/base/semigroups.h | 10 +- tachyon/math/finite_fields/prime_field_base.h | 2 + 8 files changed, 528 insertions(+), 4 deletions(-) create mode 100644 tachyon/crypto/hashes/sponge/BUILD.bazel create mode 100644 tachyon/crypto/hashes/sponge/poseidon/poseidon.h create mode 100644 tachyon/crypto/hashes/sponge/poseidon/poseidon_unittest.cc create mode 100644 tachyon/crypto/hashes/sponge/sponge.h diff --git a/tachyon/crypto/hashes/sponge/BUILD.bazel b/tachyon/crypto/hashes/sponge/BUILD.bazel new file mode 100644 index 0000000000..963d710f69 --- /dev/null +++ b/tachyon/crypto/hashes/sponge/BUILD.bazel @@ -0,0 +1,13 @@ +load("//bazel:tachyon_cc.bzl", "tachyon_cc_library") + +package(default_visibility = ["//visibility:public"]) + +tachyon_cc_library( + name = "sponge", + hdrs = ["sponge.h"], + deps = [ + "//tachyon/base:logging", + "//tachyon/base/containers:container_util", + "//tachyon/math/finite_fields:prime_field_traits", + ], +) diff --git a/tachyon/crypto/hashes/sponge/poseidon/BUILD.bazel b/tachyon/crypto/hashes/sponge/poseidon/BUILD.bazel index 16970fb14e..5195c4fa28 100644 --- a/tachyon/crypto/hashes/sponge/poseidon/BUILD.bazel +++ b/tachyon/crypto/hashes/sponge/poseidon/BUILD.bazel @@ -2,6 +2,16 @@ load("//bazel:tachyon_cc.bzl", "tachyon_cc_library", "tachyon_cc_test") package(default_visibility = ["//visibility:public"]) +tachyon_cc_library( + name = "poseidon", + hdrs = ["poseidon.h"], + deps = [ + ":poseidon_config", + "//tachyon/crypto/hashes:prime_field_serializable", + "//tachyon/crypto/hashes/sponge", + ], +) + tachyon_cc_library( name = "poseidon_config", hdrs = ["poseidon_config.h"], @@ -28,8 +38,10 @@ tachyon_cc_test( srcs = [ "grain_lfsr_unittest.cc", "poseidon_config_unittest.cc", + "poseidon_unittest.cc", ], deps = [ + ":poseidon", ":poseidon_config", "//tachyon/math/elliptic_curves/bls/bls12_381:fr", ], diff --git a/tachyon/crypto/hashes/sponge/poseidon/poseidon.h b/tachyon/crypto/hashes/sponge/poseidon/poseidon.h new file mode 100644 index 0000000000..14cd14f178 --- /dev/null +++ b/tachyon/crypto/hashes/sponge/poseidon/poseidon.h @@ -0,0 +1,266 @@ +// Copyright 2022 arkworks contributors +// Use of this source code is governed by a MIT/Apache-2.0 style license that +// can be found in the LICENSE-MIT.arkworks and the LICENCE-APACHE.arkworks +// file. + +#ifndef TACHYON_CRYPTO_HASHES_SPONGE_POSEIDON_POSEIDON_H_ +#define TACHYON_CRYPTO_HASHES_SPONGE_POSEIDON_POSEIDON_H_ + +#include "tachyon/base/containers/container_util.h" +#include "tachyon/base/logging.h" +#include "tachyon/crypto/hashes/prime_field_serializable.h" +#include "tachyon/crypto/hashes/sponge/poseidon/poseidon_config.h" +#include "tachyon/crypto/hashes/sponge/sponge.h" + +namespace tachyon::crypto { + +// A duplex sponge based using the Poseidon permutation. +// This implementation of Poseidon is entirely Fractal's implementation in +// [COS20][cos] with small syntax changes. See https://eprint.iacr.org/2019/1076 +template +struct PoseidonSponge + : public FieldBasedCryptographicSponge> { + using F = PrimeFieldTy; + + struct State { + // Current sponge's state (current elements in the permutation block) + math::Vector elements; + + // Current mode (whether its absorbing or squeezing) + DuplexSpongeMode mode = DuplexSpongeMode::Absorbing(); + + State() = default; + explicit State(size_t size) : elements(size) { + for (size_t i = 0; i < size; ++i) { + elements[i] = F::Zero(); + } + } + + size_t size() const { return elements.size(); } + + F& operator[](size_t idx) { return elements[idx]; } + const F& operator[](size_t idx) const { return elements[idx]; } + }; + + // Sponge Config + PoseidonConfig config; + + // Sponge State + State state; + + explicit PoseidonSponge(const PoseidonConfig& config) + : config(config), state(config.rate + config.capacity) {} + PoseidonSponge(const PoseidonConfig& config, const State& state) + : config(config), state(state) {} + PoseidonSponge(const PoseidonConfig& config, State&& state) + : config(config), state(std::move(state)) {} + + void ApplySBox(bool is_full_round) { + if (is_full_round) { + // Full rounds apply the S-Box (x^alpha) to every element of |state|. + for (F& elem : state.elements) { + elem = elem.Pow(math::BigInt<1>(config.alpha)); + } + } else { + // Partial rounds apply the S-Box (x^alpha) to just the first element of + // |state|. + state[0] = state[0].Pow(math::BigInt<1>(config.alpha)); + } + } + + void ApplyARK(Eigen::Index round_number) { + state.elements += config.ark.row(round_number); + } + + void ApplyMDS() { state.elements = config.mds * state.elements; } + + void Permute() { + size_t full_rounds_over_2 = config.full_rounds / 2; + for (size_t i = 0; i < full_rounds_over_2; ++i) { + ApplyARK(i); + ApplySBox(true); + ApplyMDS(); + } + for (size_t i = full_rounds_over_2; + i < full_rounds_over_2 + config.partial_rounds; ++i) { + ApplyARK(i); + ApplySBox(false); + ApplyMDS(); + } + for (size_t i = full_rounds_over_2 + config.partial_rounds; + i < config.partial_rounds + config.full_rounds; ++i) { + ApplyARK(i); + ApplySBox(true); + ApplyMDS(); + } + } + + // Absorbs everything in |elements|, this does not end in an absorbing. + void AbsorbInternal(size_t rate_start_index, const std::vector& elements) { + size_t elements_idx = 0; + while (true) { + size_t remaining_size = elements.size() - elements_idx; + // if we can finish in this call + if (rate_start_index + remaining_size <= config.rate) { + for (size_t i = 0; i < remaining_size; ++i, ++elements_idx) { + state[config.capacity + i + rate_start_index] += + elements[elements_idx]; + } + state.mode.type = DuplexSpongeMode::Type::kAbsorbing; + state.mode.next_index = rate_start_index + remaining_size; + break; + } + // otherwise absorb (|config.rate| - |rate_start_index|) elements + size_t num_elements_absorbed = config.rate - rate_start_index; + for (size_t i = 0; i < num_elements_absorbed; ++i, ++elements_idx) { + state[config.capacity + i + rate_start_index] += elements[elements_idx]; + } + Permute(); + rate_start_index = 0; + } + } + + // Squeeze |output| many elements. This does not end in a squeezing. + void SqueezeInternal(size_t rate_start_index, std::vector* output) { + size_t output_size = output->size(); + size_t output_idx = 0; + while (true) { + size_t output_remaining_size = output_size - output_idx; + // if we can finish in this call + if (rate_start_index + output_remaining_size <= config.rate) { + for (size_t i = 0; i < output_remaining_size; ++i) { + (*output)[output_idx + i] = + state[config.capacity + rate_start_index + i]; + } + state.mode.type = DuplexSpongeMode::Type::kSqueezing; + state.mode.next_index = rate_start_index + output_remaining_size; + return; + } + + // otherwise squeeze (|config.rate| - |rate_start_index|) elements + size_t num_elements_squeezed = config.rate - rate_start_index; + for (size_t i = 0; i < num_elements_squeezed; ++i) { + (*output)[output_idx + i] = + state[config.capacity + rate_start_index + i]; + } + + if (output_remaining_size != config.rate) { + Permute(); + } + output_idx += num_elements_squeezed; + rate_start_index = 0; + } + } + + // CryptographicSponge methods + template + bool Absorb(const T& input) { + std::vector elements; + if (!SerializeToFieldElements(input, &elements)) return false; + + switch (state.mode.type) { + case DuplexSpongeMode::Type::kAbsorbing: { + size_t absorb_index = state.mode.next_index; + if (absorb_index == config.rate) { + Permute(); + absorb_index = 0; + } + AbsorbInternal(absorb_index, elements); + return true; + } + case DuplexSpongeMode::Type::kSqueezing: { + Permute(); + AbsorbInternal(0, elements); + return true; + } + } + NOTREACHED(); + return false; + } + + std::vector SqueezeBytes(size_t num_bytes) { + size_t usable_bytes = (F::kModulusBits - 1) / 8; + + size_t num_elements = (num_bytes + usable_bytes - 1) / usable_bytes; + std::vector src_elements = SqueezeNativeFieldElements(num_elements); + + std::vector bytes; + bytes.reserve(usable_bytes * num_elements); + for (const F& elem : src_elements) { + std::vector elem_bytes = elem.ToBigInt().ToBytesLE(); + bytes.insert(bytes.end(), elem_bytes.begin(), elem_bytes.end()); + } + + bytes.resize(num_bytes); + return bytes; + } + + std::vector SqueezeBits(size_t num_bits) { + size_t usable_bits = F::kModulusBits - 1; + + size_t num_elements = (num_bits + usable_bits - 1) / usable_bits; + std::vector src_elements = SqueezeNativeFieldElements(num_elements); + + std::vector bits; + for (const F& elem : src_elements) { + std::bitset elem_bits = + elem.ToBigInt().template ToBitsLE(); + bits.insert(bits.end(), elem_bits.begin(), elem_bits.end()); + } + bits.resize(num_bits); + return bits; + } + + template + std::vector SqueezeFieldElementsWithSizes( + const std::vector& sizes) { + if constexpr (F::Characteristic() == F2::Characteristic()) { + // native case + return this->SqueezeNativeFieldElementsWithSizes(sizes); + } + return this->template SqueezeFieldElementsWithSizesDefaultImpl(sizes); + } + + template + std::vector SqueezeFieldElements(size_t num_elements) { + if constexpr (std::is_same_v) { + return SqueezeNativeFieldElements(num_elements); + } else { + return SqueezeFieldElementsWithSizes(base::CreateVector( + num_elements, []() { return FieldElementSize::Full(); })); + } + } + + // FieldBasedCryptographicSponge methods + std::vector SqueezeNativeFieldElements(size_t num_elements) { + std::vector ret = + base::CreateVector(num_elements, []() { return F::Zero(); }); + switch (state.mode.type) { + case DuplexSpongeMode::Type::kAbsorbing: { + Permute(); + SqueezeInternal(0, &ret); + return ret; + } + case DuplexSpongeMode::Type::kSqueezing: { + size_t squeeze_index = state.mode.next_index; + if (squeeze_index == config.rate) { + Permute(); + squeeze_index = 0; + } + SqueezeInternal(squeeze_index, &ret); + return ret; + } + } + NOTREACHED(); + return {}; + } +}; + +template +struct CryptographicSpongeTraits> { + using F = PrimeFieldTy; +}; + +} // namespace tachyon::crypto + +#endif // TACHYON_CRYPTO_HASHES_SPONGE_POSEIDON_POSEIDON_H_ diff --git a/tachyon/crypto/hashes/sponge/poseidon/poseidon_unittest.cc b/tachyon/crypto/hashes/sponge/poseidon/poseidon_unittest.cc new file mode 100644 index 0000000000..7e00805909 --- /dev/null +++ b/tachyon/crypto/hashes/sponge/poseidon/poseidon_unittest.cc @@ -0,0 +1,43 @@ +// Copyright 2022 arkworks contributors +// Use of this source code is governed by a MIT/Apache-2.0 style license that +// can be found in the LICENSE-MIT.arkworks and the LICENCE-APACHE.arkworks +// file. + +#include "tachyon/crypto/hashes/sponge/poseidon/poseidon.h" + +#include "gtest/gtest.h" + +#include "tachyon/math/elliptic_curves/bls/bls12_381/fr.h" +#include "tachyon/math/finite_fields/prime_field_forward.h" + +namespace tachyon::crypto { + +namespace { + +class PoseidonTest : public testing::Test { + public: + static void SetUpTestSuite() { math::bls12_381::Fr::Init(); } +}; + +} // namespace + +TEST_F(PoseidonTest, AbsorbSqueeze) { + using Fr = math::bls12_381::Fr; + + PoseidonConfig config = PoseidonConfig::CreateDefault(2, false); + PoseidonSponge sponge(config); + std::vector inputs = {Fr(0), Fr(1), Fr(2)}; + ASSERT_TRUE(sponge.Absorb(inputs)); + std::vector result = sponge.SqueezeNativeFieldElements(3); + std::vector expected = { + Fr::FromDecString("404427934635713040283377530022421867103101638970489622" + "78675457993207843616876"), + Fr::FromDecString("266437446169989800029115314522409928771122402171620296" + "0480903840045233645301"), + Fr::FromDecString("501910788280669236620702282565306929518015040434228440" + "38937334196346054068797"), + }; + EXPECT_EQ(result, expected); +} + +} // namespace tachyon::crypto diff --git a/tachyon/crypto/hashes/sponge/sponge.h b/tachyon/crypto/hashes/sponge/sponge.h new file mode 100644 index 0000000000..238b04606c --- /dev/null +++ b/tachyon/crypto/hashes/sponge/sponge.h @@ -0,0 +1,172 @@ +// Copyright 2022 arkworks contributors +// Use of this source code is governed by a MIT/Apache-2.0 style license that +// can be found in the LICENSE-MIT.arkworks and the LICENCE-APACHE.arkworks +// file. + +#ifndef TACHYON_CRYPTO_HASHES_SPONGE_SPONGE_H_ +#define TACHYON_CRYPTO_HASHES_SPONGE_SPONGE_H_ + +#include +#include + +#include "tachyon/base/containers/container_util.h" +#include "tachyon/base/logging.h" +#include "tachyon/math/finite_fields/prime_field_traits.h" + +namespace tachyon::crypto { + +// Specifying the output field element size. +class TACHYON_EXPORT FieldElementSize { + public: + static FieldElementSize Full() { return {false}; } + static FieldElementSize Truncated(size_t num_bits) { + return {true, num_bits}; + } + + template + size_t NumBits() { + static_assert(math::PrimeFieldTraits::kIsPrimeField, + "NumBits() is only supported for PrimeField"); + if (is_truncated_) { + CHECK_LE(num_bits_, PrimeFieldTy::kModulusBits) + << "num_bits is greater than the bit size of the field."; + return num_bits_; + } + return PrimeFieldTy::kModulusBits - 1; + } + + // Calculate the sum of prime field element sizes in |elements|. + template + static size_t Sum(const std::vector& elements) { + static_assert(math::PrimeFieldTraits::kIsPrimeField, + "Sum() is only supported for PrimeField"); + return (PrimeFieldTy::kModulusBits - 1) * elements.size(); + } + + bool IsFull() const { return !is_truncated_; } + bool IsTruncated() const { return is_truncated_; } + + private: + FieldElementSize(bool is_truncated, size_t num_bits = 0) + : is_truncated_(is_truncated), num_bits_(num_bits) {} + + // If |is_truncated_| is false, sample field elements from the entire field. + // If |is_truncated_| is true, sample field elements from a subset of the + // field, specified by the maximum number of bit. + bool is_truncated_; + size_t num_bits_ = 0; +}; + +template +struct CryptographicSpongeTraits; + +// The interface for a cryptographic sponge. +// A sponge can |Absorb| and later |Squeeze| bytes of field elements. +// The outputs are dependent on previous |Absorb| and |Squeeze| calls. +template +class CryptographicSponge { + public: + using F = typename CryptographicSpongeTraits::F; + + // Squeeze |num_elements| nonnative field elements from the sponge. + std::vector SqueezeFieldElements(size_t num_elements) { + Derived* derived = static_cast(this); + return derived->SqueezeFieldElementsWithSizes(base::CreateVector( + num_elements, []() { return FieldElementSize::Full(); })); + } + + // Creates a new sponge with applied domain separation. + Derived Fork(const absl::Span& domain) const { + const Derived* derived = static_cast(this); + CHECK(derived->Absorb(domain)); + return *derived; + } + + protected: + std::vector SqueezeFieldElementsWithSizesDefaultImpl( + const std::vector& sizes) { + if constexpr (math::PrimeFieldTraits::kIsPrimeField) { + if (sizes.empty()) { + return {}; + } + + size_t total_num_bits = FieldElementSize::Sum(sizes); + + Derived derived = static_cast(this); + std::vector bits = derived->SqueezeBits(total_num_bits); + auto bits_window = bits.begin(); + + std::vector output; + output.reserve(sizes.size()); + for (const FieldElementSize& size : sizes) { + size_t num_bits = size.NumBits(); + + std::bitset field_element_bits; + for (size_t i = 0; i < num_bits; ++i) { + field_element_bits[i] = *(bits_window + i); + } + bits_window += num_bits; + + output.push_back(F::FromBitsLE(field_element_bits)); + } + return output; + } else { + NOTIMPLEMENTED(); + return {}; + } + } +}; + +// The interface for field-based cryptographic sponge. +template +class FieldBasedCryptographicSponge : public CryptographicSponge { + public: + using NativeField = typename CryptographicSpongeTraits::F; + + // Squeeze |sizes.size()| field elements from the sponge. + // where the |i|-th element of the output has |sizes[i]|. + std::vector SqueezeNativeFieldElementsWithSizes( + const std::vector& sizes) { + bool all_full_size = + std::all_of(sizes.begin(), sizes.end(), + [](const FieldElementSize& size) { return size.IsFull(); }); + Derived* derived = static_cast(this); + if (all_full_size) { + return derived->SqueezeNativeFieldElements(sizes.size()); + } else { + return derived->SqueezeFieldElementsWithSizesDefaultImpl(sizes); + } + } +}; + +// The mode structure for duplex sponge. +struct TACHYON_EXPORT DuplexSpongeMode { + enum class Type { + // The sponge is currently absorbing data. + kAbsorbing, + // The sponge is currently squeezing data out. + kSqueezing, + }; + + constexpr static DuplexSpongeMode Absorbing(size_t next_index = 0) { + return {Type::kAbsorbing, next_index}; + } + constexpr static DuplexSpongeMode Squeezing(size_t next_index = 0) { + return {Type::kSqueezing, next_index}; + } + + Type type; + // When |type| is |kAbsorbing|, it is interpreted as next position of the + // state to be XOR-ed when absorbing. + // When |type| is |kSqueezing|, it is interpreted as next position of the + // state to be outputted when squeezing. + size_t next_index; + + private: + constexpr DuplexSpongeMode(Type type, size_t next_index) + : type(type), next_index(next_index) {} +}; + +} // namespace tachyon::crypto + +#endif // TACHYON_CRYPTO_HASHES_SPONGE_SPONGE_H_ diff --git a/tachyon/math/base/groups.h b/tachyon/math/base/groups.h index c12ad83e16..388fd04378 100644 --- a/tachyon/math/base/groups.h +++ b/tachyon/math/base/groups.h @@ -15,7 +15,12 @@ SUPPORTS_BINARY_OPERATOR(Mod); template class MultiplicativeGroup : public MultiplicativeSemigroup { public: - template + template < + typename G2, + std::enable_if_t::value || + internal::SupportsMulInPlace::value || + internal::SupportsDiv::value || + internal::SupportsDivInPlace::value>* = nullptr> constexpr auto operator/(const G2& other) const { if constexpr (internal::SupportsDiv::value) { const G* g = static_cast(this); @@ -54,7 +59,12 @@ class MultiplicativeGroup : public MultiplicativeSemigroup { template class AdditiveGroup : public AdditiveSemigroup { public: - template + template < + typename G2, + std::enable_if_t::value || + internal::SupportsAddInPlace::value || + internal::SupportsSub::value || + internal::SupportsSubInPlace::value>* = nullptr> constexpr auto operator-(const G2& other) const { if constexpr (internal::SupportsSub::value) { const G* g = static_cast(this); diff --git a/tachyon/math/base/semigroups.h b/tachyon/math/base/semigroups.h index 3f270509e2..d088a8d8da 100644 --- a/tachyon/math/base/semigroups.h +++ b/tachyon/math/base/semigroups.h @@ -57,7 +57,10 @@ struct AdditiveSemigroupTraits { template class MultiplicativeSemigroup { public: - template + template < + typename G2, + std::enable_if_t::value || + internal::SupportsMulInPlace::value>* = nullptr> constexpr auto operator*(const G2& other) const { if constexpr (internal::SupportsMul::value) { const G* g = static_cast(this); @@ -134,7 +137,10 @@ class MultiplicativeSemigroup { template class AdditiveSemigroup { public: - template + template < + typename G2, + std::enable_if_t::value || + internal::SupportsAddInPlace::value>* = nullptr> constexpr auto operator+(const G2& other) const { if constexpr (internal::SupportsAdd::value) { const G* g = static_cast(this); diff --git a/tachyon/math/finite_fields/prime_field_base.h b/tachyon/math/finite_fields/prime_field_base.h index 7306c92124..1af1a2307b 100644 --- a/tachyon/math/finite_fields/prime_field_base.h +++ b/tachyon/math/finite_fields/prime_field_base.h @@ -18,6 +18,8 @@ class PrimeFieldBase : public Field { public: using Config = typename PrimeFieldTraits::Config; + static std::string Characteristic() { return Config::kModulus.ToString(); } + // Returns false for either of the following cases: // // When there exists |Config::kLargeSubgroupRootOfUnity|: