Skip to content

Commit

Permalink
Merge pull request #199 from kroma-network/refac/create-zero-or-rando…
Browse files Browse the repository at this point in the history
…m-poly-from-domain

refac: create zero or random poly from domain
  • Loading branch information
fakedev9999 authored Dec 13, 2023
2 parents 97034d6 + d53b1de commit 02d6341
Show file tree
Hide file tree
Showing 22 changed files with 106 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class UnivariatePolynomialCommitmentScheme
using Evals = math::UnivariateEvaluations<Field, kMaxDegree>;
using Domain = math::UnivariateEvaluationDomain<Field, kMaxDegree>;

size_t D() const {
const Derived* derived = static_cast<const Derived*>(this);
return derived->N() - 1;
}

// Commit to |poly| and populates |result| with the commitment.
// Return false if the degree of |poly| exceeds |kMaxDegree|.
[[nodiscard]] bool Commit(const Poly& poly, Commitment* result) const {
Expand Down
10 changes: 10 additions & 0 deletions tachyon/math/polynomials/univariate/univariate_evaluation_domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
return t;
}

template <typename T>
constexpr T Empty() const {
return T::UnsafeZero(size_ - 1);
}

template <typename T>
constexpr T Random() const {
return T::Random(size_ - 1);
}

// Compute a FFT.
[[nodiscard]] constexpr virtual Evals FFT(const DensePoly& poly) const = 0;

Expand Down
18 changes: 9 additions & 9 deletions tachyon/math/polynomials/univariate/univariate_polynomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class UnivariatePolynomial final
return UnivariatePolynomial(Coefficients::Zero());
}

// NOTE(chokobole): This doesn't call |RemoveHighDegreeZeros()| internally.
// So when the returned evaluations is called with |IsZero()|, it returns
// false. So please use it carefully!
constexpr static UnivariatePolynomial UnsafeZero(size_t degree) {
UnivariatePolynomial ret;
ret.coefficients_ = Coefficients::UnsafeZero(degree);
return ret;
}

constexpr static UnivariatePolynomial One() {
return UnivariatePolynomial(Coefficients::One());
}
Expand Down Expand Up @@ -269,15 +278,6 @@ class UnivariatePolynomial final
friend class Radix2EvaluationDomain<Field, kMaxDegree>;
friend class MixedRadixEvaluationDomain<Field, kMaxDegree>;

// NOTE(chokobole): This doesn't call |RemoveHighDegreeZeros()| internally.
// So when the returned evaluations is called with |IsZero()|, it returns
// false. So please use it carefully!
constexpr static UnivariatePolynomial UnsafeZero(size_t degree) {
UnivariatePolynomial ret;
ret.coefficients_ = Coefficients::UnsafeZero(degree);
return ret;
}

Coefficients coefficients_;
};

Expand Down
2 changes: 2 additions & 0 deletions tachyon/zk/base/commitments/shplonk_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SHPlonkExtension

size_t N() const { return shplonk_.N(); }

size_t D() const { return N() - 1; }

[[nodiscard]] bool DoUnsafeSetup(size_t size) {
return shplonk_.DoUnsafeSetup(size);
}
Expand Down
3 changes: 2 additions & 1 deletion tachyon/zk/base/entities/prover_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ class ProverTest : public Halo2ProverTest {};
} // namespace

