Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derive lookup data impl #946

Open
wants to merge 1 commit into
base: ohad/par_lookup_data
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["crates/prover", "crates/air_utils"]
members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"]
resolver = "2"

[workspace.package]
Expand Down
1 change: 1 addition & 0 deletions crates/air_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bytemuck.workspace = true
itertools.workspace = true
rayon = { version = "1.10.0", optional = false }
stwo-prover = { path = "../prover" }
stwo-air-utils-derive = { path = "../air_utils_derive" }

[lib]
bench = false
194 changes: 31 additions & 163 deletions crates/air_utils/src/trace/examle_lookup_data.rs
Original file line number Diff line number Diff line change
@@ -1,166 +1,11 @@
// TODO(Ohad): write a derive macro for this.
use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
use rayon::prelude::*;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_air_utils_derive::StwoIterable;

#[derive(StwoIterable)]
pub struct LookupData {
pub lu0: Vec<[PackedM31; 2]>,
pub lu1: Vec<[PackedM31; 4]>,
}
impl LookupData {
/// # Safety
pub unsafe fn uninitialized(log_size: u32) -> Self {
let length = 1 << log_size;
let n_simd_elems = length / N_LANES;
let mut lu0 = Vec::with_capacity(n_simd_elems);
let mut lu1 = Vec::with_capacity(n_simd_elems);
lu0.set_len(n_simd_elems);
lu1.set_len(n_simd_elems);

Self { lu0, lu1 }
}

pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> {
LookupDataIterMut::new(&mut self.lu0, &mut self.lu1)
}

pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> {
ParLookupDataIterMut {
lu0: &mut self.lu0,
lu1: &mut self.lu1,
}
}
}

pub struct LookupDataMutChunk<'trace> {
pub lu0: &'trace mut [PackedM31; 2],
pub lu1: &'trace mut [PackedM31; 4],
}
pub struct LookupDataIterMut<'trace> {
lu0: *mut [[PackedM31; 2]],
lu1: *mut [[PackedM31; 4]],
phantom: std::marker::PhantomData<&'trace ()>,
}
impl<'trace> LookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0 as *mut _,
lu1: slice1 as *mut _,
phantom: std::marker::PhantomData,
}
}
}
impl<'trace> Iterator for LookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn next(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(1);
let (head1, tail1) = self.lu1.split_at_mut(1);
self.lu0 = tail0;
self.lu1 = tail1;
LookupDataMutChunk {
lu0: &mut (*head0)[0],
lu1: &mut (*head1)[0],
}
};
Some(item)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.lu0.len();
(len, Some(len))
}
}

impl ExactSizeIterator for LookupDataIterMut<'_> {}
impl DoubleEndedIterator for LookupDataIterMut<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1);
let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1);
self.lu0 = head0;
self.lu1 = head1;
LookupDataMutChunk {
lu0: &mut (*tail0)[0],
lu1: &mut (*tail1)[0],
}
};
Some(item)
}
}

struct RowProducer<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> Producer for RowProducer<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn split_at(self, index: usize) -> (Self, Self) {
let (lu0, rh0) = self.lu0.split_at_mut(index);
let (lu1, rh1) = self.lu1.split_at_mut(index);
(RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 })
}

type IntoIter = LookupDataIterMut<'trace>;

fn into_iter(self) -> Self::IntoIter {
LookupDataIterMut::new(self.lu0, self.lu1)
}
}

pub struct ParLookupDataIterMut<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> ParLookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0,
lu1: slice1,
}
}
}

impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn drive_unindexed<D>(self, consumer: D) -> D::Result
where
D: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl IndexedParallelIterator for ParLookupDataIterMut<'_> {
fn len(&self) -> usize {
self.lu0.len()
}

fn drive<D: Consumer<Self::Item>>(self, consumer: D) -> D::Result {
bridge(self, consumer)
}

fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
callback.callback(RowProducer {
lu0: self.lu0,
lu1: self.lu1,
})
}
lu0: Vec<[PackedM31; 2]>,
lu1: Vec<[PackedM31; 4]>,
lu2: Vec<[PackedM31; 8]>,
}

#[cfg(test)]
Expand All @@ -181,17 +26,30 @@ mod tests {
let mut trace = ComponentTrace::<N_COLUMNS>::zeroed(LOG_SIZE);
let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec();
let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) };
let expected: (Vec<_>, Vec<_>) = arr
let expected: (Vec<_>, Vec<_>, Vec<_>) = arr
.array_chunks::<N_LANES>()
.map(|x| {
let x = PackedM31::from_array(*x);
let x1 = x + PackedM31::broadcast(M31(1));
let x2 = x + x1;
let x3 = x + x1 + x2;
let x4 = x + x1 + x2 + x3;
([x, x4], [x1, x1.double(), x2, x2.double()])
(
[x, x4],
[x1, x1.double(), x2, x2.double()],
[
x3,
x3.double(),
x4,
x4.double(),
x,
x.double(),
x1,
x1.double(),
],
)
})
.unzip();
.multiunzip();

trace
.par_iter_mut()
Expand All @@ -207,6 +65,16 @@ mod tests {
*row[4] = *row[0] + *row[1] + *row[2] + *row[3];
*lookup_data.lu0 = [*row[0], *row[4]];
*lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()];
*lookup_data.lu2 = [
*row[3],
row[3].double(),
*row[4],
row[4].double(),
*row[0],
row[0].double(),
*row[1],
row[1].double(),
];
})
});

Expand Down
12 changes: 12 additions & 0 deletions crates/air_utils_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "stwo-air-utils-derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2.0.90"
quote = "1.0.37"
itertools = "0.13.0"
Loading
Loading