From 0fb1f67afd1d6c48691e2349e1ea56874f2126f4 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 15:52:44 +0200 Subject: [PATCH] derive lookup data impl --- Cargo.lock | 19 ++ Cargo.toml | 2 +- crates/air_utils/Cargo.toml | 1 + .../air_utils/src/trace/examle_lookup_data.rs | 196 +++-------------- crates/air_utils_derive/Cargo.toml | 12 + crates/air_utils_derive/src/lib.rs | 208 ++++++++++++++++++ 6 files changed, 273 insertions(+), 165 deletions(-) create mode 100644 crates/air_utils_derive/Cargo.toml create mode 100644 crates/air_utils_derive/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index c87fc6628..7ba7290e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,6 +593,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -1042,9 +1051,19 @@ dependencies = [ "bytemuck", "itertools 0.12.1", "rayon", + "stwo-air-utils-derive", "stwo-prover", ] +[[package]] +name = "stwo-air-utils-derive" +version = "0.1.0" +dependencies = [ + "itertools 0.13.0", + "quote", + "syn 2.0.90", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index fadd620de..d4bb782ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover", "crates/air_utils"] +members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index 7d09a7eaf..4463faf8b 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -8,6 +8,7 @@ bytemuck.workspace = true itertools.workspace = true rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } +stwo-air-utils-derive = { path = "../air_utils_derive" } [lib] bench = false diff --git a/crates/air_utils/src/trace/examle_lookup_data.rs b/crates/air_utils/src/trace/examle_lookup_data.rs index c78c9ae08..80492882d 100644 --- a/crates/air_utils/src/trace/examle_lookup_data.rs +++ b/crates/air_utils/src/trace/examle_lookup_data.rs @@ -1,166 +1,11 @@ // TODO(Ohad): write a derive macro for this. -use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; -use rayon::prelude::*; -use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +use stwo_air_utils_derive::StwoIterable; +#[derive(StwoIterable)] pub struct LookupData { - pub lu0: Vec<[PackedM31; 2]>, - pub lu1: Vec<[PackedM31; 4]>, -} -impl LookupData { - /// # Safety - pub unsafe fn uninitialized(log_size: u32) -> Self { - let length = 1 << log_size; - let n_simd_elems = length / N_LANES; - let mut lu0 = Vec::with_capacity(n_simd_elems); - let mut lu1 = Vec::with_capacity(n_simd_elems); - lu0.set_len(n_simd_elems); - lu1.set_len(n_simd_elems); - - Self { lu0, lu1 } - } - - pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> { - LookupDataIterMut::new(&mut self.lu0, &mut self.lu1) - } - - pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> { - ParLookupDataIterMut { - lu0: &mut self.lu0, - lu1: &mut self.lu1, - } - } -} - -pub struct LookupDataMutChunk<'trace> { - pub lu0: &'trace mut [PackedM31; 2], - pub lu1: &'trace mut [PackedM31; 4], -} -pub struct LookupDataIterMut<'trace> { - lu0: *mut [[PackedM31; 2]], - lu1: *mut [[PackedM31; 4]], - phantom: std::marker::PhantomData<&'trace ()>, -} -impl<'trace> LookupDataIterMut<'trace> { - pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { - Self { - lu0: slice0 as *mut _, - lu1: slice1 as *mut _, - phantom: std::marker::PhantomData, - } - } -} -impl<'trace> Iterator for LookupDataIterMut<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn next(&mut self) -> Option { - if self.lu0.is_empty() { - return None; - } - let item = unsafe { - let (head0, tail0) = self.lu0.split_at_mut(1); - let (head1, tail1) = self.lu1.split_at_mut(1); - self.lu0 = tail0; - self.lu1 = tail1; - LookupDataMutChunk { - lu0: &mut (*head0)[0], - lu1: &mut (*head1)[0], - } - }; - Some(item) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.lu0.len(); - (len, Some(len)) - } -} - -impl ExactSizeIterator for LookupDataIterMut<'_> {} -impl DoubleEndedIterator for LookupDataIterMut<'_> { - fn next_back(&mut self) -> Option { - if self.lu0.is_empty() { - return None; - } - let item = unsafe { - let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1); - let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1); - self.lu0 = head0; - self.lu1 = head1; - LookupDataMutChunk { - lu0: &mut (*tail0)[0], - lu1: &mut (*tail1)[0], - } - }; - Some(item) - } -} - -struct RowProducer<'trace> { - lu0: &'trace mut [[PackedM31; 2]], - lu1: &'trace mut [[PackedM31; 4]], -} - -impl<'trace> Producer for RowProducer<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn split_at(self, index: usize) -> (Self, Self) { - let (lu0, rh0) = self.lu0.split_at_mut(index); - let (lu1, rh1) = self.lu1.split_at_mut(index); - (RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 }) - } - - type IntoIter = LookupDataIterMut<'trace>; - - fn into_iter(self) -> Self::IntoIter { - LookupDataIterMut::new(self.lu0, self.lu1) - } -} - -pub struct ParLookupDataIterMut<'trace> { - lu0: &'trace mut [[PackedM31; 2]], - lu1: &'trace mut [[PackedM31; 4]], -} - -impl<'trace> ParLookupDataIterMut<'trace> { - pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { - Self { - lu0: slice0, - lu1: slice1, - } - } -} - -impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn drive_unindexed(self, consumer: D) -> D::Result - where - D: UnindexedConsumer, - { - bridge(self, consumer) - } - - fn opt_len(&self) -> Option { - Some(self.len()) - } -} - -impl IndexedParallelIterator for ParLookupDataIterMut<'_> { - fn len(&self) -> usize { - self.lu0.len() - } - - fn drive>(self, consumer: D) -> D::Result { - bridge(self, consumer) - } - - fn with_producer>(self, callback: CB) -> CB::Output { - callback.callback(RowProducer { - lu0: self.lu0, - lu1: self.lu1, - }) - } + lu0: Vec<[PackedM31; 2]>, + lu1: Vec<[PackedM31; 4]>, + lu2: Vec<[PackedM31; 8]>, } #[cfg(test)] @@ -177,11 +22,11 @@ mod tests { #[test] fn test_lookup_data() { const N_COLUMNS: usize = 5; - const LOG_SIZE: u32 = 8; + const LOG_SIZE: u32 = 12; let mut trace = IterableTrace::::zeroed(LOG_SIZE); let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) }; - let expected: (Vec<_>, Vec<_>) = arr + let expected: (Vec<_>, Vec<_>, Vec<_>) = arr .array_chunks::() .map(|x| { let x = PackedM31::from_array(*x); @@ -189,9 +34,22 @@ mod tests { let x2 = x + x1; let x3 = x + x1 + x2; let x4 = x + x1 + x2 + x3; - ([x, x4], [x1, x1.double(), x2, x2.double()]) + ( + [x, x4], + [x1, x1.double(), x2, x2.double()], + [ + x3, + x3.double(), + x4, + x4.double(), + x, + x.double(), + x1, + x1.double(), + ], + ) }) - .unzip(); + .multiunzip(); trace .par_iter_mut() @@ -207,6 +65,16 @@ mod tests { *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; *lookup_data.lu0 = [*row[0], *row[4]]; *lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()]; + *lookup_data.lu2 = [ + *row[3], + row[3].double(), + *row[4], + row[4].double(), + *row[0], + row[0].double(), + *row[1], + row[1].double(), + ]; }) }); diff --git a/crates/air_utils_derive/Cargo.toml b/crates/air_utils_derive/Cargo.toml new file mode 100644 index 000000000..1493d2cf8 --- /dev/null +++ b/crates/air_utils_derive/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "stwo-air-utils-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "2.0.90" +quote = "1.0.37" +itertools = "0.13.0" diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs new file mode 100644 index 000000000..6fdcdc9f3 --- /dev/null +++ b/crates/air_utils_derive/src/lib.rs @@ -0,0 +1,208 @@ +use itertools::Itertools; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; + +#[proc_macro_derive(StwoIterable)] +pub fn derive_stwo_iterable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = &input.ident; + let input = match input.data { + Data::Struct(data_struct) => data_struct, + _ => panic!("Expected struct"), + }; + + let fields = match input.fields { + Fields::Named(fields) => fields.named, + _ => panic!("Expected named fields"), + }; + + let mut field_names = vec![]; + let mut array_sizes = vec![]; + + for field in fields { + field_names.push(field.ident.unwrap()); + + // Extract array size. + if let Type::Path(type_path) = field.ty { + if let Some(last_segment) = type_path.path.segments.last() { + if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments { + if let Some(syn::GenericArgument::Type(Type::Array(array))) = args.args.first() + { + array_sizes.push(array.len.clone()); + } + } + } + } + } + + let field_names_head = field_names + .iter() + .map(|f| format_ident!("head_{}", f)) + .collect_vec(); + let field_names_tail = field_names + .iter() + .map(|f| format_ident!("tail_{}", f)) + .collect_vec(); + let first_field = field_names.first().unwrap(); + + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let par_iter_mut_name = format_ident!("Par{}IterMut", struct_name); + + let expansions = quote! { + use stwo_prover::core::backend::simd::m31::PackedM31; + use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; + use stwo_prover::core::backend::simd::m31::N_LANES; + use rayon::prelude::*; + + impl #struct_name { + /// # Safety + /// The caller must ensure that the trace is populated before being used. + pub unsafe fn uninitialized(log_size: u32) -> Self { + let length = 1 << log_size; + let n_simd_elems = length / N_LANES; + #( + let mut #field_names = Vec::with_capacity(n_simd_elems); + #field_names.set_len(n_simd_elems); + )* + Self { #(#field_names),* } + } + + pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> { + LookupDataIterMut::new(#(&mut self.#field_names),*) + } + + pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> { + ParLookupDataIterMut { #(#field_names: &mut self.#field_names),* } + } + } + + pub struct #mut_chunk_name<'trace> { + #(pub #field_names: &'trace mut [PackedM31; #array_sizes],)* + } + + pub struct #iter_mut_name<'trace> { + #(#field_names: *mut [[PackedM31; #array_sizes]],)* + phantom: std::marker::PhantomData<&'trace ()>, + } + + impl<'trace> #iter_mut_name<'trace> { + pub fn new(#(#field_names: &'trace mut [[PackedM31; #array_sizes]],)*) -> Self { + Self { + #(#field_names: #field_names as *mut _,)* + phantom: std::marker::PhantomData, + } + } + } + + impl<'trace> Iterator for #iter_mut_name<'trace> { + type Item = #mut_chunk_name<'trace>; + fn next(&mut self) -> Option { + if self.#first_field.is_empty() { + return None; + } + let item = unsafe { + #( + let (#field_names_head, #field_names_tail) = (*self.#field_names).split_at_mut(1); + self.#field_names = #field_names_tail; + )* + #mut_chunk_name { + #(#field_names: &mut (*#field_names_head)[0],)* + } + }; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.#first_field.len(); + (len, Some(len)) + } + } + + impl ExactSizeIterator for #iter_mut_name<'_> {} + + impl<'trace> DoubleEndedIterator for #iter_mut_name<'trace> { + fn next_back(&mut self) -> Option { + if self.#first_field.is_empty() { + return None; + } + let item = unsafe { + #( + let (#field_names_head, #field_names_tail) = (*self.#field_names) + .split_at_mut(self.#field_names.len() - 1); + self.#field_names = #field_names_head; + )* + #mut_chunk_name { + #(#field_names: &mut (*#field_names_tail)[0],)* + } + }; + Some(item) + } + } + + pub struct #row_producer_name<'trace> { + #(#field_names: &'trace mut [[PackedM31; #array_sizes]],)* + } + + impl<'trace> Producer for #row_producer_name<'trace> { + type Item = #mut_chunk_name<'trace>; + type IntoIter = #iter_mut_name<'trace>; + + fn split_at(self, index: usize) -> (Self, Self) { + #( + let (#field_names, #field_names_tail) = self.#field_names.split_at_mut(index); + )* + ( + Self { #(#field_names,)* }, + Self { #(#field_names: #field_names_tail,)* } + ) + } + + fn into_iter(self) -> Self::IntoIter { + #iter_mut_name::new(#(self.#field_names),*) + } + } + + pub struct #par_iter_mut_name<'trace> { + #(#field_names: &'trace mut [[PackedM31; #array_sizes]],)* + } + + impl<'trace> #par_iter_mut_name<'trace> { + pub fn new(#(#field_names: &'trace mut [[PackedM31; #array_sizes]],)*) -> Self { + Self { #(#field_names,)* } + } + } + + impl<'trace> ParallelIterator for #par_iter_mut_name<'trace> { + type Item = #mut_chunk_name<'trace>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } + } + + impl IndexedParallelIterator for #par_iter_mut_name<'_> { + fn len(&self) -> usize { + self.#first_field.len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(#row_producer_name { #(#field_names: self.#field_names),* }) + } + } + }; + + proc_macro::TokenStream::from(expansions) +}