From 3030189137a892ce9e980be826649c70f4389069 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Sun, 14 Apr 2024 15:31:35 +0300 Subject: [PATCH] Update wide fibonacci to 256 columns --- src/examples/wide_fibonacci/avx.rs | 14 +- .../{structs.rs => component.rs} | 12 + .../wide_fibonacci/constraint_eval.rs | 429 ++---------------- src/examples/wide_fibonacci/mod.rs | 54 ++- src/examples/wide_fibonacci/trace_asserts.rs | 254 ----------- src/examples/wide_fibonacci/trace_gen.rs | 140 +----- 6 files changed, 111 insertions(+), 792 deletions(-) rename src/examples/wide_fibonacci/{structs.rs => component.rs} (50%) delete mode 100644 src/examples/wide_fibonacci/trace_asserts.rs diff --git a/src/examples/wide_fibonacci/avx.rs b/src/examples/wide_fibonacci/avx.rs index f4704cd19..db928c6d3 100644 --- a/src/examples/wide_fibonacci/avx.rs +++ b/src/examples/wide_fibonacci/avx.rs @@ -1,13 +1,13 @@ use itertools::Itertools; use num_traits::{One, Zero}; -use super::structs::WideFibComponent; +use super::component::{WideFibAir, WideFibComponent}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, Component, ComponentTrace}; use crate::core::backend::avx512::qm31::PackedSecureField; use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE}; -use crate::core::backend::{CPUBackend, Col, Column, ColumnOps}; +use crate::core::backend::{Col, Column, ColumnOps}; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; @@ -19,19 +19,11 @@ use crate::core::ColumnVec; const N_COLS: usize = 1 << 8; -pub struct WideFibAir { - component: WideFibComponent, -} impl Air for WideFibAir { fn components(&self) -> Vec<&dyn Component> { vec![&self.component] } } -impl Air for WideFibAir { - fn components(&self) -> Vec<&dyn Component> { - vec![&self.component] - } -} pub fn gen_trace( log_size: usize, @@ -162,7 +154,7 @@ mod tests { use crate::core::fields::IntoSlice; use crate::core::prover::{prove, verify}; use crate::examples::wide_fibonacci::avx::{gen_trace, WideFibAir}; - use crate::examples::wide_fibonacci::structs::WideFibComponent; + use crate::examples::wide_fibonacci::component::WideFibComponent; #[test] fn test_avx_wide_fib_prove() { diff --git a/src/examples/wide_fibonacci/structs.rs b/src/examples/wide_fibonacci/component.rs similarity index 50% rename from src/examples/wide_fibonacci/structs.rs rename to src/examples/wide_fibonacci/component.rs index 16cf0de14..def8063da 100644 --- a/src/examples/wide_fibonacci/structs.rs +++ b/src/examples/wide_fibonacci/component.rs @@ -1,3 +1,5 @@ +use crate::core::air::{Air, Component}; +use crate::core::backend::CPUBackend; use crate::core::fields::m31::BaseField; /// Component that computes fibonacci numbers over 64 columns. @@ -5,6 +7,16 @@ pub struct WideFibComponent { pub log_size: u32, } +pub struct WideFibAir { + pub component: WideFibComponent, +} + +impl Air for WideFibAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } +} + // Input for the fibonacci claim. #[derive(Debug, Clone, Copy)] pub struct Input { diff --git a/src/examples/wide_fibonacci/constraint_eval.rs b/src/examples/wide_fibonacci/constraint_eval.rs index 8c3bc48b3..e2d25a5e1 100644 --- a/src/examples/wide_fibonacci/constraint_eval.rs +++ b/src/examples/wide_fibonacci/constraint_eval.rs @@ -1,22 +1,22 @@ -use num_traits::{One, Zero}; +use num_traits::Zero; -use super::structs::WideFibComponent; +use super::component::WideFibComponent; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Component, ComponentTrace}; -use crate::core::backend::CPUBackend; +use crate::core::backend::{CPUBackend, Column}; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CanonicCoset; -use crate::core::utils::bit_reverse_index; +use crate::core::utils::bit_reverse; use crate::core::ColumnVec; impl Component for WideFibComponent { fn n_constraints(&self) -> usize { - 62 + 255 } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -34,400 +34,48 @@ impl Component for WideFibComponent { fixed_mask_points(&vec![vec![0_usize]; 256], point) } - // TODO(ShaharS), precompute random coeff powers. - // TODO(ShaharS), use intermidiate value to save the computation of the squares. fn evaluate_constraint_quotients_on_domain( &self, trace: &ComponentTrace<'_, CPUBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, ) { - let constraint_log_degree = Component::::max_constraint_log_degree_bound(self); - let n_constraints = Component::::n_constraints(self); + let max_constraint_degree = Component::::max_constraint_log_degree_bound(self); + let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); let mut trace_evals = vec![]; - // TODO(ShaharS), Share this LDE with the commitment LDE. - for poly_index in 0..64 { + for poly_index in 0..256 { let poly = &trace.columns[poly_index]; - let trace_eval_domain = CanonicCoset::new(constraint_log_degree).circle_domain(); - trace_evals.push(poly.evaluate(trace_eval_domain).bit_reverse()); + trace_evals.push(poly.evaluate(trace_eval_domain)); } let zero_domain = CanonicCoset::new(self.log_size).coset; - let eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain(); let mut denoms = vec![]; - for point in eval_domain.iter() { + for point in trace_eval_domain.iter() { denoms.push(coset_vanishing(zero_domain, point)); } - let mut denom_inverses = vec![BaseField::zero(); 1 << (constraint_log_degree)]; + bit_reverse(&mut denoms); + let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; BaseField::batch_inverse(&denoms, &mut denom_inverses); - let mut numerators = vec![SecureField::zero(); 1 << constraint_log_degree]; - let [mut accum] = evaluation_accumulator.columns([(constraint_log_degree, n_constraints)]); - // TODO (ShaharS) Change to get the correct power of random coeff inside the loop. - let random_coeff = accum.random_coeff_powers[1]; - for (i, point_index) in eval_domain.iter_indices().enumerate() { - numerators[i] = numerators[i] * random_coeff - + (trace_evals[2].get_at(point_index) - - ((trace_evals[0].get_at(point_index) * trace_evals[0].get_at(point_index)) - + (trace_evals[1].get_at(point_index) - * trace_evals[1].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[3].get_at(point_index) - - ((trace_evals[1].get_at(point_index) * trace_evals[1].get_at(point_index)) - + (trace_evals[2].get_at(point_index) - * trace_evals[2].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[4].get_at(point_index) - - ((trace_evals[2].get_at(point_index) * trace_evals[2].get_at(point_index)) - + (trace_evals[3].get_at(point_index) - * trace_evals[3].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[5].get_at(point_index) - - ((trace_evals[3].get_at(point_index) * trace_evals[3].get_at(point_index)) - + (trace_evals[4].get_at(point_index) - * trace_evals[4].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[6].get_at(point_index) - - ((trace_evals[4].get_at(point_index) * trace_evals[4].get_at(point_index)) - + (trace_evals[5].get_at(point_index) - * trace_evals[5].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[7].get_at(point_index) - - ((trace_evals[5].get_at(point_index) * trace_evals[5].get_at(point_index)) - + (trace_evals[6].get_at(point_index) - * trace_evals[6].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[8].get_at(point_index) - - ((trace_evals[6].get_at(point_index) * trace_evals[6].get_at(point_index)) - + (trace_evals[7].get_at(point_index) - * trace_evals[7].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[9].get_at(point_index) - - ((trace_evals[7].get_at(point_index) * trace_evals[7].get_at(point_index)) - + (trace_evals[8].get_at(point_index) - * trace_evals[8].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[10].get_at(point_index) - - ((trace_evals[8].get_at(point_index) * trace_evals[8].get_at(point_index)) - + (trace_evals[9].get_at(point_index) - * trace_evals[9].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[11].get_at(point_index) - - ((trace_evals[9].get_at(point_index) * trace_evals[9].get_at(point_index)) - + (trace_evals[10].get_at(point_index) - * trace_evals[10].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[12].get_at(point_index) - - ((trace_evals[10].get_at(point_index) - * trace_evals[10].get_at(point_index)) - + (trace_evals[11].get_at(point_index) - * trace_evals[11].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[13].get_at(point_index) - - ((trace_evals[11].get_at(point_index) - * trace_evals[11].get_at(point_index)) - + (trace_evals[12].get_at(point_index) - * trace_evals[12].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[14].get_at(point_index) - - ((trace_evals[12].get_at(point_index) - * trace_evals[12].get_at(point_index)) - + (trace_evals[13].get_at(point_index) - * trace_evals[13].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[15].get_at(point_index) - - ((trace_evals[13].get_at(point_index) - * trace_evals[13].get_at(point_index)) - + (trace_evals[14].get_at(point_index) - * trace_evals[14].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[16].get_at(point_index) - - ((trace_evals[14].get_at(point_index) - * trace_evals[14].get_at(point_index)) - + (trace_evals[15].get_at(point_index) - * trace_evals[15].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[17].get_at(point_index) - - ((trace_evals[15].get_at(point_index) - * trace_evals[15].get_at(point_index)) - + (trace_evals[16].get_at(point_index) - * trace_evals[16].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[18].get_at(point_index) - - ((trace_evals[16].get_at(point_index) - * trace_evals[16].get_at(point_index)) - + (trace_evals[17].get_at(point_index) - * trace_evals[17].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[19].get_at(point_index) - - ((trace_evals[17].get_at(point_index) - * trace_evals[17].get_at(point_index)) - + (trace_evals[18].get_at(point_index) - * trace_evals[18].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[20].get_at(point_index) - - ((trace_evals[18].get_at(point_index) - * trace_evals[18].get_at(point_index)) - + (trace_evals[19].get_at(point_index) - * trace_evals[19].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[21].get_at(point_index) - - ((trace_evals[19].get_at(point_index) - * trace_evals[19].get_at(point_index)) - + (trace_evals[20].get_at(point_index) - * trace_evals[20].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[22].get_at(point_index) - - ((trace_evals[20].get_at(point_index) - * trace_evals[20].get_at(point_index)) - + (trace_evals[21].get_at(point_index) - * trace_evals[21].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[23].get_at(point_index) - - ((trace_evals[21].get_at(point_index) - * trace_evals[21].get_at(point_index)) - + (trace_evals[22].get_at(point_index) - * trace_evals[22].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[24].get_at(point_index) - - ((trace_evals[22].get_at(point_index) - * trace_evals[22].get_at(point_index)) - + (trace_evals[23].get_at(point_index) - * trace_evals[23].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[25].get_at(point_index) - - ((trace_evals[23].get_at(point_index) - * trace_evals[23].get_at(point_index)) - + (trace_evals[24].get_at(point_index) - * trace_evals[24].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[26].get_at(point_index) - - ((trace_evals[24].get_at(point_index) - * trace_evals[24].get_at(point_index)) - + (trace_evals[25].get_at(point_index) - * trace_evals[25].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[27].get_at(point_index) - - ((trace_evals[25].get_at(point_index) - * trace_evals[25].get_at(point_index)) - + (trace_evals[26].get_at(point_index) - * trace_evals[26].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[28].get_at(point_index) - - ((trace_evals[26].get_at(point_index) - * trace_evals[26].get_at(point_index)) - + (trace_evals[27].get_at(point_index) - * trace_evals[27].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[29].get_at(point_index) - - ((trace_evals[27].get_at(point_index) - * trace_evals[27].get_at(point_index)) - + (trace_evals[28].get_at(point_index) - * trace_evals[28].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[30].get_at(point_index) - - ((trace_evals[28].get_at(point_index) - * trace_evals[28].get_at(point_index)) - + (trace_evals[29].get_at(point_index) - * trace_evals[29].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[31].get_at(point_index) - - ((trace_evals[29].get_at(point_index) - * trace_evals[29].get_at(point_index)) - + (trace_evals[30].get_at(point_index) - * trace_evals[30].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[32].get_at(point_index) - - ((trace_evals[30].get_at(point_index) - * trace_evals[30].get_at(point_index)) - + (trace_evals[31].get_at(point_index) - * trace_evals[31].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[33].get_at(point_index) - - ((trace_evals[31].get_at(point_index) - * trace_evals[31].get_at(point_index)) - + (trace_evals[32].get_at(point_index) - * trace_evals[32].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[34].get_at(point_index) - - ((trace_evals[32].get_at(point_index) - * trace_evals[32].get_at(point_index)) - + (trace_evals[33].get_at(point_index) - * trace_evals[33].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[35].get_at(point_index) - - ((trace_evals[33].get_at(point_index) - * trace_evals[33].get_at(point_index)) - + (trace_evals[34].get_at(point_index) - * trace_evals[34].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[36].get_at(point_index) - - ((trace_evals[34].get_at(point_index) - * trace_evals[34].get_at(point_index)) - + (trace_evals[35].get_at(point_index) - * trace_evals[35].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[37].get_at(point_index) - - ((trace_evals[35].get_at(point_index) - * trace_evals[35].get_at(point_index)) - + (trace_evals[36].get_at(point_index) - * trace_evals[36].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[38].get_at(point_index) - - ((trace_evals[36].get_at(point_index) - * trace_evals[36].get_at(point_index)) - + (trace_evals[37].get_at(point_index) - * trace_evals[37].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[39].get_at(point_index) - - ((trace_evals[37].get_at(point_index) - * trace_evals[37].get_at(point_index)) - + (trace_evals[38].get_at(point_index) - * trace_evals[38].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[40].get_at(point_index) - - ((trace_evals[38].get_at(point_index) - * trace_evals[38].get_at(point_index)) - + (trace_evals[39].get_at(point_index) - * trace_evals[39].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[41].get_at(point_index) - - ((trace_evals[39].get_at(point_index) - * trace_evals[39].get_at(point_index)) - + (trace_evals[40].get_at(point_index) - * trace_evals[40].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[42].get_at(point_index) - - ((trace_evals[40].get_at(point_index) - * trace_evals[40].get_at(point_index)) - + (trace_evals[41].get_at(point_index) - * trace_evals[41].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[43].get_at(point_index) - - ((trace_evals[41].get_at(point_index) - * trace_evals[41].get_at(point_index)) - + (trace_evals[42].get_at(point_index) - * trace_evals[42].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[44].get_at(point_index) - - ((trace_evals[42].get_at(point_index) - * trace_evals[42].get_at(point_index)) - + (trace_evals[43].get_at(point_index) - * trace_evals[43].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[45].get_at(point_index) - - ((trace_evals[43].get_at(point_index) - * trace_evals[43].get_at(point_index)) - + (trace_evals[44].get_at(point_index) - * trace_evals[44].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[46].get_at(point_index) - - ((trace_evals[44].get_at(point_index) - * trace_evals[44].get_at(point_index)) - + (trace_evals[45].get_at(point_index) - * trace_evals[45].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[47].get_at(point_index) - - ((trace_evals[45].get_at(point_index) - * trace_evals[45].get_at(point_index)) - + (trace_evals[46].get_at(point_index) - * trace_evals[46].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[48].get_at(point_index) - - ((trace_evals[46].get_at(point_index) - * trace_evals[46].get_at(point_index)) - + (trace_evals[47].get_at(point_index) - * trace_evals[47].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[49].get_at(point_index) - - ((trace_evals[47].get_at(point_index) - * trace_evals[47].get_at(point_index)) - + (trace_evals[48].get_at(point_index) - * trace_evals[48].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[50].get_at(point_index) - - ((trace_evals[48].get_at(point_index) - * trace_evals[48].get_at(point_index)) - + (trace_evals[49].get_at(point_index) - * trace_evals[49].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[51].get_at(point_index) - - ((trace_evals[49].get_at(point_index) - * trace_evals[49].get_at(point_index)) - + (trace_evals[50].get_at(point_index) - * trace_evals[50].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[52].get_at(point_index) - - ((trace_evals[50].get_at(point_index) - * trace_evals[50].get_at(point_index)) - + (trace_evals[51].get_at(point_index) - * trace_evals[51].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[53].get_at(point_index) - - ((trace_evals[51].get_at(point_index) - * trace_evals[51].get_at(point_index)) - + (trace_evals[52].get_at(point_index) - * trace_evals[52].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[54].get_at(point_index) - - ((trace_evals[52].get_at(point_index) - * trace_evals[52].get_at(point_index)) - + (trace_evals[53].get_at(point_index) - * trace_evals[53].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[55].get_at(point_index) - - ((trace_evals[53].get_at(point_index) - * trace_evals[53].get_at(point_index)) - + (trace_evals[54].get_at(point_index) - * trace_evals[54].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[56].get_at(point_index) - - ((trace_evals[54].get_at(point_index) - * trace_evals[54].get_at(point_index)) - + (trace_evals[55].get_at(point_index) - * trace_evals[55].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[57].get_at(point_index) - - ((trace_evals[55].get_at(point_index) - * trace_evals[55].get_at(point_index)) - + (trace_evals[56].get_at(point_index) - * trace_evals[56].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[58].get_at(point_index) - - ((trace_evals[56].get_at(point_index) - * trace_evals[56].get_at(point_index)) - + (trace_evals[57].get_at(point_index) - * trace_evals[57].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[59].get_at(point_index) - - ((trace_evals[57].get_at(point_index) - * trace_evals[57].get_at(point_index)) - + (trace_evals[58].get_at(point_index) - * trace_evals[58].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[60].get_at(point_index) - - ((trace_evals[58].get_at(point_index) - * trace_evals[58].get_at(point_index)) - + (trace_evals[59].get_at(point_index) - * trace_evals[59].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[61].get_at(point_index) - - ((trace_evals[59].get_at(point_index) - * trace_evals[59].get_at(point_index)) - + (trace_evals[60].get_at(point_index) - * trace_evals[60].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[62].get_at(point_index) - - ((trace_evals[60].get_at(point_index) - * trace_evals[60].get_at(point_index)) - + (trace_evals[61].get_at(point_index) - * trace_evals[61].get_at(point_index)))); - numerators[i] = numerators[i] * random_coeff - + (trace_evals[63].get_at(point_index) - - ((trace_evals[61].get_at(point_index) - * trace_evals[61].get_at(point_index)) - + (trace_evals[62].get_at(point_index) - * trace_evals[62].get_at(point_index)))); + let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)]; + let [mut accum] = evaluation_accumulator.columns([( + max_constraint_degree, + Component::::n_constraints(self), + )]); + + #[allow(clippy::needless_range_loop)] + for i in 0..trace_eval_domain.size() { + // Boundary constraint. + numerators[i] += accum.random_coeff_powers[254] + * (trace_evals[0].values.at(i) - BaseField::from_u32_unchecked(1)); + + // Step constraints. + for j in 0..254 { + numerators[i] += accum.random_coeff_powers[253 - j] + * (trace_evals[j].values.at(i).square() + + trace_evals[j + 1].values.at(i).square() + - trace_evals[j + 2].values.at(i)); + } } for (i, (num, denom)) in numerators.iter().zip(denom_inverses.iter()).enumerate() { - accum.accumulate(bit_reverse_index(i, constraint_log_degree), *num * *denom); + accum.accumulate(i, *num * *denom); } } @@ -437,12 +85,15 @@ impl Component for WideFibComponent { mask: &ColumnVec>, evaluation_accumulator: &mut PointEvaluationAccumulator, ) { - let zero_domain = CanonicCoset::new(self.log_size).coset; - let denominator = coset_vanishing(zero_domain, point); - evaluation_accumulator.accumulate((mask[0][0] - SecureField::one()) / denominator); - for i in 0..(256 - 2) { + let constraint_zero_domain = CanonicCoset::new(self.log_size).coset; + let denom = coset_vanishing(constraint_zero_domain, point); + let denom_inverse = denom.inverse(); + let numerator = mask[0][0] - BaseField::from_u32_unchecked(1); + evaluation_accumulator.accumulate(numerator * denom_inverse); + + for i in 0..254 { let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0]; - evaluation_accumulator.accumulate(numerator / denominator); + evaluation_accumulator.accumulate(numerator * denom_inverse); } } } diff --git a/src/examples/wide_fibonacci/mod.rs b/src/examples/wide_fibonacci/mod.rs index aed1acde9..8a7ec8590 100644 --- a/src/examples/wide_fibonacci/mod.rs +++ b/src/examples/wide_fibonacci/mod.rs @@ -1,8 +1,7 @@ #[cfg(target_arch = "x86_64")] pub mod avx; +pub mod component; pub mod constraint_eval; -pub mod structs; -pub mod trace_asserts; pub mod trace_gen; #[cfg(test)] @@ -10,26 +9,40 @@ mod tests { use itertools::Itertools; use num_traits::{One, Zero}; - use super::structs::{Input, WideFibComponent}; - use super::trace_asserts::assert_constraints_on_row; + use super::component::{Input, WideFibAir, WideFibComponent}; use super::trace_gen::write_trace_row; + use crate::commitment_scheme::blake2_hash::Blake2sHasher; + use crate::commitment_scheme::hasher::Hasher; use crate::core::air::accumulation::DomainEvaluationAccumulator; use crate::core::air::{Component, ComponentTrace}; use crate::core::backend::cpu::CPUCircleEvaluation; use crate::core::backend::CPUBackend; + use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::QM31; + use crate::core::fields::IntoSlice; use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::BitReversedOrder; + use crate::core::prover::{prove, verify}; fn fill_trace(private_input: &[Input]) -> Vec> { let zero_vec = vec![BaseField::zero(); private_input.len()]; - let mut dst = vec![zero_vec; 64]; + let mut dst = vec![zero_vec; 256]; for (offset, input) in private_input.iter().enumerate() { write_trace_row(&mut dst, input, offset); } dst } + pub fn assert_constraints_on_row(row: &[BaseField]) { + for i in 2..row.len() { + assert_eq!( + (row[i] - (row[i - 1] * row[i - 1] + row[i - 2] * row[i - 2])), + BaseField::zero() + ); + } + } + #[test] fn test_wide_fib_trace() { let input = Input { @@ -43,7 +56,7 @@ mod tests { } #[test] - fn test_wide_fib_constraints() { + fn test_composition_is_low_degree() { let wide_fib = WideFibComponent { log_size: 7 }; let mut acc = DomainEvaluationAccumulator::new( QM31::from_u32_unchecked(1, 2, 3, 4), @@ -84,4 +97,33 @@ mod tests { assert_eq!(*coeff, BaseField::zero()); } } + + #[test] + fn test_prove() { + let wide_fib = WideFibComponent { log_size: 7 }; + let wide_fib_air = WideFibAir { + component: wide_fib, + }; + let inputs = (0..1 << wide_fib_air.component.log_size) + .map(|i| Input { + a: BaseField::one(), + b: BaseField::from_u32_unchecked(i as u32), + }) + .collect_vec(); + let trace = fill_trace(&inputs); + + let trace_domain = CanonicCoset::new(wide_fib_air.component.log_size).circle_domain(); + let trace = trace + .into_iter() + .map(|eval| CPUCircleEvaluation::::new(trace_domain, eval)) + .collect_vec(); + + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let proof = prove(&wide_fib_air, prover_channel, trace).unwrap(); + + let verifier_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + verify(proof, &wide_fib_air, verifier_channel).unwrap(); + } } diff --git a/src/examples/wide_fibonacci/trace_asserts.rs b/src/examples/wide_fibonacci/trace_asserts.rs deleted file mode 100644 index 4916784cd..000000000 --- a/src/examples/wide_fibonacci/trace_asserts.rs +++ /dev/null @@ -1,254 +0,0 @@ -use num_traits::Zero; - -use crate::core::fields::m31::BaseField; - -pub fn assert_constraints_on_row(row: &[BaseField]) { - assert_eq!( - (row[2] - ((row[0] * row[0]) + (row[1] * row[1]))), - BaseField::zero() - ); - assert_eq!( - (row[3] - ((row[1] * row[1]) + (row[2] * row[2]))), - BaseField::zero() - ); - assert_eq!( - (row[4] - ((row[2] * row[2]) + (row[3] * row[3]))), - BaseField::zero() - ); - assert_eq!( - (row[5] - ((row[3] * row[3]) + (row[4] * row[4]))), - BaseField::zero() - ); - assert_eq!( - (row[6] - ((row[4] * row[4]) + (row[5] * row[5]))), - BaseField::zero() - ); - assert_eq!( - (row[7] - ((row[5] * row[5]) + (row[6] * row[6]))), - BaseField::zero() - ); - assert_eq!( - (row[8] - ((row[6] * row[6]) + (row[7] * row[7]))), - BaseField::zero() - ); - assert_eq!( - (row[9] - ((row[7] * row[7]) + (row[8] * row[8]))), - BaseField::zero() - ); - assert_eq!( - (row[10] - ((row[8] * row[8]) + (row[9] * row[9]))), - BaseField::zero() - ); - assert_eq!( - (row[11] - ((row[9] * row[9]) + (row[10] * row[10]))), - BaseField::zero() - ); - assert_eq!( - (row[12] - ((row[10] * row[10]) + (row[11] * row[11]))), - BaseField::zero() - ); - assert_eq!( - (row[13] - ((row[11] * row[11]) + (row[12] * row[12]))), - BaseField::zero() - ); - assert_eq!( - (row[14] - ((row[12] * row[12]) + (row[13] * row[13]))), - BaseField::zero() - ); - assert_eq!( - (row[15] - ((row[13] * row[13]) + (row[14] * row[14]))), - BaseField::zero() - ); - assert_eq!( - (row[16] - ((row[14] * row[14]) + (row[15] * row[15]))), - BaseField::zero() - ); - assert_eq!( - (row[17] - ((row[15] * row[15]) + (row[16] * row[16]))), - BaseField::zero() - ); - assert_eq!( - (row[18] - ((row[16] * row[16]) + (row[17] * row[17]))), - BaseField::zero() - ); - assert_eq!( - (row[19] - ((row[17] * row[17]) + (row[18] * row[18]))), - BaseField::zero() - ); - assert_eq!( - (row[20] - ((row[18] * row[18]) + (row[19] * row[19]))), - BaseField::zero() - ); - assert_eq!( - (row[21] - ((row[19] * row[19]) + (row[20] * row[20]))), - BaseField::zero() - ); - assert_eq!( - (row[22] - ((row[20] * row[20]) + (row[21] * row[21]))), - BaseField::zero() - ); - assert_eq!( - (row[23] - ((row[21] * row[21]) + (row[22] * row[22]))), - BaseField::zero() - ); - assert_eq!( - (row[24] - ((row[22] * row[22]) + (row[23] * row[23]))), - BaseField::zero() - ); - assert_eq!( - (row[25] - ((row[23] * row[23]) + (row[24] * row[24]))), - BaseField::zero() - ); - assert_eq!( - (row[26] - ((row[24] * row[24]) + (row[25] * row[25]))), - BaseField::zero() - ); - assert_eq!( - (row[27] - ((row[25] * row[25]) + (row[26] * row[26]))), - BaseField::zero() - ); - assert_eq!( - (row[28] - ((row[26] * row[26]) + (row[27] * row[27]))), - BaseField::zero() - ); - assert_eq!( - (row[29] - ((row[27] * row[27]) + (row[28] * row[28]))), - BaseField::zero() - ); - assert_eq!( - (row[30] - ((row[28] * row[28]) + (row[29] * row[29]))), - BaseField::zero() - ); - assert_eq!( - (row[31] - ((row[29] * row[29]) + (row[30] * row[30]))), - BaseField::zero() - ); - assert_eq!( - (row[32] - ((row[30] * row[30]) + (row[31] * row[31]))), - BaseField::zero() - ); - assert_eq!( - (row[33] - ((row[31] * row[31]) + (row[32] * row[32]))), - BaseField::zero() - ); - assert_eq!( - (row[34] - ((row[32] * row[32]) + (row[33] * row[33]))), - BaseField::zero() - ); - assert_eq!( - (row[35] - ((row[33] * row[33]) + (row[34] * row[34]))), - BaseField::zero() - ); - assert_eq!( - (row[36] - ((row[34] * row[34]) + (row[35] * row[35]))), - BaseField::zero() - ); - assert_eq!( - (row[37] - ((row[35] * row[35]) + (row[36] * row[36]))), - BaseField::zero() - ); - assert_eq!( - (row[38] - ((row[36] * row[36]) + (row[37] * row[37]))), - BaseField::zero() - ); - assert_eq!( - (row[39] - ((row[37] * row[37]) + (row[38] * row[38]))), - BaseField::zero() - ); - assert_eq!( - (row[40] - ((row[38] * row[38]) + (row[39] * row[39]))), - BaseField::zero() - ); - assert_eq!( - (row[41] - ((row[39] * row[39]) + (row[40] * row[40]))), - BaseField::zero() - ); - assert_eq!( - (row[42] - ((row[40] * row[40]) + (row[41] * row[41]))), - BaseField::zero() - ); - assert_eq!( - (row[43] - ((row[41] * row[41]) + (row[42] * row[42]))), - BaseField::zero() - ); - assert_eq!( - (row[44] - ((row[42] * row[42]) + (row[43] * row[43]))), - BaseField::zero() - ); - assert_eq!( - (row[45] - ((row[43] * row[43]) + (row[44] * row[44]))), - BaseField::zero() - ); - assert_eq!( - (row[46] - ((row[44] * row[44]) + (row[45] * row[45]))), - BaseField::zero() - ); - assert_eq!( - (row[47] - ((row[45] * row[45]) + (row[46] * row[46]))), - BaseField::zero() - ); - assert_eq!( - (row[48] - ((row[46] * row[46]) + (row[47] * row[47]))), - BaseField::zero() - ); - assert_eq!( - (row[49] - ((row[47] * row[47]) + (row[48] * row[48]))), - BaseField::zero() - ); - assert_eq!( - (row[50] - ((row[48] * row[48]) + (row[49] * row[49]))), - BaseField::zero() - ); - assert_eq!( - (row[51] - ((row[49] * row[49]) + (row[50] * row[50]))), - BaseField::zero() - ); - assert_eq!( - (row[52] - ((row[50] * row[50]) + (row[51] * row[51]))), - BaseField::zero() - ); - assert_eq!( - (row[53] - ((row[51] * row[51]) + (row[52] * row[52]))), - BaseField::zero() - ); - assert_eq!( - (row[54] - ((row[52] * row[52]) + (row[53] * row[53]))), - BaseField::zero() - ); - assert_eq!( - (row[55] - ((row[53] * row[53]) + (row[54] * row[54]))), - BaseField::zero() - ); - assert_eq!( - (row[56] - ((row[54] * row[54]) + (row[55] * row[55]))), - BaseField::zero() - ); - assert_eq!( - (row[57] - ((row[55] * row[55]) + (row[56] * row[56]))), - BaseField::zero() - ); - assert_eq!( - (row[58] - ((row[56] * row[56]) + (row[57] * row[57]))), - BaseField::zero() - ); - assert_eq!( - (row[59] - ((row[57] * row[57]) + (row[58] * row[58]))), - BaseField::zero() - ); - assert_eq!( - (row[60] - ((row[58] * row[58]) + (row[59] * row[59]))), - BaseField::zero() - ); - assert_eq!( - (row[61] - ((row[59] * row[59]) + (row[60] * row[60]))), - BaseField::zero() - ); - assert_eq!( - (row[62] - ((row[60] * row[60]) + (row[61] * row[61]))), - BaseField::zero() - ); - assert_eq!( - (row[63] - ((row[61] * row[61]) + (row[62] * row[62]))), - BaseField::zero() - ); -} diff --git a/src/examples/wide_fibonacci/trace_gen.rs b/src/examples/wide_fibonacci/trace_gen.rs index af67a35fd..2bb48b4e5 100644 --- a/src/examples/wide_fibonacci/trace_gen.rs +++ b/src/examples/wide_fibonacci/trace_gen.rs @@ -1,138 +1,14 @@ -use super::structs::Input; +use super::component::Input; use crate::core::fields::m31::BaseField; +use crate::core::fields::FieldExpOps; // TODO(ShaharS), try to make it into a for loop and use intermiddiate variables to save // computation. /// Given a private input, write the trace row for the wide Fibonacci example to dst. -pub fn write_trace_row(dst: &mut [Vec], private_input: &Input, row_offset: usize) { - let a = private_input.a; - let b = private_input.b; - let col0 = a; - dst[0][row_offset] = col0; - let col1 = b; - dst[1][row_offset] = col1; - let col2 = col0 * col0 + col1 * col1; - dst[2][row_offset] = col2; - let col3 = col1 * col1 + col2 * col2; - dst[3][row_offset] = col3; - let col4 = col2 * col2 + col3 * col3; - dst[4][row_offset] = col4; - let col5 = col3 * col3 + col4 * col4; - dst[5][row_offset] = col5; - let col6 = col4 * col4 + col5 * col5; - dst[6][row_offset] = col6; - let col7 = col5 * col5 + col6 * col6; - dst[7][row_offset] = col7; - let col8 = col6 * col6 + col7 * col7; - dst[8][row_offset] = col8; - let col9 = col7 * col7 + col8 * col8; - dst[9][row_offset] = col9; - let col10 = col8 * col8 + col9 * col9; - dst[10][row_offset] = col10; - let col11 = col9 * col9 + col10 * col10; - dst[11][row_offset] = col11; - let col12 = col10 * col10 + col11 * col11; - dst[12][row_offset] = col12; - let col13 = col11 * col11 + col12 * col12; - dst[13][row_offset] = col13; - let col14 = col12 * col12 + col13 * col13; - dst[14][row_offset] = col14; - let col15 = col13 * col13 + col14 * col14; - dst[15][row_offset] = col15; - let col16 = col14 * col14 + col15 * col15; - dst[16][row_offset] = col16; - let col17 = col15 * col15 + col16 * col16; - dst[17][row_offset] = col17; - let col18 = col16 * col16 + col17 * col17; - dst[18][row_offset] = col18; - let col19 = col17 * col17 + col18 * col18; - dst[19][row_offset] = col19; - let col20 = col18 * col18 + col19 * col19; - dst[20][row_offset] = col20; - let col21 = col19 * col19 + col20 * col20; - dst[21][row_offset] = col21; - let col22 = col20 * col20 + col21 * col21; - dst[22][row_offset] = col22; - let col23 = col21 * col21 + col22 * col22; - dst[23][row_offset] = col23; - let col24 = col22 * col22 + col23 * col23; - dst[24][row_offset] = col24; - let col25 = col23 * col23 + col24 * col24; - dst[25][row_offset] = col25; - let col26 = col24 * col24 + col25 * col25; - dst[26][row_offset] = col26; - let col27 = col25 * col25 + col26 * col26; - dst[27][row_offset] = col27; - let col28 = col26 * col26 + col27 * col27; - dst[28][row_offset] = col28; - let col29 = col27 * col27 + col28 * col28; - dst[29][row_offset] = col29; - let col30 = col28 * col28 + col29 * col29; - dst[30][row_offset] = col30; - let col31 = col29 * col29 + col30 * col30; - dst[31][row_offset] = col31; - let col32 = col30 * col30 + col31 * col31; - dst[32][row_offset] = col32; - let col33 = col31 * col31 + col32 * col32; - dst[33][row_offset] = col33; - let col34 = col32 * col32 + col33 * col33; - dst[34][row_offset] = col34; - let col35 = col33 * col33 + col34 * col34; - dst[35][row_offset] = col35; - let col36 = col34 * col34 + col35 * col35; - dst[36][row_offset] = col36; - let col37 = col35 * col35 + col36 * col36; - dst[37][row_offset] = col37; - let col38 = col36 * col36 + col37 * col37; - dst[38][row_offset] = col38; - let col39 = col37 * col37 + col38 * col38; - dst[39][row_offset] = col39; - let col40 = col38 * col38 + col39 * col39; - dst[40][row_offset] = col40; - let col41 = col39 * col39 + col40 * col40; - dst[41][row_offset] = col41; - let col42 = col40 * col40 + col41 * col41; - dst[42][row_offset] = col42; - let col43 = col41 * col41 + col42 * col42; - dst[43][row_offset] = col43; - let col44 = col42 * col42 + col43 * col43; - dst[44][row_offset] = col44; - let col45 = col43 * col43 + col44 * col44; - dst[45][row_offset] = col45; - let col46 = col44 * col44 + col45 * col45; - dst[46][row_offset] = col46; - let col47 = col45 * col45 + col46 * col46; - dst[47][row_offset] = col47; - let col48 = col46 * col46 + col47 * col47; - dst[48][row_offset] = col48; - let col49 = col47 * col47 + col48 * col48; - dst[49][row_offset] = col49; - let col50 = col48 * col48 + col49 * col49; - dst[50][row_offset] = col50; - let col51 = col49 * col49 + col50 * col50; - dst[51][row_offset] = col51; - let col52 = col50 * col50 + col51 * col51; - dst[52][row_offset] = col52; - let col53 = col51 * col51 + col52 * col52; - dst[53][row_offset] = col53; - let col54 = col52 * col52 + col53 * col53; - dst[54][row_offset] = col54; - let col55 = col53 * col53 + col54 * col54; - dst[55][row_offset] = col55; - let col56 = col54 * col54 + col55 * col55; - dst[56][row_offset] = col56; - let col57 = col55 * col55 + col56 * col56; - dst[57][row_offset] = col57; - let col58 = col56 * col56 + col57 * col57; - dst[58][row_offset] = col58; - let col59 = col57 * col57 + col58 * col58; - dst[59][row_offset] = col59; - let col60 = col58 * col58 + col59 * col59; - dst[60][row_offset] = col60; - let col61 = col59 * col59 + col60 * col60; - dst[61][row_offset] = col61; - let col62 = col60 * col60 + col61 * col61; - dst[62][row_offset] = col62; - let col63 = col61 * col61 + col62 * col62; - dst[63][row_offset] = col63; +pub fn write_trace_row(dst: &mut [Vec], private_input: &Input, row_index: usize) { + dst[0][row_index] = private_input.a; + dst[1][row_index] = private_input.b; + for i in 2..256 { + dst[i][row_index] = dst[i - 1][row_index].square() + dst[i - 2][row_index].square(); + } }