Skip to content

Commit

Permalink
Add build_trace functions for MLE eval component
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 26, 2024
1 parent 8b554c7 commit 1093cee
Showing 1 changed file with 146 additions and 99 deletions.
245 changes: 146 additions & 99 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@
#![allow(dead_code)]

use std::array;
use std::iter::zip;

use itertools::{chain, zip_eq};
use num_traits::{One, Zero};

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::poly::circle::{CanonicCoset, SecureEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

Expand Down Expand Up @@ -121,6 +130,64 @@ fn eval_prefix_sum_constraints<E: EvalAtRow>(
eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift);
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// ---------------------------------------------------------
/// | EqEvals (basis) | MLE terms (prefix sum) |
/// ---------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 |
/// ---------------------------------------------------------
/// ```
pub fn build_trace(
mle: &Mle<SimdBackend, SecureField>,
eval_point: &[SecureField],
claim: SecureField,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals();
let mle_terms = hadamard_product(mle, &eq_evals);

let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns;
let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns;

#[cfg(test)]
assert_eq!(claim, mle.eval_at_point(eval_point));
let shift = claim / BaseField::from(mle.len());
let packed_shift_coords = PackedSecureField::broadcast(shift).into_packed_m31s();
let mut shifted_mle_terms_cols = mle_terms_cols.clone();
zip(&mut shifted_mle_terms_cols, packed_shift_coords)
.for_each(|(col, shift_coord)| col.data.iter_mut().for_each(|v| *v -= shift_coord));
let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum);

let log_trace_domain_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain();

chain![eq_evals_cols, shifted_prefix_sum_cols]
.map(|c| CircleEvaluation::new(trace_domain, c))
.collect()
}

/// Generates a trace.
///
/// Trace structure:
/// 1. Is first selector column (see [gen_is_first]).
/// 2. Eq carry quotients column (see [gen_carry_quotient_trace]).
///
/// ```text
/// ------------------------------------------------
/// | is first selector | eq carry quotients |
/// ------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 |
/// ------------------------------------------------
/// ```
pub fn build_constant_trace<const N_VARIABLES: usize>(
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = N_VARIABLES as u32;
vec![gen_is_first(log_size)]
}

