From 3fde6e5ad0daaa7e33d9cf08844a350ebec0428a Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 20 Aug 2024 23:36:04 -0400 Subject: [PATCH] Create MLE eval component --- .../src/constraint_framework/component.rs | 31 +- crates/prover/src/core/air/accumulation.rs | 1 + .../examples/xor/gkr_lookups/accumulation.rs | 13 +- .../src/examples/xor/gkr_lookups/mle_eval.rs | 605 +++++++++++++++--- 4 files changed, 534 insertions(+), 116 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 9950d5ca2..7902a49cb 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -31,7 +31,10 @@ pub struct TreeColumnSpanProvider { } impl TreeColumnSpanProvider { - fn next_for_structure(&mut self, structure: &TreeVec>) -> Vec { + pub fn next_for_structure( + &mut self, + structure: &TreeVec>, + ) -> Vec { structure .iter() .enumerate() @@ -82,6 +85,10 @@ impl FrameworkComponent { trace_locations, } } + + pub fn trace_locations(&self) -> &[TreeColumnSpan] { + &self.trace_locations + } } impl Component for FrameworkComponent { @@ -94,26 +101,20 @@ impl Component for FrameworkComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new( - self.eval - .evaluate(InfoEvaluator::default()) - .mask_offsets - .iter() - .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) - .collect(), - ) + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()]) } fn mask_points( &self, point: CirclePoint, ) -> TreeVec>>> { - let info = self.eval.evaluate(InfoEvaluator::default()); let trace_step = CanonicCoset::new(self.eval.log_size()).step(); - info.mask_offsets.map_cols(|col_mask| { - col_mask + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map_cols(|col_offsets| { + col_offsets .iter() - .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) .collect() }) } @@ -138,6 +139,10 @@ impl ComponentProver for FrameworkComponent { trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, ) { + if self.n_constraints() == 0 { + return; + } + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); let trace_domain = CanonicCoset::new(self.eval.log_size()); diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 8fcf57549..5ed55ac56 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers; /// Accumulates N evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. /// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i). +#[derive(Debug, Clone, Copy)] pub struct PointEvaluationAccumulator { random_coeff: SecureField, accumulation: SecureField, diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs index 53ae956e6..a63b62503 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// IOP for multilinear eval at point. pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; -/// Accumulates [`Mle`]s grouped by their number of variables. +/// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { mles_by_n_variables: Vec>>>, } impl MleCollection { - /// Appends an [`Mle`] to the collection. + /// Appends an [`Mle`] to the back of the collection. pub fn push(&mut self, mle: impl Into>) { let mle = mle.into(); let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); @@ -35,6 +35,7 @@ impl MleCollection { impl MleCollection { /// Performs a random linear combination of all MLEs, grouped by their number of variables. /// + /// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`. /// MLEs are returned in ascending order by number of variables. pub fn random_linear_combine_by_n_variables( self, @@ -53,13 +54,15 @@ impl MleCollection { /// Panics if `mles` is empty or all MLEs don't have the same number of variables. fn mle_random_linear_combination( mles: Vec>, - alpha: SecureField, + random_coeff: SecureField, ) -> Mle { assert!(!mles.is_empty()); let n_variables = mles[0].n_variables(); assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); - let mut mle_and_coeff = zip(mles, alpha_powers); + let coeff_powers = generate_secure_powers(random_coeff, mles.len()) + .into_iter() + .rev(); + let mut mle_and_coeff = zip(mles, coeff_powers); // The last value can initialize the accumulator. let (mle, coeff) = mle_and_coeff.next_back().unwrap(); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 9872c17dd..03a40857f 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -2,17 +2,23 @@ // TODO(andrew): Remove in downstream PR. #![allow(dead_code)] -use std::array; use std::iter::zip; -use itertools::{chain, zip_eq}; +use itertools::{chain, zip_eq, Itertools}; use num_traits::{One, Zero}; +use tracing::{span, Level}; use crate::constraint_framework::constant_columns::gen_is_first; -use crate::constraint_framework::EvalAtRow; -use crate::core::backend::simd::column::SecureColumn; +use crate::constraint_framework::{ + EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, TreeColumnSpanProvider, +}; +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::simd::column::{SecureColumn, VeryPackedSecureColumnByCoords}; +use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::circle::{CirclePoint, Coset}; @@ -24,20 +30,272 @@ use crate::core::fields::FieldExpOps; use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; +use crate::core::pcs::{TreeColumnSpan, TreeVec}; +use crate::core::poly::circle::{ + CanonicCoset, CircleEvaluation, SecureCirclePoly, SecureEvaluation, +}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::ColumnVec; + +/// Component that carries out a univariate IOP for multilinear eval at point. +/// +/// See (Section 5.1). +#[allow(dead_code)] +pub struct MleEvalProverComponent<'twiddles, 'oracle, O: MleCoeffColumnOracle> { + /// Polynomials encoding the multilinear Lagrange basis coefficients of the MLE. + mle_coeff_column_poly: SecureCirclePoly, + /// Oracle for the polynomial encoding the multilinear Lagrange basis coefficients of the MLE. + /// + /// The oracle values should match `mle_coeff_column_poly` for any given evaluation point. The + /// polynomial is only stored directly to speed up constraint evaluation. The oracle is stored + /// to perform consistency checks with `mle_coeff_column_poly`. + mle_coeff_column_oracle: &'oracle O, + /// Multilinear evaluation point. + mle_eval_point: MleEvalPoint, + /// Equals `mle_claim / 2^mle_n_variables`. + mle_claim_shift: SecureField, + /// Commitment tree index for the trace. + evals_interaction: usize, + /// Commitment tree index for the constants trace. + const_interaction: usize, + /// Location in the trace for the this component. + trace_locations: Vec, + /// Precomputed twiddles tree. + twiddles: &'twiddles TwiddleTree, +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddles, 'oracle, O> { + // TODO(andrew): Some eval points may affect completeness. Document. + #[allow(clippy::too_many_arguments)] + pub fn generate( + provider: &mut TreeColumnSpanProvider, + mle_coeff_column_oracle: &'oracle O, + mle_eval_point: &[SecureField], + mle: Mle, + mle_claim: SecureField, + twiddles: &'twiddles TwiddleTree, + evals_interaction: usize, + const_interaction: usize, + ) -> Self { + #[cfg(test)] + assert_eq!(mle_claim, mle.eval_at_point(mle_eval_point)); + let n_variables = mle.n_variables(); + let mle_claim_shift = mle_claim / BaseField::from(1 << n_variables); + + let domain = CanonicCoset::new(n_variables as u32).circle_domain(); + let values = mle.into_evals().into_secure_column_by_coords(); + let mle_trace = SecureEvaluation::::new(domain, values); + let mle_coeff_column_poly = mle_trace.interpolate_with_twiddles(twiddles); + + let trace_structure = + mle_eval_info(evals_interaction, const_interaction, n_variables).mask_offsets; + let trace_locations = provider.next_for_structure(&trace_structure); + + Self { + mle_coeff_column_poly, + mle_coeff_column_oracle, + mle_eval_point: MleEvalPoint::new(mle_eval_point), + mle_claim_shift, + evals_interaction, + const_interaction, + trace_locations, + twiddles, + } + } + + /// Size of this components trace columns. + pub fn log_size(&self) -> u32 { + self.mle_eval_point.n_variables() as u32 + } + + pub fn eval_info(&self) -> InfoEvaluator { + let n_variables = self.mle_eval_point.n_variables(); + mle_eval_info(self.evals_interaction, self.const_interaction, n_variables) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn n_constraints(&self) -> usize { + self.eval_info().n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + let log_size = self.log_size(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map(|tree_offsets| vec![log_size; tree_offsets.len()]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let trace_step = CanonicCoset::new(self.log_size()).step(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map_cols(|col_offsets| { + col_offsets + .iter() + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + accumulator: &mut PointEvaluationAccumulator, + ) { + // Consistency check the MLE coeffs column polynomial and oracle. + let mle_coeff_col_eval = self.mle_coeff_column_poly.eval_at_point(point); + let oracle_mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); + assert_eq!(mle_coeff_col_eval, oracle_mle_coeff_col_eval); + + let component_mask = mask.sub_tree(&self.trace_locations); + let trace_coset = CanonicCoset::new(self.log_size()).coset; + let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); + let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + + let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); + + eval_mle_eval_constraints( + self.evals_interaction, + self.const_interaction, + &mut eval, + mle_coeff_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + ) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, SimdBackend>, + accumulator: &mut DomainEvaluationAccumulator, + ) { + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.log_size()); + + let component_trace = trace.evals.sub_tree(&self.trace_locations).map_cols(|c| *c); + + // Extend MLE coeffs column. + let span = span!(Level::INFO, "Extension").entered(); + let mle_coeffs_column_lde = VeryPackedSecureColumnByCoords::from( + self.mle_coeff_column_poly + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + let carry_quotients_column_mle = VeryPackedSecureColumnByCoords::from( + gen_carry_quotient_col(&self.mle_eval_point.p) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + span.exit(); + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + acc.random_coeff_powers.reverse(); + let acc_col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(acc.col) }; + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + let n_very_packed_rows = + 1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS); + for vec_row in 0..n_very_packed_rows { + // Evaluate constrains at row. + let mut eval = SimdDomainEvaluator::new( + &component_trace, + vec_row, + &acc.random_coeff_powers, + trace_domain.log_size(), + eval_domain.log_size(), + ); + + let mle_coeffs_col_eval = unsafe { mle_coeffs_column_lde.packed_at(vec_row) }; + let carry_quotients_col_eval = unsafe { carry_quotients_column_mle.packed_at(vec_row) }; + eval_mle_eval_constraints( + self.evals_interaction, + self.const_interaction, + &mut eval, + mle_coeffs_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + ); + + // Finalize row. + let row_res = eval.row_res; + let denom_inv = VeryPackedBaseField::broadcast( + denom_inv + [vec_row >> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)], + ); + unsafe { acc_col.set_packed(vec_row, acc_col.packed_at(vec_row) + row_res * denom_inv) } + } + } +} + +fn mle_eval_info( + mle_interaction: usize, + selector_interaction: usize, + n_variables: usize, +) -> InfoEvaluator { + let mut eval = InfoEvaluator::default(); + let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); + let mle_claim_shift = SecureField::zero(); + let mle_coeffs_col_eval = SecureField::zero(); + let carry_quotients_col_eval = SecureField::zero(); + eval_mle_eval_constraints( + mle_interaction, + selector_interaction, + &mut eval, + mle_coeffs_col_eval, + &mle_eval_point, + mle_claim_shift, + carry_quotients_col_eval, + ); + eval +} + +/// Univariate polynomial oracle that encodes multilinear Lagrange basis coefficients of a MLE. +/// +/// The column should encode the MLE coefficients ordered on a circle domain. +pub trait MleCoeffColumnOracle { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField; +} /// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. /// /// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. -pub fn eval_mle_eval_constraints( +pub fn eval_mle_eval_constraints( mle_interaction: usize, const_interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, mle_claim_shift: SecureField, carry_quotients_col_eval: E::EF, ) { @@ -52,37 +310,49 @@ pub fn eval_mle_eval_constraints( eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) } -#[derive(Debug, Clone, Copy)] -pub struct MleEvalPoint { +#[derive(Debug, Clone)] +pub struct MleEvalPoint { // Equals `eq({0}^|p|, p)`. eq_0_p: SecureField, // Equals `eq({1}^|p|, p)`. eq_1_p: SecureField, // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. - eq_carry_quotients: [SecureField; N_VARIABLES], + eq_carry_quotients: Vec, // Point `p`. - p: [SecureField; N_VARIABLES], + p: Vec, } -impl MleEvalPoint { +impl MleEvalPoint { /// Creates new metadata from point `p`. - pub fn new(p: [SecureField; N_VARIABLES]) -> Self { + /// + /// # Panics + /// + /// Panics if the point is empty. + pub fn new(p: &[SecureField]) -> Self { + assert!(!p.is_empty()); + let n_variables = p.len(); let zero = SecureField::zero(); let one = SecureField::one(); Self { - eq_0_p: eq(&[zero; N_VARIABLES], &p), - eq_1_p: eq(&[one; N_VARIABLES], &p), - eq_carry_quotients: array::from_fn(|i| { - let mut numer_assignment = vec![one; i + 1]; - numer_assignment[i] = zero; - let mut denom_assignment = vec![zero; i + 1]; - denom_assignment[i] = one; - eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) - }), - p, + eq_0_p: eq(&vec![zero; n_variables], p), + eq_1_p: eq(&vec![one; n_variables], p), + eq_carry_quotients: (0..n_variables) + .map(|i| { + let mut numer_assignment = vec![one; i + 1]; + numer_assignment[i] = zero; + let mut denom_assignment = vec![zero; i + 1]; + denom_assignment[i] = one; + eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) + }) + .collect(), + p: p.to_vec(), } } + + pub fn n_variables(&self) -> usize { + self.p.len() + } } /// Evaluates EqEvals constraints on a column. @@ -93,11 +363,11 @@ impl MleEvalPoint { /// evaluates constraints that guarantee: `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. /// /// See (Section 5.1). -fn eval_eq_constraints( +fn eval_eq_constraints( eq_interaction: usize, const_interaction: usize, eval: &mut E, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, carry_quotients_col_eval: E::EF, ) -> E::EF { let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); @@ -169,22 +439,11 @@ pub fn build_trace( .collect() } -/// Generates a trace. -/// -/// Trace structure: -/// 1. Is first selector column (see [gen_is_first]). -/// 2. Eq carry quotients column (see [gen_carry_quotient_trace]). -/// -/// ```text -/// ------------------------------------------------ -/// | is first selector | eq carry quotients | -/// ------------------------------------------------ -/// | c0 | c1 | c2 | c3 | c4 | -/// ------------------------------------------------ -/// ``` -pub fn build_constant_trace( +/// Returns a trace with a single "is first" selector column. +pub fn build_constant_trace( + mle_n_variables: usize, ) -> Vec> { - let log_size = N_VARIABLES as u32; + let log_size = mle_n_variables as u32; vec![gen_is_first(log_size)] } @@ -196,14 +455,15 @@ pub fn build_constant_trace( /// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain -fn gen_carry_quotient_col( - eval_point: &[SecureField; N_VARIABLES], +fn gen_carry_quotient_col( + eval_point: &[SecureField], ) -> SecureEvaluation { - let mle_eval_point = MleEvalPoint::new(*eval_point); + assert!(!eval_point.is_empty()); + let mle_eval_point = MleEvalPoint::new(eval_point); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = gen_half_coset_carry_quotients(&mle_eval_point); - let log_size = N_VARIABLES as u32; + let log_size = mle_eval_point.n_variables() as u32; let size = 1 << log_size; let half_coset_size = size / 2; let mut col = SecureColumnByCoords::::zeros(size); @@ -230,11 +490,9 @@ fn gen_carry_quotient_col( /// Evaluates the succinct Eq carry quotients column. /// /// See [`gen_carry_quotient_col`]. -fn eval_carry_quotient_col( - eval_point: &MleEvalPoint, - p: CirclePoint, -) -> SecureField { - let log_size = N_VARIABLES as u32; +fn eval_carry_quotient_col(eval_point: &MleEvalPoint, p: CirclePoint) -> SecureField { + let n_variables = eval_point.n_variables(); + let log_size = n_variables as u32; let coset = CanonicCoset::new(log_size).coset(); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = @@ -242,7 +500,7 @@ fn eval_carry_quotient_col( let mut eval = SecureField::zero(); - for variable_i in 0..N_VARIABLES.saturating_sub(1) { + for variable_i in 0..n_variables.saturating_sub(1) { let log_step = variable_i as u32 + 2; let offset = (1 << (log_step - 1)) - 2; let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); @@ -293,14 +551,17 @@ fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { } /// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. -fn gen_half_coset_carry_quotients( - eval_point: &MleEvalPoint, -) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { +fn gen_half_coset_carry_quotients( + eval_point: &MleEvalPoint, +) -> (Vec, Vec) { let last_variable = *eval_point.p.last().unwrap(); - let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients.clone(); *half_coset0_carry_quotients.last_mut().unwrap() *= eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); - let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + let half_coset1_carry_quotients = half_coset0_carry_quotients + .iter() + .map(|v| v.inverse()) + .collect(); (half_coset0_carry_quotients, half_coset1_carry_quotients) } @@ -322,6 +583,7 @@ mod tests { use std::iter::{repeat, zip}; use itertools::{chain, Itertools}; + use mle_coeff_column::{MleCoeffColumnComponent, MleCoeffColumnEval}; use num_traits::One; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -331,23 +593,97 @@ mod tests { eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, }; use crate::constraint_framework::constant_columns::gen_is_step_with_offset; - use crate::constraint_framework::{assert_constraints, EvalAtRow}; + use crate::constraint_framework::{assert_constraints, EvalAtRow, TreeColumnSpanProvider}; + use crate::core::air::{Component, ComponentProver, Components}; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; + use crate::core::channel::Blake2sChannel; use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::lookups::mle::Mle; - use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps, SecureEvaluation}; + use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; + use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; use crate::examples::xor::gkr_lookups::mle_eval::{ - build_constant_trace, build_trace, eval_step_selector_with_offset, + build_constant_trace, build_trace, eval_step_selector_with_offset, MleEvalProverComponent, }; + #[test] + fn mle_eval_prover_component() -> Result<(), VerificationError> { + const N_VARIABLES: usize = 8; + const COEFFS_COL_TRACE: usize = 0; + const EVAL_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + const LOG_EXPAND: u32 = 1; + // Create the test MLE. + let mut rng = SmallRng::seed_from_u64(0); + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let mle_coeffs = (0..size).map(|_| rng.gen::()).collect(); + let mle = Mle::::new(mle_coeffs); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let claim = mle.eval_at_point(&eval_point); + // Setup protocol. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_size + LOG_EXPAND + MIN_LOG_BLOWUP_FACTOR) + .circle_domain() + .half_coset, + ); + let config = PcsConfig::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + let channel = &mut Blake2sChannel::default(); + // Build trace. + // 1. MLE coeffs trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(mle_coeff_column::build_trace(&mle)); + tree_builder.commit(channel); + // 2. MLE eval trace (eq evals + prefix sum). + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_trace(&mle, &eval_point, claim)); + tree_builder.commit(channel); + // 3. Constants trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_constant_trace(N_VARIABLES)); + tree_builder.commit(channel); + // Create components. + let provider = &mut TreeColumnSpanProvider::default(); + let mle_coeffs_col_component = MleCoeffColumnComponent::new( + provider, + MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + ); + let mle_eval_component = MleEvalProverComponent::generate( + provider, + &mle_coeffs_col_component, + &eval_point, + mle, + claim, + &twiddles, + EVAL_TRACE, + CONST_TRACE, + ); + let components: &[&dyn ComponentProver] = + &[&mle_coeffs_col_component, &mle_eval_component]; + // Generate proof. + let proof = prove(components, channel, commitment_scheme).unwrap(); + + // Verify. + let components = Components(components.iter().map(|&c| c as &dyn Component).collect()); + let log_sizes = components.column_log_sizes(); + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + commitment_scheme.commit(proof.commitments[0], &log_sizes[0], channel); + commitment_scheme.commit(proof.commitments[1], &log_sizes[1], channel); + commitment_scheme.commit(proof.commitments[2], &log_sizes[2], channel); + verify(&components.0, channel, commitment_scheme, proof) + } + #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; @@ -362,13 +698,13 @@ mod tests { let mle = Mle::::new(mle_coeffs); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let claim = mle.eval_at_point(&eval_point); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let mle_eval_trace = build_trace(&mle, &eval_point, claim); - let mle_coeffs_col_trace = build_mle_coeffs_trace(mle); + let mle_coeffs_col_trace = mle_coeff_column::build_trace(&mle); let claim_shift = claim / BaseField::from(size); let carry_quotients_col = gen_carry_quotient_col(&eval_point); let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec(); - let constants_trace = build_constant_trace::(); + let constants_trace = build_constant_trace(N_VARIABLES); let traces = TreeVec::new(vec![ mle_coeffs_col_trace, mle_eval_trace, @@ -387,7 +723,7 @@ mod tests { CONST_TRACE, &mut eval, mle_coeff_col_eval, - mle_eval_point, + &mle_eval_point, claim_shift, carry_quotients_col_eval, ) @@ -404,14 +740,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point); let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec(); - let constants_trace = build_constant_trace::(); + let constants_trace = build_constant_trace(N_VARIABLES); let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = @@ -420,7 +756,7 @@ mod tests { EVAL_TRACE, CONST_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, ); }); @@ -435,14 +771,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point); let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec(); - let constants_trace = build_constant_trace::(); + let constants_trace = build_constant_trace(N_VARIABLES); let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = @@ -451,7 +787,7 @@ mod tests { EVAL_TRACE, CONST_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, ); }); @@ -466,14 +802,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point); let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec(); - let constants_trace = build_constant_trace::(); + let constants_trace = build_constant_trace(N_VARIABLES); let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = @@ -482,7 +818,7 @@ mod tests { EVAL_TRACE, CONST_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, ); }); @@ -525,7 +861,7 @@ mod tests { const N_VARIABLES: usize = 5; let mut rng = SmallRng::seed_from_u64(0); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let col_eval = gen_carry_quotient_col(&eval_point); let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); let col_poly = col_eval.interpolate_with_twiddles(&twiddles); @@ -576,26 +912,99 @@ mod tests { .collect() } - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// ----------------------------- - /// | MLE coeffs col | - /// ----------------------------- - /// | c0 | c1 | c2 | c3 | - /// ----------------------------- - /// ``` - fn build_mle_coeffs_trace( - mle: Mle, - ) -> Vec> { - let log_size = mle.n_variables() as u32; - let trace_domain = CanonicCoset::new(log_size).circle_domain(); - let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords(); - SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) - .into_coordinate_evals() - .into_iter() - .collect() + mod mle_coeff_column { + use num_traits::One; + + use crate::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator, + }; + use crate::core::air::accumulation::PointEvaluationAccumulator; + use crate::core::backend::simd::SimdBackend; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::mle::Mle; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::core::ColumnVec; + use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + + pub type MleCoeffColumnComponent = FrameworkComponent; + + pub struct MleCoeffColumnEval { + interaction: usize, + n_variables: usize, + } + + impl MleCoeffColumnEval { + pub fn new(interaction: usize, n_variables: usize) -> Self { + Self { + interaction, + n_variables, + } + } + } + + impl FrameworkEval for MleCoeffColumnEval { + fn log_size(&self) -> u32 { + self.n_variables as u32 + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = eval_mle_coeff_col(self.interaction, &mut eval); + eval + } + } + + impl MleCoeffColumnOracle for MleCoeffColumnComponent { + fn evaluate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + // Create dummy point evaluator just to extract the value we need from the mask + let mut accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut accumulator, + SecureField::one(), + ); + + eval_mle_coeff_col(self.interaction, &mut eval) + } + } + + fn eval_mle_coeff_col(interaction: usize, eval: &mut E) -> E::EF { + let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(interaction, [0]); + mle_coeff_col_eval + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// ----------------------------- + /// | MLE coeffs col | + /// ----------------------------- + /// | c0 | c1 | c2 | c3 | + /// ----------------------------- + /// ``` + pub fn build_trace( + mle: &Mle, + ) -> Vec> { + let log_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + let mle_coeffs_col_by_coords = mle.clone().into_evals().into_secure_column_by_coords(); + SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) + .into_coordinate_evals() + .into_iter() + .collect() + } } }