Skip to content

Commit

Permalink
derive subcomponentinput
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 17, 2024
1 parent 4de4d1d commit 5b7c782
Show file tree
Hide file tree
Showing 38 changed files with 4,187 additions and 4,124 deletions.
20 changes: 20 additions & 0 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"crates/utils",
"crates/vm_runner",
"crates/prover_types",
"crates/air_structs_derive",
]
resolver = "2"

Expand Down
13 changes: 13 additions & 0 deletions stwo_cairo_prover/crates/air_structs_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "air_structs_derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2.0.90"
quote = "1.0.37"
proc-macro2 = "1.0.92"
itertools = "0.13.0"
151 changes: 151 additions & 0 deletions stwo_cairo_prover/crates/air_structs_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use itertools::Itertools;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Path, Type};

#[proc_macro_derive(SubComponentInputs)]
pub fn derive_sub_component_inputs(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
assert_is_struct(&input);
let name = &input.ident;
let vec_array_fields = extract_vec_array_fields(&input);

// TODO(Ohad): deprecate with_capacity.
let with_capacity_method = generate_with_capacity_method(&vec_array_fields);
let uninitialized_method = generate_uninitialized_method(&vec_array_fields);
let bit_reverse_method = generate_bit_reverse_method(&vec_array_fields);

let expanded = quote! {
impl #name {
#with_capacity_method
#uninitialized_method
#bit_reverse_method
}
};

proc_macro::TokenStream::from(expanded)
}

#[derive(Clone)]
struct VecArrayField {
name: syn::Ident,
array_length: usize,
}

fn assert_is_struct(input: &DeriveInput) {
if !matches!(input.data, Data::Struct(_)) {
panic!("SubComponentInputs can only be derived for structs");
}
}

fn extract_vec_array_fields(input: &DeriveInput) -> Vec<VecArrayField> {
let mut vec_array_fields = Vec::new();

if let Data::Struct(data_struct) = &input.data {
if let Fields::Named(fields) = &data_struct.fields {
for field in &fields.named {
// Field is an array of Vecs.
if let Type::Array(type_array) = &field.ty {
if let Type::Path(element_type) = &*type_array.elem {
// Element is a Vec.
if is_vec_type(&element_type.path) {
// Get the array length
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(length_lit),
..
}) = type_array.len.clone()
{
vec_array_fields.push(VecArrayField {
name: field.ident.clone().unwrap(),
array_length: length_lit.base10_parse().unwrap(),
});
}
}
}
}
}
}
}

vec_array_fields
}

fn is_vec_type(path: &Path) -> bool {
path.segments.len() == 1 && path.segments.first().unwrap().ident == "Vec"
}

fn generate_with_capacity_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let field_initializations = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;
let array_length = field.array_length;

quote! {
#field_name: [(); #array_length].map(|_| Vec::with_capacity(capacity))
}
})
.collect_vec();

quote! {
fn with_capacity(capacity: usize) -> Self {
Self {
#(#field_initializations),*
}
}
}
}

fn generate_uninitialized_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let (field_initializations, field_len_updates): (Vec<_>, Vec<_>) = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;
let array_length = field.array_length;

(
quote! {
#field_name: [(); #array_length].map(|_| Vec::with_capacity(capacity))
},
quote! {
result.#field_name
.iter_mut()
.for_each(|v| v.set_len(capacity));
},
)
})
.unzip();

quote! {
unsafe fn uninitialized(capacity: usize) -> Self {
let mut result = Self {
#(#field_initializations),*
};

#(#field_len_updates)*

result
}
}
}

fn generate_bit_reverse_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let field_updates = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;

quote! {
self.#field_name
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
})
.collect_vec();

quote! {
fn bit_reverse_coset_to_circle_domain_order(&mut self) {
#(#field_updates)*
}
}
}
1 change: 1 addition & 0 deletions stwo_cairo_prover/crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
parallel = ["rayon"]

[dependencies]
air_structs_derive = { path = "../air_structs_derive" }
bytemuck.workspace = true
cairo-lang-casm.workspace = true
cairo-vm.workspace = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Loading

0 comments on commit 5b7c782

Please sign in to comment.