diff --git a/Cargo.lock b/Cargo.lock index a7e009352..c87fc6628 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,16 @@ dependencies = [ "serde", ] +[[package]] +name = "stwo-air-utils" +version = "0.1.1" +dependencies = [ + "bytemuck", + "itertools 0.12.1", + "rayon", + "stwo-prover", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 0f314a496..fadd620de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover"] +members = ["crates/prover", "crates/air_utils"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml new file mode 100644 index 000000000..7d09a7eaf --- /dev/null +++ b/crates/air_utils/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "stwo-air-utils" +version.workspace = true +edition.workspace = true + +[dependencies] +bytemuck.workspace = true +itertools.workspace = true +rayon = { version = "1.10.0", optional = false } +stwo-prover = { path = "../prover" } + +[lib] +bench = false diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs new file mode 100644 index 000000000..dd5257d2f --- /dev/null +++ b/crates/air_utils/src/lib.rs @@ -0,0 +1,2 @@ +#![feature(exact_size_is_empty, raw_slice_split, portable_simd, array_chunks)] +pub mod trace; diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs new file mode 100644 index 000000000..f14aace91 --- /dev/null +++ b/crates/air_utils/src/trace/component_trace.rs @@ -0,0 +1,279 @@ +use std::marker::PhantomData; + +use bytemuck::Zeroable; +use itertools::Itertools; +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; + +/// A 2D Matrix of [`PackedM31`] values. +/// Used for generating the witness of 'Stwo' proofs. +/// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. +/// Exposes an iterator over mutable references to the rows of the matrix. +/// +/// # Example: +/// +/// ```text +/// Computation trace of a^2 + (a + 1)^2 for a in 0..256 +/// ``` +/// ``` +/// use stwo_air_utils::trace::component_trace::ComponentTrace; +/// use itertools::Itertools; +/// use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +/// use stwo_prover::core::fields::m31::M31; +/// use stwo_prover::core::fields::FieldExpOps; +/// +/// const N_COLUMNS: usize = 3; +/// const LOG_SIZE: u32 = 8; +/// let mut trace = ComponentTrace::::zeroed(LOG_SIZE); +/// let example_input = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); // 0..256 +/// trace +/// .iter_mut() +/// .zip(example_input.chunks(N_LANES)) +/// .chunks(4) +/// .into_iter() +/// .for_each(|chunk| { +/// chunk.into_iter().for_each(|(row, input)| { +/// *row[0] = PackedM31::from_array(input.try_into().unwrap()); +/// *row[1] = *row[0] + PackedM31::broadcast(M31(1)); +/// *row[2] = row[0].square() + row[1].square(); +/// }) +/// }); +/// +/// let first_3_rows = (0..N_COLUMNS).map(|i| trace.row_at(i)).collect::>(); +/// assert_eq!(first_3_rows, [[0,1,1], [1,2,5], [2,3,13]].map(|row| row.map(M31::from))); +/// ``` +#[derive(Debug)] +pub struct ComponentTrace { + data: [Vec; N], + + /// Log number of non-packed rows in each column. + log_size: u32, +} + +impl ComponentTrace { + pub fn zeroed(log_size: u32) -> Self { + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| vec![PackedM31::zeroed(); n_simd_elems]); + Self { data, log_size } + } + + /// # Safety + /// The caller must ensure that the column is populated before being used. + #[allow(clippy::uninit_vec)] + pub unsafe fn uninitialized(_log_size: u32) -> Self { + todo!() + } + + pub fn log_size(&self) -> u32 { + self.log_size + } + + pub fn iter_mut(&mut self) -> RowIterMut<'_, N> { + let v = self + .data + .iter_mut() + .map(|column| column.as_mut_slice()) + .collect_vec() + .try_into() + .unwrap(); + RowIterMut::new(v) + } + + pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> { + let v = self + .data + .iter_mut() + .map(|column| column.as_mut_slice()) + .collect_vec() + .try_into() + .unwrap(); + ParRowIterMut::new(v) + } + + pub fn to_evals(self) -> [CircleEvaluation; N] { + todo!() + } + + pub fn row_at(&self, row: usize) -> [M31; N] { + assert!(row < 1 << self.log_size); + let packed_row = row / N_LANES; + let idx_in_simd_vector = row % N_LANES; + self.data + .iter() + .map(|column| column[packed_row].to_array()[idx_in_simd_vector]) + .collect_vec() + .try_into() + .unwrap() + } +} + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`IterableTrace`]. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} +impl DoubleEndedIterator for RowIterMut<'_, N> { + fn next_back(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1); + self.v[i] = head; + &mut (*tail)[0] + }); + Some(item) + } +} + +struct RowProducer<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> Producer for RowProducer<'trace, N> { + type Item = MutRow<'trace, N>; + + fn split_at(self, index: usize) -> (Self, Self) { + let mut left: [_; N] = unsafe { std::mem::zeroed() }; + let mut right: [_; N] = unsafe { std::mem::zeroed() }; + for (i, slice) in self.data.into_iter().enumerate() { + let (lhs, rhs) = slice.split_at_mut(index); + left[i] = lhs; + right[i] = rhs; + } + (RowProducer { data: left }, RowProducer { data: right }) + } + + type IntoIter = RowIterMut<'trace, N>; + + fn into_iter(self) -> Self::IntoIter { + RowIterMut { + v: self.data.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} + +pub struct ParRowIterMut<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> ParRowIterMut<'trace, N> { + pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self { + Self { data } + } +} +impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} +impl IndexedParallelIterator for ParRowIterMut<'_, N> { + fn len(&self) -> usize { + self.data[0].len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { data: self.data }) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + use stwo_prover::core::fields::m31::M31; + use stwo_prover::core::fields::FieldExpOps; + + #[test] + fn test_parallel_trace() { + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 8; + let mut trace = super::ComponentTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let expected = arr + .iter() + .map(|&a| { + let b = a + M31::from(1); + let c = a.square() + b.square(); + (a, b, c) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .chunks(4) + .for_each(|chunk| { + chunk.into_iter().for_each(|(row, input)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = row[0].square() + row[1].square(); + }) + }); + let actual = trace + .data + .map(|c| { + c.into_iter() + .flat_map(|packed| packed.to_array()) + .collect_vec() + }) + .into_iter() + .next_tuple() + .unwrap(); + + assert_eq!(expected, actual); + } +} diff --git a/crates/air_utils/src/trace/examle_lookup_data.rs b/crates/air_utils/src/trace/examle_lookup_data.rs new file mode 100644 index 000000000..9436071bc --- /dev/null +++ b/crates/air_utils/src/trace/examle_lookup_data.rs @@ -0,0 +1,228 @@ +// TODO(Ohad): write a derive macro for this. +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; +use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + +pub struct LookupData { + pub lu0: Vec<[PackedM31; 2]>, + pub lu1: Vec<[PackedM31; 4]>, +} +impl LookupData { + /// # Safety + pub unsafe fn uninitialized(log_size: u32) -> Self { + let length = 1 << log_size; + let n_simd_elems = length / N_LANES; + let mut lu0 = Vec::with_capacity(n_simd_elems); + let mut lu1 = Vec::with_capacity(n_simd_elems); + lu0.set_len(n_simd_elems); + lu1.set_len(n_simd_elems); + + Self { lu0, lu1 } + } + + pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> { + LookupDataIterMut::new(&mut self.lu0, &mut self.lu1) + } + + pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> { + ParLookupDataIterMut { + lu0: &mut self.lu0, + lu1: &mut self.lu1, + } + } +} + +pub struct LookupDataMutChunk<'trace> { + pub lu0: &'trace mut [PackedM31; 2], + pub lu1: &'trace mut [PackedM31; 4], +} +pub struct LookupDataIterMut<'trace> { + lu0: *mut [[PackedM31; 2]], + lu1: *mut [[PackedM31; 4]], + phantom: std::marker::PhantomData<&'trace ()>, +} +impl<'trace> LookupDataIterMut<'trace> { + pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { + Self { + lu0: slice0 as *mut _, + lu1: slice1 as *mut _, + phantom: std::marker::PhantomData, + } + } +} +impl<'trace> Iterator for LookupDataIterMut<'trace> { + type Item = LookupDataMutChunk<'trace>; + + fn next(&mut self) -> Option { + if self.lu0.is_empty() { + return None; + } + let item = unsafe { + let (head0, tail0) = self.lu0.split_at_mut(1); + let (head1, tail1) = self.lu1.split_at_mut(1); + self.lu0 = tail0; + self.lu1 = tail1; + LookupDataMutChunk { + lu0: &mut (*head0)[0], + lu1: &mut (*head1)[0], + } + }; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.lu0.len(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for LookupDataIterMut<'_> {} +impl DoubleEndedIterator for LookupDataIterMut<'_> { + fn next_back(&mut self) -> Option { + if self.lu0.is_empty() { + return None; + } + let item = unsafe { + let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1); + let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1); + self.lu0 = head0; + self.lu1 = head1; + LookupDataMutChunk { + lu0: &mut (*tail0)[0], + lu1: &mut (*tail1)[0], + } + }; + Some(item) + } +} + +struct RowProducer<'trace> { + lu0: &'trace mut [[PackedM31; 2]], + lu1: &'trace mut [[PackedM31; 4]], +} + +impl<'trace> Producer for RowProducer<'trace> { + type Item = LookupDataMutChunk<'trace>; + + fn split_at(self, index: usize) -> (Self, Self) { + let (lu0, rh0) = self.lu0.split_at_mut(index); + let (lu1, rh1) = self.lu1.split_at_mut(index); + (RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 }) + } + + type IntoIter = LookupDataIterMut<'trace>; + + fn into_iter(self) -> Self::IntoIter { + LookupDataIterMut::new(self.lu0, self.lu1) + } +} + +pub struct ParLookupDataIterMut<'trace> { + lu0: &'trace mut [[PackedM31; 2]], + lu1: &'trace mut [[PackedM31; 4]], +} + +impl<'trace> ParLookupDataIterMut<'trace> { + pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { + Self { + lu0: slice0, + lu1: slice1, + } + } +} + +impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> { + type Item = LookupDataMutChunk<'trace>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl IndexedParallelIterator for ParLookupDataIterMut<'_> { + fn len(&self) -> usize { + self.lu0.len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { + lu0: self.lu0, + lu1: self.lu1, + }) + } +} + +#[cfg(test)] +mod tests { + use itertools::{all, Itertools}; + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + use stwo_prover::core::fields::m31::M31; + + use crate::trace::examle_lookup_data::LookupData; + use crate::trace::component_trace::ComponentTrace; + + #[test] + fn test_lookup_data() { + const N_COLUMNS: usize = 5; + const LOG_SIZE: u32 = 8; + let mut trace = ComponentTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) }; + let expected: (Vec<_>, Vec<_>) = arr + .array_chunks::() + .map(|x| { + let x = PackedM31::from_array(*x); + let x1 = x + PackedM31::broadcast(M31(1)); + let x2 = x + x1; + let x3 = x + x1 + x2; + let x4 = x + x1 + x2 + x3; + ([x, x4], [x1, x1.double(), x2, x2.double()]) + }) + .unzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .zip(lookup_data.par_iter_mut()) + .chunks(4) + .for_each(|chunk| { + chunk.into_iter().for_each(|((row, input), lookup_data)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = *row[0] + *row[1]; + *row[3] = *row[0] + *row[1] + *row[2]; + *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; + *lookup_data.lu0 = [*row[0], *row[4]]; + *lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()]; + }) + }); + + assert!(all( + lookup_data.lu0.into_iter().zip(expected.0), + |(actual, expected)| actual[0].to_array() == expected[0].to_array() + && actual[1].to_array() == expected[1].to_array() + )); + assert!(all( + lookup_data.lu1.into_iter().zip(expected.1), + |(actual, expected)| { + actual[0].to_array() == expected[0].to_array() + && actual[1].to_array() == expected[1].to_array() + && actual[2].to_array() == expected[2].to_array() + && actual[3].to_array() == expected[3].to_array() + } + )); + } +} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs new file mode 100644 index 000000000..4db8186ba --- /dev/null +++ b/crates/air_utils/src/trace/mod.rs @@ -0,0 +1,2 @@ +pub mod component_trace; +pub mod examle_lookup_data;