Skip to content

Commit

Permalink
Combine eq step constraints into single constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 23, 2024
1 parent f16ba08 commit 42f2f6f
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 49 deletions.
3 changes: 2 additions & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub trait EvalAtRow {
+ Debug
+ Zero
+ Neg<Output = Self::F>
+ AddAssign<Self::F>
+ AddAssign
+ AddAssign<BaseField>
+ Add<Self::F, Output = Self::F>
+ Sub<Self::F, Output = Self::F>
Expand All @@ -52,6 +52,7 @@ pub trait EvalAtRow {
+ Zero
+ From<Self::F>
+ Neg<Output = Self::EF>
+ AddAssign
+ Add<SecureField, Output = Self::EF>
+ Sub<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
Expand Down
7 changes: 7 additions & 0 deletions crates/prover/src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ impl<B: FieldOps<BaseField>, EvalOrder> SecureEvaluation<B, EvalOrder> {
_eval_order: PhantomData,
}
}

pub fn into_coordinate_evals(
self,
) -> [CircleEvaluation<B, BaseField, EvalOrder>; SECURE_EXTENSION_DEGREE] {
let Self { domain, values, .. } = self;
values.columns.map(|c| CircleEvaluation::new(domain, c))
}
}

impl<B: FieldOps<BaseField>, EvalOrder> Deref for SecureEvaluation<B, EvalOrder> {
Expand Down
123 changes: 75 additions & 48 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@ use std::array;
use num_traits::{One, Zero};

use crate::constraint_framework::EvalAtRow;
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::eq;
use crate::core::poly::circle::{CanonicCoset, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

/// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point.
///
/// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the
/// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`.
pub fn eval_mle_eval_constraints<E: EvalAtRow, const N_VARIABLES: usize>(
mle_interaction: usize,
selector_interaction: usize,
const_interaction: usize,
eval: &mut E,
mle_coeffs_col_eval: E::EF,
mle_eval_point: MleEvalPoint<N_VARIABLES>,
mle_claim_shift: SecureField,
) {
let eq_col_eval =
eval_eq_constraints(mle_interaction, selector_interaction, eval, mle_eval_point);
let eq_col_eval = eval_eq_constraints(mle_interaction, const_interaction, eval, mle_eval_point);
let terms_col_eval = mle_coeffs_col_eval * eq_col_eval;
eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift)
}
Expand All @@ -34,7 +39,7 @@ pub struct MleEvalPoint<const N_VARIABLES: usize> {
// Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`.
eq_carry_quotients: [SecureField; N_VARIABLES],
// Point `p`.
_p: [SecureField; N_VARIABLES],
p: [SecureField; N_VARIABLES],
}

impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
Expand All @@ -53,7 +58,7 @@ impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
denom_assignment[i] = one;
eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1])
}),
_p: p,
p,
}
}
}
Expand All @@ -68,12 +73,12 @@ impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
/// See <https://eprint.iacr.org/2023/1284.pdf> (Section 5.1).
fn eval_eq_constraints<E: EvalAtRow, const N_VARIABLES: usize>(
eq_interaction: usize,
selector_interaction: usize,
const_interaction: usize,
eval: &mut E,
mle_eval_point: MleEvalPoint<N_VARIABLES>,
) -> E::EF {
let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]);
let [is_first, is_second] = eval.next_interaction_mask(selector_interaction, [0, -1]);
let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]);

// Check the initial value on half_coset0 and final value on half_coset1.
// Combining these constraints is safe because `is_first` and `is_second` are never
Expand All @@ -82,20 +87,9 @@ fn eval_eq_constraints<E: EvalAtRow, const N_VARIABLES: usize>(
let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second;
eval.add_constraint(half_coset0_initial_check + half_coset1_final_check);

// Check all variables except the last (last variable is handled by the constraint above).
#[allow(clippy::needless_range_loop)]
for variable_i in 0..N_VARIABLES.saturating_sub(1) {
let half_coset0_next = next_next;
let half_coset1_prev = next_next;
let [half_coset0_step, half_coset1_step] =
eval.next_interaction_mask(selector_interaction, [0, -1]);
let carry_quotient = mle_eval_point.eq_carry_quotients[variable_i];
// Safe to combine these constraints as `is_step.half_coset0` and `is_step.half_coset1`
// are never non-zero at the same time on the trace.
let half_coset0_check = (curr - half_coset0_next * carry_quotient) * half_coset0_step;
let half_coset1_check = (curr * carry_quotient - half_coset1_prev) * half_coset1_step;
eval.add_constraint(half_coset0_check + half_coset1_check);
}
// Check all the steps.
let [carry_quotient] = eval.next_extension_interaction_mask(const_interaction, [0]);
eval.add_constraint(curr - next_next * carry_quotient);

curr
}
Expand All @@ -114,6 +108,50 @@ fn eval_prefix_sum_constraints<E: EvalAtRow>(
eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift);
}

