From 5b3001a0cbdeaa39467ecc4eac63ffb8a8b90a7b Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 20 Aug 2024 23:36:04 -0400 Subject: [PATCH 1/3] Create MLE eval component --- .../src/constraint_framework/component.rs | 35 +- crates/prover/src/core/air/accumulation.rs | 1 + .../examples/xor/gkr_lookups/accumulation.rs | 11 +- .../src/examples/xor/gkr_lookups/mle_eval.rs | 587 +++++++++++++++--- 4 files changed, 534 insertions(+), 100 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index c0d8319fa..4c7f75768 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -31,7 +31,10 @@ pub struct TraceLocationAllocator { } impl TraceLocationAllocator { - fn next_for_structure(&mut self, structure: &TreeVec>) -> TreeVec { + pub fn next_for_structure( + &mut self, + structure: &TreeVec>, + ) -> TreeVec { if structure.len() > self.next_tree_offsets.len() { self.next_tree_offsets.resize(structure.len(), 0); } @@ -72,14 +75,18 @@ pub struct FrameworkComponent { } impl FrameworkComponent { - pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self { let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; - let trace_locations = provider.next_for_structure(&eval_tree_structure); + let trace_locations = location_allocator.next_for_structure(&eval_tree_structure); Self { eval, trace_locations, } } + + pub fn trace_locations(&self) -> &[TreeSubspan] { + &self.trace_locations + } } impl Component for FrameworkComponent { @@ -92,26 +99,20 @@ impl Component for FrameworkComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new( - self.eval - .evaluate(InfoEvaluator::default()) - .mask_offsets - .iter() - .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) - .collect(), - ) + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()]) } fn mask_points( &self, point: CirclePoint, ) -> TreeVec>>> { - let info = self.eval.evaluate(InfoEvaluator::default()); let trace_step = CanonicCoset::new(self.eval.log_size()).step(); - info.mask_offsets.map_cols(|col_mask| { - col_mask + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map_cols(|col_offsets| { + col_offsets .iter() - .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) .collect() }) } @@ -136,6 +137,10 @@ impl ComponentProver for FrameworkComponent { trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, ) { + if self.n_constraints() == 0 { + return; + } + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); let trace_domain = CanonicCoset::new(self.eval.log_size()); diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 8fcf57549..2e7010ae6 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers; /// Accumulates N evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. /// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i). +#[derive(Debug, Clone)] pub struct PointEvaluationAccumulator { random_coeff: SecureField, accumulation: SecureField, diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs index 53ae956e6..986289572 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// IOP for multilinear eval at point. pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; -/// Accumulates [`Mle`]s grouped by their number of variables. +/// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { mles_by_n_variables: Vec>>>, } impl MleCollection { - /// Appends an [`Mle`] to the collection. + /// Appends an [`Mle`] to the back of the collection. pub fn push(&mut self, mle: impl Into>) { let mle = mle.into(); let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); @@ -35,6 +35,7 @@ impl MleCollection { impl MleCollection { /// Performs a random linear combination of all MLEs, grouped by their number of variables. /// + /// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`. /// MLEs are returned in ascending order by number of variables. pub fn random_linear_combine_by_n_variables( self, @@ -53,13 +54,13 @@ impl MleCollection { /// Panics if `mles` is empty or all MLEs don't have the same number of variables. fn mle_random_linear_combination( mles: Vec>, - alpha: SecureField, + random_coeff: SecureField, ) -> Mle { assert!(!mles.is_empty()); let n_variables = mles[0].n_variables(); assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); - let mut mle_and_coeff = zip(mles, alpha_powers); + let coeff_powers = generate_secure_powers(random_coeff, mles.len()); + let mut mle_and_coeff = zip(mles, coeff_powers.into_iter().rev()); // The last value can initialize the accumulator. let (mle, coeff) = mle_and_coeff.next_back().unwrap(); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 72dc133a5..0026baeac 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -2,16 +2,23 @@ // TODO(andrew): Remove in downstream PR. #![allow(dead_code)] -use std::array; use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; use num_traits::{One, Zero}; - -use crate::constraint_framework::EvalAtRow; -use crate::core::backend::simd::column::SecureColumn; +use tracing::{span, Level}; + +use crate::constraint_framework::constant_columns::gen_is_first; +use crate::constraint_framework::{ + EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, TraceLocationAllocator, +}; +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::simd::column::{SecureColumn, VeryPackedSecureColumnByCoords}; +use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::circle::{CirclePoint, Coset}; @@ -23,20 +30,284 @@ use crate::core::fields::{Field, FieldExpOps}; use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::{ + CanonicCoset, CircleEvaluation, SecureCirclePoly, SecureEvaluation, +}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::ColumnVec; + +/// Component that carries out a univariate IOP for multilinear eval at point. +/// +/// See (Section 5.1). +#[allow(dead_code)] +pub struct MleEvalProverComponent<'twiddles, 'oracle, O: MleCoeffColumnOracle> { + /// Polynomials encoding the multilinear Lagrange basis coefficients of the MLE. + mle_coeff_column_poly: SecureCirclePoly, + /// Oracle for the polynomial encoding the multilinear Lagrange basis coefficients of the MLE. + /// + /// The oracle values should match `mle_coeff_column_poly` for any given evaluation point. The + /// polynomial is only stored directly to speed up constraint evaluation. The oracle is stored + /// to perform consistency checks with `mle_coeff_column_poly`. + mle_coeff_column_oracle: &'oracle O, + /// Multilinear evaluation point. + mle_eval_point: MleEvalPoint, + /// Equals `mle_claim / 2^mle_n_variables`. + mle_claim_shift: SecureField, + /// Commitment tree index for the trace. + interaction: usize, + /// Location in the trace for the this component. + trace_locations: TreeVec, + /// Precomputed twiddles tree. + twiddles: &'twiddles TwiddleTree, +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddles, 'oracle, O> { + /// Generates prover component that carries out univariate IOP for MLE eval at point. + /// + /// # Panics + /// + /// Panics if the eval point has a coordinate that is zero or one. This is a completeness bug. + pub fn generate( + location_allocator: &mut TraceLocationAllocator, + mle_coeff_column_oracle: &'oracle O, + mle_eval_point: &[SecureField], + mle: Mle, + mle_claim: SecureField, + twiddles: &'twiddles TwiddleTree, + interaction: usize, + ) -> Self { + #[cfg(test)] + assert_eq!(mle_claim, mle.eval_at_point(mle_eval_point)); + let n_variables = mle.n_variables(); + let mle_claim_shift = mle_claim / BaseField::from(1 << n_variables); + + let domain = CanonicCoset::new(n_variables as u32).circle_domain(); + let values = mle.into_evals().into_secure_column_by_coords(); + let mle_trace = SecureEvaluation::::new(domain, values); + let mle_coeff_column_poly = mle_trace.interpolate_with_twiddles(twiddles); + + let trace_structure = mle_eval_info(interaction, n_variables).mask_offsets; + let trace_locations = location_allocator.next_for_structure(&trace_structure); + + Self { + mle_coeff_column_poly, + mle_coeff_column_oracle, + mle_eval_point: MleEvalPoint::new(mle_eval_point), + mle_claim_shift, + interaction, + trace_locations, + twiddles, + } + } + + /// Size of this components trace columns. + pub fn log_size(&self) -> u32 { + self.mle_eval_point.n_variables() as u32 + } + + pub fn eval_info(&self) -> InfoEvaluator { + let n_variables = self.mle_eval_point.n_variables(); + mle_eval_info(self.interaction, n_variables) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn n_constraints(&self) -> usize { + self.eval_info().n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + let log_size = self.log_size(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map(|tree_offsets| vec![log_size; tree_offsets.len()]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let trace_step = CanonicCoset::new(self.log_size()).step(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map_cols(|col_offsets| { + col_offsets + .iter() + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + accumulator: &mut PointEvaluationAccumulator, + ) { + // Consistency check the MLE coeffs column polynomial and oracle. + let mle_coeff_col_eval = self.mle_coeff_column_poly.eval_at_point(point); + let oracle_mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); + assert_eq!(mle_coeff_col_eval, oracle_mle_coeff_col_eval); + + let component_mask = mask.sub_tree(&self.trace_locations); + let trace_coset = CanonicCoset::new(self.log_size()).coset; + let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); + let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + + let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); + let is_first = eval_is_first(trace_coset, point); + let is_second = eval_is_first(trace_coset, point - trace_coset.step.into_ef()); + + // TODO(andrew): Consider evaluating `is_first` and `is_second` inside + // `eval_mle_eval_constraints` once constant column approach updated. + eval_mle_eval_constraints( + self.interaction, + &mut eval, + mle_coeff_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, SimdBackend>, + accumulator: &mut DomainEvaluationAccumulator, + ) { + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.log_size()); + + let mut component_trace = trace.evals.sub_tree(&self.trace_locations).map_cols(|c| *c); + + // Build auxiliary trace. + let span = span!(Level::INFO, "Extension").entered(); + let mle_coeffs_column_lde = self + .mle_coeff_column_poly + .evaluate_with_twiddles(eval_domain, self.twiddles) + .into_coordinate_evals(); + let carry_quotients_column_lde = gen_carry_quotient_col(&self.mle_eval_point.p) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles) + .into_coordinate_evals(); + let is_first_lde = gen_is_first::(self.log_size()) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles); + let aux_interaction = component_trace.len(); + let aux_trace = chain![ + &mle_coeffs_column_lde, + &carry_quotients_column_lde, + [&is_first_lde] + ] + .collect(); + component_trace.push(aux_trace); + span.exit(); + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + acc.random_coeff_powers.reverse(); + let acc_col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(acc.col) }; + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + let n_very_packed_rows = + 1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS); + for vec_row in 0..n_very_packed_rows { + // Evaluate constrains at row. + let mut eval = SimdDomainEvaluator::new( + &component_trace, + vec_row, + &acc.random_coeff_powers, + trace_domain.log_size(), + eval_domain.log_size(), + ); + let [mle_coeffs_col_eval] = eval.next_extension_interaction_mask(aux_interaction, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(aux_interaction, [0]); + let [is_first, is_second] = eval.next_interaction_mask(aux_interaction, [0, -1]); + eval_mle_eval_constraints( + self.interaction, + &mut eval, + mle_coeffs_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ); + + // Finalize row. + let row_res = eval.row_res; + let denom_inv = VeryPackedBaseField::broadcast( + denom_inv + [vec_row >> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)], + ); + unsafe { acc_col.set_packed(vec_row, acc_col.packed_at(vec_row) + row_res * denom_inv) } + } + } +} + +fn mle_eval_info(interaction: usize, n_variables: usize) -> InfoEvaluator { + let mut eval = InfoEvaluator::default(); + let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); + let mle_claim_shift = SecureField::zero(); + let mle_coeffs_col_eval = SecureField::zero(); + let carry_quotients_col_eval = SecureField::zero(); + let is_first = BaseField::zero(); + let is_second = BaseField::zero(); + eval_mle_eval_constraints( + interaction, + &mut eval, + mle_coeffs_col_eval, + &mle_eval_point, + mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ); + eval +} + +/// Univariate polynomial oracle that encodes multilinear Lagrange basis coefficients of a MLE. +/// +/// The column should encode the MLE coefficients ordered on a circle domain. +pub trait MleCoeffColumnOracle { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField; +} /// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. /// /// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. #[allow(clippy::too_many_arguments)] -pub fn eval_mle_eval_constraints( +pub fn eval_mle_eval_constraints( interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, mle_claim_shift: SecureField, carry_quotients_col_eval: E::EF, is_first: E::F, @@ -54,37 +325,49 @@ pub fn eval_mle_eval_constraints( eval_prefix_sum_constraints(interaction, eval, terms_col_eval, mle_claim_shift) } -#[derive(Debug, Clone, Copy)] -pub struct MleEvalPoint { +#[derive(Debug, Clone)] +pub struct MleEvalPoint { // Equals `eq({0}^|p|, p)`. eq_0_p: SecureField, // Equals `eq({1}^|p|, p)`. eq_1_p: SecureField, // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. - eq_carry_quotients: [SecureField; N_VARIABLES], + eq_carry_quotients: Vec, // Point `p`. - p: [SecureField; N_VARIABLES], + p: Vec, } -impl MleEvalPoint { +impl MleEvalPoint { /// Creates new metadata from point `p`. - pub fn new(p: [SecureField; N_VARIABLES]) -> Self { + /// + /// # Panics + /// + /// Panics if the point is empty or has a coordinate that is zero or one. + pub fn new(p: &[SecureField]) -> Self { + assert!(!p.is_empty()); + let n_variables = p.len(); let zero = SecureField::zero(); let one = SecureField::one(); Self { - eq_0_p: eq(&[zero; N_VARIABLES], &p), - eq_1_p: eq(&[one; N_VARIABLES], &p), - eq_carry_quotients: array::from_fn(|i| { - let mut numer_assignment = vec![one; i + 1]; - numer_assignment[i] = zero; - let mut denom_assignment = vec![zero; i + 1]; - denom_assignment[i] = one; - eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) - }), - p, + eq_0_p: eq(&vec![zero; n_variables], p), + eq_1_p: eq(&vec![one; n_variables], p), + eq_carry_quotients: (0..n_variables) + .map(|i| { + let mut numer_assignment = vec![one; i + 1]; + numer_assignment[i] = zero; + let mut denom_assignment = vec![zero; i + 1]; + denom_assignment[i] = one; + eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) + }) + .collect(), + p: p.to_vec(), } } + + pub fn n_variables(&self) -> usize { + self.p.len() + } } /// Evaluates EqEvals constraints on a column. @@ -95,10 +378,10 @@ impl MleEvalPoint { /// evaluates constraints that guarantee: `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. /// /// See (Section 5.1). -fn eval_eq_constraints( +fn eval_eq_constraints( eq_interaction: usize, eval: &mut E, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, carry_quotients_col_eval: E::EF, is_first: E::F, is_second: E::F, @@ -179,14 +462,15 @@ pub fn build_trace( /// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain -fn gen_carry_quotient_col( - eval_point: &[SecureField; N_VARIABLES], +fn gen_carry_quotient_col( + eval_point: &[SecureField], ) -> SecureEvaluation { - let mle_eval_point = MleEvalPoint::new(*eval_point); + assert!(!eval_point.is_empty()); + let mle_eval_point = MleEvalPoint::new(eval_point); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = gen_half_coset_carry_quotients(&mle_eval_point); - let log_size = N_VARIABLES as u32; + let log_size = mle_eval_point.n_variables() as u32; let size = 1 << log_size; let half_coset_size = size / 2; let mut col = SecureColumnByCoords::::zeros(size); @@ -216,11 +500,9 @@ fn gen_carry_quotient_col( // TODO(andrew): Optimize further. Inline `eval_step_selector` and get runtime down to // O(N_VARIABLES) vs current O(N_VARIABLES^2). Can also use vanishing evals to compute // half_coset0_last half_coset1_first. -fn eval_carry_quotient_col( - eval_point: &MleEvalPoint, - p: CirclePoint, -) -> SecureField { - let log_size = N_VARIABLES as u32; +fn eval_carry_quotient_col(eval_point: &MleEvalPoint, p: CirclePoint) -> SecureField { + let n_variables = eval_point.n_variables(); + let log_size = n_variables as u32; let coset = CanonicCoset::new(log_size).coset(); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = @@ -228,7 +510,7 @@ fn eval_carry_quotient_col( let mut eval = SecureField::zero(); - for variable_i in 0..N_VARIABLES.saturating_sub(1) { + for variable_i in 0..n_variables.saturating_sub(1) { let log_step = variable_i as u32 + 2; let offset = (1 << (log_step - 1)) - 2; let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); @@ -265,7 +547,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) return SecureField::one(); } - // Rotate the coset to have points on the `x` axis. + // Rotate the coset so its first point is the identity element. let p = p - coset.initial.into_ef(); let mut vanish_at_log_step = (0..coset.log_size) .scan(p, |p, _| { @@ -292,14 +574,17 @@ fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { } /// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. -fn gen_half_coset_carry_quotients( - eval_point: &MleEvalPoint, -) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { +fn gen_half_coset_carry_quotients( + eval_point: &MleEvalPoint, +) -> (Vec, Vec) { let last_variable = *eval_point.p.last().unwrap(); - let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients.clone(); *half_coset0_carry_quotients.last_mut().unwrap() *= eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); - let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + let half_coset1_carry_quotients = half_coset0_carry_quotients + .iter() + .map(|v| v.inverse()) + .collect(); (half_coset0_carry_quotients, half_coset1_carry_quotients) } @@ -321,6 +606,7 @@ mod tests { use std::iter::{repeat, zip}; use itertools::{chain, Itertools}; + use mle_coeff_column::{MleCoeffColumnComponent, MleCoeffColumnEval}; use num_traits::One; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -330,23 +616,91 @@ mod tests { eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, }; use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; - use crate::constraint_framework::{assert_constraints, EvalAtRow}; + use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; + use crate::core::air::{Component, ComponentProver, Components}; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; + use crate::core::channel::Blake2sChannel; use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::lookups::mle::Mle; - use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps, SecureEvaluation}; + use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; + use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; use crate::examples::xor::gkr_lookups::mle_eval::{ - build_trace, eval_step_selector_with_offset, + build_trace, eval_step_selector_with_offset, MleEvalProverComponent, }; + #[test] + fn mle_eval_prover_component() -> Result<(), VerificationError> { + const N_VARIABLES: usize = 8; + const COEFFS_COL_TRACE: usize = 0; + const MLE_EVAL_TRACE: usize = 1; + const LOG_EXPAND: u32 = 1; + // Create the test MLE. + let mut rng = SmallRng::seed_from_u64(0); + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let mle_coeffs = (0..size).map(|_| rng.gen::()).collect(); + let mle = Mle::::new(mle_coeffs); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let claim = mle.eval_at_point(&eval_point); + // Setup protocol. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_size + LOG_EXPAND + MIN_LOG_BLOWUP_FACTOR) + .circle_domain() + .half_coset, + ); + let config = PcsConfig::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + let channel = &mut Blake2sChannel::default(); + // Build trace. + // 1. MLE coeffs trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(mle_coeff_column::build_trace(&mle)); + tree_builder.commit(channel); + // 2. MLE eval trace (eq evals + prefix sum). + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_trace(&mle, &eval_point, claim)); + tree_builder.commit(channel); + // Create components. + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let mle_coeffs_col_component = MleCoeffColumnComponent::new( + trace_location_allocator, + MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + ); + let mle_eval_component = MleEvalProverComponent::generate( + trace_location_allocator, + &mle_coeffs_col_component, + &eval_point, + mle, + claim, + &twiddles, + MLE_EVAL_TRACE, + ); + let components: &[&dyn ComponentProver] = + &[&mle_coeffs_col_component, &mle_eval_component]; + // Generate proof. + let proof = prove(components, channel, commitment_scheme).unwrap(); + + // Verify. + let components = Components(components.iter().map(|&c| c as &dyn Component).collect()); + let log_sizes = components.column_log_sizes(); + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + commitment_scheme.commit(proof.commitments[0], &log_sizes[0], channel); + commitment_scheme.commit(proof.commitments[1], &log_sizes[1], channel); + verify(&components.0, channel, commitment_scheme, proof) + } + #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; @@ -360,9 +714,9 @@ mod tests { let mle = Mle::::new(mle_coeffs); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let claim = mle.eval_at_point(&eval_point); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let mle_eval_trace = build_trace(&mle, &eval_point, claim); - let mle_coeffs_col_trace = build_mle_coeffs_trace(mle); + let mle_coeffs_col_trace = mle_coeff_column::build_trace(&mle); let claim_shift = claim / BaseField::from(size); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(log_size)]; @@ -379,7 +733,7 @@ mod tests { MLE_EVAL_TRACE, &mut eval, mle_coeff_col_eval, - mle_eval_point, + &mle_eval_point, claim_shift, carry_quotients_col_eval, is_first_eval, @@ -397,14 +751,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -412,7 +766,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -428,14 +782,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -443,7 +797,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -459,14 +813,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -474,7 +828,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -519,7 +873,7 @@ mod tests { const N_VARIABLES: usize = 5; let mut rng = SmallRng::seed_from_u64(0); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let col_eval = gen_carry_quotient_col(&eval_point); let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); let col_poly = col_eval.interpolate_with_twiddles(&twiddles); @@ -570,26 +924,99 @@ mod tests { .collect() } - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// ----------------------------- - /// | MLE coeffs col | - /// ----------------------------- - /// | c0 | c1 | c2 | c3 | - /// ----------------------------- - /// ``` - fn build_mle_coeffs_trace( - mle: Mle, - ) -> Vec> { - let log_size = mle.n_variables() as u32; - let trace_domain = CanonicCoset::new(log_size).circle_domain(); - let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords(); - SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) - .into_coordinate_evals() - .into_iter() - .collect() + mod mle_coeff_column { + use num_traits::One; + + use crate::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator, + }; + use crate::core::air::accumulation::PointEvaluationAccumulator; + use crate::core::backend::simd::SimdBackend; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::mle::Mle; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::core::ColumnVec; + use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + + pub type MleCoeffColumnComponent = FrameworkComponent; + + pub struct MleCoeffColumnEval { + interaction: usize, + n_variables: usize, + } + + impl MleCoeffColumnEval { + pub fn new(interaction: usize, n_variables: usize) -> Self { + Self { + interaction, + n_variables, + } + } + } + + impl FrameworkEval for MleCoeffColumnEval { + fn log_size(&self) -> u32 { + self.n_variables as u32 + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = eval_mle_coeff_col(self.interaction, &mut eval); + eval + } + } + + impl MleCoeffColumnOracle for MleCoeffColumnComponent { + fn evaluate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + // Create dummy point evaluator just to extract the value we need from the mask + let mut accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut accumulator, + SecureField::one(), + ); + + eval_mle_coeff_col(self.interaction, &mut eval) + } + } + + fn eval_mle_coeff_col(interaction: usize, eval: &mut E) -> E::EF { + let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(interaction, [0]); + mle_coeff_col_eval + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// ----------------------------- + /// | MLE coeffs col | + /// ----------------------------- + /// | c0 | c1 | c2 | c3 | + /// ----------------------------- + /// ``` + pub fn build_trace( + mle: &Mle, + ) -> Vec> { + let log_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + let mle_coeffs_col_by_coords = mle.clone().into_evals().into_secure_column_by_coords(); + SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) + .into_coordinate_evals() + .into_iter() + .collect() + } } } From a969689129afec32cca298eef228f0c7a4d94d0f Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sat, 24 Aug 2024 20:17:27 -0400 Subject: [PATCH 2/3] Create MLE eval verifier component --- .../src/examples/xor/gkr_lookups/mle_eval.rs | 198 +++++++++++++++++- 1 file changed, 191 insertions(+), 7 deletions(-) diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 0026baeac..9d05ce7aa 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -39,7 +39,7 @@ use crate::core::poly::BitReversedOrder; use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; use crate::core::ColumnVec; -/// Component that carries out a univariate IOP for multilinear eval at point. +/// Prover component that carries out a univariate IOP for multilinear eval at point. /// /// See (Section 5.1). #[allow(dead_code)] @@ -103,7 +103,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddl } } - /// Size of this components trace columns. + /// Size of this component's trace columns. pub fn log_size(&self) -> u32 { self.mle_eval_point.n_variables() as u32 } @@ -266,6 +266,115 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver } } +/// Verifier component that carries out a univariate IOP for multilinear eval at point. +/// +/// See (Section 5.1). +pub struct MleEvalVerifierComponent<'oracle, O: MleCoeffColumnOracle> { + /// Oracle for the polynomial encoding the multilinear Lagrange basis coefficients of the MLE. + mle_coeff_column_oracle: &'oracle O, + /// Multilinear evaluation point. + mle_eval_point: MleEvalPoint, + /// Equals `mle_claim / 2^mle_n_variables`. + mle_claim_shift: SecureField, + /// Commitment tree index for the trace. + interaction: usize, + /// Location in the trace for the this component. + trace_location: TreeVec, +} + +impl<'oracle, O: MleCoeffColumnOracle> MleEvalVerifierComponent<'oracle, O> { + pub fn new( + location_allocator: &mut TraceLocationAllocator, + mle_coeff_column_oracle: &'oracle O, + eval_point: &[SecureField], + claim: SecureField, + interaction: usize, + ) -> Self { + let mle_eval_point = MleEvalPoint::new(eval_point); + let n_variables = mle_eval_point.n_variables(); + let mle_claim_shift = claim / BaseField::from(1 << n_variables); + + let trace_structure = mle_eval_info(interaction, n_variables).mask_offsets; + let trace_location = location_allocator.next_for_structure(&trace_structure); + + Self { + mle_coeff_column_oracle, + mle_eval_point, + mle_claim_shift, + interaction, + trace_location, + } + } + + /// Size of this component's trace columns. + pub fn log_size(&self) -> u32 { + self.mle_eval_point.n_variables() as u32 + } + + pub fn eval_info(&self) -> InfoEvaluator { + let n_variables = self.mle_eval_point.n_variables(); + mle_eval_info(self.interaction, n_variables) + } +} + +impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'oracle, O> { + fn n_constraints(&self) -> usize { + self.eval_info().n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + let log_size = self.log_size(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map(|tree_offsets| vec![log_size; tree_offsets.len()]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let trace_step = CanonicCoset::new(self.log_size()).step(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map_cols(|col_offsets| { + col_offsets + .iter() + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + accumulator: &mut PointEvaluationAccumulator, + ) { + let component_mask = mask.sub_tree(&self.trace_location); + let trace_coset = CanonicCoset::new(self.log_size()).coset; + let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); + let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + + let mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); + let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); + let is_first = eval_is_first(trace_coset, point); + let is_second = eval_is_first(trace_coset, point - trace_coset.step.into_ef()); + + eval_mle_eval_constraints( + self.interaction, + &mut eval, + mle_coeff_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ) + } +} + fn mle_eval_info(interaction: usize, n_variables: usize) -> InfoEvaluator { let mut eval = InfoEvaluator::default(); let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); @@ -612,8 +721,9 @@ mod tests { use rand::{Rng, SeedableRng}; use super::{ - eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints, - eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, + build_trace, eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints, + eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, MleEvalProverComponent, + MleEvalVerifierComponent, }; use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; @@ -634,9 +744,7 @@ mod tests { use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; - use crate::examples::xor::gkr_lookups::mle_eval::{ - build_trace, eval_step_selector_with_offset, MleEvalProverComponent, - }; + use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn mle_eval_prover_component() -> Result<(), VerificationError> { @@ -701,6 +809,82 @@ mod tests { verify(&components.0, channel, commitment_scheme, proof) } + #[test] + fn mle_eval_verifier_component() -> Result<(), VerificationError> { + const N_VARIABLES: usize = 8; + const COEFFS_COL_TRACE: usize = 0; + const MLE_EVAL_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + const LOG_EXPAND: u32 = 1; + // Create the test MLE. + let mut rng = SmallRng::seed_from_u64(0); + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let mle_coeffs = (0..size).map(|_| rng.gen::()).collect(); + let mle = Mle::::new(mle_coeffs); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let claim = mle.eval_at_point(&eval_point); + // Setup protocol. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_size + LOG_EXPAND + MIN_LOG_BLOWUP_FACTOR) + .circle_domain() + .half_coset, + ); + let config = PcsConfig::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + let channel = &mut Blake2sChannel::default(); + // Build trace. + // 1. MLE coeffs trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(mle_coeff_column::build_trace(&mle)); + tree_builder.commit(channel); + // 2. MLE eval trace (eq evals + prefix sum). + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_trace(&mle, &eval_point, claim)); + tree_builder.commit(channel); + // Create components. + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let mle_coeffs_col_component = MleCoeffColumnComponent::new( + trace_location_allocator, + MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + ); + let mle_eval_component = MleEvalProverComponent::generate( + trace_location_allocator, + &mle_coeffs_col_component, + &eval_point, + mle, + claim, + &twiddles, + MLE_EVAL_TRACE, + ); + let components: &[&dyn ComponentProver] = + &[&mle_coeffs_col_component, &mle_eval_component]; + // Generate proof. + let proof = prove(components, channel, commitment_scheme).unwrap(); + + // Verify. + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let mle_coeffs_col_component = MleCoeffColumnComponent::new( + trace_location_allocator, + MleCoeffColumnEval::new(COEFFS_COL_TRACE, N_VARIABLES), + ); + let mle_eval_component = MleEvalVerifierComponent::new( + trace_location_allocator, + &mle_coeffs_col_component, + &eval_point, + claim, + MLE_EVAL_TRACE, + ); + let components = Components(vec![&mle_coeffs_col_component, &mle_eval_component]); + let log_sizes = components.column_log_sizes(); + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + commitment_scheme.commit(proof.commitments[0], &log_sizes[0], channel); + commitment_scheme.commit(proof.commitments[1], &log_sizes[1], channel); + verify(&components.0, channel, commitment_scheme, proof) + } + #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; From e56d31ec2e698a9096c6adb1c8ebeec181bdc448 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sun, 25 Aug 2024 13:03:34 -0400 Subject: [PATCH 3/3] Create blake component that uses GKR for lookups --- .../src/constraint_framework/component.rs | 14 +- .../prover/src/constraint_framework/logup.rs | 2 +- crates/prover/src/constraint_framework/mod.rs | 4 + .../prover/src/constraint_framework/point.rs | 53 ++ .../src/core/backend/cpu/lookups/gkr.rs | 4 + .../src/core/backend/simd/lookups/gkr.rs | 4 + crates/prover/src/core/lookups/gkr_prover.rs | 13 +- .../prover/src/core/lookups/gkr_verifier.rs | 109 +++- crates/prover/src/examples/blake/air.rs | 31 +- crates/prover/src/examples/blake/mod.rs | 52 +- crates/prover/src/examples/blake/round/gen.rs | 22 +- crates/prover/src/examples/blake/round/mod.rs | 5 +- .../examples/blake/scheduler/constraints.rs | 2 +- .../src/examples/blake/scheduler/gen.rs | 8 +- .../src/examples/blake/scheduler/mod.rs | 3 +- .../src/examples/blake/xor_table/gen.rs | 2 +- .../src/examples/blake/xor_table/mod.rs | 6 +- crates/prover/src/examples/blake_gkr/air.rs | 526 ++++++++++++++++++ .../gkr_lookups/accumulation.rs | 34 +- .../gkr_lookups/mle_eval.rs | 41 +- .../src/examples/blake_gkr/gkr_lookups/mod.rs | 54 ++ crates/prover/src/examples/blake_gkr/mod.rs | 5 + crates/prover/src/examples/blake_gkr/round.rs | 337 +++++++++++ .../src/examples/blake_gkr/scheduler.rs | 209 +++++++ .../src/examples/blake_gkr/xor_table.rs | 328 +++++++++++ crates/prover/src/examples/mod.rs | 2 +- .../src/examples/xor/gkr_lookups/mod.rs | 2 - crates/prover/src/examples/xor/mod.rs | 1 - 28 files changed, 1789 insertions(+), 84 deletions(-) create mode 100644 crates/prover/src/examples/blake_gkr/air.rs rename crates/prover/src/examples/{xor => blake_gkr}/gkr_lookups/accumulation.rs (84%) rename crates/prover/src/examples/{xor => blake_gkr}/gkr_lookups/mle_eval.rs (96%) create mode 100644 crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs create mode 100644 crates/prover/src/examples/blake_gkr/mod.rs create mode 100644 crates/prover/src/examples/blake_gkr/round.rs create mode 100644 crates/prover/src/examples/blake_gkr/scheduler.rs create mode 100644 crates/prover/src/examples/blake_gkr/xor_table.rs delete mode 100644 crates/prover/src/examples/xor/gkr_lookups/mod.rs delete mode 100644 crates/prover/src/examples/xor/mod.rs diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 4c7f75768..3066b8e66 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -5,7 +5,7 @@ use std::ops::Deref; use itertools::Itertools; use tracing::{span, Level}; -use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; +use super::{EvalAtRow, EvalAtRowWithMle, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; @@ -57,6 +57,18 @@ impl TraceLocationAllocator { } } +/// A component defined solely in means of the constraints framework. +/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for +/// the SIMD backend. +/// Note that the constraint framework only support components with columns of the same size. +pub trait FrameworkEvalWithMle { + fn log_size(&self) -> u32; + + fn max_constraint_log_degree_bound(&self) -> u32; + + fn evaluate(&self, eval: E) -> E; +} + /// A component defined solely in means of the constraints framework. /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for /// the SIMD backend. diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index a608d89b0..ef30033b6 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -107,7 +107,7 @@ impl LookupElements { } pub fn combine(&self, values: &[F]) -> EF where - EF: Copy + Zero + From + From + Mul + Sub, + EF: Copy + Zero + From + From + Mul + Sub, { zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (&value, power)| { acc + EF::from(power) * value diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 87069d344..db7a1d3e8 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -95,3 +95,7 @@ pub trait EvalAtRow { /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; } + +trait EvalAtRowWithMle: EvalAtRow { + fn add_mle_coeff_col_eval(&mut self, eval: Self::EF); +} diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 6c6f72f81..a91b8ff7b 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -55,3 +55,56 @@ impl<'a> EvalAtRow for PointEvaluator<'a> { SecureField::from_partial_evals(values) } } + +// /// Evaluates expressions at a point out of domain. +// pub struct MleCoeffColEvalAccumulator<'a> { +// pub mask: TreeVec>>, +// pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, +// pub col_index: Vec, +// pub denom_inverse: SecureField, +// } +// impl<'a> MleCoeffColEvalAccumulator<'a> { +// pub fn new( +// mask: TreeVec>>, +// evaluation_accumulator: &'a mut PointEvaluationAccumulator, +// denom_inverse: SecureField, +// ) -> Self { +// let col_index = vec![0; mask.len()]; +// Self { +// mask, +// evaluation_accumulator, +// col_index, +// denom_inverse, +// } +// } +// } +// impl<'a> EvalAtRow for MleCoeffColEvalAccumulator<'a> { +// type F = SecureField; +// type EF = SecureField; + +// fn next_interaction_mask( +// &mut self, +// interaction: usize, +// _offsets: [isize; N], +// ) -> [Self::F; N] { +// let col_index = self.col_index[interaction]; +// self.col_index[interaction] += 1; +// let mask = self.mask[interaction][col_index].clone(); +// assert_eq!(mask.len(), N); +// mask.try_into().unwrap() +// } +// fn add_constraint(&mut self, constraint: G) +// where +// Self::EF: Mul, +// { +// } +// fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { +// SecureField::from_partial_evals(values) +// } +// } + +// impl<'a> EvalAtRowWithMle for MleCoeffColEvalAccumulator<'a> { +// fn add_mle_coeff_col_eval(&mut self, eval: Self::EF) { +// todo!() +// } +// } diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b65..cd3f5937e 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -326,6 +326,7 @@ mod tests { let GkrArtifact { ood_point: r, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -354,6 +355,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -391,6 +393,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -427,6 +430,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..7ee7b268e 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -559,6 +559,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -590,6 +591,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -629,6 +631,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -666,6 +669,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..d3d792ccb 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use num_traits::{One, Zero}; use thiserror::Error; -use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; +use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask}; use super::mle::{Mle, MleOps}; use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; @@ -409,6 +409,16 @@ pub fn prove_batch( .collect_vec(); let n_layers = *n_layers_by_instance.iter().max().unwrap(); + let gate_by_instance = input_layer_by_instance + .iter() + .map(|l| match l { + Layer::GrandProduct(_) => Gate::GrandProduct, + Layer::LogUpGeneric { .. } + | Layer::LogUpMultiplicities { .. } + | Layer::LogUpSingles { .. } => Gate::LogUp, + }) + .collect(); + // Evaluate all instance circuits and collect the layer values. let mut layers_by_instance = input_layer_by_instance .into_iter() @@ -502,6 +512,7 @@ pub fn prove_batch( let artifact = GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: n_layers_by_instance, }; diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..4d17c15dc 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -143,6 +143,7 @@ pub fn partially_verify_batch( Ok(GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(), }) @@ -162,12 +163,114 @@ pub struct GkrBatchProof { pub struct GkrArtifact { /// Out-of-domain (OOD) point for evaluating columns in the input layer. pub ood_point: Vec, + /// The gate of each instance. + pub gate_by_instance: Vec, /// The claimed evaluation at `ood_point` for each column in the input layer of each instance. pub claims_to_verify_by_instance: Vec>, /// The number of variables that interpolate the input layer of each instance. pub n_variables_by_instance: Vec, } +impl GkrArtifact { + pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] { + &self.ood_point[self.ood_point.len() - instance_n_variables..] + } +} + +pub struct LookupArtifactInstanceIter<'proof, 'artifact> { + instance: usize, + gkr_proof: &'proof GkrBatchProof, + gkr_artifact: &'artifact GkrArtifact, +} + +impl<'proof, 'artifact> LookupArtifactInstanceIter<'proof, 'artifact> { + pub fn new(gkr_proof: &'proof GkrBatchProof, gkr_artifact: &'artifact GkrArtifact) -> Self { + Self { + instance: 0, + gkr_proof, + gkr_artifact, + } + } +} + +impl<'proof, 'artifact> Iterator for LookupArtifactInstanceIter<'proof, 'artifact> { + type Item = LookupArtifactInstance; + + fn next(&mut self) -> Option { + if self.instance >= self.gkr_proof.output_claims_by_instance.len() { + return None; + } + + let instance = self.instance; + let input_n_variables = self.gkr_artifact.n_variables_by_instance[instance]; + let eval_point = self.gkr_artifact.ood_point(input_n_variables).to_vec(); + let output_claim = &*self.gkr_proof.output_claims_by_instance[instance]; + let input_claims = &*self.gkr_artifact.claims_to_verify_by_instance[instance]; + let gate = self.gkr_artifact.gate_by_instance[instance]; + + let res = Some(match gate { + Gate::LogUp => { + let [numerator, denominator] = output_claim.try_into().unwrap(); + let claimed_sum = Fraction::new(numerator, denominator); + let [input_numerators_claim, input_denominators_claim] = + input_claims.try_into().unwrap(); + + LookupArtifactInstance::LogUp(LogUpArtifactInstance { + eval_point, + input_n_variables, + input_numerators_claim, + input_denominators_claim, + claimed_sum, + }) + } + Gate::GrandProduct => { + let [claimed_product] = output_claim.try_into().unwrap(); + let [input_claim] = input_claims.try_into().unwrap(); + + LookupArtifactInstance::GrandProduct(GrandProductArtifactInstance { + eval_point, + input_n_variables, + input_claim, + claimed_product, + }) + } + }); + + self.instance += 1; + res + } +} + +// TODO: Consider making the GKR artifact just a Vec. +pub enum LookupArtifactInstance { + GrandProduct(GrandProductArtifactInstance), + LogUp(LogUpArtifactInstance), +} + +pub struct GrandProductArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLE in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input MLE evaluation at `eval_point`. + pub input_claim: SecureField, + /// Output claim from the circuit. + pub claimed_product: SecureField, +} + +pub struct LogUpArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLEs in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input numerators MLE evaluation at `eval_point`. + pub input_numerators_claim: SecureField, + /// Claimed input denominators MLE evaluation at `eval_point`. + pub input_denominators_claim: SecureField, + /// Output claim from the circuit. + pub claimed_sum: Fraction, +} + /// Defines how a circuit operates locally on two input rows to produce a single output row. /// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. /// @@ -176,7 +279,7 @@ pub struct GkrArtifact { /// circuit) GKR prover implementations. /// /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Gate { LogUp, GrandProduct, @@ -305,11 +408,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); @@ -338,11 +443,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index ca583abe3..a351220d1 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -60,9 +60,9 @@ impl BlakeStatement0 { } pub struct AllElements { - blake_elements: BlakeElements, - round_elements: RoundElements, - xor_elements: BlakeXorElements, + pub blake_elements: BlakeElements, + pub round_elements: RoundElements, + pub xor_elements: BlakeXorElements, } impl AllElements { pub fn draw(channel: &mut impl Channel) -> Self { @@ -222,7 +222,7 @@ where { assert!(log_size >= LOG_N_LANES); assert_eq!( - ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), N_ROUNDS ); @@ -239,7 +239,7 @@ where span.exit(); // Prepare inputs. - let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) .map(|i| { let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; @@ -281,18 +281,15 @@ where // Trace commitment. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6fbe6d81b..577d7acf5 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -12,28 +12,28 @@ use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::FieldExpOps; -mod air; -mod round; -mod scheduler; -mod xor_table; +pub mod air; +pub mod round; +pub mod scheduler; +pub mod xor_table; -const STATE_SIZE: usize = 16; -const MESSAGE_SIZE: usize = 16; -const N_FELTS_IN_U32: usize = 2; -const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; +pub const STATE_SIZE: usize = 16; +pub const MESSAGE_SIZE: usize = 16; +pub const N_FELTS_IN_U32: usize = 2; +pub const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; // Parameters for Blake2s. Change these for blake3. -const N_ROUNDS: usize = 10; +pub const N_ROUNDS: usize = 10; /// A splitting N_ROUNDS into several powers of 2. -const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; +pub const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; #[derive(Default)] -struct XorAccums { - xor12: XorAccumulator<12, 4>, - xor9: XorAccumulator<9, 2>, - xor8: XorAccumulator<8, 2>, - xor7: XorAccumulator<7, 2>, - xor4: XorAccumulator<4, 0>, +pub struct XorAccums { + pub xor12: XorAccumulator<12, 4>, + pub xor9: XorAccumulator<9, 2>, + pub xor8: XorAccumulator<8, 2>, + pub xor7: XorAccumulator<7, 2>, + pub xor4: XorAccumulator<4, 0>, } impl XorAccums { fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { @@ -50,11 +50,11 @@ impl XorAccums { #[derive(Clone)] pub struct BlakeXorElements { - xor12: XorElements, - xor9: XorElements, - xor8: XorElements, - xor7: XorElements, - xor4: XorElements, + pub xor12: XorElements, + pub xor9: XorElements, + pub xor8: XorElements, + pub xor7: XorElements, + pub xor4: XorElements, } impl BlakeXorElements { fn draw(channel: &mut impl Channel) -> Self { @@ -75,7 +75,7 @@ impl BlakeXorElements { xor4: XorElements::dummy(), } } - fn get(&self, w: u32) -> &XorElements { + pub fn get(&self, w: u32) -> &XorElements { match w { 12 => &self.xor12, 9 => &self.xor9, @@ -89,7 +89,7 @@ impl BlakeXorElements { /// Utility for representing a u32 as two field elements, for constraint evaluation. #[derive(Clone, Copy, Debug)] -struct Fu32 +pub struct Fu32 where F: FieldExpOps + Copy @@ -99,8 +99,8 @@ where + Sub + Mul, { - l: F, - h: F, + pub l: F, + pub h: F, } impl Fu32 where @@ -112,7 +112,7 @@ where + Sub + Mul, { - fn to_felts(self) -> [F; 2] { + pub fn to_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 7adbe6fd5..a5f458b03 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -23,19 +23,19 @@ use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZ pub struct BlakeRoundLookupData { /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. /// w is the xor width. c_col is the xor col of a_col and b_col. - xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, /// A column of round lookup values (v_in, v_out, m). - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } pub struct TraceGenerator { - log_size: u32, - trace: Vec, - xor_lookups: Vec<(u32, [BaseColumn; 3])>, - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub log_size: u32, + pub trace: Vec, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } impl TraceGenerator { - fn new(log_size: u32) -> Self { + pub fn new(log_size: u32) -> Self { assert!(log_size >= LOG_N_LANES); let trace = (0..blake_round_info().mask_offsets[0].len()) .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) @@ -50,7 +50,7 @@ impl TraceGenerator { } } - fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { + pub fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { TraceGeneratorRow { gen: self, col_index: 0, @@ -61,7 +61,7 @@ impl TraceGenerator { } /// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] -struct TraceGeneratorRow<'a> { +pub struct TraceGeneratorRow<'a> { gen: &'a mut TraceGenerator, col_index: usize, vec_row: usize, @@ -79,7 +79,7 @@ impl<'a> TraceGeneratorRow<'a> { self.append_felt(val >> 16); } - fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { + pub fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { let input_v = v; v.iter().for_each(|s| { self.append_u32(*s); @@ -215,7 +215,7 @@ pub fn generate_trace( let _span = span!(Level::INFO, "Round Generation").entered(); let mut generator = TraceGenerator::new(log_size); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut row_gen = generator.gen_row(vec_row); let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); row_gen.generate(v, m); diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf8311339..49d798f4b 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,7 +1,10 @@ mod constraints; mod gen; -pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; +pub use gen::{ + generate_interaction_trace, generate_trace, BlakeRoundInput, BlakeRoundLookupData, + TraceGenerator, TraceGeneratorRow, +}; use num_traits::Zero; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index ee9a1c654..6b45f08d4 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -61,7 +61,7 @@ pub fn eval_blake_scheduler_constraints( logup.finalize(eval); } -fn eval_next_u32(eval: &mut E) -> Fu32 { +pub fn eval_next_u32(eval: &mut E) -> Fu32 { let l = eval.next_trace_mask(); let h = eval.next_trace_mask(); Fu32 { l, h } diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index cd6a99b2f..ae3569ed0 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -58,7 +58,7 @@ pub fn gen_trace( .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) .collect_vec(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut col_index = 0; let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { @@ -125,11 +125,11 @@ pub fn gen_interaction_trace( let mut logup_gen = LogupTraceGenerator::new(log_size); - for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { + for [l0, l1] in lookup_data.round_lookups.array_chunks() { let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p0: PackedSecureField = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); let p1: PackedSecureField = @@ -145,7 +145,7 @@ pub fn gen_interaction_trace( // with the entire blake lookup. let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p_blake: PackedSecureField = blake_lookup_elements.combine( &lookup_data .blake_lookups diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3..3035626c8 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -2,7 +2,8 @@ mod constraints; mod gen; use constraints::eval_blake_scheduler_constraints; -pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; +pub use constraints::eval_next_u32; +pub use gen::{gen_interaction_trace, gen_trace, BlakeInput, BlakeSchedulerLookupData}; use num_traits::Zero; use super::round::RoundElements; diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 195a6ca46..46309e640 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -74,7 +74,7 @@ pub fn generate_interaction_trace( // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + for vec_row in 0..1 << (column_bits::() - LOG_N_LANES) { // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. // Extract al, blh from vec_row. let al = vec_row >> (limb_bits - LOG_N_LANES); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 877a65114..b796e702e 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -17,7 +17,9 @@ use std::simd::u32x16; use itertools::Itertools; use num_traits::Zero; -pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; +pub use r#gen::{ + generate_constant_trace, generate_interaction_trace, generate_trace, XorTableLookupData, +}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; @@ -37,7 +39,7 @@ pub fn trace_sizes() -> TreeVec()) } -const fn limb_bits() -> u32 { +pub const fn limb_bits() -> u32 { ELEM_BITS - EXPAND_BITS } pub const fn column_bits() -> u32 { diff --git a/crates/prover/src/examples/blake_gkr/air.rs b/crates/prover/src/examples/blake_gkr/air.rs new file mode 100644 index 000000000..79a94b6a4 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/air.rs @@ -0,0 +1,526 @@ +use std::array; +use std::simd::u32x16; + +use itertools::{chain, multiunzip, Itertools}; +use tracing::{span, Level}; + +use super::gkr_lookups::MleCoeffColumnOracleAccumulator; +use super::round::{BlakeRoundComponent, BlakeRoundEval}; +use super::scheduler::BlakeSchedulerComponent; +use super::xor_table::{XorLookupArtifacts, XorTableComponent, XorTableEval}; +use crate::constraint_framework::{FrameworkEval, TraceLocationAllocator}; +use crate::core::air::ComponentProver; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::lookups::gkr_prover::prove_batch; +use crate::core::lookups::gkr_verifier::{ + GkrBatchProof, LookupArtifactInstance, LookupArtifactInstanceIter, +}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::poly::circle::{CanonicCoset, PolyOps}; +use crate::core::prover::{prove, StarkProof, VerificationError}; +use crate::core::vcs::ops::MerkleHasher; +use crate::examples::blake::air::AllElements; +use crate::examples::blake::scheduler::{self as air_scheduler, BlakeInput}; +use crate::examples::blake::{ + round as air_round, xor_table as air_xor_table, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, +}; +use crate::examples::blake_gkr::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use crate::examples::blake_gkr::gkr_lookups::mle_eval::{self, MleEvalProverComponent}; +use crate::examples::blake_gkr::round::RoundLookupArtifact; +use crate::examples::blake_gkr::scheduler::{BlakeSchedulerEval, SchedulerLookupArtifact}; +use crate::examples::blake_gkr::{round, scheduler, xor_table}; + +pub struct BlakeClaim { + log_size: u32, +} + +impl BlakeClaim { + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_u64(self.log_size as u64); + } +} + +pub struct BlakeProof { + pub claim: BlakeClaim, + pub gkr_proof: GkrBatchProof, + pub stark_proof: StarkProof, +} + +pub struct BlakeLookupArtifacts { + scheduler: SchedulerLookupArtifact, + /// `|ROUND_LOG_SPLIT|` many round artifacts. + rounds: Vec, + xor: XorLookupArtifacts, +} + +impl BlakeLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + scheduler: SchedulerLookupArtifact::new_from_iter(&mut iter), + rounds: ROUND_LOG_SPLIT + .iter() + .map(|_| RoundLookupArtifact::new_from_iter(&mut iter)) + .collect(), + xor: XorLookupArtifacts::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims( + &self, + lookup_elements: &AllElements, + ) -> Result<(), InvalidClaimError> { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.verify_succinct_mle_claims()?; + for round in rounds { + round.verify_succinct_mle_claims()?; + } + xor.verify_succinct_mle_claims(&lookup_elements.xor_elements)?; + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.accumulate_mle_eval_iop_claims(acc); + rounds + .iter() + .for_each(|round| round.accumulate_mle_eval_iop_claims(acc)); + xor.accumulate_mle_eval_iop_claims(acc); + } +} + +#[derive(Debug)] +pub struct InvalidClaimError; + +pub struct BlakeComponents { + scheduler_component: BlakeSchedulerComponent, + round_components: Vec, + xor12: XorTableComponent<12, 4>, + xor9: XorTableComponent<9, 2>, + xor8: XorTableComponent<8, 2>, + xor7: XorTableComponent<7, 2>, + xor4: XorTableComponent<4, 0>, +} + +impl BlakeComponents { + pub fn new( + trace_location_allocator: &mut TraceLocationAllocator, + claim: &BlakeClaim, + all_elements: &AllElements, + ) -> Self { + Self { + scheduler_component: BlakeSchedulerComponent::new( + trace_location_allocator, + BlakeSchedulerEval { + log_size: claim.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + }, + ), + round_components: ROUND_LOG_SPLIT + .iter() + .map(|l| { + BlakeRoundComponent::new( + trace_location_allocator, + BlakeRoundEval { + log_size: claim.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + }, + ) + }) + .collect(), + xor12: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor9: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor8: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor7: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor4: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + } + } + + pub fn accumulate_mle_coeff_col_oracles<'this: 'acc, 'acc>( + &'this self, + acc_by_n_vars: &mut [Option>], + ) { + let Self { + scheduler_component, + round_components, + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + acc_by_n_vars[scheduler_component.log_size as usize] + .as_mut() + .unwrap() + .accumulate(scheduler_component); + for round_component in round_components { + acc_by_n_vars[round_component.log_size as usize] + .as_mut() + .unwrap() + .accumulate(round_component) + } + acc_by_n_vars[xor12.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor12); + acc_by_n_vars[xor9.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor9); + acc_by_n_vars[xor8.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor8); + acc_by_n_vars[xor7.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor7); + acc_by_n_vars[xor4.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor4); + } + + fn component_provers(&self) -> Vec<&dyn ComponentProver> { + chain![ + [&self.scheduler_component as &dyn ComponentProver], + self.round_components + .iter() + .map(|c| c as &dyn ComponentProver), + [ + &self.xor12 as &dyn ComponentProver, + &self.xor9 as &dyn ComponentProver, + &self.xor8 as &dyn ComponentProver, + &self.xor7 as &dyn ComponentProver, + &self.xor4 as &dyn ComponentProver, + ] + ] + .collect() + } +} + +pub fn prove_blake(log_size: u32, config: PcsConfig) -> BlakeProof +where + SimdBackend: BackendForChannel, +{ + assert!(log_size >= LOG_N_LANES); + assert_eq!( + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), + N_ROUNDS + ); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; + let max_log_size = + (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(max_log_size + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Prepare inputs. + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) + .map(|i| { + let v = [u32x16::from_array(array::from_fn(|j| (i + 2 * j) as u32)); 16]; + let m = [u32x16::from_array(array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; + BlakeInput { v, m } + }) + .collect_vec(); + + // Setup protocol. + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + + let span = span!(Level::INFO, "Trace").entered(); + + // Scheduler. + let (scheduler_trace, scheduler_lookup_data, round_inputs) = + air_scheduler::gen_trace(log_size, &blake_inputs); + + // Rounds. + let mut xor_accums = XorAccums::default(); + let mut rest = &round_inputs[..]; + // Split round inputs to components, according to [ROUND_LOG_SPLIT]. + let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = + multiunzip(ROUND_LOG_SPLIT.map(|l| { + let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); + rest = r; + air_round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) + })); + + // Xor tables. + let (xor_trace12, xor_lookup_data12) = air_xor_table::generate_trace(xor_accums.xor12); + let (xor_trace9, xor_lookup_data9) = air_xor_table::generate_trace(xor_accums.xor9); + let (xor_trace8, xor_lookup_data8) = air_xor_table::generate_trace(xor_accums.xor8); + let (xor_trace7, xor_lookup_data7) = air_xor_table::generate_trace(xor_accums.xor7); + let (xor_trace4, xor_lookup_data4) = air_xor_table::generate_trace(xor_accums.xor4); + + // Claim. + let claim = BlakeClaim { log_size }; + claim.mix_into(channel); + + // Trace commitment. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let mut lookup_input_layers = Vec::new(); + let mut mle_eval_at_point_collection = MleCollection::default(); + + lookup_input_layers.extend(scheduler::generate_lookup_instances( + log_size, + scheduler_lookup_data, + &all_elements.round_elements, + &all_elements.blake_elements, + &mut mle_eval_at_point_collection, + )); + + ROUND_LOG_SPLIT + .iter() + .zip(round_lookup_datas) + .for_each(|(l, lookup_data)| { + lookup_input_layers.extend(round::generate_lookup_instances( + log_size + l, + lookup_data, + &all_elements.xor_elements, + &all_elements.round_elements, + &mut mle_eval_at_point_collection, + )); + }); + + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data12, + &all_elements.xor_elements.xor12, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data9, + &all_elements.xor_elements.xor9, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data8, + &all_elements.xor_elements.xor8, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data7, + &all_elements.xor_elements.xor7, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data4, + &all_elements.xor_elements.xor4, + &mut mle_eval_at_point_collection, + )); + + let gkr_span = span!(Level::INFO, "GKR proof").entered(); + let (gkr_proof, gkr_artifact) = prove_batch(channel, lookup_input_layers); + gkr_span.exit(); + let mle_acc_coeff = channel.draw_felt(); + let mles = mle_eval_at_point_collection.random_linear_combine_by_n_variables(mle_acc_coeff); + + // TODO(andrew): Consider unifying new_from_iter, verify_succinct_mle_claims, + // accumulate_mle_eval_iop_claims. + let mut lookup_instances_iter = LookupArtifactInstanceIter::new(&gkr_proof, &gkr_artifact); + let blake_lookup_artifacts = BlakeLookupArtifacts::new_from_iter(&mut lookup_instances_iter); + assert!(lookup_instances_iter.next().is_none()); + blake_lookup_artifacts + .verify_succinct_mle_claims(&all_elements) + .unwrap(); + let mut mle_eval_iop_acc = MleClaimAccumulator::new(mle_acc_coeff); + blake_lookup_artifacts.accumulate_mle_eval_iop_claims(&mut mle_eval_iop_acc); + let mut mle_claim_by_n_variables = mle_eval_iop_acc.finalize(); + + let max_mle_n_variables = mles.iter().map(|mle| mle.n_variables()).max().unwrap(); + let mut mle_coeff_col_acc_by_n_variables = vec![None; max_mle_n_variables + 1]; + + for mle in &mles { + let n_variables = mle.n_variables(); + mle_coeff_col_acc_by_n_variables[n_variables] = + Some(MleCoeffColumnOracleAccumulator::new(mle_acc_coeff)); + } + + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let blake_components = BlakeComponents::new(trace_location_allocator, &claim, &all_elements); + blake_components.accumulate_mle_coeff_col_oracles(&mut mle_coeff_col_acc_by_n_variables); + + let mut tree_builder = commitment_scheme.tree_builder(); + let mle_eval_prover_components = mles + .into_iter() + .map(|mle| { + let n_vars = mle.n_variables(); + let coeff_column_oracle = mle_coeff_col_acc_by_n_variables[n_vars].as_ref().unwrap(); + let claim = mle_claim_by_n_variables[n_vars].take().unwrap(); + let eval_point = gkr_artifact.ood_point(n_vars); + + tree_builder.extend_evals(mle_eval::build_trace(&mle, eval_point, claim)); + + // Sanity check the claims. + #[cfg(test)] + debug_assert_eq!(claim, mle.eval_at_point(eval_point)); + + MleEvalProverComponent::generate( + trace_location_allocator, + coeff_column_oracle, + eval_point, + mle, + claim, + &twiddles, + 1, + ) + }) + .collect_vec(); + tree_builder.commit(channel); + span.exit(); + + let components = chain![ + blake_components.component_provers(), + mle_eval_prover_components + .iter() + .map(|c| c as &dyn ComponentProver) + ] + .collect_vec(); + + let stark_proof = prove(&components, channel, commitment_scheme).unwrap(); + + BlakeProof { + claim, + gkr_proof, + stark_proof, + } +} + +#[allow(unused)] +pub fn verify_blake( + BlakeProof { + claim, + gkr_proof, + stark_proof, + }: BlakeProof, + config: PcsConfig, +) -> Result<(), VerificationError> { + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // let log_sizes = stmt0.log_sizes(); + + // // Trace. + // stmt0.mix_into(channel); + // commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); + + // // Draw interaction elements. + // let all_elements = AllElements::draw(channel); + + // // Interaction trace. + // stmt1.mix_into(channel); + // commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); + + // // Constant trace. + // commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); + + // let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + + // // Check that all sums are correct. + // let total_sum = stmt1.scheduler_claimed_sum + // + stmt1.round_claimed_sums.iter().sum::() + // + stmt1.xor12_claimed_sum + // + stmt1.xor9_claimed_sum + // + stmt1.xor8_claimed_sum + // + stmt1.xor7_claimed_sum + // + stmt1.xor4_claimed_sum; + + // // TODO(spapini): Add inputs to sum, and constraint them. + // assert_eq!(total_sum, SecureField::zero()); + + // verify( + // &components.components(), + // channel, + // commitment_scheme, + // stark_proof, + // ) + + todo!() +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::core::pcs::PcsConfig; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::blake_gkr::air::prove_blake; + + // Note: this test is slow. Only run in release. + #[cfg_attr(not(feature = "slow-tests"), ignore)] + #[test_log::test] + fn test_simd_blake_gkr_prove() { + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "6".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig::default(); + + // Prove. + let _proof = prove_blake::(log_n_instances, config); + + // Verify. + // verify_blake::(proof, config).unwrap(); + } +} diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs similarity index 84% rename from crates/prover/src/examples/xor/gkr_lookups/accumulation.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs index 986289572..d3c49aed9 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs @@ -2,7 +2,7 @@ use std::iter::zip; use std::ops::{AddAssign, Mul}; use educe::Educe; -use num_traits::One; +use num_traits::{One, Zero}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Backend; @@ -16,7 +16,7 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// Max number of variables for multilinear polynomials that get compiled into a univariate /// IOP for multilinear eval at point. -pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; +pub const MAX_MLE_N_VARIABLES: usize = (M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR) as usize; /// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { @@ -90,7 +90,7 @@ pub fn combine + Copy, F: Copy>( impl Default for MleCollection { fn default() -> Self { Self { - mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], + mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], } } } @@ -133,6 +133,32 @@ impl DynMle { } } +/// Accumulates claims of multilinear polynomials, grouped by their number of variables. +// TODO(andrew): Consider group by eval point to make sure everything done correctly. +pub struct MleClaimAccumulator { + acc_coeff: SecureField, + acc_by_n_variables: Vec>, +} + +impl MleClaimAccumulator { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + acc_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], + } + } + + pub fn accumulate(&mut self, n_variables: usize, evaluation: SecureField) { + let acc = self.acc_by_n_variables[n_variables].get_or_insert_with(SecureField::zero); + *acc = *acc * self.acc_coeff + evaluation; + } + + /// Returns a mapping of number of variables to claim accumulation. + pub fn finalize(self) -> Vec> { + self.acc_by_n_variables + } +} + #[cfg(test)] mod tests { use std::iter::repeat; @@ -144,7 +170,7 @@ mod tests { use crate::core::fields::qm31::SecureField; use crate::core::fields::Field; use crate::core::lookups::mle::{Mle, MleOps}; - use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MleCollection; #[test] fn random_linear_combine_by_n_variables() { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs similarity index 96% rename from crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs index 9d05ce7aa..9a6c97c27 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs @@ -1,7 +1,4 @@ //! Multilinear extension (MLE) eval at point constraints. -// TODO(andrew): Remove in downstream PR. -#![allow(dead_code)] - use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; @@ -154,7 +151,6 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component // Consistency check the MLE coeffs column polynomial and oracle. let mle_coeff_col_eval = self.mle_coeff_column_poly.eval_at_point(point); let oracle_mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); - assert_eq!(mle_coeff_col_eval, oracle_mle_coeff_col_eval); let component_mask = mask.sub_tree(&self.trace_locations); let trace_coset = CanonicCoset::new(self.log_size()).coset; @@ -650,6 +646,36 @@ fn eval_step_selector_with_offset( eval_step_selector(coset, log_step, p - offset_step.into_ef()) } +// /// Returns `log(|coset|)` evaluations where the `i`th evaluation is of a polynomial that's `1` +// /// every `2^i` coset points and `0` elsewhere on coset. +// fn eval_step_selectors_by_log_step(coset: Coset, p: CirclePoint) -> Vec +// { let res = vec![SecureField::one()]; + +// if log_step == 0 { +// return SecureField::one(); +// } + +// // Rotate the coset to have points on the `x` axis. +// let p = p - coset.initial.into_ef(); +// let mut vanish_at_log_step = (0..coset.log_size) +// .scan(p, |p, _| { +// let res = *p; +// *p = p.double(); +// Some(res.y) +// }) +// .collect_vec(); +// vanish_at_log_step.reverse(); +// let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; +// SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + +// let norm = BaseField::from(2).inverse(); + +// let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); +// let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::(); +// (half_coset_selector_dbl + vanish_at_log_step[0] * vanish_substep_inv_sum.double()) +// / BaseField::from(1 << (log_step + 1)) +// } + /// Evaluates a polynomial that's `1` every `2^log_step` coset points and `0` elsewhere on coset. fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) -> SecureField { if log_step == 0 { @@ -743,8 +769,8 @@ mod tests { use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn mle_eval_prover_component() -> Result<(), VerificationError> { @@ -814,7 +840,6 @@ mod tests { const N_VARIABLES: usize = 8; const COEFFS_COL_TRACE: usize = 0; const MLE_EVAL_TRACE: usize = 1; - const CONST_TRACE: usize = 2; const LOG_EXPAND: u32 = 1; // Create the test MLE. let mut rng = SmallRng::seed_from_u64(0); @@ -1124,7 +1149,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::ColumnVec; - use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::MleCoeffColumnOracle; pub type MleCoeffColumnComponent = FrameworkComponent; diff --git a/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs new file mode 100644 index 000000000..ab5319db5 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs @@ -0,0 +1,54 @@ +use mle_eval::MleCoeffColumnOracle; + +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; + +pub mod accumulation; +pub mod mle_eval; + +// TODO(andrew): Try come up with less verbose name. +pub trait AccumulatedMleCoeffColumnOracle { + fn accumulate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ); +} + +// TODO(andrew): Try come up with less verbose name. +#[derive(Clone)] +pub struct MleCoeffColumnOracleAccumulator<'a> { + acc_coeff: SecureField, + oracles: Vec<&'a dyn AccumulatedMleCoeffColumnOracle>, +} + +impl<'a> MleCoeffColumnOracleAccumulator<'a> { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + oracles: Vec::new(), + } + } + + pub fn accumulate<'b: 'a>(&mut self, oracle: &'b dyn AccumulatedMleCoeffColumnOracle) { + self.oracles.push(oracle) + } +} + +impl<'a> MleCoeffColumnOracle for MleCoeffColumnOracleAccumulator<'a> { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + let mut acc = PointEvaluationAccumulator::new(self.acc_coeff); + for oracle in &self.oracles { + oracle.accumulate_at_point(point, mask, &mut acc); + } + acc.finalize() + } +} diff --git a/crates/prover/src/examples/blake_gkr/mod.rs b/crates/prover/src/examples/blake_gkr/mod.rs new file mode 100644 index 000000000..a1f767c6b --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod gkr_lookups; +pub mod round; +pub mod scheduler; +pub mod xor_table; diff --git a/crates/prover/src/examples/blake_gkr/round.rs b/crates/prover/src/examples/blake_gkr/round.rs new file mode 100644 index 000000000..820b02e24 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/round.rs @@ -0,0 +1,337 @@ +use std::array; + +use itertools::{chain, Itertools}; +use num_traits::One; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::round::{BlakeRoundLookupData, RoundElements, TraceGenerator}; +use crate::examples::blake::{BlakeXorElements, Fu32, STATE_SIZE}; + +pub type BlakeRoundComponent = FrameworkComponent; + +pub struct BlakeRoundEval { + pub log_size: u32, + pub xor_lookup_elements: BlakeXorElements, + pub round_lookup_elements: RoundElements, +} + +impl FrameworkEval for BlakeRoundEval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, eval: E) -> E { + const MLE_COEFF_COL_EVAL: bool = false; + let blake_eval = BlakeRoundConstraintEval:: { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + mle_coeff_col_evals: None, + }; + blake_eval.eval() + } +} + +impl AccumulatedMleCoeffColumnOracle for BlakeRoundComponent { + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + let mut mle_coef_col_evals = Vec::new(); + + const MLE_COEFF_COL_EVAL: bool = true; + BlakeRoundConstraintEval::<_, MLE_COEFF_COL_EVAL> { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + mle_coeff_col_evals: Some(&mut mle_coef_col_evals), + } + .eval(); + + for eval in mle_coef_col_evals { + acc.accumulate(eval) + } + } +} + +const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15); +const TWO: BaseField = BaseField::from_u32_unchecked(2); + +pub struct BlakeRoundConstraintEval<'a, E: EvalAtRow, const MLE_COEFF_COL_EVAL: bool> { + pub eval: E, + pub xor_lookup_elements: &'a BlakeXorElements, + pub round_lookup_elements: &'a RoundElements, + pub mle_coeff_col_evals: Option<&'a mut Vec>, +} +impl<'a, E: EvalAtRow, const MLE_COEFF_COL_EVAL: bool> + BlakeRoundConstraintEval<'a, E, MLE_COEFF_COL_EVAL> +{ + pub fn eval(mut self) -> E { + let mut v: [Fu32; STATE_SIZE] = array::from_fn(|_| self.next_u32()); + let input_v = v; + let m: [Fu32; STATE_SIZE] = array::from_fn(|_| self.next_u32()); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + if MLE_COEFF_COL_EVAL { + self.mle_coeff_col_evals.as_mut().unwrap().push( + self.round_lookup_elements.combine( + &chain![ + input_v.iter().copied().flat_map(Fu32::to_felts), + v.iter().copied().flat_map(Fu32::to_felts), + m.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + ), + ); + } + + self.eval + } + fn next_u32(&mut self) -> Fu32 { + let l = self.eval.next_trace_mask(); + let h = self.eval.next_trace_mask(); + Fu32 { l, h } + } + fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { + let [a, b, c, d] = v; + + *a = self.add3_u32_unchecked(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32_unchecked(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + /// Assumes a, b are properly range checked. + /// The caller is responsible for checking: + /// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) + fn add2_u32_unchecked(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l - sl) * E::F::from(INV16); + self.eval.add_constraint(carry_l * carry_l - carry_l); + + let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); + self.eval.add_constraint(carry_h * carry_h - carry_h); + + Fu32 { l: sl, h: sh } + } + + /// Adds three u32s, returning the sum. + /// Assumes a, b, c are properly range checked. + /// Caller is responsible for checking: + /// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) + fn add3_u32_unchecked(&mut self, a: Fu32, b: Fu32, c: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); + self.eval + .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); + + let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); + self.eval + .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); + + Fu32 { l: sl, h: sh } + } + + /// Splits a felt at r. + /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. + fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { + let h = self.eval.next_trace_mask(); + let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + /// Guarantees that all elements are in range. + fn xor_rotr_u32(&mut self, a: Fu32, b: Fu32, r: u32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, r); + let (ahl, ahh) = self.split_unchecked(a.h, r); + let (bll, blh) = self.split_unchecked(b.l, r); + let (bhl, bhh) = self.split_unchecked(b.h, r); + + // These also guarantee that all elements are in range. + let [xorll, xorhl] = self.xor2(r, [all, ahl], [bll, bhl]); + let [xorlh, xorhh] = self.xor2(16 - r, [alh, ahh], [blh, bhh]); + + Fu32 { + l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, + h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, + } + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + /// Guarantees that all elements are in range. + fn xor_rotr16_u32(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, 8); + let (ahl, ahh) = self.split_unchecked(a.h, 8); + let (bll, blh) = self.split_unchecked(b.l, 8); + let (bhl, bhh) = self.split_unchecked(b.h, 8); + + // These also guarantee that all elements are in range. + let [xorll, xorhl] = self.xor2(8, [all, ahl], [bll, bhl]); + let [xorlh, xorhh] = self.xor2(8, [alh, ahh], [blh, bhh]); + + Fu32 { + l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, + h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, + } + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + fn xor2(&mut self, w: u32, a: [E::F; 2], b: [E::F; 2]) -> [E::F; 2] { + // TODO: Separate lookups by w. + let c = [self.eval.next_trace_mask(), self.eval.next_trace_mask()]; + + if MLE_COEFF_COL_EVAL { + let lookup_elements = self.xor_lookup_elements.get(w); + let comb0 = lookup_elements.combine::(&[a[0], b[0], c[0]]); + let comb1 = lookup_elements.combine::(&[a[1], b[1], c[1]]); + self.mle_coeff_col_evals + .as_mut() + .unwrap() + .extend([comb0, comb1]) + } + + c + } +} + +pub struct RoundLookupArtifact { + pub round: LogUpArtifactInstance, + pub xors: Vec, +} + +impl RoundLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let xors = (0..n_xor_lookups()) + .map(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(); + + let round = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { round, xors } + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { round, xors } = self; + + for xor in xors { + acc.accumulate(xor.input_n_variables, xor.input_denominators_claim); + } + + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { round, xors } = self; + + for xor_artifact in xors { + if !xor_artifact.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + + Ok(()) + } +} + +/// Returns an ordered list of all XOR lookup types the round component uses. +fn n_xor_lookups() -> usize { + // Create a dummy trace to extract the structural xor lookup information. + let mut trace_generator = TraceGenerator::new(LOG_N_LANES); + let mut row = trace_generator.gen_row(0); + row.generate(Default::default(), Default::default()); + trace_generator.xor_lookups.len() +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for (w, l) in &lookup_data.xor_lookups { + let lookup_elements = xor_lookup_elements.get(*w); + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = lookup_elements.combine(&l.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }); + } + + // Blake round lookup. + let mut round_denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + round_denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(round_denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { + denominators: round_denominators, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/scheduler.rs b/crates/prover/src/examples/blake_gkr/scheduler.rs new file mode 100644 index 000000000..b1dd50dd3 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/scheduler.rs @@ -0,0 +1,209 @@ +use std::array; + +use itertools::{chain, Itertools}; +use num_traits::{One, Zero}; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::blake2s::SIGMA; +use crate::core::backend::simd::column::{BaseColumn, SecureColumn}; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::scheduler::{eval_next_u32, BlakeElements, BlakeSchedulerLookupData}; +use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; + +pub type BlakeSchedulerComponent = FrameworkComponent; + +pub struct BlakeSchedulerEval { + pub log_size: u32, + pub blake_lookup_elements: BlakeElements, + pub round_lookup_elements: RoundElements, +} + +impl FrameworkEval for BlakeSchedulerEval { + fn log_size(&self) -> u32 { + self.log_size + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = SchedulerEvals::new(&mut eval); + eval + } +} + +impl AccumulatedMleCoeffColumnOracle for BlakeSchedulerComponent { + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + let SchedulerEvals { messages, states } = SchedulerEvals::new(&mut eval); + + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + round_messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.round_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.blake_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } +} + +struct SchedulerEvals { + messages: [Fu32; STATE_SIZE], + states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1], +} + +impl SchedulerEvals { + fn new(eval: &mut E) -> Self { + Self { + messages: array::from_fn(|_| eval_next_u32(eval)), + states: array::from_fn(|_| array::from_fn(|_| eval_next_u32(eval))), + } + } +} + +pub struct SchedulerLookupArtifact { + scheduler: LogUpArtifactInstance, + rounds: [LogUpArtifactInstance; N_ROUNDS], +} + +impl SchedulerLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let rounds = array::from_fn(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }); + + let scheduler = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { scheduler, rounds } + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { scheduler, rounds } = self; + + // TODO(andrew): Consider checking the n_variables is correct. + // if !self.scheduler.input_numerators_claim.is_one() { + if !scheduler.input_numerators_claim.is_zero() { + return Err(InvalidClaimError); + } + + for round in rounds { + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { scheduler, rounds } = self; + + for round in rounds { + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + // TODO: Note `n_variables` is not verified. Probably fine since if the prover gives wrong + // info they'll be caught. Can panic though if the n_variables is too high. Consider + // checking the number of GKR layers in the verifier is less than + // LOG_CIRCLE_ORDER-LOG_BLOWUP-LOG_EXPAND. + acc.accumulate( + scheduler.input_n_variables, + scheduler.input_denominators_claim, + ); + } +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeSchedulerLookupData, + round_lookup_elements: &RoundElements, + blake_lookup_elements: &BlakeElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for l0 in &lookup_data.round_lookups { + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }) + } + + // Blake hash lookup. + let blake_numers = Mle::::new(BaseColumn::zeros(size)); + let mut blake_denoms = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let blake_denom: PackedSecureField = blake_lookup_elements.combine( + &lookup_data + .blake_lookups + .each_ref() + .map(|l| l.data[vec_row]), + ); + blake_denoms.data[vec_row] = blake_denom; + } + collection_for_univariate_iop.push(blake_denoms.clone()); + round_lookup_layers.push(Layer::LogUpMultiplicities { + numerators: blake_numers, + denominators: blake_denoms, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/xor_table.rs b/crates/prover/src/examples/blake_gkr/xor_table.rs new file mode 100644 index 000000000..18b5fe31e --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/xor_table.rs @@ -0,0 +1,328 @@ +use std::array; +use std::iter::zip; +use std::simd::u32x16; + +use itertools::Itertools; +use num_traits::{One, Zero}; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::Field; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::xor_table::{column_bits, limb_bits, XorElements, XorTableLookupData}; +use crate::examples::blake::BlakeXorElements; + +/// Component that evaluates the xor table. +pub type XorTableComponent = + FrameworkComponent>; + +/// Evaluates the xor table. +pub struct XorTableEval { + pub lookup_elements: XorElements, +} + +impl FrameworkEval + for XorTableEval +{ + fn log_size(&self) -> u32 { + column_bits::() + } + fn max_constraint_log_degree_bound(&self) -> u32 { + column_bits::() + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let _ = eval_xor_table_multiplicity_cols::(&mut eval); + eval + } +} + +impl AccumulatedMleCoeffColumnOracle + for XorTableComponent +{ + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + for eval in eval_xor_table_multiplicity_cols::<_, ELEM_BITS, EXPAND_BITS>(&mut eval) { + acc.accumulate(eval) + } + } +} + +fn eval_xor_table_multiplicity_cols( + eval: &mut E, +) -> Vec { + (0..1 << (2 * EXPAND_BITS)) + .map(|_| eval.next_trace_mask()) + .collect() +} + +pub struct XorLookupArtifacts { + xor12: XorLookupArtifact<12, 4>, + xor9: XorLookupArtifact<9, 2>, + xor8: XorLookupArtifact<8, 2>, + xor7: XorLookupArtifact<7, 2>, + xor4: XorLookupArtifact<4, 0>, +} + +impl XorLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + xor12: XorLookupArtifact::new_from_iter(&mut iter), + xor9: XorLookupArtifact::new_from_iter(&mut iter), + xor8: XorLookupArtifact::new_from_iter(&mut iter), + xor7: XorLookupArtifact::new_from_iter(&mut iter), + xor4: XorLookupArtifact::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims( + &self, + lookup_elements: &BlakeXorElements, + ) -> Result<(), InvalidClaimError> { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + + xor12.verify_succinct_mle_claims(lookup_elements.get(12))?; + xor9.verify_succinct_mle_claims(lookup_elements.get(9))?; + xor8.verify_succinct_mle_claims(lookup_elements.get(8))?; + xor7.verify_succinct_mle_claims(lookup_elements.get(7))?; + xor4.verify_succinct_mle_claims(lookup_elements.get(4))?; + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + xor12.accumulate_mle_eval_iop_claims(acc); + xor9.accumulate_mle_eval_iop_claims(acc); + xor8.accumulate_mle_eval_iop_claims(acc); + xor7.accumulate_mle_eval_iop_claims(acc); + xor4.accumulate_mle_eval_iop_claims(acc); + } +} + +pub struct XorLookupArtifact { + /// `2^(2*EXPAND_BITS)` many LogUp instances. + artifacts: Vec, +} + +impl XorLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + artifacts: (0..1 << (2 * EXPAND_BITS)) + .map(|_| match iter.next() { + // TODO: check input MLEs have expected number of variables. + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(), + } + } + + fn verify_succinct_mle_claims( + &self, + lookup_elements: &XorElements, + ) -> Result<(), InvalidClaimError> { + for (i, artifact) in self.artifacts.iter().enumerate() { + let eval_point = &artifact.eval_point; + let denoms_claim = artifact.input_denominators_claim; + let denoms_eval = eval_logup_denominators_mle::( + i, + lookup_elements, + eval_point, + ) + .unwrap(); + + if denoms_claim != denoms_eval { + return Err(InvalidClaimError); + } + } + + Ok(()) + } + + fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { artifacts } = self; + for artifact in artifacts { + acc.accumulate(artifact.input_n_variables, artifact.input_numerators_claim); + } + } +} + +pub fn generate_lookup_instances( + lookup_data: XorTableLookupData, + lookup_elements: &XorElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Xor interaction trace").entered(); + let mut xor_lookup_layers = Vec::new(); + + // There are 2^(2*EXPAND_BITS) columns, for each combination of ah, bh. + for (column_index, mults) in lookup_data.xor_accum.mults.iter().enumerate() { + let numerators = Mle::::new(mults.clone()); + collection_for_univariate_iop.push(numerators.clone()); + let denominators = + gen_logup_denominators_mle::(column_index, lookup_elements); + xor_lookup_layers.push(Layer::LogUpMultiplicities { + numerators, + denominators, + }); + } + + xor_lookup_layers +} + +/// Returns an MLE representing the LogUp denominator terms for the xor table. +fn gen_logup_denominators_mle( + column_index: usize, + lookup_elements: &XorElements, +) -> Mle { + let offsets_vec = u32x16::from_array(array::from_fn(|i| i as u32)); + let column_bits = column_bits::(); + let column_size = 1 << column_bits; + let mut denominators = Mle::::new(SecureColumn::zeros(column_size)); + + // Extract ah, bh from column index. + let ah = column_index as u32 >> EXPAND_BITS; + let bh = column_index as u32 & ((1 << EXPAND_BITS) - 1); + + // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. + for vec_row in 0..1 << (column_bits - LOG_N_LANES) { + let limb_bits = limb_bits::(); + + // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. + // Extract al, blh from vec_row. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + + // Construct the 3 vectors a, b, c. + let a = u32x16::splat((ah << limb_bits) | al); + // bll is just the consecutive numbers 0..N_LANES-1. + let b = u32x16::splat((bh << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + let c = a ^ b; + + let denominator = lookup_elements + .combine(&[a, b, c].map(|v| unsafe { PackedBaseField::from_simd_unchecked(v) })); + denominators.data[vec_row as usize] = denominator; + } + + denominators +} + +/// Evaluates the succinct MLE representing the LogUp denominator terms for the xor table. +/// +/// Evaluates the MLE returned by [`gen_logup_denominators_mle`]. +fn eval_logup_denominators_mle( + column_index: usize, + lookup_elements: &XorElements, + eval_point: &[SecureField], +) -> Result { + assert!(column_index < 1 << (2 * EXPAND_BITS)); + let limb_bits = limb_bits::() as usize; + if eval_point.len() != limb_bits * 2 { + return Err(InvalidEvalPoint); + } + + let (al_assignment, bl_assignment) = eval_point.split_at(limb_bits); + let cl_assignment = &zip(al_assignment, bl_assignment) + // Note `a ^ b = a + b - 2 * a * b` for all `a, b` in `{0, 1}`. + .map(|(&li, &ri)| li + ri - (li * ri).double()) + .collect_vec(); + + let al = pack_little_endian_bits(al_assignment); + let bl = pack_little_endian_bits(bl_assignment); + let cl = pack_little_endian_bits(cl_assignment); + + // Extract ah, bh from column index. + let ah = column_index >> EXPAND_BITS; + let bh = column_index & ((1 << EXPAND_BITS) - 1); + let ch = ah ^ bh; + + let a = al + BaseField::from(ah << limb_bits); + let b = bl + BaseField::from(bh << limb_bits); + let c = cl + BaseField::from(ch << limb_bits); + + Ok(lookup_elements.combine(&[a, b, c])) +} + +fn pack_little_endian_bits(bits: &[SecureField]) -> SecureField { + bits.iter() + .fold(SecureField::zero(), |acc, &bit| acc.double() + bit) +} + +/// Eval point is invalid. +#[derive(Debug)] +struct InvalidEvalPoint; + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::core::channel::Channel; + use crate::core::test_utils::test_channel; + use crate::examples::blake::xor_table::XorElements; + use crate::examples::blake_gkr::xor_table::{ + eval_logup_denominators_mle, gen_logup_denominators_mle, + }; + + #[test] + fn eval_logup_denominators_mle_works() { + const ELEM_BITS: u32 = 8; + const EXPAND_BITS: u32 = 2; + let column_index = 0b1011; + assert!((0..1 << (2 * EXPAND_BITS)).contains(&column_index)); + let channel = &mut test_channel(); + let lookup_elements = XorElements::draw(channel); + let denominators_mle = + gen_logup_denominators_mle::(column_index, &lookup_elements); + let eval_point = (0..denominators_mle.n_variables()) + .map(|_| channel.draw_felt()) + .collect_vec(); + + let eval = eval_logup_denominators_mle::( + column_index, + &lookup_elements, + &eval_point, + ) + .unwrap(); + + assert_eq!(eval, denominators_mle.eval_at_point(&eval_point)); + } +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 330662de9..0ad6f301a 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,5 +1,5 @@ pub mod blake; +pub mod blake_gkr; pub mod plonk; pub mod poseidon; pub mod wide_fibonacci; -pub mod xor; diff --git a/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/crates/prover/src/examples/xor/gkr_lookups/mod.rs deleted file mode 100644 index 6ee603eb0..000000000 --- a/crates/prover/src/examples/xor/gkr_lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod accumulation; -pub mod mle_eval; diff --git a/crates/prover/src/examples/xor/mod.rs b/crates/prover/src/examples/xor/mod.rs deleted file mode 100644 index 34e702a9b..000000000 --- a/crates/prover/src/examples/xor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod gkr_lookups;