Skip to content

Commit

Permalink
par trace
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 24, 2024
1 parent 11b8841 commit eda26b7
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/air_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition.workspace = true
[dependencies]
bytemuck.workspace = true
itertools.workspace = true
rayon = { version = "1.10.0", optional = false }
stwo-prover = { path = "../prover" }

[lib]
Expand Down
133 changes: 133 additions & 0 deletions crates/air_utils/src/trace/component_trace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::marker::PhantomData;

use bytemuck::Zeroable;
use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
use rayon::prelude::*;
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::fields::m31::M31;
Expand Down Expand Up @@ -74,6 +76,10 @@ impl<const N: usize> ComponentTrace<N> {
RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice()))
}

pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> {
ParRowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice()))
}

pub fn to_evals(self) -> [CircleEvaluation<SimdBackend, M31, BitReversedOrder>; N] {
todo!()
}
Expand Down Expand Up @@ -125,3 +131,130 @@ impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> {
}
}
impl<const N: usize> ExactSizeIterator for RowIterMut<'_, N> {}
impl<const N: usize> DoubleEndedIterator for RowIterMut<'_, N> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.v[0].is_empty() {
return None;
}
let item = std::array::from_fn(|i| unsafe {
// SAFETY: The self.v contract ensures that any split_at_mut is valid.
let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1);
self.v[i] = head;
&mut (*tail)[0]
});
Some(item)
}
}

struct RowProducer<'trace, const N: usize> {
data: [&'trace mut [PackedM31]; N],
}
impl<'trace, const N: usize> Producer for RowProducer<'trace, N> {
type Item = MutRow<'trace, N>;

fn split_at(self, index: usize) -> (Self, Self) {
let mut left: [_; N] = unsafe { std::mem::zeroed() };
let mut right: [_; N] = unsafe { std::mem::zeroed() };
for (i, slice) in self.data.into_iter().enumerate() {
let (lhs, rhs) = slice.split_at_mut(index);
left[i] = lhs;
right[i] = rhs;
}
(RowProducer { data: left }, RowProducer { data: right })
}

type IntoIter = RowIterMut<'trace, N>;

fn into_iter(self) -> Self::IntoIter {
RowIterMut {
v: self.data.map(|s| s as *mut _),
phantom: PhantomData,
}
}
}

/// A parallel iterator over mutable references to the rows of a [`ComponentTrace`].
/// [`ComponentTrace`] is an array of columns, hence iterating over rows is not trivial.
pub struct ParRowIterMut<'trace, const N: usize> {
data: [&'trace mut [PackedM31]; N],
}
impl<'trace, const N: usize> ParRowIterMut<'trace, N> {
pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self {
Self { data }
}
}
impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> {
type Item = MutRow<'trace, N>;

fn drive_unindexed<D>(self, consumer: D) -> D::Result
where
D: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}
impl<const N: usize> IndexedParallelIterator for ParRowIterMut<'_, N> {
fn len(&self) -> usize {
self.data[0].len()
}

fn drive<D: Consumer<Self::Item>>(self, consumer: D) -> D::Result {
bridge(self, consumer)
}

fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
callback.callback(RowProducer { data: self.data })
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::FieldExpOps;

#[test]
fn test_parallel_trace() {
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;

const N_COLUMNS: usize = 3;
const LOG_SIZE: u32 = 8;
let mut trace = super::ComponentTrace::<N_COLUMNS>::zeroed(LOG_SIZE);
let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec();
let expected = arr
.iter()
.map(|&a| {
let b = a + M31::from(1);
let c = a.square() + b.square();
(a, b, c)
})
.multiunzip();

trace
.par_iter_mut()
.zip(arr.par_chunks(N_LANES))
.for_each(|(row, input)| {
*row[0] = PackedM31::from_array(input.try_into().unwrap());
*row[1] = *row[0] + PackedM31::broadcast(M31(1));
*row[2] = row[0].square() + row[1].square();
});
let actual = trace
.data
.map(|c| {
c.into_iter()
.flat_map(|packed| packed.to_array())
.collect_vec()
})
.into_iter()
.next_tuple()
.unwrap();

assert_eq!(expected, actual);
}
}

0 comments on commit eda26b7

Please sign in to comment.