From 42f2f6fb4dd80bbca432c7422eff46cad00f2dd8 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 21 Aug 2024 15:05:57 -0400 Subject: [PATCH] Combine eq step constraints into single constraint --- crates/prover/src/constraint_framework/mod.rs | 3 +- .../src/core/poly/circle/secure_poly.rs | 7 + .../src/examples/xor/gkr_lookups/mle_eval.rs | 123 +++++++++++------- 3 files changed, 84 insertions(+), 49 deletions(-) diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 1d893d6ef..f0d6ca9be 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -34,7 +34,7 @@ pub trait EvalAtRow { + Debug + Zero + Neg - + AddAssign + + AddAssign + AddAssign + Add + Sub @@ -52,6 +52,7 @@ pub trait EvalAtRow { + Zero + From + Neg + + AddAssign + Add + Sub + Mul diff --git a/crates/prover/src/core/poly/circle/secure_poly.rs b/crates/prover/src/core/poly/circle/secure_poly.rs index 8482e2971..a503bd2c6 100644 --- a/crates/prover/src/core/poly/circle/secure_poly.rs +++ b/crates/prover/src/core/poly/circle/secure_poly.rs @@ -73,6 +73,13 @@ impl, EvalOrder> SecureEvaluation { _eval_order: PhantomData, } } + + pub fn into_coordinate_evals( + self, + ) -> [CircleEvaluation; SECURE_EXTENSION_DEGREE] { + let Self { domain, values, .. } = self; + values.columns.map(|c| CircleEvaluation::new(domain, c)) + } } impl, EvalOrder> Deref for SecureEvaluation { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index ec8537595..09a4be7df 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -4,8 +4,14 @@ use std::array; use num_traits::{One, Zero}; use crate::constraint_framework::EvalAtRow; +use crate::core::backend::simd::SimdBackend; use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::eq; +use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; /// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. /// @@ -13,14 +19,13 @@ use crate::core::lookups::utils::eq; /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. pub fn eval_mle_eval_constraints( mle_interaction: usize, - selector_interaction: usize, + const_interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, mle_eval_point: MleEvalPoint, mle_claim_shift: SecureField, ) { - let eq_col_eval = - eval_eq_constraints(mle_interaction, selector_interaction, eval, mle_eval_point); + let eq_col_eval = eval_eq_constraints(mle_interaction, const_interaction, eval, mle_eval_point); let terms_col_eval = mle_coeffs_col_eval * eq_col_eval; eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) } @@ -34,7 +39,7 @@ pub struct MleEvalPoint { // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. eq_carry_quotients: [SecureField; N_VARIABLES], // Point `p`. - _p: [SecureField; N_VARIABLES], + p: [SecureField; N_VARIABLES], } impl MleEvalPoint { @@ -53,7 +58,7 @@ impl MleEvalPoint { denom_assignment[i] = one; eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) }), - _p: p, + p, } } } @@ -68,12 +73,12 @@ impl MleEvalPoint { /// See (Section 5.1). fn eval_eq_constraints( eq_interaction: usize, - selector_interaction: usize, + const_interaction: usize, eval: &mut E, mle_eval_point: MleEvalPoint, ) -> E::EF { let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); - let [is_first, is_second] = eval.next_interaction_mask(selector_interaction, [0, -1]); + let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]); // Check the initial value on half_coset0 and final value on half_coset1. // Combining these constraints is safe because `is_first` and `is_second` are never @@ -82,20 +87,9 @@ fn eval_eq_constraints( let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second; eval.add_constraint(half_coset0_initial_check + half_coset1_final_check); - // Check all variables except the last (last variable is handled by the constraint above). - #[allow(clippy::needless_range_loop)] - for variable_i in 0..N_VARIABLES.saturating_sub(1) { - let half_coset0_next = next_next; - let half_coset1_prev = next_next; - let [half_coset0_step, half_coset1_step] = - eval.next_interaction_mask(selector_interaction, [0, -1]); - let carry_quotient = mle_eval_point.eq_carry_quotients[variable_i]; - // Safe to combine these constraints as `is_step.half_coset0` and `is_step.half_coset1` - // are never non-zero at the same time on the trace. - let half_coset0_check = (curr - half_coset0_next * carry_quotient) * half_coset0_step; - let half_coset1_check = (curr * carry_quotient - half_coset1_prev) * half_coset1_step; - eval.add_constraint(half_coset0_check + half_coset1_check); - } + // Check all the steps. + let [carry_quotient] = eval.next_extension_interaction_mask(const_interaction, [0]); + eval.add_constraint(curr - next_next * carry_quotient); curr } @@ -114,6 +108,50 @@ fn eval_prefix_sum_constraints( eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift); } +/// Returns succinct Eq carry quotients column. +/// +/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point +/// `(r0, r1, ...)` let `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. This function +/// returns column `q(P)` such that all `c(C[i]) = c(C[i + 1]) * q(C[i])` and +/// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. +/// +/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain +pub fn gen_carry_quotient_trace( + eval_point: &MleEvalPoint, +) -> SecureEvaluation { + let last_variable = *eval_point.p.last().unwrap(); + let zero = SecureField::zero(); + let one = SecureField::one(); + + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + *half_coset0_carry_quotients.last_mut().unwrap() *= + eq(&[one], &[last_variable]) / eq(&[zero], &[last_variable]); + let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let half_coset_size = size / 2; + let mut col = SecureColumnByCoords::::zeros(size); + + // TODO(andrew): Optimize. + for i in 0..half_coset_size { + let half_coset0_index = coset_index_to_circle_domain_index(i * 2, log_size); + let half_coset1_index = coset_index_to_circle_domain_index(i * 2 + 1, log_size); + let half_coset0_index_bit_rev = bit_reverse_index(half_coset0_index, log_size); + let half_coset1_index_bit_rev = bit_reverse_index(half_coset1_index, log_size); + + let n_trailing_ones = i.trailing_ones() as usize; + let half_coset0_carry_quotient = half_coset0_carry_quotients[n_trailing_ones]; + let half_coset1_carry_quotient = half_coset1_carry_quotients[n_trailing_ones]; + + col.set(half_coset0_index_bit_rev, half_coset0_carry_quotient); + col.set(half_coset1_index_bit_rev, half_coset1_carry_quotient); + } + + let domain = CanonicCoset::new(log_size).circle_domain(); + SecureEvaluation::new(domain, col) +} + #[cfg(test)] mod tests { use std::array; @@ -125,9 +163,10 @@ mod tests { use rand::{Rng, SeedableRng}; use super::{ - eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints, MleEvalPoint, + eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints, + gen_carry_quotient_trace, MleEvalPoint, }; - use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; + use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::{assert_constraints, EvalAtRow}; use crate::core::backend::simd::column::SecureColumn; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; @@ -159,7 +198,7 @@ mod tests { let base_trace = gen_base_trace(&mle, &eval_point); let claim = mle.eval_at_point(&eval_point); let claim_shift = claim / BaseField::from(size); - let constants_trace = gen_constants_trace(N_VARIABLES); + let constants_trace = gen_constants_trace(&mle_eval_point); let traces = TreeVec::new(vec![base_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(log_size); @@ -184,12 +223,12 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); + let constants_trace = gen_constants_trace(&mle_eval_point); let traces = TreeVec::new(vec![base_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); @@ -203,12 +242,12 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); + let constants_trace = gen_constants_trace(&mle_eval_point); let traces = TreeVec::new(vec![base_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); @@ -222,12 +261,12 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); + let constants_trace = gen_constants_trace(&mle_eval_point); let traces = TreeVec::new(vec![base_trace, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); @@ -258,7 +297,7 @@ mod tests { /// /// ```text /// ------------------------------------------------------------------------------------- - /// | MLE coeffs | eq evals (basis) | MLE terms (prefix sum) | + /// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) | /// ------------------------------------------------------------------------------------- /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 | /// ------------------------------------------------------------------------------------- @@ -343,25 +382,13 @@ mod tests { } } - fn gen_constants_trace( - n_variables: usize, + fn gen_constants_trace( + eval_point: &MleEvalPoint, ) -> Vec> { - let log_size = n_variables as u32; + let log_size = N_VARIABLES as u32; let mut constants_trace = Vec::new(); constants_trace.push(gen_is_first(log_size)); - - // TODO(andrew): Note the last selector column is not needed. The column for `is_first` - // with an offset for each half coset midpoint can be used instead. - for variable_i in 1..n_variables as u32 { - let half_coset_log_step = variable_i; - let half_coset_offset = (1 << (half_coset_log_step - 1)) - 1; - - let log_step = half_coset_log_step + 1; - let offset = half_coset_offset * 2; - - constants_trace.push(gen_is_step_with_offset(log_size, log_step, offset)) - } - + constants_trace.extend(gen_carry_quotient_trace(eval_point).into_coordinate_evals()); constants_trace } }