diff --git a/src/core/air/accumulation.rs b/src/core/air/accumulation.rs index 338dda855..117eee9d2 100644 --- a/src/core/air/accumulation.rs +++ b/src/core/air/accumulation.rs @@ -138,6 +138,9 @@ impl DomainEvaluationAccumulator { .zip(self.n_cols_per_size.iter()) .skip(1) { + if *n_cols == 0 { + continue; + } let coeffs = SecureColumn:: { columns: values.columns.map(|c| { CircleEvaluation::::new( @@ -160,8 +163,8 @@ impl DomainEvaluationAccumulator { /// An domain accumulator for polynomials of a single size. pub struct ColumnAccumulator<'a, B: Backend> { - random_coeff_pow: SecureField, - col: &'a mut SecureColumn, + pub random_coeff_pow: SecureField, + pub col: &'a mut SecureColumn, } impl<'a> ColumnAccumulator<'a, CPUBackend> { pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { diff --git a/src/core/backend/avx512/accumulation.rs b/src/core/backend/avx512/accumulation.rs index 2191ccc0d..0754e46aa 100644 --- a/src/core/backend/avx512/accumulation.rs +++ b/src/core/backend/avx512/accumulation.rs @@ -7,8 +7,8 @@ use crate::core::fields::secure_column::SecureColumn; impl AccumulationOps for AVX512Backend { fn accumulate(column: &mut SecureColumn, alpha: SecureField, other: &SecureColumn) { let alpha = PackedQM31::broadcast(alpha); - for i in 0..column.len() { - unsafe { + unsafe { + for i in 0..column.n_packs() { let res_coeff = column.get_packed(i) * alpha + other.get_packed(i); column.set_packed(i, res_coeff); } diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 6573ba7eb..fc9dcbaa3 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -131,6 +131,10 @@ impl FromIterator for BaseFieldVec { } impl SecureColumn { + pub fn n_packs(&self) -> usize { + self.columns[0].data.len() + } + /// # Safety /// /// Calling this method with an out-of-bounds index is undefined behavior. diff --git a/src/fibonacci/air.rs b/src/examples/fibonacci/air.rs similarity index 100% rename from src/fibonacci/air.rs rename to src/examples/fibonacci/air.rs diff --git a/src/fibonacci/component.rs b/src/examples/fibonacci/component.rs similarity index 100% rename from src/fibonacci/component.rs rename to src/examples/fibonacci/component.rs diff --git a/src/fibonacci/mod.rs b/src/examples/fibonacci/mod.rs similarity index 100% rename from src/fibonacci/mod.rs rename to src/examples/fibonacci/mod.rs diff --git a/src/examples/mod.rs b/src/examples/mod.rs new file mode 100644 index 000000000..cc38b619c --- /dev/null +++ b/src/examples/mod.rs @@ -0,0 +1,2 @@ +pub mod fibonacci; +pub mod wide_fib; diff --git a/src/examples/wide_fib/mod.rs b/src/examples/wide_fib/mod.rs new file mode 100644 index 000000000..dd4440117 --- /dev/null +++ b/src/examples/wide_fib/mod.rs @@ -0,0 +1,170 @@ +use itertools::Itertools; +use num_traits::One; + +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Air, Component, ComponentTrace, Mask}; +use crate::core::backend::avx512::qm31::PackedQM31; +use crate::core::backend::avx512::{AVX512Backend, PackedBaseField, VECS_LOG_SIZE}; +use crate::core::backend::{Col, Column}; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse_index; +use crate::core::ColumnVec; + +const N_COLS: usize = 1 << 8; + +pub struct WideFibAir { + component: WideFibComponent, +} +impl Air for WideFibAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } +} +pub struct WideFibComponent { + pub log_size: u32, +} + +pub fn gen_trace( + log_size: usize, +) -> ColumnVec> { + assert!(log_size >= VECS_LOG_SIZE); + let mut trace = (0..N_COLS) + .map(|_| Col::::zeros(1 << log_size)) + .collect_vec(); + for vec_index in 0..(1 << (log_size - VECS_LOG_SIZE)) { + let mut a = PackedBaseField::one(); + let mut b = PackedBaseField::one(); + trace[0].data[vec_index] = a; + trace[1].data[vec_index] = b; + trace.iter_mut().take(log_size).skip(2).for_each(|col| { + (a, b) = (b, a.square() + b.square()); + col.data[vec_index] = b; + }); + } + let domain = CanonicCoset::new(log_size as u32).circle_domain(); + trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +impl Component for WideFibComponent { + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + + fn trace_log_degree_bounds(&self) -> Vec { + vec![self.log_size; N_COLS] + } + + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, AVX512Backend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + ) { + assert_eq!(trace.columns.len(), N_COLS); + // TODO(spapini): Steal evaluation from commitment. + let eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain(); + let trace_eval = trace + .columns + .iter() + .map(|poly| poly.evaluate(eval_domain)) + .collect_vec(); + let random_coeff = PackedQM31::broadcast(evaluation_accumulator.random_coeff); + let column_coeffs = (0..N_COLS) + .scan(PackedQM31::one(), |state, _| { + let res = *state; + *state *= random_coeff; + Some(res) + }) + .collect_vec(); + + let constraint_log_degree_bound = self.log_size + 1; + let [accum] = evaluation_accumulator.columns([(constraint_log_degree_bound, N_COLS - 2)]); + + for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) { + // Numerator. + let mut row_res = PackedQM31::zero(); + let mut a = trace_eval[0].data[vec_row]; + let mut b = trace_eval[1].data[vec_row]; + #[allow(clippy::needless_range_loop)] + for i in 0..(N_COLS - 2) { + unsafe { + let c = *trace_eval.get_unchecked(i + 2).data.get_unchecked(vec_row); + row_res = row_res + column_coeffs[i] * (a.square() + b.square() - c); + (a, b) = (b, c); + } + } + + // Denominator. + // TODO(spapini): Optimized this, for the small number of columns case. + let points = std::array::from_fn(|i| { + eval_domain.at(bit_reverse_index( + (vec_row << VECS_LOG_SIZE) + i, + eval_domain.log_size(), + ) + 1) + }); + let mut shifted_xs = PackedBaseField::from_array(points.map(|p| p.x)); + for _ in 1..self.log_size { + shifted_xs = shifted_xs.square() - PackedBaseField::one(); + } + + unsafe { + accum.col.set_packed( + vec_row, + accum.col.get_packed(vec_row) * PackedQM31::broadcast(accum.random_coeff_pow) + + row_res, + ) + } + } + } + + fn mask(&self) -> Mask { + Mask(vec![vec![0]; N_COLS]) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &ColumnVec>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + ) { + let constraint_zero_domain = Coset::subgroup(self.log_size); + let constraint_log_degree_bound = self.log_size + 1; + for i in 0..(N_COLS - 2) { + let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0]; + let denominator = coset_vanishing(constraint_zero_domain, point); + evaluation_accumulator.accumulate(constraint_log_degree_bound, numerator / denominator); + } + } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +#[cfg(test)] +mod tests { + use crate::commitment_scheme::blake2_hash::Blake2sHasher; + use crate::commitment_scheme::hasher::Hasher; + use crate::core::channel::{Blake2sChannel, Channel}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::IntoSlice; + use crate::core::prover::prove; + use crate::examples::wide_fib::{gen_trace, WideFibAir, WideFibComponent}; + + #[test] + fn test_avx_wide_fib_prove() { + // TODO(spapini): Increase to 20, to get 1GB of trace. + const LOG_SIZE: u32 = 12; + let component = WideFibComponent { log_size: LOG_SIZE }; + let air = WideFibAir { component }; + let trace = gen_trace(LOG_SIZE as usize); + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + // TODO(spapini): Fix the constraints. + prove(&air, channel, trace).unwrap_err(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4ee4907df..ac21515a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ )] pub mod commitment_scheme; pub mod core; -pub mod fibonacci; +pub mod examples; pub mod hash_functions; pub mod math; pub mod platform;