From e7bdfb7fcb7700474cf8277ed5110b8b34d9e5a3 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 10 Sep 2024 18:08:41 +0200 Subject: [PATCH] Change the way arities are computed --- plonky2/src/fri/mod.rs | 7 ++--- plonky2/src/fri/proof.rs | 36 +++++++++--------------- plonky2/src/fri/prover.rs | 11 ++------ plonky2/src/fri/recursive_verifier.rs | 26 ++++++----------- plonky2/src/fri/reduction_strategies.rs | 13 +++++++-- plonky2/src/gates/coset_interpolation.rs | 2 +- 6 files changed, 37 insertions(+), 58 deletions(-) diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index 252c6d5e11..725efe00e8 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -51,6 +51,7 @@ impl FriConfig { self.rate_bits, self.cap_height, self.num_query_rounds, + hiding, ); FriParams { config: self.clone(), @@ -87,11 +88,7 @@ pub struct FriParams { impl FriParams { pub fn total_arities(&self) -> usize { - if self.hiding { - (1 as usize) + self.reduction_arity_bits.iter().sum::() - } else { - self.reduction_arity_bits.iter().sum() - } + self.reduction_arity_bits.iter().sum::() } pub(crate) fn max_arity_bits(&self) -> Option { diff --git a/plonky2/src/fri/proof.rs b/plonky2/src/fri/proof.rs index db8d3ca025..f2b57d45ae 100644 --- a/plonky2/src/fri/proof.rs +++ b/plonky2/src/fri/proof.rs @@ -144,14 +144,8 @@ impl, H: Hasher, const D: usize> FriProof, H: Hasher, const D: usize> FriProof>= reduction_arity_bits[i]; + let index_within_coset = index & ((1 << params.reduction_arity_bits[i]) - 1); + index >>= params.reduction_arity_bits[i]; steps_indices[i].push(index); let mut evals = query_step.evals; // Remove the element that can be inferred. @@ -221,7 +215,7 @@ impl, H: Hasher, const D: usize> FriProof>= reduction_arity_bits[j]; + index >>= params.reduction_arity_bits[j]; let query_step = FriQueryStep { evals: steps_evals[j][i].clone(), merkle_proof: steps_proofs[j][i].clone(), @@ -262,14 +256,8 @@ impl, H: Hasher, const D: usize> CompressedFriPr } = &challenges.fri_challenges; let mut fri_inferred_elements = fri_inferred_elements.0.into_iter(); let cap_height = params.config.cap_height; - let reduction_arity_bits = if params.hiding { - let mut tmp = vec![1]; - tmp.extend(¶ms.reduction_arity_bits); - tmp - } else { - params.reduction_arity_bits.clone() - }; - let num_reductions = reduction_arity_bits.len(); + + let num_reductions = params.reduction_arity_bits.len(); let num_initial_trees = query_round_proofs .initial_trees_proofs .values() @@ -286,7 +274,8 @@ impl, H: Hasher, const D: usize> CompressedFriPr let mut steps_evals = vec![vec![]; num_reductions]; let mut steps_proofs = vec![vec![]; num_reductions]; let height = params.degree_bits + params.config.rate_bits; - let heights = reduction_arity_bits + let heights = params + .reduction_arity_bits .iter() .scan(height, |acc, &bits| { *acc -= bits; @@ -295,7 +284,8 @@ impl, H: Hasher, const D: usize> CompressedFriPr .collect::>(); // Holds the `evals` vectors that have already been reconstructed at each reduction depth. - let mut evals_by_depth = vec![HashMap::>::new(); reduction_arity_bits.len()]; + let mut evals_by_depth = + vec![HashMap::>::new(); params.reduction_arity_bits.len()]; for &(mut index) in indices { let initial_trees_proof = query_round_proofs.initial_trees_proofs[&index].clone(); for (i, (leaves_data, proof)) in @@ -306,8 +296,8 @@ impl, H: Hasher, const D: usize> CompressedFriPr initial_trees_proofs[i].push(proof); } for i in 0..num_reductions { - let index_within_coset = index & ((1 << reduction_arity_bits[i]) - 1); - index >>= reduction_arity_bits[i]; + let index_within_coset = index & ((1 << params.reduction_arity_bits[i]) - 1); + index >>= params.reduction_arity_bits[i]; let FriQueryStep { mut evals, merkle_proof, diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 2241364d64..21eb569b86 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -77,18 +77,11 @@ fn fri_committed_trees, C: GenericConfig, challenger: &mut Challenger, fri_params: &FriParams, ) -> FriCommitedTrees { - let arities = if fri_params.hiding { - let mut tmp = vec![1]; - tmp.extend(&fri_params.reduction_arity_bits); - tmp - } else { - fri_params.reduction_arity_bits.clone() - }; - let mut trees = Vec::with_capacity(arities.len()); + let mut trees = Vec::with_capacity(fri_params.reduction_arity_bits.len()); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; - for arity_bits in &arities { + for arity_bits in &fri_params.reduction_arity_bits { let arity = 1 << arity_bits; reverse_index_bits_in_place(&mut values.values); diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 8a620d15a5..e11273c93e 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -2,6 +2,7 @@ use alloc::{format, vec::Vec}; use itertools::Itertools; +use plonky2_field::interpolation; use plonky2_field::types::Field; use crate::field::extension::Extendable; @@ -53,6 +54,10 @@ impl, const D: usize> CircuitBuilder { arity_bits, self.config.max_quotient_degree_factor, ); + println!( + "arity bits {}, interpolation gate {:?}", + arity_bits, interpolation_gate + ); self.interpolate_coset(interpolation_gate, coset_start, &evals, beta) } @@ -341,19 +346,11 @@ impl, const D: usize> CircuitBuilder { ) ); - let reduction_arity_bits = if params.hiding { - let mut tmp = vec![1]; - tmp.extend(¶ms.reduction_arity_bits); - tmp - } else { - params.reduction_arity_bits.clone() - }; - let cap_index = self.le_sum( x_index_bits[x_index_bits.len() + params.hiding as usize - params.config.cap_height..] .iter(), ); - for (i, &arity_bits) in reduction_arity_bits.iter().enumerate() { + for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() { let evals = &round_proof.steps[i].evals; // Split x_index into the index of the coset x is in, and the index of x within that coset. @@ -472,20 +469,13 @@ impl, const D: usize> CircuitBuilder { let initial_trees_proof = self.add_virtual_fri_initial_trees_proof(num_leaves_per_oracle, merkle_proof_len); - let reduction_arity_bits = if params.hiding { - let mut tmp = vec![1]; - tmp.extend(¶ms.reduction_arity_bits); - tmp - } else { - params.reduction_arity_bits.clone() - }; - let mut steps = Vec::with_capacity(reduction_arity_bits.len()); + let mut steps = Vec::with_capacity(params.reduction_arity_bits.len()); merkle_proof_len = if params.hiding { merkle_proof_len + 1 } else { merkle_proof_len }; - for &arity_bits in &reduction_arity_bits { + for &arity_bits in ¶ms.reduction_arity_bits { assert!(merkle_proof_len >= arity_bits); merkle_proof_len -= arity_bits; steps.push(self.add_virtual_fri_query_step(arity_bits, merkle_proof_len)); diff --git a/plonky2/src/fri/reduction_strategies.rs b/plonky2/src/fri/reduction_strategies.rs index e7f5d799ff..864ab79250 100644 --- a/plonky2/src/fri/reduction_strategies.rs +++ b/plonky2/src/fri/reduction_strategies.rs @@ -33,11 +33,20 @@ impl FriReductionStrategy { rate_bits: usize, cap_height: usize, num_queries: usize, + hiding: bool, ) -> Vec { match self { - FriReductionStrategy::Fixed(reduction_arity_bits) => reduction_arity_bits.to_vec(), + FriReductionStrategy::Fixed(reduction_arity_bits) => { + if hiding { + let mut tmp = vec![1]; + tmp.extend(reduction_arity_bits); + tmp + } else { + reduction_arity_bits.to_vec() + } + } &FriReductionStrategy::ConstantArityBits(arity_bits, final_poly_bits) => { - let mut result = Vec::new(); + let mut result = if hiding { vec![1] } else { Vec::new() }; while degree_bits > final_poly_bits && degree_bits + rate_bits - arity_bits >= cap_height { diff --git a/plonky2/src/gates/coset_interpolation.rs b/plonky2/src/gates/coset_interpolation.rs index 0b19a1733c..aba008d342 100644 --- a/plonky2/src/gates/coset_interpolation.rs +++ b/plonky2/src/gates/coset_interpolation.rs @@ -68,7 +68,7 @@ impl, const D: usize> CosetInterpolationGate Self::with_max_degree(subgroup_bits, 1 << subgroup_bits) } - pub(crate) fn with_max_degree(subgroup_bits: usize, max_degree: usize) -> Self { + pub fn with_max_degree(subgroup_bits: usize, max_degree: usize) -> Self { assert!(max_degree > 1, "need at least quadratic constraints"); let n_points = 1 << subgroup_bits;