TEST_F(ProverTest, CommitEvalsWithBlind) {
const Domain* domain = prover_->domain();
// setting random polynomial
Evals evals = Evals::Random(prover_->pcs().N() - 1);
Evals evals = domain->Random<Evals>();

// setting struct to get output
BlindedPolynomial<Poly> out;
Expand Down
10 changes: 5 additions & 5 deletions tachyon/zk/lookup/compress_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

namespace tachyon::zk {

template <typename Evals, typename F = typename Evals::Field>
template <typename Domain, typename Evals, typename F>
bool CompressExpressions(
const Domain* domain,
const std::vector<std::unique_ptr<Expression<F>>>& expressions,
size_t domain_size, const F& theta,
const SimpleEvaluator<Evals>& evaluator_tpl, Evals* out) {
Evals compressed_value = Evals::UnsafeZero(domain_size - 1);
Evals values = Evals::UnsafeZero(domain_size - 1);
const F& theta, const SimpleEvaluator<Evals>& evaluator_tpl, Evals* out) {
Evals compressed_value = domain->template Empty<Evals>();
Evals values = domain->template Empty<Evals>();

for (size_t expr_idx = 0; expr_idx < expressions.size(); ++expr_idx) {
base::Parallelize(
Expand Down
3 changes: 2 additions & 1 deletion tachyon/zk/lookup/compress_expression_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ TEST_F(CompressExpressionTest, CompressExpressions) {
}

Evals out;
ASSERT_TRUE(CompressExpressions(expressions, n, theta_, evaluator_, &out));
ASSERT_TRUE(CompressExpressions(prover_->domain(), expressions, theta_,
evaluator_, &out));
EXPECT_EQ(out, Evals(std::move(expected)));
}

Expand Down
8 changes: 4 additions & 4 deletions tachyon/zk/lookup/lookup_argument_runner_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ LookupPermuted<Poly, Evals> LookupArgumentRunner<Poly, Evals>::PermuteArgument(
const SimpleEvaluator<Evals>& evaluator_tpl) {
// A_compressed(X) = θᵐ⁻¹A₀(X) + θᵐ⁻²A₁(X) + ... + θAₘ₋₂(X) + Aₘ₋₁(X)
Evals compressed_input_expression;
CHECK(CompressExpressions(argument.input_expressions(),
prover->domain()->size(), theta, evaluator_tpl,
CHECK(CompressExpressions(prover->domain(), argument.input_expressions(),
theta, evaluator_tpl,
&compressed_input_expression));

// S_compressed(X) = θᵐ⁻¹S₀(X) + θᵐ⁻²S₁(X) + ... + θSₘ₋₂(X) + Sₘ₋₁(X)
Evals compressed_table_expression;
CHECK(CompressExpressions(argument.table_expressions(),
prover->domain()->size(), theta, evaluator_tpl,
CHECK(CompressExpressions(prover->domain(), argument.table_expressions(),
theta, evaluator_tpl,
&compressed_table_expression));

// Permute compressed (InputExpression, TableExpression) pair.
Expand Down
8 changes: 4 additions & 4 deletions tachyon/zk/lookup/lookup_argument_runner_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ TEST_F(LookupArgumentRunnerTest, ComputePermutationProduct) {
}

Evals compressed_input_expression;
ASSERT_TRUE(CompressExpressions(input_expressions, n, theta_, evaluator_,
&compressed_input_expression));
ASSERT_TRUE(CompressExpressions(prover_->domain(), input_expressions, theta_,
evaluator_, &compressed_input_expression));
Evals compressed_table_expression;
ASSERT_TRUE(CompressExpressions(table_expressions, n, theta_, evaluator_,
&compressed_table_expression));
ASSERT_TRUE(CompressExpressions(prover_->domain(), table_expressions, theta_,
evaluator_, &compressed_table_expression));

LookupPair<Evals> compressed_evals_pair(
std::move(compressed_input_expression),
Expand Down
6 changes: 3 additions & 3 deletions tachyon/zk/plonk/circuit/examples/simple_circuit_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ TEST_F(SimpleCircuitTest, Synthesize) {
size_t n = 16;
CHECK(prover_->pcs().UnsafeSetup(n, F(2)));
prover_->set_domain(Domain::Create(n));
const Domain* domain = prover_->domain();

ConstraintSystem<F> constraint_system;
FieldConfig<math::bn254::Fr> config =
SimpleCircuit<math::bn254::Fr>::Configure(constraint_system);
Assembly<PCS> assembly =
VerifyingKey<PCS>::CreateAssembly(prover_->pcs(), constraint_system);
VerifyingKey<PCS>::CreateAssembly(domain, constraint_system);

F constant(7);
F a(2);
Expand All @@ -148,9 +149,8 @@ TEST_F(SimpleCircuitTest, Synthesize) {
SimpleCircuit<math::bn254::Fr>::FloorPlanner::Synthesize(
&assembly, circuit, std::move(config), constraint_system.constants());

EXPECT_EQ(assembly.k(), 4);
std::vector<RationalEvals> expected_fixed_columns;
RationalEvals evals = RationalEvals::UnsafeZero(n - 1);
RationalEvals evals = domain->Empty<RationalEvals>();
*evals[0] = math::RationalField<F>(constant);
expected_fixed_columns.push_back(std::move(evals));
EXPECT_EQ(assembly.fixed_columns(), expected_fixed_columns);
Expand Down
7 changes: 2 additions & 5 deletions tachyon/zk/plonk/keys/assembly.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@ class Assembly : public Assignment<typename PCSTy::Field> {
using AssignCallback = typename Assignment<F>::AssignCallback;

Assembly() = default;
Assembly(uint32_t k, std::vector<RationalEvals>&& fixed_columns,
Assembly(std::vector<RationalEvals>&& fixed_columns,
PermutationAssembly<PCSTy>&& permutation,
std::vector<std::vector<bool>>&& selectors,
base::Range<size_t> usable_rows)
: k_(k),
fixed_columns_(std::move(fixed_columns)),
: fixed_columns_(std::move(fixed_columns)),
permutation_(std::move(permutation)),
selectors_(std::move(selectors)),
usable_rows_(usable_rows) {}

uint32_t k() const { return k_; }
const std::vector<RationalEvals>& fixed_columns() const {
return fixed_columns_;
}
Expand Down Expand Up @@ -80,7 +78,6 @@ class Assembly : public Assignment<typename PCSTy::Field> {
}

private:
uint32_t k_ = 0;
std::vector<RationalEvals> fixed_columns_;
PermutationAssembly<PCSTy> permutation_;
std::vector<std::vector<bool>> selectors_;
Expand Down
15 changes: 8 additions & 7 deletions tachyon/zk/plonk/keys/key.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,23 @@ template <typename PCSTy>
class Key {
public:
using F = typename PCSTy::Field;
using Domain = typename PCSTy::Domain;
using Evals = typename PCSTy::Evals;

static Assembly<PCSTy> CreateAssembly(
const PCSTy& pcs, const ConstraintSystem<F>& constraint_system) {
const Domain* domain, const ConstraintSystem<F>& constraint_system) {
using RationalEvals = typename Assembly<PCSTy>::RationalEvals;
size_t n = domain->size();
return {
static_cast<uint32_t>(pcs.K()),
base::CreateVector(constraint_system.num_fixed_columns(),
RationalEvals::UnsafeZero(pcs.N() - 1)),
PermutationAssembly<PCSTy>(constraint_system.permutation(), pcs.N()),
domain->template Empty<RationalEvals>()),
PermutationAssembly<PCSTy>(constraint_system.permutation(), n),
base::CreateVector(constraint_system.num_selectors(),
base::CreateVector(pcs.N(), false)),
base::CreateVector(n, false)),
// NOTE(chokobole): Considering that this is called from a verifier,
// then you can't load this number through |prover->GetUsableRows()|.
base::Range<size_t>::Until(
pcs.N() - (constraint_system.ComputeBlindingFactors() + 1))};
n - (constraint_system.ComputeBlindingFactors() + 1))};
}

protected:
Expand Down Expand Up @@ -67,7 +68,7 @@ class Key {
entity->set_extended_domain(
ExtendedDomain::Create(size_t{1} << extended_k));

result->assembly = CreateAssembly(pcs, constraint_system);
result->assembly = CreateAssembly(entity->domain(), constraint_system);
Assembly<PCSTy>& assembly = result->assembly;
FloorPlanner::Synthesize(&assembly, circuit, std::move(config),
constraint_system.constants());
Expand Down
6 changes: 2 additions & 4 deletions tachyon/zk/plonk/keys/proving_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class ProvingKey : public Key<PCSTy> {
permutations = std::move(vk_load_result->permutations);
} else {
permutations =
pre_load_result.assembly.permutation().GeneratePermutations(
prover->domain());
pre_load_result.assembly.permutation().GeneratePermutations(domain);
}

permutation_proving_key_ =
Expand All @@ -104,8 +103,7 @@ class ProvingKey : public Key<PCSTy> {
// | 5 | 0 |
// | 6 | 0 |
// | 7 | 0 |
const PCSTy& pcs = prover->pcs();
Evals evals = Evals::UnsafeZero(pcs.N() - 1);
Evals evals = domain->template Empty<Evals>();
*evals[0] = F::One();
l_first_ = domain->IFFT(evals);
*evals[0] = F::Zero();
Expand Down
20 changes: 11 additions & 9 deletions tachyon/zk/plonk/permutation/permutation_argument_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class PermutationArgumentTest : public Halo2ProverTest {
void SetUp() override {
Halo2ProverTest::SetUp();

const size_t n = prover_->pcs().N();
const size_t d = n - 1;
const Domain* domain = prover_->domain();

Evals cycled_column = Evals::Random(d);
fixed_columns_ = {cycled_column, Evals::Random(d), Evals::Random(d)};
advice_columns_ = {Evals::Random(d), cycled_column, Evals::Random(d)};
instance_columns_ = {cycled_column, Evals::Random(d), Evals::Random(d)};
Evals cycled_column = domain->Random<Evals>();
fixed_columns_ = {cycled_column, domain->Random<Evals>(),
domain->Random<Evals>()};
advice_columns_ = {domain->Random<Evals>(), cycled_column,
domain->Random<Evals>()};
instance_columns_ = {cycled_column, domain->Random<Evals>(),
domain->Random<Evals>()};

table_ = Table<Evals>(absl::MakeConstSpan(fixed_columns_),
absl::MakeConstSpan(advice_columns_),
Expand All @@ -40,8 +42,8 @@ class PermutationArgumentTest : public Halo2ProverTest {
};
argument_ = PermutationArgument(column_keys_);

unpermuted_table_ = UnpermutedTable<Evals>::Construct(column_keys_.size(),
n, prover_->domain());
unpermuted_table_ = UnpermutedTable<Evals>::Construct(
column_keys_.size(), prover_->pcs().N(), prover_->domain());
}

protected:
Expand Down Expand Up @@ -92,7 +94,7 @@ TEST_F(PermutationArgumentTest, Commit) {
F gamma = F::Random();

PermutationArgumentRunner<Poly, Evals>::CommitArgument(
prover_.get(), argument_, table_, prover_->pcs().N(), pk, beta, gamma);
prover_.get(), argument_, table_, n, pk, beta, gamma);
}

} // namespace tachyon::zk
3 changes: 2 additions & 1 deletion tachyon/zk/plonk/permutation/permutation_assembly.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ class PermutationAssembly {
// permutations. Note that the permutation polynomials are in evaluation
// form.
std::vector<Evals> GeneratePermutations(const Domain* domain) const {
CHECK_EQ(domain->size(), rows_);
UnpermutedTable<Evals> unpermuted_table =
UnpermutedTable<Evals>::Construct(columns_.size(), rows_, domain);

// Init evaluation formed polynomials with all-zero coefficients.
std::vector<Evals> permutations =
base::CreateVector(columns_.size(), Evals::UnsafeZero(rows_ - 1));
base::CreateVector(columns_.size(), domain->template Empty<Evals>());

// Assign |unpermuted_table| to |permutations|.
base::Parallelize(permutations, [&unpermuted_table, this](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class PermutationProvingKeyTest : public Halo2ProverTest {
} // namespace

TEST_F(PermutationProvingKeyTest, Copyable) {
ProvingKey expected({Evals::Random(prover_->pcs().N() - 1)},
{Poly::Random(5)});
const Domain* domain = prover_->domain();
ProvingKey expected({domain->Random<Evals>()}, {domain->Random<Poly>()});
ProvingKey value;

base::VectorBuffer write_buf;
Expand Down
16 changes: 9 additions & 7 deletions tachyon/zk/plonk/permutation/permutation_table_store_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class PermutationTableStoreTest : public Halo2ProverTest {
void SetUp() override {
Halo2ProverTest::SetUp();

size_t n = prover_->pcs().N();
size_t d = n - 1;
fixed_columns_ = {Evals::Random(d), Evals::Random(d), Evals::Random(d)};
advice_columns_ = {Evals::Random(d), Evals::Random(d), Evals::Random(d)};
instance_columns_ = {Evals::Random(d), Evals::Random(d), Evals::Random(d)};
const Domain* domain = prover_->domain();
fixed_columns_ =
base::CreateVector(3, [domain]() { return domain->Random<Evals>(); });
advice_columns_ =
base::CreateVector(3, [domain]() { return domain->Random<Evals>(); });
instance_columns_ =
base::CreateVector(3, [domain]() { return domain->Random<Evals>(); });

table_ = Table<Evals>(absl::MakeConstSpan(fixed_columns_),
absl::MakeConstSpan(advice_columns_),
Expand All @@ -32,8 +34,8 @@ class PermutationTableStoreTest : public Halo2ProverTest {
AdviceColumnKey(1), FixedColumnKey(1), FixedColumnKey(2),
AdviceColumnKey(2), InstanceColumnKey(1), InstanceColumnKey(2)};

unpermuted_table_ = UnpermutedTable<Evals>::Construct(column_keys_.size(),
n, prover_->domain());
unpermuted_table_ = UnpermutedTable<Evals>::Construct(
column_keys_.size(), prover_->pcs().N(), prover_->domain());
for (const Evals& column : unpermuted_table_.table()) {
permutations_.push_back(column);
}
Expand Down
26 changes: 13 additions & 13 deletions tachyon/zk/plonk/prover/argument_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@ namespace {

class ArgumentTest : public Halo2ProverTest {
public:
using F = typename PCS::Field;
using Poly = typename PCS::Poly;
using Evals = typename PCS::Evals;

void InitColumns() {
size_t d = prover_->pcs().N() - 1;
const Domain* domain = prover_->domain();
num_circuits_ = 2;
expected_fixed_columns_ =
base::CreateVector(1, [d]() { return Evals::Random(d); });
base::CreateVector(1, [domain]() { return domain->Random<Evals>(); });
expected_fixed_polys_ =
base::CreateVector(1, [d]() { return Poly::Random(d); });
expected_advice_columns_vec_ = base::CreateVector(num_circuits_, [d]() {
return base::CreateVector(2, [d]() { return Evals::Random(d); });
});
base::CreateVector(1, [domain]() { return domain->Random<Poly>(); });
expected_advice_columns_vec_ =
base::CreateVector(num_circuits_, [domain]() {
return base::CreateVector(
2, [domain]() { return domain->Random<Evals>(); });
});
expected_advice_blinds_vec_ = base::CreateVector(num_circuits_, []() {
return base::CreateVector(2, []() { return F::Random(); });
});
expected_instance_columns_vec_ = base::CreateVector(num_circuits_, [d]() {
return base::CreateVector(1, [d]() { return Evals::Random(d); });
});
expected_instance_columns_vec_ =
base::CreateVector(num_circuits_, [domain]() {
return base::CreateVector(
1, [domain]() { return domain->Random<Evals>(); });
});
expected_challenges_ = {F::Random()};
}

Expand Down
Loading

0 comments on commit 02d6341

Please sign in to comment.