diff --git a/Cargo.lock b/Cargo.lock index 54d4d4ff7..c87fc6628 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,7 @@ version = "0.1.1" dependencies = [ "bytemuck", "itertools 0.12.1", + "rayon", "stwo-prover", ] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index c021b76cb..7d09a7eaf 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -6,6 +6,7 @@ edition.workspace = true [dependencies] bytemuck.workspace = true itertools.workspace = true +rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } [lib] diff --git a/crates/air_utils/src/trace/iterable_trace.rs b/crates/air_utils/src/trace/iterable_trace.rs index f2b700f26..4e7394471 100644 --- a/crates/air_utils/src/trace/iterable_trace.rs +++ b/crates/air_utils/src/trace/iterable_trace.rs @@ -2,6 +2,8 @@ use std::marker::PhantomData; use bytemuck::{cast_slice, Zeroable}; use itertools::Itertools; +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; use stwo_prover::core::backend::simd::SimdBackend; use stwo_prover::core::fields::m31::M31; @@ -83,6 +85,17 @@ impl IterableTrace { RowIterMut::new(v) } + pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> { + let v = self + .data + .iter_mut() + .map(|column| column.as_mut_slice()) + .collect_vec() + .try_into() + .unwrap(); + ParRowIterMut::new(v) + } + pub fn to_evals(self) -> [CircleEvaluation; N] { todo!() } @@ -144,6 +157,83 @@ impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { } } impl ExactSizeIterator for RowIterMut<'_, N> {} +impl DoubleEndedIterator for RowIterMut<'_, N> { + fn next_back(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1); + self.v[i] = head; + &mut (*tail)[0] + }); + Some(item) + } +} + +struct RowProducer<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> Producer for RowProducer<'trace, N> { + type Item = MutRow<'trace, N>; + + fn split_at(self, index: usize) -> (Self, Self) { + let mut left: [_; N] = unsafe { std::mem::zeroed() }; + let mut right: [_; N] = unsafe { std::mem::zeroed() }; + for (i, slice) in self.data.into_iter().enumerate() { + let (lhs, rhs) = slice.split_at_mut(index); + left[i] = lhs; + right[i] = rhs; + } + (RowProducer { data: left }, RowProducer { data: right }) + } + + type IntoIter = RowIterMut<'trace, N>; + + fn into_iter(self) -> Self::IntoIter { + RowIterMut { + v: self.data.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} + +pub struct ParRowIterMut<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> ParRowIterMut<'trace, N> { + pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self { + Self { data } + } +} +impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} +impl IndexedParallelIterator for ParRowIterMut<'_, N> { + fn len(&self) -> usize { + self.data[0].len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { data: self.data }) + } +} #[cfg(test)] mod tests { @@ -192,4 +282,47 @@ mod tests { assert_eq!(expected, actual); } + + #[test] + fn test_parallel_trace() { + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 8; + let mut trace = super::IterableTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let expected = arr + .iter() + .map(|&a| { + let b = a + M31::from(1); + let c = a.square() + b.square(); + (a, b, c) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .chunks(4) + .for_each(|chunk| { + chunk.into_iter().for_each(|(row, input)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = row[0].square() + row[1].square(); + }) + }); + let actual = trace + .data + .map(|c| { + c.into_iter() + .flat_map(|packed| packed.to_array()) + .collect_vec() + }) + .into_iter() + .next_tuple() + .unwrap(); + + assert_eq!(expected, actual); + } }