Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derive subcomponentinput #260

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"crates/utils",
"crates/vm_runner",
"crates/prover_types",
"crates/air_structs_derive",
]
resolver = "2"

Expand Down
13 changes: 13 additions & 0 deletions stwo_cairo_prover/crates/air_structs_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "air_structs_derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2.0.90"
quote = "1.0.37"
proc-macro2 = "1.0.92"
itertools = "0.13.0"
146 changes: 146 additions & 0 deletions stwo_cairo_prover/crates/air_structs_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use itertools::Itertools;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Path, Type};

#[proc_macro_derive(SubComponentInputs)]
pub fn derive_sub_component_inputs(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
assert_is_struct(&input);
let name = &input.ident;
let vec_array_fields = extract_vec_array_fields(&input);

// TODO(Ohad): deprecate with_capacity.
let with_capacity_method = generate_with_capacity_method(&vec_array_fields);
let uninitialized_method = generate_uninitialized_method(&vec_array_fields);
let bit_reverse_method = generate_bit_reverse_method(&vec_array_fields);

let expanded = quote! {
impl #name {
#with_capacity_method
#uninitialized_method
#bit_reverse_method
}
};

proc_macro::TokenStream::from(expanded)
}

#[derive(Clone)]
struct VecArrayField {
name: syn::Ident,
array_length: usize,
}

fn assert_is_struct(input: &DeriveInput) {
if !matches!(input.data, Data::Struct(_)) {
panic!("Derive(SubComponentInputs) can only be applied to structs.");
}
}

fn extract_vec_array_fields(input: &DeriveInput) -> Vec<VecArrayField> {
let mut vec_array_fields = Vec::new();

if let Data::Struct(data_struct) = &input.data {
if let Fields::Named(fields) = &data_struct.fields {
for field in &fields.named {
// Field is an array of Vecs.
if let Type::Array(type_array) = &field.ty {
if let Type::Path(element_type) = &*type_array.elem {
// Element is a Vec.
if is_vec_type(&element_type.path) {
// Get the array length
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(length_lit),
..
}) = type_array.len.clone()
{
vec_array_fields.push(VecArrayField {
name: field.ident.clone().unwrap(),
array_length: length_lit.base10_parse().unwrap(),
});
}
}
}
}
}
}
}

vec_array_fields
}

fn is_vec_type(path: &Path) -> bool {
path.segments.len() == 1 && path.segments.first().unwrap().ident == "Vec"
}

fn generate_with_capacity_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let field_initializations = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;
let array_length = field.array_length;

quote! {
#field_name: [(); #array_length].map(|_| Vec::with_capacity(capacity))
}
})
.collect_vec();

quote! {
fn with_capacity(capacity: usize) -> Self {
Self {
#(#field_initializations),*
}
}
}
}

fn generate_uninitialized_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let field_initializations = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;
let array_length = field.array_length;

quote! {
#field_name: [(); #array_length].map(|_|{
let mut v = Vec::with_capacity(capacity);
v.set_len(capacity);
v
})
}
})
.collect_vec();

quote! {
unsafe fn uninitialized(capacity: usize) -> Self {
let mut result = Self {
#(#field_initializations),*
};

result
}
}
}

fn generate_bit_reverse_method(vec_array_fields: &[VecArrayField]) -> TokenStream {
let field_updates = vec_array_fields
.iter()
.map(|field| {
let field_name = &field.name;

quote! {
self.#field_name
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
})
.collect_vec();

quote! {
fn bit_reverse_coset_to_circle_domain_order(&mut self) {
#(#field_updates)*
}
}
}
1 change: 1 addition & 0 deletions stwo_cairo_prover/crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
parallel = ["rayon"]

[dependencies]
air_structs_derive = { path = "../air_structs_derive" }
bytemuck.workspace = true
cairo-lang-casm.workspace = true
cairo-vm.workspace = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(unused_parens)]
#![allow(unused_imports)]
use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand Down Expand Up @@ -115,33 +116,12 @@ impl ClaimGenerator {
}
}

#[derive(SubComponentInputs)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}
impl SubComponentInputs {
#[allow(unused_variables)]
fn with_capacity(capacity: usize) -> Self {
Self {
memory_address_to_id_inputs: [Vec::with_capacity(capacity)],
memory_id_to_big_inputs: [Vec::with_capacity(capacity)],
verify_instruction_inputs: [Vec::with_capacity(capacity)],
}
}

fn bit_reverse_coset_to_circle_domain_order(&mut self) {
self.memory_address_to_id_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.memory_id_to_big_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
self.verify_instruction_inputs
.iter_mut()
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec));
}
}

#[allow(clippy::useless_conversion)]
#[allow(unused_variables)]
Expand Down
Loading
Loading