diff --git a/Cargo.lock b/Cargo.lock index 54d4d4ff7..c87fc6628 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,7 @@ version = "0.1.1" dependencies = [ "bytemuck", "itertools 0.12.1", + "rayon", "stwo-prover", ] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index c021b76cb..7d09a7eaf 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -6,6 +6,7 @@ edition.workspace = true [dependencies] bytemuck.workspace = true itertools.workspace = true +rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } [lib] diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index 1fe6928f0..20a96579d 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use bytemuck::Zeroable; use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::backend::simd::SimdBackend; @@ -7,6 +5,8 @@ use stwo_prover::core::fields::m31::M31; use stwo_prover::core::poly::circle::CircleEvaluation; use stwo_prover::core::poly::BitReversedOrder; +use super::row_iterator::{ParRowIterMut, RowIterMut}; + /// A 2D Matrix of [`PackedM31`] values. /// Used for generating the witness of 'Stwo' proofs. /// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. @@ -74,6 +74,10 @@ impl ComponentTrace { RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) } + pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> { + ParRowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + pub fn to_evals(self) -> [CircleEvaluation; N] { todo!() } @@ -88,40 +92,54 @@ impl ComponentTrace { } } -pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; +#[cfg(test)] +mod tests { + use itertools::Itertools; + use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + use stwo_prover::core::fields::m31::M31; + use stwo_prover::core::fields::FieldExpOps; -/// An iterator over mutable references to the rows of a [`ComponentTrace`]. -pub struct RowIterMut<'trace, const N: usize> { - v: [*mut [PackedM31]; N], - phantom: PhantomData<&'trace ()>, -} -impl<'trace, const N: usize> RowIterMut<'trace, N> { - pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { - Self { - v: slice.map(|s| s as *mut _), - phantom: PhantomData, - } - } -} -impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { - type Item = MutRow<'trace, N>; + #[test] + fn test_parallel_trace() { + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; - fn next(&mut self) -> Option { - if self.v[0].is_empty() { - return None; - } - let item = std::array::from_fn(|i| unsafe { - // SAFETY: The self.v contract ensures that any split_at_mut is valid. - let (head, tail) = self.v[i].split_at_mut(1); - self.v[i] = tail; - &mut (*head)[0] - }); - Some(item) - } + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 8; + const CHUNK_SIZE: usize = 4; + let mut trace = super::ComponentTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let expected = arr + .iter() + .map(|&a| { + let b = a + M31::from(1); + let c = a.square() + b.square(); + (a, b, c) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .chunks(CHUNK_SIZE) + .for_each(|chunk| { + chunk.into_iter().for_each(|(row, input)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = row[0].square() + row[1].square(); + }); + }); + let actual = trace + .data + .map(|c| { + c.into_iter() + .flat_map(|packed| packed.to_array()) + .collect_vec() + }) + .into_iter() + .next_tuple() + .unwrap(); - fn size_hint(&self) -> (usize, Option) { - let len = self.v[0].len(); - (len, Some(len)) + assert_eq!(expected, actual); } } -impl ExactSizeIterator for RowIterMut<'_, N> {} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs index 03a022de5..6e44c9033 100644 --- a/crates/air_utils/src/trace/mod.rs +++ b/crates/air_utils/src/trace/mod.rs @@ -1 +1,2 @@ pub mod component_trace; +mod row_iterator; diff --git a/crates/air_utils/src/trace/row_iterator.rs b/crates/air_utils/src/trace/row_iterator.rs new file mode 100644 index 000000000..78d03ebea --- /dev/null +++ b/crates/air_utils/src/trace/row_iterator.rs @@ -0,0 +1,126 @@ +use std::marker::PhantomData; + +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; +use stwo_prover::core::backend::simd::m31::PackedM31; + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`super::component_trace::ComponentTrace`]. +// TODO(Ohad): Iterating over single rows is not optimal, figure out optimal chunk size when using +// this iterator. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} +impl DoubleEndedIterator for RowIterMut<'_, N> { + fn next_back(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1); + self.v[i] = head; + &mut (*tail)[0] + }); + Some(item) + } +} + +struct RowProducer<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> Producer for RowProducer<'trace, N> { + type Item = MutRow<'trace, N>; + + fn split_at(self, index: usize) -> (Self, Self) { + let mut left: [_; N] = unsafe { std::mem::zeroed() }; + let mut right: [_; N] = unsafe { std::mem::zeroed() }; + for (i, slice) in self.data.into_iter().enumerate() { + let (lhs, rhs) = slice.split_at_mut(index); + left[i] = lhs; + right[i] = rhs; + } + (RowProducer { data: left }, RowProducer { data: right }) + } + + type IntoIter = RowIterMut<'trace, N>; + + fn into_iter(self) -> Self::IntoIter { + RowIterMut { + v: self.data.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} + +/// A parallel iterator over mutable references to the rows of a +/// [`super::component_trace::ComponentTrace`]. [`super::component_trace::ComponentTrace`] is an +/// array of columns, hence iterating over rows is not trivial. Iteration is done by iterating over +/// `N` columns in parallel. +pub struct ParRowIterMut<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> ParRowIterMut<'trace, N> { + pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self { + Self { data } + } +} +impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} +impl IndexedParallelIterator for ParRowIterMut<'_, N> { + fn len(&self) -> usize { + self.data[0].len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { data: self.data }) + } +}