Skip to content

Commit

Permalink
Merge pull request #77 from ZettaScaleLabs/expand-enum-support
Browse files Browse the repository at this point in the history
Support multiple and named fields in `#[repr(C, u*)]` enums
  • Loading branch information
p-avital authored Jun 23, 2024
2 parents 31dc42a + 70a20d9 commit 3569122
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 65 deletions.
152 changes: 112 additions & 40 deletions stabby-macros/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,33 @@ impl syn::parse::Parse for Repr {
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FullRepr {
repr: Option<Repr>,
is_c: bool,
}
impl syn::parse::Parse for FullRepr {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut this = FullRepr {
repr: None,
is_c: false,
};
while !input.is_empty() {
match input.parse()? {
Repr::C => this.is_c = true,
repr => match this.repr {
None => this.repr = Some(repr),
Some(r) if repr == r => {}
_ => return Err(input.error("Determinants may only have one representation. You can use `#[repr(C, u8)]` to use a u8 as determinant while ensuring all variants have their data in C layout.")),
},
}
if !input.is_empty() {
let _: syn::token::Comma = input.parse()?;
}
}
Ok(this)
}
}

pub fn stabby(
attrs: Vec<Attribute>,
Expand All @@ -63,7 +90,7 @@ pub fn stabby(
) -> TokenStream {
let st = crate::tl_mod();
let unbound_generics = &generics.params;
let mut repr = None;
let mut repr: Option<FullRepr> = None;
let repr_ident = quote::format_ident!("repr");
let mut new_attrs = Vec::with_capacity(attrs.len());
for a in attrs {
Expand All @@ -77,53 +104,91 @@ pub fn stabby(
new_attrs.push(a)
}
}
if matches!(
repr,
Some(FullRepr {
repr: Some(Repr::Stabby),
is_c: true
})
) {
panic!("#[repr(C)] and #[repr(stabby)] connot be combined")
}
if data.variants.is_empty() {
todo!("empty enums are not supported by stabby YET")
}
let mut layout = quote!(());
let DataEnum { variants, .. } = &data;
let mut has_non_empty_fields = false;
let unit = syn::parse2(quote!(())).unwrap();
let mut report = Vec::new();
let mut report = crate::Report::r#enum(ident.to_string(), 0);
for variant in variants {
match &variant.fields {
syn::Fields::Named(f) if matches!(repr, Some(FullRepr { is_c: true, .. })) => {
has_non_empty_fields = true;
let mut variant_report = crate::Report::r#struct(variant.ident.to_string(), 0);
let mut variant_layout = quote!(());
for f in &f.named {
let ty = &f.ty;
variant_layout = quote!(#st::FieldPair<#variant_layout, #ty>);
variant_report.add_field(f.ident.as_ref().unwrap().to_string(), ty);
}
variant_layout = quote!(#st::Struct<#variant_layout>);
layout = quote!(#st::Union<#layout, core::mem::ManuallyDrop<#variant_layout>>);
report.add_field(variant.ident.to_string(), variant_report);
}
syn::Fields::Named(_) => {
panic!("stabby does not support named fields in enum variants")
panic!("stabby only supports named fields in #[repr(C, u*)] enums");
}
syn::Fields::Unnamed(f) => {
assert_eq!(
f.unnamed.len(),
1,
"stabby only supports one field per enum variant"
);
has_non_empty_fields = true;
let f = f.unnamed.first().unwrap();
let ty = &f.ty;
layout = quote!(#st::Union<#layout, core::mem::ManuallyDrop<#ty>>);
report.push((variant.ident.to_string(), ty));
if f.unnamed.len() != 1 && matches!(repr, Some(FullRepr { is_c: true, .. })) {
has_non_empty_fields = true;
let mut variant_report = crate::Report::r#struct(variant.ident.to_string(), 0);
let mut variant_layout = quote!(());
for (n, f) in f.unnamed.iter().enumerate() {
let ty = &f.ty;
variant_layout = quote!(#st::FieldPair<#variant_layout, #ty>);
variant_report.add_field(n.to_string(), ty);
}
variant_layout = quote!(#st::Struct<#variant_layout>);
layout = quote!(#st::Union<#layout, core::mem::ManuallyDrop<#variant_layout>>);
report.add_field(variant.ident.to_string(), variant_report);
} else {
assert_eq!(
f.unnamed.len(),
1,
"stabby only supports multiple fields per enum variant in #[repr(C, u*)] enums"
);
has_non_empty_fields = true;
let f = f.unnamed.first().unwrap();
let ty = &f.ty;
layout = quote!(#st::Union<#layout, core::mem::ManuallyDrop<#ty>>);
report.add_field(variant.ident.to_string(), ty);
}
}
syn::Fields::Unit => {
report.push((variant.ident.to_string(), &unit));
report.add_field(variant.ident.to_string(), &unit);
}
}
}
let report = crate::report(&report);
let repr = match repr {
let reprstr = match repr
.as_ref()
.and_then(|r| if r.is_c { Some(Repr::C) } else { r.repr })
{
None | Some(Repr::Stabby) => {
if !has_non_empty_fields {
panic!("Your enum doesn't have any field with values: use #[repr(C)] or #[repr(u*)] instead")
panic!("Your enum doesn't have any field with values: use #[repr(C)] or (preferably) #[repr(u*)] instead")
}
return repr_stabby(
&new_attrs,
&vis,
&ident,
&generics,
data,
data.clone(),
report,
repr.is_none(),
);
}
Some(Repr::C) => "u8",
Some(Repr::C) => "u8", // TODO: Remove support for `#[repr(C)]` alone on the next API-breaking release
Some(Repr::U8) => "u8",
Some(Repr::U16) => "u16",
Some(Repr::U32) => "u32",
Expand All @@ -135,13 +200,24 @@ pub fn stabby(
Some(Repr::I64) => "i64",
Some(Repr::Isize) => "isize",
};
let reprid = quote::format_ident!("{}", repr);
let reprid = quote::format_ident!("{}", reprstr);
let reprattr = if repr.map_or(false, |r| r.is_c) {
quote!(#[repr(C, #reprid)])
} else {
quote!(#[repr(#reprid)])
};
layout = quote!(#st::Tuple<#reprid, #layout>);
let sident = format!("{ident}");
let (report, report_bounds) = report;
report.tyty = quote!(#st::report::TyTy::Enum(#st::str::Str::new(#reprstr)));
let report_bounds = report.bounds();
let size_bug = format!(
"{ident}'s size was mis-evaluated by stabby, this is a definitely a bug and may cause UB, please file an issue"
);
let align_bug = format!(
"{ident}'s align was mis-evaluated by stabby, this is a definitely a bug and may cause UB, please file an issue"
);
quote! {
#(#new_attrs)*
#[repr(#reprid)]
#reprattr
#vis enum #ident #generics {
#variants
}
Expand All @@ -154,14 +230,16 @@ pub fn stabby(
type Align = <#layout as #st::IStable>::Align;
type HasExactlyOneNiche = #st::B0;
type ContainsIndirections = <#layout as #st::IStable>::ContainsIndirections;
const REPORT: &'static #st::report::TypeReport = & #st::report::TypeReport {
name: #st::str::Str::new(#sident),
module: #st::str::Str::new(core::module_path!()),
fields: unsafe {#st::StableLike::new(#report)},
version: 0,
tyty: #st::report::TyTy::Enum(#st::str::Str::new(#repr)),
const REPORT: &'static #st::report::TypeReport = & #report;
const ID: u64 ={
if core::mem::size_of::<Self>() != <<Self as #st::IStable>::Size as #st::Unsigned>::USIZE {
panic!(#size_bug)
}
if core::mem::align_of::<Self>() != <<Self as #st::IStable>::Align as #st::Unsigned>::USIZE {
panic!(#align_bug)
}
#st::report::gen_id(Self::REPORT)
};
const ID: u64 = #st::report::gen_id(Self::REPORT);
}
}
}
Expand Down Expand Up @@ -263,7 +341,7 @@ pub fn repr_stabby(
ident: &Ident,
generics: &Generics,
data: DataEnum,
report: (TokenStream, TokenStream),
mut report: crate::Report,
check: bool,
) -> TokenStream {
let st = crate::tl_mod();
Expand Down Expand Up @@ -375,8 +453,8 @@ pub fn repr_stabby(
let bounds2 = generics.where_clause.as_ref().map(|c| &c.predicates);
let bounds = quote!(#bounds #bounds2);

let sident = format!("{ident}");
let (report, report_bounds) = report;
report.tyty = quote!(#st::report::TyTy::Enum(#st::str::Str::new("stabby")));
let report_bounds = report.bounds();
let enum_as_struct = quote! {
#(#attrs)*
#vis struct #ident #generics (#result) where #report_bounds #bounds;
Expand Down Expand Up @@ -438,13 +516,7 @@ pub fn repr_stabby(
type Align = <#layout as #st::IStable>::Align;
type HasExactlyOneNiche = #st::B0;
type ContainsIndirections = <#layout as #st::IStable>::ContainsIndirections;
const REPORT: &'static #st::report::TypeReport = & #st::report::TypeReport {
name: #st::str::Str::new(#sident),
module: #st::str::Str::new(core::module_path!()),
fields: unsafe {#st::StableLike::new(#report)},
version: 0,
tyty: #st::report::TyTy::Enum(#st::str::Str::new("stabby")),
};
const REPORT: &'static #st::report::TypeReport = & #report;
const ID: u64 = #st::report::gen_id(Self::REPORT);
}
#[automatically_derived]
Expand Down
108 changes: 107 additions & 1 deletion stabby-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::collections::HashSet;

use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
use quote::{quote, ToTokens};
use syn::{parse::Parser, DeriveInput, TypeParamBound};

#[allow(dead_code)]
Expand Down Expand Up @@ -253,6 +253,112 @@ pub fn gen_closures_impl(_: TokenStream) -> TokenStream {
gen_closures::gen_closures().into()
}

enum Type<'a> {
Syn(&'a syn::Type),
Report(Report<'a>),
}
impl<'a> From<&'a syn::Type> for Type<'a> {
fn from(value: &'a syn::Type) -> Self {
Self::Syn(value)
}
}
impl<'a> From<Report<'a>> for Type<'a> {
fn from(value: Report<'a>) -> Self {
Self::Report(value)
}
}
pub(crate) struct Report<'a> {
name: String,
fields: Vec<(String, Type<'a>)>,
version: u32,
pub tyty: proc_macro2::TokenStream,
}
impl<'a> Report<'a> {
pub fn r#struct(name: impl Into<String>, version: u32) -> Self {
let st = crate::tl_mod();
Self {
name: name.into(),
fields: Vec::new(),
version,
tyty: quote!(#st::report::TyTy::Struct),
}
}
pub fn r#enum(name: impl Into<String>, version: u32) -> Self {
let st = crate::tl_mod();
Self {
name: name.into(),
fields: Vec::new(),
version,
tyty: quote!(#st::report::TyTy::Struct),
}
}
pub fn add_field(&mut self, name: String, ty: impl Into<Type<'a>>) {
self.fields.push((name, ty.into()));
}
fn __bounds(
&self,
bounded_types: &mut HashSet<&'a syn::Type>,
mut report_bounds: proc_macro2::TokenStream,
st: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
for (_, ty) in self.fields.iter() {
match ty {
Type::Syn(ty) => {
if bounded_types.insert(*ty) {
report_bounds = quote!(#ty: #st::IStable, #report_bounds);
}
}
Type::Report(report) => {
report_bounds = report.__bounds(bounded_types, report_bounds, st)
}
}
}
report_bounds
}
pub fn bounds(&self) -> proc_macro2::TokenStream {
let st = crate::tl_mod();
let mut bounded_types = HashSet::new();
self.__bounds(&mut bounded_types, quote!(), &st)
}
}
impl ToTokens for Report<'_> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let st = crate::tl_mod();
let mut fields = quote!(None);
for (name, ty) in &self.fields {
fields = match ty {
Type::Syn(ty) => quote! {
Some(& #st::report::FieldReport {
name: #st::str::Str::new(#name),
ty: <#ty as #st::IStable>::REPORT,
next_field: #st::StableLike::new(#fields)
})
},
Type::Report(re) => quote! {
Some(& #st::report::FieldReport {
name: #st::str::Str::new(#name),
ty: &#re,
next_field: #st::StableLike::new(#fields)
})
},
}
}
let Self {
name,
version,
tyty,
..
} = self;
tokens.extend(quote!(#st::report::TypeReport {
name: #st::str::Str::new(#name),
module: #st::str::Str::new(core::module_path!()),
fields: unsafe{#st::StableLike::new(#fields)},
version: #version,
tyty: #tyty,
}));
}
}

pub(crate) fn report(
fields: &[(String, &syn::Type)],
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
Expand Down
Loading

0 comments on commit 3569122

Please sign in to comment.