Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom ctx override for derive macro #106

Merged
merged 10 commits into from
Oct 22, 2024
81 changes: 81 additions & 0 deletions scroll_derive/examples/derive_custom_ctx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use scroll_derive::{Pread, Pwrite, SizeWith};

/// An example of using a method as the value for a ctx in a derive.
struct EndianDependent(Endian);
impl EndianDependent {
fn len(&self) -> usize {
match self.0 {
scroll::Endian::Little => 5,
scroll::Endian::Big => 6,
}
}
}

#[derive(Debug, PartialEq)]
struct VariableLengthData {
buf: Vec<u8>,
}

impl<'a> TryFromCtx<'a, usize> for VariableLengthData {
type Error = scroll::Error;

fn try_from_ctx(from: &'a [u8], ctx: usize) -> Result<(Self, usize), Self::Error> {
let offset = &mut 0;
let buf = from.gread_with::<&[u8]>(offset, ctx)?.to_owned();
Ok((Self { buf }, *offset))
}
}
impl<'a> TryIntoCtx<usize> for &'a VariableLengthData {
type Error = scroll::Error;
fn try_into_ctx(self, dst: &mut [u8], ctx: usize) -> Result<usize, Self::Error> {
let offset = &mut 0;
for i in 0..(ctx.min(self.buf.len())) {
dst.gwrite(self.buf[i], offset)?;
}
Ok(*offset)
}
}
impl SizeWith<usize> for VariableLengthData {
fn size_with(ctx: &usize) -> usize {
*ctx
}
}

#[derive(Debug, PartialEq, Pread, Pwrite, SizeWith)]
#[repr(C)]
struct Data {
id: u32,
timestamp: f64,
// You can fix the ctx regardless of what is passed in.
#[scroll(ctx = BE)]
arr: [u16; 2],
// You can use arbitrary expressions for the ctx.
// You have access to the `ctx` parameter of the `{pread/gread}_with` inside the expression.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so wild hah, i have to think about usecases/what it means, or if there's some kind of hygiene issue that a different design might not have, but it's very interesting! is the fact that ctx is available within the scope of the pread something you have made available or is it a property of derive with field overrides in macros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the user's expression is effectively copy-pasted without hygiene and so the parameters of the trait method are accessible. It would take more work to remove this ability, since we'd have to parse for an instance of a ctx ident and actively block it.

Copy link
Contributor Author

@Easyoakland Easyoakland Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also means that if only Pwrite derive is used (for example), they won't get an error using self or dst idents.

// TODO(implement) you have access to previous fields.
// TODO(check) will this break structs with fields named `ctx`?.
#[scroll(ctx = EndianDependent(ctx.clone()).len())]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last thing I'll say: i think the example is actually better if EndianDependent is the ctx for VariableLength data, and inside of it's impl, it calls len() to get the read/write size, instead of the ctx == read/write size and relying on len() being called/passed in here.

That is clearer and more idiomatic scroll, imho, if there is such a thing :D

However, i understand that you are showing off arbitrary expressions are callable in the ctx here, which is very cool and probably worth retaining as you've written it.

Side note: I'm surprised you have to ctx.clone() ? Endian is Copy and I'd expect to not need that clone there? I also recall that effectively Ctx's generally need to be copy to be usable within pread/gread but maybe I forgot and that bound was relaxed?

Side Side note: if the ctx does need to be cloned, and you don't clone, how terrible is the error message presented to the user (at derive time)? I imagine it could be pretty bad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error[E0308]: mismatched types
  --> examples\derive_custom_ctx.rs:56:36
   |
56 |     #[scroll(ctx = EndianDependent(ctx).len())]
   |                    --------------- ^^^ expected `Endian`, found `&Endian`
   |                    |
   |                    arguments to this struct are incorrect
   |
note: tuple struct defined here
  --> examples\derive_custom_ctx.rs:4:8
   |
4  | struct EndianDependent(Endian);
   |        ^^^^^^^^^^^^^^^
help: consider dereferencing the borrow
   |
56 |     #[scroll(ctx = EndianDependent(*ctx).len())]

But if you follow the directions:

error[E0614]: type `Endian` cannot be dereferenced
  --> examples\derive_custom_ctx.rs:56:36
   |
56 |     #[scroll(ctx = EndianDependent(*ctx).len())]
   |                                    ^^^^

The clone works because Endian.clone() and (&Endian).clone both become Endian.
The issue is that the SizeWith derive makes this code:

fn size_with(ctx: &::scroll::Endian)

but the try_{from,into}_ctx derives make this code:

fn try_into_ctx(
        self,
        dst: &mut [u8],
        ctx: ::scroll::Endian,
    ) ->

and

fn try_from_ctx(
        src: &'a [u8],
        ctx: ::scroll::Endian,
    ) ->

Yeah, its a little hacky :)

