Skip to content

Commit

Permalink
derive lookup data impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 24, 2024
1 parent c699955 commit 2b0e92f
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 165 deletions.
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["crates/prover", "crates/air_utils"]
members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"]
resolver = "2"

[workspace.package]
Expand Down
1 change: 1 addition & 0 deletions crates/air_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bytemuck.workspace = true
itertools.workspace = true
rayon = { version = "1.10.0", optional = false }
stwo-prover = { path = "../prover" }
stwo-air-utils-derive = { path = "../air_utils_derive" }

[lib]
bench = false
196 changes: 32 additions & 164 deletions crates/air_utils/src/trace/examle_lookup_data.rs
Original file line number Diff line number Diff line change
@@ -1,166 +1,11 @@
// TODO(Ohad): write a derive macro for this.
use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
use rayon::prelude::*;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_air_utils_derive::StwoIterable;

#[derive(StwoIterable)]
pub struct LookupData {
pub lu0: Vec<[PackedM31; 2]>,
pub lu1: Vec<[PackedM31; 4]>,
}
impl LookupData {
/// # Safety
pub unsafe fn uninitialized(log_size: u32) -> Self {
let length = 1 << log_size;
let n_simd_elems = length / N_LANES;
let mut lu0 = Vec::with_capacity(n_simd_elems);
let mut lu1 = Vec::with_capacity(n_simd_elems);
lu0.set_len(n_simd_elems);
lu1.set_len(n_simd_elems);

Self { lu0, lu1 }
}

pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> {
LookupDataIterMut::new(&mut self.lu0, &mut self.lu1)
}

pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> {
ParLookupDataIterMut {
lu0: &mut self.lu0,
lu1: &mut self.lu1,
}
}
}

pub struct LookupDataMutChunk<'trace> {
pub lu0: &'trace mut [PackedM31; 2],
pub lu1: &'trace mut [PackedM31; 4],
}
pub struct LookupDataIterMut<'trace> {
lu0: *mut [[PackedM31; 2]],
lu1: *mut [[PackedM31; 4]],
phantom: std::marker::PhantomData<&'trace ()>,
}
impl<'trace> LookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0 as *mut _,
lu1: slice1 as *mut _,
phantom: std::marker::PhantomData,
}
}
}
impl<'trace> Iterator for LookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn next(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(1);
let (head1, tail1) = self.lu1.split_at_mut(1);
self.lu0 = tail0;
self.lu1 = tail1;
LookupDataMutChunk {
lu0: &mut (*head0)[0],
lu1: &mut (*head1)[0],
}
};
Some(item)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.lu0.len();
(len, Some(len))
}
}

impl ExactSizeIterator for LookupDataIterMut<'_> {}
impl DoubleEndedIterator for LookupDataIterMut<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1);
let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1);
self.lu0 = head0;
self.lu1 = head1;
LookupDataMutChunk {
lu0: &mut (*tail0)[0],
lu1: &mut (*tail1)[0],
}
};
Some(item)
}
}

struct RowProducer<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> Producer for RowProducer<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn split_at(self, index: usize) -> (Self, Self) {
let (lu0, rh0) = self.lu0.split_at_mut(index);
let (lu1, rh1) = self.lu1.split_at_mut(index);
(RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 })
}

type IntoIter = LookupDataIterMut<'trace>;

fn into_iter(self) -> Self::IntoIter {
LookupDataIterMut::new(self.lu0, self.lu1)
}
}

pub struct ParLookupDataIterMut<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> ParLookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0,
lu1: slice1,
}
}
}

impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn drive_unindexed<D>(self, consumer: D) -> D::Result
where
D: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl IndexedParallelIterator for ParLookupDataIterMut<'_> {
fn len(&self) -> usize {
self.lu0.len()
}

fn drive<D: Consumer<Self::Item>>(self, consumer: D) -> D::Result {
bridge(self, consumer)
}

fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
callback.callback(RowProducer {
lu0: self.lu0,
lu1: self.lu1,
})
}
lu0: Vec<[PackedM31; 2]>,
lu1: Vec<[PackedM31; 4]>,
lu2: Vec<[PackedM31; 8]>,
}

#[cfg(test)]
Expand All @@ -177,21 +22,34 @@ mod tests {
#[test]
fn test_lookup_data() {
const N_COLUMNS: usize = 5;
const LOG_SIZE: u32 = 8;
const LOG_SIZE: u32 = 12;
let mut trace = IterableTrace::<N_COLUMNS>::zeroed(LOG_SIZE);
let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec();
let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) };
let expected: (Vec<_>, Vec<_>) = arr
let expected: (Vec<_>, Vec<_>, Vec<_>) = arr
.array_chunks::<N_LANES>()
.map(|x| {
let x = PackedM31::from_array(*x);
let x1 = x + PackedM31::broadcast(M31(1));
let x2 = x + x1;
let x3 = x + x1 + x2;
let x4 = x + x1 + x2 + x3;
([x, x4], [x1, x1.double(), x2, x2.double()])
(
[x, x4],
[x1, x1.double(), x2, x2.double()],
[
x3,
x3.double(),
x4,
x4.double(),
x,
x.double(),
x1,
x1.double(),
],
)
})
.unzip();
.multiunzip();

trace
.par_iter_mut()
Expand All @@ -207,6 +65,16 @@ mod tests {
*row[4] = *row[0] + *row[1] + *row[2] + *row[3];
*lookup_data.lu0 = [*row[0], *row[4]];
*lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()];
*lookup_data.lu2 = [
*row[3],
row[3].double(),
*row[4],
row[4].double(),
*row[0],
row[0].double(),
*row[1],
row[1].double(),
];
})
});

Expand Down
12 changes: 12 additions & 0 deletions crates/air_utils_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "stwo-air-utils-derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2.0.90"
quote = "1.0.37"
itertools = "0.13.0"
Loading

0 comments on commit 2b0e92f

Please sign in to comment.