/// Returns succinct Eq carry quotients column.
///
/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point
/// `(r0, r1, ...)` let `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. This function
/// returns column `q(P)` such that all `c(C[i]) = c(C[i + 1]) * q(C[i])` and
/// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`.
///
/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain
pub fn gen_carry_quotient_trace<const N_VARIABLES: usize>(
eval_point: &MleEvalPoint<N_VARIABLES>,
) -> SecureEvaluation<SimdBackend, BitReversedOrder> {
let last_variable = *eval_point.p.last().unwrap();
let zero = SecureField::zero();
let one = SecureField::one();

let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients;
*half_coset0_carry_quotients.last_mut().unwrap() *=
eq(&[one], &[last_variable]) / eq(&[zero], &[last_variable]);
let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse());

let log_size = N_VARIABLES as u32;
let size = 1 << log_size;
let half_coset_size = size / 2;
let mut col = SecureColumnByCoords::<SimdBackend>::zeros(size);

// TODO(andrew): Optimize.
for i in 0..half_coset_size {
let half_coset0_index = coset_index_to_circle_domain_index(i * 2, log_size);
let half_coset1_index = coset_index_to_circle_domain_index(i * 2 + 1, log_size);
let half_coset0_index_bit_rev = bit_reverse_index(half_coset0_index, log_size);
let half_coset1_index_bit_rev = bit_reverse_index(half_coset1_index, log_size);

let n_trailing_ones = i.trailing_ones() as usize;
let half_coset0_carry_quotient = half_coset0_carry_quotients[n_trailing_ones];
let half_coset1_carry_quotient = half_coset1_carry_quotients[n_trailing_ones];

col.set(half_coset0_index_bit_rev, half_coset0_carry_quotient);
col.set(half_coset1_index_bit_rev, half_coset1_carry_quotient);
}

let domain = CanonicCoset::new(log_size).circle_domain();
SecureEvaluation::new(domain, col)
}

#[cfg(test)]
mod tests {
use std::array;
Expand All @@ -125,9 +163,10 @@ mod tests {
use rand::{Rng, SeedableRng};

use super::{
eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints, MleEvalPoint,
eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints,
gen_carry_quotient_trace, MleEvalPoint,
};
use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
Expand Down Expand Up @@ -159,7 +198,7 @@ mod tests {
let base_trace = gen_base_trace(&mle, &eval_point);
let claim = mle.eval_at_point(&eval_point);
let claim_shift = claim / BaseField::from(size);
let constants_trace = gen_constants_trace(N_VARIABLES);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(log_size);
Expand All @@ -184,12 +223,12 @@ mod tests {
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
Expand All @@ -203,12 +242,12 @@ mod tests {
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
Expand All @@ -222,12 +261,12 @@ mod tests {
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::<N_VARIABLES>::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);
let trace_domain = CanonicCoset::new(N_VARIABLES as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
Expand Down Expand Up @@ -258,7 +297,7 @@ mod tests {
///
/// ```text
/// -------------------------------------------------------------------------------------
/// | MLE coeffs | eq evals (basis) | MLE terms (prefix sum) |
/// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) |
/// -------------------------------------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 |
/// -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -343,25 +382,13 @@ mod tests {
}
}

fn gen_constants_trace(
n_variables: usize,
fn gen_constants_trace<const N_VARIABLES: usize>(
eval_point: &MleEvalPoint<N_VARIABLES>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = n_variables as u32;
let log_size = N_VARIABLES as u32;
let mut constants_trace = Vec::new();
constants_trace.push(gen_is_first(log_size));

// TODO(andrew): Note the last selector column is not needed. The column for `is_first`
// with an offset for each half coset midpoint can be used instead.
for variable_i in 1..n_variables as u32 {
let half_coset_log_step = variable_i;
let half_coset_offset = (1 << (half_coset_log_step - 1)) - 1;

let log_step = half_coset_log_step + 1;
let offset = half_coset_offset * 2;

constants_trace.push(gen_is_step_with_offset(log_size, log_step, offset))
}

constants_trace.extend(gen_carry_quotient_trace(eval_point).into_coordinate_evals());
constants_trace
}
}

0 comments on commit 42f2f6f

Please sign in to comment.