If size_with was by value then EndianDependant(ctx) would work.

custom_ctx: VariableLengthData,
}

use scroll::{
ctx::{SizeWith, TryFromCtx, TryIntoCtx},
Endian, Pread, Pwrite, BE, LE,
};

fn main() {
let bytes = [
0xefu8, 0xbe, 0xad, 0xde, 0, 0, 0, 0, 0, 0, 224, 63, 0xad, 0xde, 0xef, 0xbe, 0xaa, 0xbb,
0xcc, 0xdd, 0xee,
];
let data: Data = bytes.pread_with(0, LE).unwrap();
println!("data: {data:?}");
assert_eq!(data.id, 0xdeadbeefu32);
assert_eq!(data.arr, [0xadde, 0xefbe]);
let mut bytes2 = vec![0; ::std::mem::size_of::<Data>()];
bytes2.pwrite_with(data, 0, LE).unwrap();
let data: Data = bytes.pread_with(0, LE).unwrap();
let data2: Data = bytes2.pread_with(0, LE).unwrap();
assert_eq!(data, data2);
// Not enough bytes because of ctx dependent length being too long.
assert!(bytes.pread_with::<Data>(0, BE).is_err())
}
116 changes: 93 additions & 23 deletions scroll_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

extern crate proc_macro;
use proc_macro2;
use quote::quote;
use quote::{quote, ToTokens};

use proc_macro::TokenStream;

fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match *ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -15,20 +21,63 @@ fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, ctx)?; __tmp }
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, #ctx)?; __tmp }
}
}
_ => panic!("Pread derive with bad array constexpr"),
},
syn::Type::Group(ref group) => impl_field(ident, &group.elem),
syn::Type::Group(ref group) => impl_field(ident, &group.elem, custom_ctx),
_ => {
quote! {
#ident: src.gread_with::<#ty>(offset, ctx)?
#ident: src.gread_with::<#ty>(offset, #ctx)?
}
}
}
}

/// Retrieve the field attribute with given ident e.g:
/// ```ignore
/// #[attr_ident(..)]
/// field: T,
/// ```
fn get_attr<'a>(attr_ident: &str, field: &'a syn::Field) -> Option<&'a syn::Attribute> {
field
.attrs
.iter()
.filter(|attr| attr.path().is_ident(attr_ident))
.next()
}

/// Gets the `TokenStream` for the custom ctx set in the `ctx` attribute. e.g. `expr` in the following
/// ```ignore
/// #[scroll(ctx = expr)]
/// field: T,
/// ```
fn custom_ctx(field: &syn::Field) -> Option<proc_macro2::TokenStream> {
get_attr("scroll", field).and_then(|x| {
// parsed #[scroll..]
// `expr` is `None` if the `ctx` key is not used.
let mut expr = None;
let res = x.parse_nested_meta(|meta| {
// parsed #[scroll(..)]
if meta.path.is_ident("ctx") {
// parsed #[scroll(ctx..)]
let value = meta.value()?; // parsed #[scroll(ctx = ..)]
expr = Some(value.parse::<syn::Expr>()?.into_token_stream()); // parsed #[scroll(ctx = expr)]
return Ok(());
}
Err(meta.error(match meta.path.get_ident() {
Some(ident) => format!("unrecognized attribute: {ident}"),
None => "unrecognized and invalid attribute".to_owned(),
}))
});
match res {
Ok(()) => expr,
Err(e) => Some(e.into_compile_error()),
}
})
}

fn impl_struct(
name: &syn::Ident,
fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
Expand All @@ -43,7 +92,9 @@ fn impl_struct(
quote! {#t}
});
let ty = &f.ty;
impl_field(ident, ty)
// parse the `expr` out of #[scroll(ctx = expr)]
let custom_ctx = custom_ctx(f);
impl_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -104,14 +155,20 @@ fn impl_try_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pread)]
#[proc_macro_derive(Pread, attributes(scroll))]
pub fn derive_pread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_try_from_ctx(&ast);
gen.into()
}

fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_pwrite_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -121,24 +178,24 @@ fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_m
let size = int.base10_parse::<usize>().unwrap();
quote! {
for i in 0..#size {
dst.gwrite_with(&self.#ident[i], offset, ctx)?;
dst.gwrite_with(&self.#ident[i], offset, #ctx)?;
}
}
}
_ => panic!("Pwrite derive with bad array constexpr"),
},
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem),
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem, custom_ctx),
syn::Type::Reference(reference) => match *reference.elem {
syn::Type::Slice(_) => quote! {
dst.gwrite_with(self.#ident, offset, ())?
},
_ => quote! {
dst.gwrite_with(self.#ident, offset, ctx)?
dst.gwrite_with(self.#ident, offset, #ctx)?
},
},
_ => {
quote! {
dst.gwrite_with(&self.#ident, offset, ctx)?
dst.gwrite_with(&self.#ident, offset, #ctx)?
}
}
}
Expand All @@ -158,7 +215,8 @@ fn impl_try_into_ctx(
quote! {#t}
});
let ty = &f.ty;
impl_pwrite_field(ident, ty)
let custom_ctx = custom_ctx(f);
impl_pwrite_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -249,7 +307,7 @@ fn impl_pwrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pwrite)]
#[proc_macro_derive(Pwrite, attributes(scroll))]
pub fn derive_pwrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_pwrite(&ast);
Expand All @@ -265,6 +323,10 @@ fn size_with(
.iter()
.map(|f| {
let ty = &f.ty;
let custom_ctx = custom_ctx(f).map(|x| quote! {&#x});
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let elem = &array.elem;
Expand All @@ -275,15 +337,15 @@ fn size_with(
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
(#size * <#elem>::size_with(ctx))
(#size * <#elem>::size_with(#ctx))
}
}
_ => panic!("Pread derive with bad array constexpr"),
}
}
_ => {
quote! {
<#ty>::size_with(ctx)
<#ty>::size_with(#ctx)
}
}
}
Expand Down Expand Up @@ -341,7 +403,7 @@ fn impl_size_with(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(SizeWith)]
#[proc_macro_derive(SizeWith, attributes(scroll))]
pub fn derive_sizewith(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_size_with(&ast);
Expand All @@ -356,6 +418,10 @@ fn impl_cread_struct(
let items: Vec<_> = fields.iter().enumerate().map(|(i, f)| {
let ident = &f.ident.as_ref().map(|i|quote!{#i}).unwrap_or({let t = proc_macro2::Literal::usize_unsuffixed(i); quote!{#t}});
let ty = &f.ty;
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
Expand All @@ -367,7 +433,7 @@ fn impl_cread_struct(
#ident: {
let mut __tmp: #ty = [0u8.into(); #size];
for i in 0..__tmp.len() {
__tmp[i] = src.cread_with(*offset, ctx);
__tmp[i] = src.cread_with(*offset, #ctx);
*offset += #incr;
}
__tmp
Expand All @@ -380,7 +446,7 @@ fn impl_cread_struct(
_ => {
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
quote! {
#ident: { let res = src.cread_with::<#ty>(*offset, ctx); *offset += #size; res }
#ident: { let res = src.cread_with::<#ty>(*offset, #ctx); *offset += #size; res }
}
}
}
Expand Down Expand Up @@ -440,7 +506,7 @@ fn impl_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOread)]
#[proc_macro_derive(IOread, attributes(scroll))]
pub fn derive_ioread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_from_ctx(&ast);
Expand All @@ -462,20 +528,24 @@ fn impl_into_ctx(
});
let ty = &f.ty;
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
quote! {
let size = ::scroll::export::mem::size_of::<#arrty>();
for i in 0..self.#ident.len() {
dst.cwrite_with(self.#ident[i], *offset, ctx);
dst.cwrite_with(self.#ident[i], *offset, #ctx);
*offset += size;
}
}
}
_ => {
quote! {
dst.cwrite_with(self.#ident, *offset, ctx);
dst.cwrite_with(self.#ident, *offset, #ctx);
*offset += #size;
}
}
Expand Down Expand Up @@ -544,7 +614,7 @@ fn impl_iowrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOwrite)]
#[proc_macro_derive(IOwrite, attributes(scroll))]
pub fn derive_iowrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_iowrite(&ast);
Expand Down
Loading
Loading