From 5f9adcec2e61a521715a0e53a5c7c34ae803eb99 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Thu, 4 Apr 2024 15:12:38 +0300 Subject: [PATCH] Remove Mask struct --- src/core/air/mask.rs | 91 +++++++++++++++++++ src/core/air/mod.rs | 34 +------ src/examples/fibonacci/component.rs | 10 +- src/examples/wide_fibonacci/avx.rs | 8 +- .../wide_fibonacci/constraint_eval.rs | 8 +- 5 files changed, 106 insertions(+), 45 deletions(-) create mode 100644 src/core/air/mask.rs diff --git a/src/core/air/mask.rs b/src/core/air/mask.rs new file mode 100644 index 000000000..e2748a653 --- /dev/null +++ b/src/core/air/mask.rs @@ -0,0 +1,91 @@ +use std::collections::HashSet; +use std::vec; + +use itertools::Itertools; + +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::CanonicCoset; +use crate::core::ColumnVec; + +/// Mask holds a vector with an entry for each column. +/// Each entry holds a list of mask items, which are the offsets of the mask at that column. +type Mask = ColumnVec>; + +/// Returns the same point for each mask item. +/// Should be used where all the mask items has no shift from the constraint point. +pub fn fixed_mask_points( + mask: &Mask, + point: CirclePoint, +) -> ColumnVec>> { + assert_eq!( + mask.iter() + .flat_map(|mask_entry| mask_entry.iter().collect::>()) + .collect::>() + .into_iter() + .collect_vec(), + vec![&0] + ); + mask.iter() + .map(|mask_entry| mask_entry.iter().map(|_| point).collect()) + .collect() +} + +/// For each mask item returns the point shifted by the domain initial point of the column. +/// Should be used where the mask items are shifted from the constraint point. +pub fn shifted_mask_points( + mask: &Mask, + domains: &[CanonicCoset], + point: CirclePoint, +) -> ColumnVec>> { + mask.iter() + .zip(domains.iter()) + .map(|(mask_entry, domain)| { + mask_entry + .iter() + .map(|mask_item| point + domain.at(*mask_item).into_ef()) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::core::air::mask::{fixed_mask_points, shifted_mask_points}; + use crate::core::circle::CirclePoint; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_mask_fixed_points() { + let mask = vec![vec![0], vec![0]]; + let constraint_point = CirclePoint::get_point(1234); + + let points = fixed_mask_points(&mask, constraint_point); + + assert_eq!(points.len(), 2); + assert_eq!(points[0].len(), 1); + assert_eq!(points[1].len(), 1); + assert_eq!(points[0][0], constraint_point); + assert_eq!(points[1][0], constraint_point); + } + + #[test] + fn test_mask_shifted_points() { + let mask = vec![vec![0, 1], vec![0, 1, 2]]; + let constraint_point = CirclePoint::get_point(1234); + let domains = (0..mask.len() as u32) + .map(|i| CanonicCoset::new(7 + i)) + .collect::>(); + + let points = shifted_mask_points(&mask, &domains, constraint_point); + + assert_eq!(points.len(), 2); + assert_eq!(points[0].len(), 2); + assert_eq!(points[1].len(), 3); + assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef()); + assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef()); + assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef()); + assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef()); + assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef()); + } +} diff --git a/src/core/air/mod.rs b/src/core/air/mod.rs index fa34fa2e9..9c5efa7ae 100644 --- a/src/core/air/mod.rs +++ b/src/core/air/mod.rs @@ -1,15 +1,15 @@ use std::iter::zip; -use std::ops::Deref; use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use super::backend::Backend; use super::circle::CirclePoint; use super::fields::qm31::SecureField; -use super::poly::circle::{CanonicCoset, CirclePoly}; +use super::poly::circle::CirclePoly; use super::ColumnVec; pub mod accumulation; mod air_ext; +pub mod mask; pub use air_ext::AirExt; @@ -23,36 +23,6 @@ pub trait Air { fn components(&self) -> Vec<&dyn Component>; } -/// Holds the mask offsets at each column. -/// Holds a vector with an entry for each column. Each entry holds the offsets -/// of the mask at that column. -pub struct Mask(pub ColumnVec>); - -impl Mask { - pub fn to_points( - &self, - domains: &[CanonicCoset], - point: CirclePoint, - ) -> ColumnVec>> { - self.iter() - .zip(domains.iter()) - .map(|(col, domain)| { - col.iter() - .map(|i| point + domain.at(*i).into_ef()) - .collect() - }) - .collect() - } -} - -impl Deref for Mask { - type Target = ColumnVec>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - /// A component is a set of trace columns of various sizes along with a set of /// constraints on them. pub trait Component { diff --git a/src/examples/fibonacci/component.rs b/src/examples/fibonacci/component.rs index f80c097b9..27ab352c2 100644 --- a/src/examples/fibonacci/component.rs +++ b/src/examples/fibonacci/component.rs @@ -3,7 +3,8 @@ use std::ops::Div; use num_traits::One; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentTrace, Mask}; +use crate::core::air::mask::shifted_mask_points; +use crate::core::air::{Component, ComponentTrace}; use crate::core::backend::CPUBackend; use crate::core::circle::{CirclePoint, Coset}; use crate::core::constraints::{coset_vanishing, pair_vanishing}; @@ -120,8 +121,11 @@ impl Component for FibonacciComponent { &self, point: CirclePoint, ) -> ColumnVec>> { - let fib_mask = Mask(vec![vec![0, 1, 2]]); - fib_mask.to_points(&[CanonicCoset::new(self.log_size)], point) + shifted_mask_points( + &vec![vec![0, 1, 2]], + &[CanonicCoset::new(self.log_size)], + point, + ) } fn evaluate_constraint_quotients_at_point( diff --git a/src/examples/wide_fibonacci/avx.rs b/src/examples/wide_fibonacci/avx.rs index 2b1d9c0f9..f4704cd19 100644 --- a/src/examples/wide_fibonacci/avx.rs +++ b/src/examples/wide_fibonacci/avx.rs @@ -3,7 +3,8 @@ use num_traits::{One, Zero}; use super::structs::WideFibComponent; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Air, Component, ComponentTrace, Mask}; +use crate::core::air::mask::fixed_mask_points; +use crate::core::air::{Air, Component, ComponentTrace}; use crate::core::backend::avx512::qm31::PackedSecureField; use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE}; use crate::core::backend::{CPUBackend, Col, Column, ColumnOps}; @@ -132,10 +133,7 @@ impl Component for WideFibComponent { &self, point: CirclePoint, ) -> ColumnVec>> { - let mask = Mask(vec![vec![0_usize]; 256]); - mask.iter() - .map(|col| col.iter().map(|_| point).collect()) - .collect() + fixed_mask_points(&vec![vec![0_usize]; 256], point) } fn evaluate_constraint_quotients_at_point( diff --git a/src/examples/wide_fibonacci/constraint_eval.rs b/src/examples/wide_fibonacci/constraint_eval.rs index 814ef1408..8c3bc48b3 100644 --- a/src/examples/wide_fibonacci/constraint_eval.rs +++ b/src/examples/wide_fibonacci/constraint_eval.rs @@ -2,7 +2,8 @@ use num_traits::{One, Zero}; use super::structs::WideFibComponent; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentTrace, Mask}; +use crate::core::air::mask::fixed_mask_points; +use crate::core::air::{Component, ComponentTrace}; use crate::core::backend::CPUBackend; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; @@ -30,10 +31,7 @@ impl Component for WideFibComponent { &self, point: CirclePoint, ) -> ColumnVec>> { - let mask = Mask(vec![vec![0_usize]; 256]); - mask.iter() - .map(|col| col.iter().map(|_| point).collect()) - .collect() + fixed_mask_points(&vec![vec![0_usize]; 256], point) } // TODO(ShaharS), precompute random coeff powers.