From 8cc4e988502d750aaa431e3ee3f00f78f6a1f186 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 20 Aug 2024 23:01:15 -0400 Subject: [PATCH] Add MleCollection for accumulating MLEs --- crates/prover/src/core/backend/simd/column.rs | 6 + crates/prover/src/core/fields/qm31.rs | 18 ++ crates/prover/src/core/lookups/mle.rs | 8 +- .../examples/xor/gkr_lookups/accumulation.rs | 186 ++++++++++++++++++ .../src/examples/xor/gkr_lookups/mod.rs | 1 + 5 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 crates/prover/src/examples/xor/gkr_lookups/accumulation.rs diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index db98d2669..64869405a 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -66,6 +66,12 @@ impl BaseColumn { .map(BaseColumnMutSlice) .collect_vec() } + + pub fn into_secure_column(self) -> SecureColumn { + let length = self.len(); + let data = self.data.into_iter().map(PackedSecureField::from).collect(); + SecureColumn { data, length } + } } impl Column for BaseColumn { diff --git a/crates/prover/src/core/fields/qm31.rs b/crates/prover/src/core/fields/qm31.rs index 4467aeb41..6da19a3c0 100644 --- a/crates/prover/src/core/fields/qm31.rs +++ b/crates/prover/src/core/fields/qm31.rs @@ -85,6 +85,24 @@ impl Mul for QM31 { } } +impl From for QM31 { + fn from(value: usize) -> Self { + M31::from(value).into() + } +} + +impl From for QM31 { + fn from(value: u32) -> Self { + M31::from(value).into() + } +} + +impl From for QM31 { + fn from(value: i32) -> Self { + M31::from(value).into() + } +} + impl TryInto for QM31 { type Error = (); diff --git a/crates/prover/src/core/lookups/mle.rs b/crates/prover/src/core/lookups/mle.rs index 7ac7f9eb3..7449f40f2 100644 --- a/crates/prover/src/core/lookups/mle.rs +++ b/crates/prover/src/core/lookups/mle.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use educe::Educe; @@ -58,6 +58,12 @@ impl, F: Field> Deref for Mle { } } +impl, F: Field> DerefMut for Mle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.evals + } +} + #[cfg(test)] mod test { use super::{Mle, MleOps}; diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs new file mode 100644 index 000000000..53ae956e6 --- /dev/null +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -0,0 +1,186 @@ +use std::iter::zip; +use std::ops::{AddAssign, Mul}; + +use educe::Educe; +use num_traits::One; + +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Backend; +use crate::core::circle::M31_CIRCLE_LOG_ORDER; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::mle::Mle; +use crate::core::utils::generate_secure_powers; + +pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; + +/// Max number of variables for multilinear polynomials that get compiled into a univariate +/// IOP for multilinear eval at point. +pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; + +/// Accumulates [`Mle`]s grouped by their number of variables. +pub struct MleCollection { + mles_by_n_variables: Vec>>>, +} + +impl MleCollection { + /// Appends an [`Mle`] to the collection. + pub fn push(&mut self, mle: impl Into>) { + let mle = mle.into(); + let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); + mles.push(mle); + } +} + +impl MleCollection { + /// Performs a random linear combination of all MLEs, grouped by their number of variables. + /// + /// MLEs are returned in ascending order by number of variables. + pub fn random_linear_combine_by_n_variables( + self, + alpha: SecureField, + ) -> Vec> { + self.mles_by_n_variables + .into_iter() + .flatten() + .map(|mles| mle_random_linear_combination(mles, alpha)) + .collect() + } +} + +/// # Panics +/// +/// Panics if `mles` is empty or all MLEs don't have the same number of variables. +fn mle_random_linear_combination( + mles: Vec>, + alpha: SecureField, +) -> Mle { + assert!(!mles.is_empty()); + let n_variables = mles[0].n_variables(); + assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); + let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); + let mut mle_and_coeff = zip(mles, alpha_powers); + + // The last value can initialize the accumulator. + let (mle, coeff) = mle_and_coeff.next_back().unwrap(); + assert!(coeff.is_one()); + let mut acc_mle = mle.into_secure_mle(); + + for (mle, coeff) in mle_and_coeff { + match mle { + DynMle::Base(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), + DynMle::Secure(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), + } + } + + acc_mle +} + +/// Computes all `acc[i] += alpha * v[i]`. +pub fn combine + Copy, F: Copy>( + acc: &mut [EF], + v: &[F], + alpha: EF, +) { + assert_eq!(acc.len(), v.len()); + zip(acc, v).for_each(|(acc, &v)| *acc += alpha * v); +} + +impl Default for MleCollection { + fn default() -> Self { + Self { + mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], + } + } +} + +/// Dynamic dispatch for [`Mle`] types. +#[derive(Educe)] +#[educe(Debug, Clone)] +pub enum DynMle { + Base(Mle), + Secure(Mle), +} + +impl DynMle { + fn n_variables(&self) -> usize { + match self { + DynMle::Base(mle) => mle.n_variables(), + DynMle::Secure(mle) => mle.n_variables(), + } + } +} + +impl From> for DynMle { + fn from(mle: Mle) -> Self { + DynMle::Secure(mle) + } +} + +impl From> for DynMle { + fn from(mle: Mle) -> Self { + DynMle::Base(mle) + } +} + +impl DynMle { + fn into_secure_mle(self) -> Mle { + match self { + Self::Base(mle) => Mle::new(mle.into_evals().into_secure_column()), + Self::Secure(mle) => mle, + } + } +} + +#[cfg(test)] +mod tests { + use std::iter::repeat; + + use num_traits::Zero; + + use crate::core::backend::simd::SimdBackend; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::Field; + use crate::core::lookups::mle::{Mle, MleOps}; + use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + + #[test] + fn random_linear_combine_by_n_variables() { + const SMALL_N_VARS: usize = 4; + const LARGE_N_VARS: usize = 6; + let alpha = SecureField::from(10); + let mut mle_collection = MleCollection::::default(); + mle_collection.push(const_mle(SMALL_N_VARS, BaseField::from(1))); + mle_collection.push(const_mle(SMALL_N_VARS, SecureField::from(2))); + mle_collection.push(const_mle(LARGE_N_VARS, BaseField::from(3))); + mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(4))); + mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(5))); + let small_eval_point = [SecureField::zero(); SMALL_N_VARS]; + let large_eval_point = [SecureField::zero(); LARGE_N_VARS]; + + let [small_mle, large_mle] = mle_collection + .random_linear_combine_by_n_variables(alpha) + .try_into() + .unwrap(); + + assert_eq!(small_mle.n_variables(), SMALL_N_VARS); + assert_eq!(large_mle.n_variables(), LARGE_N_VARS); + assert_eq!( + small_mle.eval_at_point(&small_eval_point), + SecureField::from(1) * alpha + SecureField::from(2) + ); + assert_eq!( + large_mle.eval_at_point(&large_eval_point), + (SecureField::from(3) * alpha + SecureField::from(4)) * alpha + SecureField::from(5) + ); + } + + fn const_mle(n_variables: usize, v: F) -> Mle + where + B: MleOps, + F: Field, + { + Mle::new(repeat(v).take(1 << n_variables).collect()) + } +} diff --git a/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/crates/prover/src/examples/xor/gkr_lookups/mod.rs index 8c570c3ba..6ee603eb0 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mod.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mod.rs @@ -1 +1,2 @@ +pub mod accumulation; pub mod mle_eval;