diff --git a/src/core/air/evaluation.rs b/src/core/air/evaluation.rs index c82e21cc2..91239d007 100644 --- a/src/core/air/evaluation.rs +++ b/src/core/air/evaluation.rs @@ -6,14 +6,50 @@ use core::slice; use super::{Component, ComponentTrace, ComponentVisitor}; +use crate::core::backend::cpu::CPUCircleEvaluation; use crate::core::backend::{Backend, CPUBackend}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::{Col, Field}; -use crate::core::poly::circle::{CircleDomain, CircleEvaluation, CirclePoly}; +use crate::core::fields::{Col, Column, ExtensionOf, Field}; +use crate::core::poly::circle::{CircleDomain, CirclePoly}; +use crate::core::utils::IteratorMutExt; use crate::core::{ColumnVec, ComponentVec}; +// TODO(spapini): find a better place for this +pub struct SecureColumn { + pub cols: [Col; >::EXTENSION_DEGREE], +} + +impl SecureColumn { + fn at(&self, index: usize) -> SecureField { + SecureField::from_m31_array(std::array::from_fn(|i| self.cols[i][index])) + } + + fn set(&mut self, index: usize, value: SecureField) { + self.cols + .iter_mut() + .map(|c| &mut c[index]) + .assign(value.to_m31_array()); + } +} + +impl SecureColumn { + pub fn zeros(len: usize) -> Self { + Self { + cols: std::array::from_fn(|_| Col::::zeros(len)), + } + } + + pub fn len(&self) -> usize { + self.cols[0].len() + } + + pub fn is_empty(&self) -> bool { + self.cols[0].is_empty() + } +} + /// Accumulates evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. pub struct PointEvaluationAccumulator { @@ -41,6 +77,7 @@ impl PointEvaluationAccumulator { /// Accumulates u_i(P0), a polynomial evaluation at a P0. pub fn accumulate(&mut self, log_size: u32, evaluation: SecureField) { + assert!(log_size > 0 && log_size < self.sub_accumulations.len() as u32); let sub_accumulation = &mut self.sub_accumulations[log_size as usize]; *sub_accumulation = *sub_accumulation * self.random_coeff + evaluation; @@ -69,7 +106,7 @@ pub struct DomainEvaluationAccumulator { /// Accumulated evaluations for each log_size. /// Each `sub_accumulation` holds `sum_{i=0}^{n-1} evaluation_i * alpha^(n-1-i)`, /// where `n` is the number of accumulated evaluations for this log_size. - sub_accumulations: Vec>, + sub_accumulations: Vec>, /// Number of accumulated evaluations for each log_size. n_cols_per_size: Vec, } @@ -83,7 +120,7 @@ impl DomainEvaluationAccumulator { Self { random_coeff, sub_accumulations: (0..(max_log_size + 1)) - .map(|n| Col::::from_iter(vec![SecureField::default(); 1 << n])) + .map(|n| SecureColumn::zeros(1 << n)) .collect(), n_cols_per_size: vec![0; max_log_size + 1], } @@ -99,6 +136,7 @@ impl DomainEvaluationAccumulator { n_cols_per_size: [(u32, usize); N], ) -> [ColumnAccumulator<'_, B>; N] { n_cols_per_size.iter().for_each(|(log_size, n_col)| { + assert!(*log_size > 0 && *log_size < self.sub_accumulations.len() as u32); self.n_cols_per_size[*log_size as usize] += n_col; }); self.sub_accumulations @@ -119,49 +157,50 @@ impl DomainEvaluationAccumulator { impl DomainEvaluationAccumulator { /// Computes f(P) as coefficients. pub fn finalize(self) -> CirclePoly { - let mut res_coeffs = vec![SecureField::default(); 1 << self.log_size()]; + let mut res_coeffs = SecureColumn::::zeros(1 << self.log_size()); let res_log_size = self.log_size(); - for (coeffs, n_cols) in self + let res_size = 1 << res_log_size; + + for ((log_size, values), n_cols) in self .sub_accumulations .into_iter() .enumerate() - .map(|(log_size, values)| { - if log_size == 0 { - return values; - } - CircleEvaluation::::new( - CircleDomain::constraint_evaluation_domain(log_size as u32), - values, - ) - .interpolate() - .extend(res_log_size) - .coeffs - }) .zip(self.n_cols_per_size.iter()) + .skip(1) { + let coeffs = SecureColumn { + cols: values.cols.map(|c| { + CPUCircleEvaluation::new( + CircleDomain::constraint_evaluation_domain(log_size as u32), + c, + ) + .interpolate() + .extend(res_log_size) + .coeffs + }), + }; // Add poly.coeffs into coeffs, elementwise, inplace. let multiplier = self.random_coeff.pow(*n_cols as u128); - res_coeffs - .iter_mut() - .zip(coeffs.iter()) - .for_each(|(res_coeff, current_coeff)| { - *res_coeff = *res_coeff * multiplier + *current_coeff - }); + for i in 0..res_size { + let res_coeff = res_coeffs.at(i) * multiplier + coeffs.at(i); + res_coeffs.set(i, res_coeff); + } } - CirclePoly::new(res_coeffs) + // TODO(spapini): Return multiple polys instead. + CirclePoly::new((0..res_size).map(|i| res_coeffs.at(i)).collect::>()) } } /// An domain accumulator for polynomials of a single size. pub struct ColumnAccumulator<'a, B: Backend> { random_coeff: SecureField, - col: &'a mut Col, + col: &'a mut SecureColumn, } impl<'a> ColumnAccumulator<'a, CPUBackend> { pub fn accumulate(&mut self, index: usize, evaluation: BaseField) { - let accum = &mut self.col[index]; - *accum = *accum * self.random_coeff + evaluation; + let val = self.col.at(index) * self.random_coeff + evaluation; + self.col.set(index, val); } } diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs index d01669f1d..b340a6140 100644 --- a/src/core/backend/cpu/mod.rs +++ b/src/core/backend/cpu/mod.rs @@ -23,7 +23,7 @@ impl FieldOps for CPUBackend { } } -impl Column for Vec { +impl Column for Vec { fn zeros(len: usize) -> Self { vec![F::zero(); len] } diff --git a/src/core/fields/cm31.rs b/src/core/fields/cm31.rs index 176291034..e3620e5b5 100644 --- a/src/core/fields/cm31.rs +++ b/src/core/fields/cm31.rs @@ -13,7 +13,7 @@ pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2 /// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial. /// Represented as (a, b) of a + bi. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CM31(M31, M31); +pub struct CM31(pub M31, pub M31); impl_field!(CM31, P2); impl_extension_field!(CM31, M31); diff --git a/src/core/fields/mod.rs b/src/core/fields/mod.rs index 417f63065..fe6649d57 100644 --- a/src/core/fields/mod.rs +++ b/src/core/fields/mod.rs @@ -13,8 +13,11 @@ pub trait FieldOps { type Column: Column; fn bit_reverse_column(column: Self::Column) -> Self::Column; } + pub type Col = >::Column; -pub trait Column: Clone + Debug + Index + FromIterator { + +// TODO(spapini): Consider removing the generic parameter and only support BaseField. +pub trait Column: Clone + Debug + Index + FromIterator { fn zeros(len: usize) -> Self; fn to_vec(&self) -> Vec; fn len(&self) -> usize; @@ -96,9 +99,13 @@ pub trait ComplexConjugate { fn complex_conjugate(&self) -> Self; } -pub trait ExtensionOf: Field + From + NumOps + NumAssignOps {} +pub trait ExtensionOf: Field + From + NumOps + NumAssignOps { + const EXTENSION_DEGREE: usize; +} -impl ExtensionOf for F {} +impl ExtensionOf for F { + const EXTENSION_DEGREE: usize = 1; +} #[macro_export] macro_rules! impl_field { @@ -182,7 +189,10 @@ macro_rules! impl_extension_field { ($field_name: ty, $extended_field_name: ty) => { use $crate::core::fields::ExtensionOf; - impl ExtensionOf for $field_name {} + impl ExtensionOf for $field_name { + const EXTENSION_DEGREE: usize = + <$extended_field_name as ExtensionOf>::EXTENSION_DEGREE * 2; + } impl Add for $field_name { type Output = Self; diff --git a/src/core/fields/qm31.rs b/src/core/fields/qm31.rs index c7e58eff1..1deda61b1 100644 --- a/src/core/fields/qm31.rs +++ b/src/core/fields/qm31.rs @@ -37,6 +37,10 @@ impl QM31 { pub fn from_m31_array(array: [M31; 4]) -> Self { Self::from_m31(array[0], array[1], array[2], array[3]) } + + pub fn to_m31_array(self) -> [M31; 4] { + [self.0 .0, self.0 .1, self.1 .0, self.1 .1] + } } impl Display for QM31 { diff --git a/src/core/utils.rs b/src/core/utils.rs index 369a018d6..2839df5c1 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -1,3 +1,14 @@ +pub trait IteratorMutExt<'a, T: 'a>: Iterator { + fn assign(self, other: impl IntoIterator) + where + Self: Sized, + { + self.zip(other).for_each(|(a, b)| *a = b); + } +} + +impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} + pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { i.reverse_bits() >> (usize::BITS - log_size) }