/// Returns succinct Eq carry quotients column.
///
/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point
Expand All @@ -130,10 +197,11 @@ fn eval_prefix_sum_constraints<E: EvalAtRow>(
///
/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain
fn gen_carry_quotient_col<const N_VARIABLES: usize>(
eval_point: &MleEvalPoint<N_VARIABLES>,
eval_point: &[SecureField; N_VARIABLES],
) -> SecureEvaluation<SimdBackend, BitReversedOrder> {
let mle_eval_point = MleEvalPoint::new(*eval_point);
let (half_coset0_carry_quotients, half_coset1_carry_quotients) =
gen_half_coset_carry_quotients(eval_point);
gen_half_coset_carry_quotients(&mle_eval_point);

let log_size = N_VARIABLES as u32;
let size = 1 << log_size;
Expand Down Expand Up @@ -236,12 +304,24 @@ fn gen_half_coset_carry_quotients<const N_VARIABLES: usize>(
(half_coset0_carry_quotients, half_coset1_carry_quotients)
}

/// Returns the element-wise product of `a` and `b`.
fn hadamard_product(
a: &Col<SimdBackend, SecureField>,
b: &Col<SimdBackend, SecureField>,
) -> Col<SimdBackend, SecureField> {
assert_eq!(a.len(), b.len());
SecureColumn {
data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(),
length: a.len(),
}
}

#[cfg(test)]
mod tests {
use std::array;
use std::iter::{repeat, zip};

use itertools::{chain, zip_eq, Itertools};
use itertools::{chain, Itertools};
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
Expand All @@ -250,50 +330,56 @@ mod tests {
eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints,
eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint,
};
use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset};
use crate::constraint_framework::constant_columns::gen_is_step_with_offset;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset;
use crate::examples::xor::gkr_lookups::mle_eval::{
build_constant_trace, build_trace, eval_step_selector_with_offset,
};

#[test]
fn test_mle_eval_constraints_with_log_size_5() {
const N_VARIABLES: usize = 5;
const EVAL_TRACE: usize = 0;
const CARRY_QUOTIENTS_TRACE: usize = 1;
const CONST_TRACE: usize = 2;
const COEFFS_COL_TRACE: usize = 0;
const EVAL_TRACE: usize = 1;
const CARRY_QUOTIENTS_TRACE: usize = 2;
const CONST_TRACE: usize = 3;
let mut rng = SmallRng::seed_from_u64(0);
let log_size = N_VARIABLES as u32;
let size = 1 << log_size;
let mle = Mle::new((0..size).map(|_| rng.gen::<SecureField>()).collect());
let mle_coeffs = (0..size).map(|_| rng.gen::<SecureField>()).collect();
let mle = Mle::<SimdBackend, SecureField>::new(mle_coeffs);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let claim = mle.eval_at_point(&eval_point);
let mle_eval_point = MleEvalPoint::new(eval_point);
let mle_eval_trace = build_trace(&mle, &eval_point, claim);
let mle_coeffs_col_trace = build_mle_coeffs_trace(mle);
let claim_shift = claim / BaseField::from(size);
let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point)
.into_coordinate_evals()
.to_vec();
let constants_trace = gen_constants_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]);
let carry_quotients_col = gen_carry_quotient_col(&eval_point);
let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec();
let constants_trace = build_constant_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![
mle_coeffs_col_trace,
mle_eval_trace,
carry_quotients_trace,
constants_trace,
]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(log_size);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(COEFFS_COL_TRACE, [0]);
let [carry_quotients_col_eval] =
eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]);
eval_mle_eval_constraints(
Expand All @@ -319,17 +405,15 @@ mod tests {
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point)
.into_coordinate_evals()
.to_vec();
let constants_trace = gen_constants_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let carry_quotients_col = gen_carry_quotient_col(&eval_point);
let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec();
let constants_trace = build_constant_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
let [carry_quotients_col_eval] =
eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]);
eval_eq_constraints(
Expand All @@ -352,17 +436,15 @@ mod tests {
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point)
.into_coordinate_evals()
.to_vec();
let constants_trace = gen_constants_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let carry_quotients_col = gen_carry_quotient_col(&eval_point);
let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec();
let constants_trace = build_constant_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
let [carry_quotients_col_eval] =
eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]);
eval_eq_constraints(
Expand All @@ -385,17 +467,15 @@ mod tests {
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point)
.into_coordinate_evals()
.to_vec();
let constants_trace = gen_constants_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let carry_quotients_col = gen_carry_quotient_col(&eval_point);
let carry_quotients_trace = carry_quotients_col.into_coordinate_evals().to_vec();
let constants_trace = build_constant_trace::<N_VARIABLES>();
let traces = TreeVec::new(vec![trace, carry_quotients_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
let [carry_quotients_col_eval] =
eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]);
eval_eq_constraints(
Expand Down Expand Up @@ -444,8 +524,9 @@ mod tests {
fn eval_carry_quotient_col_works() {
const N_VARIABLES: usize = 5;
let mut rng = SmallRng::seed_from_u64(0);
let mle_eval_point = MleEvalPoint::<N_VARIABLES>::new(array::from_fn(|_| rng.gen()));
let col_eval = gen_carry_quotient_col(&mle_eval_point);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let col_eval = gen_carry_quotient_col(&eval_point);
let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset);
let col_poly = col_eval.interpolate_with_twiddles(&twiddles);
let p = SECURE_FIELD_CIRCLE_GEN;
Expand All @@ -455,45 +536,6 @@ mod tests {
assert_eq!(eval, col_poly.eval_at_point(p));
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// -------------------------------------------------------------------------------------
/// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) |
/// -------------------------------------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 |
/// -------------------------------------------------------------------------------------
/// ```
fn gen_base_trace(
mle: &Mle<SimdBackend, SecureField>,
eval_point: &[SecureField],
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let mle_coeffs = mle.clone().into_evals();
let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals();
let mle_terms = hadamard_product(&mle_coeffs, &eq_evals);

let mle_coeff_cols = mle_coeffs.into_secure_column_by_coords().columns;
let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns;
let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns;

let claim = mle.eval_at_point(eval_point);
let shift = claim / BaseField::from(mle.len());
let packed_shifts = PackedSecureField::broadcast(shift).into_packed_m31s();
let mut shifted_mle_terms_cols = mle_terms_cols.clone();
zip(&mut shifted_mle_terms_cols, packed_shifts)
.for_each(|(col, shift)| col.data.iter_mut().for_each(|v| *v -= shift));
let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum);

let log_trace_domain_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain();

chain![mle_coeff_cols, eq_evals_cols, shifted_prefix_sum_cols]
.map(|c| CircleEvaluation::new(trace_domain, c))
.collect()
}

/// Generates a trace.
///
/// Trace structure:
Expand Down Expand Up @@ -534,21 +576,26 @@ mod tests {
.collect()
}

/// Returns the element-wise product of `a` and `b`.
fn hadamard_product(
a: &Col<SimdBackend, SecureField>,
b: &Col<SimdBackend, SecureField>,
) -> Col<SimdBackend, SecureField> {
assert_eq!(a.len(), b.len());
SecureColumn {
data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(),
length: a.len(),
}
}

fn gen_constants_trace<const N_VARIABLES: usize>(
/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// -----------------------------
/// | MLE coeffs col |
/// -----------------------------
/// | c0 | c1 | c2 | c3 |
/// -----------------------------
/// ```
fn build_mle_coeffs_trace(
mle: Mle<SimdBackend, SecureField>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = N_VARIABLES as u32;
vec![gen_is_first(log_size)]
let log_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_size).circle_domain();
let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords();
SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords)
.into_coordinate_evals()
.into_iter()
.collect()
}
}

0 comments on commit 1093cee

Please sign in to comment.