-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4de4d1d
commit 5b7c782
Showing
38 changed files
with
4,187 additions
and
4,124 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[package] | ||
name = "air_structs_derive" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[lib] | ||
proc-macro = true | ||
|
||
[dependencies] | ||
syn = "2.0.90" | ||
quote = "1.0.37" | ||
proc-macro2 = "1.0.92" | ||
itertools = "0.13.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
use itertools::Itertools; | ||
use proc_macro2::TokenStream; | ||
use quote::quote; | ||
use syn::{parse_macro_input, Data, DeriveInput, Fields, Path, Type}; | ||
|
||
#[proc_macro_derive(SubComponentInputs)] | ||
pub fn derive_sub_component_inputs(input: proc_macro::TokenStream) -> proc_macro::TokenStream { | ||
let input = parse_macro_input!(input as DeriveInput); | ||
assert_is_struct(&input); | ||
let name = &input.ident; | ||
let vec_array_fields = extract_vec_array_fields(&input); | ||
|
||
// TODO(Ohad): deprecate with_capacity. | ||
let with_capacity_method = generate_with_capacity_method(&vec_array_fields); | ||
let uninitialized_method = generate_uninitialized_method(&vec_array_fields); | ||
let bit_reverse_method = generate_bit_reverse_method(&vec_array_fields); | ||
|
||
let expanded = quote! { | ||
impl #name { | ||
#with_capacity_method | ||
#uninitialized_method | ||
#bit_reverse_method | ||
} | ||
}; | ||
|
||
proc_macro::TokenStream::from(expanded) | ||
} | ||
|
||
#[derive(Clone)] | ||
struct VecArrayField { | ||
name: syn::Ident, | ||
array_length: usize, | ||
} | ||
|
||
fn assert_is_struct(input: &DeriveInput) { | ||
if !matches!(input.data, Data::Struct(_)) { | ||
panic!("SubComponentInputs can only be derived for structs"); | ||
} | ||
} | ||
|
||
fn extract_vec_array_fields(input: &DeriveInput) -> Vec<VecArrayField> { | ||
let mut vec_array_fields = Vec::new(); | ||
|
||
if let Data::Struct(data_struct) = &input.data { | ||
if let Fields::Named(fields) = &data_struct.fields { | ||
for field in &fields.named { | ||
// Field is an array of Vecs. | ||
if let Type::Array(type_array) = &field.ty { | ||
if let Type::Path(element_type) = &*type_array.elem { | ||
// Element is a Vec. | ||
if is_vec_type(&element_type.path) { | ||
// Get the array length | ||
if let syn::Expr::Lit(syn::ExprLit { | ||
lit: syn::Lit::Int(length_lit), | ||
.. | ||
}) = type_array.len.clone() | ||
{ | ||
vec_array_fields.push(VecArrayField { | ||
name: field.ident.clone().unwrap(), | ||
array_length: length_lit.base10_parse().unwrap(), | ||
}); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
vec_array_fields | ||
} | ||
|
||
fn is_vec_type(path: &Path) -> bool { | ||
path.segments.len() == 1 && path.segments.first().unwrap().ident == "Vec" | ||
} | ||
|
||
fn generate_with_capacity_method(vec_array_fields: &[VecArrayField]) -> TokenStream { | ||
let field_initializations = vec_array_fields | ||
.iter() | ||
.map(|field| { | ||
let field_name = &field.name; | ||
let array_length = field.array_length; | ||
|
||
quote! { | ||
#field_name: [(); #array_length].map(|_| Vec::with_capacity(capacity)) | ||
} | ||
}) | ||
.collect_vec(); | ||
|
||
quote! { | ||
fn with_capacity(capacity: usize) -> Self { | ||
Self { | ||
#(#field_initializations),* | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn generate_uninitialized_method(vec_array_fields: &[VecArrayField]) -> TokenStream { | ||
let (field_initializations, field_len_updates): (Vec<_>, Vec<_>) = vec_array_fields | ||
.iter() | ||
.map(|field| { | ||
let field_name = &field.name; | ||
let array_length = field.array_length; | ||
|
||
( | ||
quote! { | ||
#field_name: [(); #array_length].map(|_| Vec::with_capacity(capacity)) | ||
}, | ||
quote! { | ||
result.#field_name | ||
.iter_mut() | ||
.for_each(|v| v.set_len(capacity)); | ||
}, | ||
) | ||
}) | ||
.unzip(); | ||
|
||
quote! { | ||
unsafe fn uninitialized(capacity: usize) -> Self { | ||
let mut result = Self { | ||
#(#field_initializations),* | ||
}; | ||
|
||
#(#field_len_updates)* | ||
|
||
result | ||
} | ||
} | ||
} | ||
|
||
fn generate_bit_reverse_method(vec_array_fields: &[VecArrayField]) -> TokenStream { | ||
let field_updates = vec_array_fields | ||
.iter() | ||
.map(|field| { | ||
let field_name = &field.name; | ||
|
||
quote! { | ||
self.#field_name | ||
.iter_mut() | ||
.for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); | ||
} | ||
}) | ||
.collect_vec(); | ||
|
||
quote! { | ||
fn bit_reverse_coset_to_circle_domain_order(&mut self) { | ||
#(#field_updates)* | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.