diff --git a/borsh-derive-internal/src/lib.rs b/borsh-derive-internal/src/lib.rs index 214d1c353..ad8412ea4 100644 --- a/borsh-derive-internal/src/lib.rs +++ b/borsh-derive-internal/src/lib.rs @@ -8,6 +8,7 @@ mod enum_discriminant_map; mod enum_ser; mod struct_de; mod struct_ser; +mod tokio; mod union_de; mod union_ser; @@ -17,3 +18,10 @@ pub use struct_de::struct_de; pub use struct_ser::struct_ser; pub use union_de::union_de; pub use union_ser::union_ser; + +pub use tokio::enum_de as tokio_enum_de; +pub use tokio::enum_ser as tokio_enum_ser; +pub use tokio::struct_de as tokio_struct_de; +pub use tokio::struct_ser as tokio_struct_ser; +pub use tokio::union_de as tokio_union_de; +pub use tokio::union_ser as tokio_union_ser; diff --git a/borsh-derive-internal/src/tokio.rs b/borsh-derive-internal/src/tokio.rs new file mode 100644 index 000000000..b8b420ba7 --- /dev/null +++ b/borsh-derive-internal/src/tokio.rs @@ -0,0 +1,13 @@ +mod enum_de; +mod enum_ser; +mod struct_de; +mod struct_ser; +mod union_de; +mod union_ser; + +pub use enum_de::enum_de; +pub use enum_ser::enum_ser; +pub use struct_de::struct_de; +pub use struct_ser::struct_ser; +pub use union_de::union_de; +pub use union_ser::union_ser; diff --git a/borsh-derive-internal/src/tokio/enum_de.rs b/borsh-derive-internal/src/tokio/enum_de.rs new file mode 100644 index 000000000..448309e0c --- /dev/null +++ b/borsh-derive-internal/src/tokio/enum_de.rs @@ -0,0 +1,113 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::{Fields, Ident, ItemEnum, WhereClause}; + +use crate::{ + attribute_helpers::{contains_initialize_with, contains_skip}, + enum_discriminant_map::discriminant_map, +}; + +pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { + where_token: Default::default(), + predicates: Default::default(), + }, + Clone::clone, + ); + let init_method = contains_initialize_with(&input.attrs)?; + let mut variant_arms = TokenStream2::new(); + let discriminants = discriminant_map(&input.variants); + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let discriminant = discriminants.get(variant_ident).unwrap(); + let mut variant_header = TokenStream2::new(); + match &variant.fields { + Fields::Named(fields) => { + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + if contains_skip(&field.attrs) { + variant_header.extend(quote! { + #field_name: Default::default(), + }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::AsyncBorshDeserialize + }) + .unwrap(), + ); + + variant_header.extend(quote! { + #field_name: #cratename::AsyncBorshDeserialize::deserialize_reader(reader).await?, + }); + } + } + variant_header = quote! { { #variant_header }}; + } + Fields::Unnamed(fields) => { + for field in fields.unnamed.iter() { + if contains_skip(&field.attrs) { + variant_header.extend(quote! { Default::default(), }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::AsyncBorshDeserialize + }) + .unwrap(), + ); + + variant_header.extend( + quote! { #cratename::tokio::AsyncBorshDeserialize::deserialize_reader(reader).await?, }, + ); + } + } + variant_header = quote! { ( #variant_header )}; + } + Fields::Unit => {} + } + variant_arms.extend(quote! { + if variant_tag == #discriminant { #name::#variant_ident #variant_header } else + }); + } + + let init = if let Some(method_ident) = init_method { + quote! { + return_value.#method_ident(); + } + } else { + quote! {} + }; + + Ok(quote! { + #[async_trait::async_trait] + impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause { + async fn deserialize_reader(reader: &mut R) -> ::core::result::Result { + let tag = ::deserialize_reader(reader).await?; + ::deserialize_variant(reader, tag).await + } + } + + #[async_trait::async_trait] + impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { + async fn deserialize_variant( + reader: &mut R, + variant_tag: u8, + ) -> ::core::result::Result { + let mut return_value = + #variant_arms { + return Err(#cratename::maybestd::io::Error::new( + #cratename::maybestd::io::ErrorKind::InvalidInput, + #cratename::maybestd::format!("Unexpected variant tag: {:?}", variant_tag), + )) + }; + #init + Ok(return_value) + } + } + }) +} diff --git a/borsh-derive-internal/src/tokio/enum_ser.rs b/borsh-derive-internal/src/tokio/enum_ser.rs new file mode 100644 index 000000000..e16cbcb13 --- /dev/null +++ b/borsh-derive-internal/src/tokio/enum_ser.rs @@ -0,0 +1,112 @@ +use core::convert::TryFrom; + +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{Fields, Ident, ItemEnum, WhereClause}; + +use crate::{attribute_helpers::contains_skip, enum_discriminant_map::discriminant_map}; + +pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { + where_token: Default::default(), + predicates: Default::default(), + }, + Clone::clone, + ); + let mut variant_idx_body = TokenStream2::new(); + let mut fields_body = TokenStream2::new(); + let discriminants = discriminant_map(&input.variants); + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let mut variant_header = TokenStream2::new(); + let mut variant_body = TokenStream2::new(); + let discriminant_value = discriminants.get(variant_ident).unwrap(); + match &variant.fields { + Fields::Named(fields) => { + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + if contains_skip(&field.attrs) { + variant_header.extend(quote! { _#field_name, }); + continue; + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::ser::AsyncBorshSerialize + }) + .unwrap(), + ); + variant_header.extend(quote! { #field_name, }); + } + variant_body.extend(quote! { + #cratename::tokio::AsyncBorshSerialize::serialize(#field_name, writer).await?; + }) + } + variant_header = quote! { { #variant_header }}; + variant_idx_body.extend(quote!( + #name::#variant_ident { .. } => #discriminant_value, + )); + } + Fields::Unnamed(fields) => { + for (field_idx, field) in fields.unnamed.iter().enumerate() { + let field_idx = + u32::try_from(field_idx).expect("up to 2^32 fields are supported"); + if contains_skip(&field.attrs) { + let field_ident = + Ident::new(format!("_id{}", field_idx).as_str(), Span::call_site()); + variant_header.extend(quote! { #field_ident, }); + continue; + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::ser::AsyncBorshSerialize + }) + .unwrap(), + ); + + let field_ident = + Ident::new(format!("id{}", field_idx).as_str(), Span::call_site()); + variant_header.extend(quote! { #field_ident, }); + variant_body.extend(quote! { + #cratename::tokio::AsyncBorshSerialize::serialize(#field_ident, writer).await?; + }) + } + } + variant_header = quote! { ( #variant_header )}; + variant_idx_body.extend(quote!( + #name::#variant_ident(..) => #discriminant_value, + )); + } + Fields::Unit => { + variant_idx_body.extend(quote!( + #name::#variant_ident => #discriminant_value, + )); + } + } + fields_body.extend(quote!( + #name::#variant_ident #variant_header => { + #variant_body + } + )) + } + Ok(quote! { + #[async_trait::async_trait] + impl #impl_generics #cratename::tokio::ser::AsyncBorshSerialize for #name #ty_generics #where_clause { + async fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::maybestd::io::Error> { + let variant_idx: u8 = match self { + #variant_idx_body + }; + writer.write_all(&variant_idx.to_le_bytes()).await?; + + match self { + #fields_body + } + Ok(()) + } + } + }) +} diff --git a/borsh-derive-internal/src/tokio/lib.rs b/borsh-derive-internal/src/tokio/lib.rs new file mode 100644 index 000000000..214d1c353 --- /dev/null +++ b/borsh-derive-internal/src/tokio/lib.rs @@ -0,0 +1,19 @@ +#![recursion_limit = "128"] +// TODO: re-enable this lint when we bump msrv to 1.58 +#![allow(clippy::uninlined_format_args)] + +mod attribute_helpers; +mod enum_de; +mod enum_discriminant_map; +mod enum_ser; +mod struct_de; +mod struct_ser; +mod union_de; +mod union_ser; + +pub use enum_de::enum_de; +pub use enum_ser::enum_ser; +pub use struct_de::struct_de; +pub use struct_ser::struct_ser; +pub use union_de::union_de; +pub use union_ser::union_ser; diff --git a/borsh-derive-internal/src/tokio/struct_de.rs b/borsh-derive-internal/src/tokio/struct_de.rs new file mode 100644 index 000000000..032e3f59b --- /dev/null +++ b/borsh-derive-internal/src/tokio/struct_de.rs @@ -0,0 +1,85 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::{Fields, Ident, ItemStruct, WhereClause}; + +use crate::attribute_helpers::{contains_initialize_with, contains_skip}; + +pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { + where_token: Default::default(), + predicates: Default::default(), + }, + Clone::clone, + ); + let init_method = contains_initialize_with(&input.attrs)?; + let return_value = match &input.fields { + Fields::Named(fields) => { + let mut body = TokenStream2::new(); + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + let delta = if contains_skip(&field.attrs) { + quote! { + #field_name: Default::default(), + } + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::de::AsyncBorshDeserialize + }) + .unwrap(), + ); + + quote! { + #field_name: #cratename::tokio::de::AsyncBorshDeserialize::deserialize_reader(reader).await?, + } + }; + body.extend(delta); + } + quote! { + Self { #body } + } + } + Fields::Unnamed(fields) => { + let mut body = TokenStream2::new(); + for _ in 0..fields.unnamed.len() { + let delta = quote! { + #cratename::tokio::de::AsyncBorshDeserialize::deserialize_reader(reader).await?, + }; + body.extend(delta); + } + quote! { + Self( #body ) + } + } + Fields::Unit => { + quote! { + Self {} + } + } + }; + if let Some(method_ident) = init_method { + Ok(quote! { + #[async_trait::async_trait] + impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause { + async fn deserialize_reader(reader: &mut R) -> ::core::result::Result { + let mut return_value = #return_value; + return_value.#method_ident(); + Ok(return_value) + } + } + }) + } else { + Ok(quote! { + #[async_trait::async_trait] + impl #impl_generics #cratename::tokio::de::AsyncBorshDeserialize for #name #ty_generics #where_clause { + async fn deserialize_reader(reader: &mut R) -> ::core::result::Result { + Ok(#return_value) + } + } + }) + } +} diff --git a/borsh-derive-internal/src/tokio/struct_ser.rs b/borsh-derive-internal/src/tokio/struct_ser.rs new file mode 100644 index 000000000..042dbab0d --- /dev/null +++ b/borsh-derive-internal/src/tokio/struct_ser.rs @@ -0,0 +1,157 @@ +use core::convert::TryFrom; + +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{Fields, Ident, Index, ItemStruct, WhereClause}; + +use crate::attribute_helpers::contains_skip; + +pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { + where_token: Default::default(), + predicates: Default::default(), + }, + Clone::clone, + ); + let mut body = TokenStream2::new(); + match &input.fields { + Fields::Named(fields) => { + for field in &fields.named { + if contains_skip(&field.attrs) { + continue; + } + let field_name = field.ident.as_ref().unwrap(); + let delta = quote! { + #cratename::tokio::AsyncBorshSerialize::serialize(&self.#field_name, writer).await?; + }; + body.extend(delta); + + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::tokio::ser::AsyncBorshSerialize + }) + .unwrap(), + ); + } + } + Fields::Unnamed(fields) => { + for field_idx in 0..fields.unnamed.len() { + let field_idx = Index { + index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"), + span: Span::call_site(), + }; + let delta = quote! { + #cratename::tokio::AsyncBorshSerialize::serialize(&self.#field_idx, writer).await?; + }; + body.extend(delta); + } + } + Fields::Unit => {} + } + Ok(quote! { + #[async_trait::async_trait] + impl #impl_generics #cratename::tokio::ser::AsyncBorshSerialize for #name #ty_generics #where_clause { + async fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::maybestd::io::Error> { + #body + Ok(()) + } + } + }) +} + +// Rustfmt removes comas. +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use super::*; + + fn assert_eq(expected: TokenStream2, actual: TokenStream2) { + assert_eq!(expected.to_string(), actual.to_string()) + } + + #[test] + fn simple_struct() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct A { + x: u64, + y: String, + } + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + let expected = quote!{ + #[async_trait::async_trait] + impl borsh::tokio::ser::BorshSerialize for A + where + u64: borsh::tokio::ser::AsyncBorshSerialize, + String: borsh::tokio::ser::AsyncBorshSerialize + { + async fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), borsh::maybestd::io::Error> { + borsh::tokio::AsyncBorshSerialize::serialize(&self.x, writer).await?; + borsh::tokio::AsyncBorshSerialize::serialize(&self.y, writer).await?; + Ok(()) + } + } + }; + assert_eq(expected, actual); + } + + #[test] + fn simple_generics() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct A { + x: HashMap, + y: String, + } + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + let expected = quote!{ + #[async_trait::async_trait] + impl borsh::tokio::ser::AsyncBorshSerialize for A + where + HashMap: borsh::tokio::ser::AsyncBorshSerialize, + String: borsh::tokio::ser::AsyncBorshSerialize + { + async fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), borsh::maybestd::io::Error> { + borsh::tokio::AsyncBorshSerialize::serialize(&self.x, writer).await?; + borsh::tokio::AsyncBorshSerialize::serialize(&self.y, writer).await?; + Ok(()) + } + } + }; + assert_eq(expected, actual); + } + + #[test] + fn bound_generics() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct A where V: Value { + x: HashMap, + y: String, + } + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + let expected = quote!{ + #[async_trait::async_trait] + impl borsh::tokio::ser::AsyncBorshSerialize for A + where + V: Value, + HashMap: borsh::tokio::ser::AsyncBorshSerialize, + String: borsh::tokio::ser::AsyncBorshSerialize + { + async fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), borsh::maybestd::io::Error> { + borsh::tokio::AsyncBorshSerialize::serialize(&self.x, writer)?; + borsh::tokio::AsyncBorshSerialize::serialize(&self.y, writer)?; + Ok(()) + } + } + }; + assert_eq(expected, actual); + } +} diff --git a/borsh-derive-internal/src/tokio/union_de.rs b/borsh-derive-internal/src/tokio/union_de.rs new file mode 100644 index 000000000..768b84e13 --- /dev/null +++ b/borsh-derive-internal/src/tokio/union_de.rs @@ -0,0 +1,6 @@ +use proc_macro2::TokenStream as TokenStream2; +use syn::{Ident, ItemUnion}; + +pub fn union_de(_input: &ItemUnion, _cratename: Ident) -> syn::Result { + unimplemented!() +} diff --git a/borsh-derive-internal/src/tokio/union_ser.rs b/borsh-derive-internal/src/tokio/union_ser.rs new file mode 100644 index 000000000..a86a6dc39 --- /dev/null +++ b/borsh-derive-internal/src/tokio/union_ser.rs @@ -0,0 +1,6 @@ +use proc_macro2::TokenStream as TokenStream2; +use syn::{Ident, ItemUnion}; + +pub fn union_ser(_input: &ItemUnion, _cratename: Ident) -> syn::Result { + unimplemented!() +} diff --git a/borsh-derive/src/lib.rs b/borsh-derive/src/lib.rs index 86ae04d58..b66758828 100644 --- a/borsh-derive/src/lib.rs +++ b/borsh-derive/src/lib.rs @@ -80,3 +80,49 @@ pub fn borsh_schema(input: TokenStream) -> TokenStream { Err(err) => err.to_compile_error(), }) } + +#[proc_macro_derive(AsyncBorshSerialize, attributes(borsh_skip))] +pub fn async_borsh_serialize(input: TokenStream) -> TokenStream { + let cratename = Ident::new( + &crate_name("borsh").unwrap_or_else(|_| "borsh".to_string()), + Span::call_site(), + ); + + let res = if let Ok(input) = syn::parse::(input.clone()) { + tokio_struct_ser(&input, cratename) + } else if let Ok(input) = syn::parse::(input.clone()) { + tokio_enum_ser(&input, cratename) + } else if let Ok(input) = syn::parse::(input) { + tokio_union_ser(&input, cratename) + } else { + // Derive macros can only be defined on structs, enums, and unions. + unreachable!() + }; + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +#[proc_macro_derive(AsyncBorshDeserialize, attributes(borsh_skip, borsh_init))] +pub fn async_borsh_deserialize(input: TokenStream) -> TokenStream { + let cratename = Ident::new( + &crate_name("borsh").unwrap_or_else(|_| "borsh".to_string()), + Span::call_site(), + ); + + let res = if let Ok(input) = syn::parse::(input.clone()) { + tokio_struct_de(&input, cratename) + } else if let Ok(input) = syn::parse::(input.clone()) { + tokio_enum_de(&input, cratename) + } else if let Ok(input) = syn::parse::(input) { + tokio_union_de(&input, cratename) + } else { + // Derive macros can only be defined on structs, enums, and unions. + unreachable!() + }; + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} diff --git a/borsh/Cargo.toml b/borsh/Cargo.toml index e218e4bcb..e18b6aa98 100644 --- a/borsh/Cargo.toml +++ b/borsh/Cargo.toml @@ -26,14 +26,20 @@ hashbrown = ">=0.11,<0.14" bytes = { version = "1", optional = true } bson = { version = "2", optional = true } +# tokio feature flag +async-trait = { version = "0.1.68", optional = true } +tokio = { version = "1.28.1", default-features = false, features = [ "io-util" ], optional = true } + [dev-dependencies] bytes = "1" bson = "2" # Enable the "bytes" and "bson" features in integ tests: https://github.com/rust-lang/cargo/issues/2911#issuecomment-1464060655 borsh = { path = ".", features = ["bytes", "bson"] } +tokio = { version = "1.28.1", default-features = false, features = [ "io-util", "test-util", "macros" ] } [features] -default = ["std"] +default = ["std", "tokio", "rc"] std = [] rc = [] const-generics = [] +tokio = ["dep:tokio", "dep:async-trait"] diff --git a/borsh/src/lib.rs b/borsh/src/lib.rs index cf51b18a1..0401649a3 100644 --- a/borsh/src/lib.rs +++ b/borsh/src/lib.rs @@ -12,6 +12,10 @@ pub mod schema; pub mod schema_helpers; pub mod ser; +//TODO: Needs to handle no-std and `pub use` it here. +#[cfg(feature = "tokio")] +pub mod tokio; + pub use de::BorshDeserialize; pub use schema::BorshSchema; pub use schema_helpers::{try_from_slice_with_schema, try_to_vec_with_schema}; diff --git a/borsh/src/tokio.rs b/borsh/src/tokio.rs new file mode 100644 index 000000000..0db8d7d46 --- /dev/null +++ b/borsh/src/tokio.rs @@ -0,0 +1,8 @@ +pub mod de; +pub mod ser; + +pub use de::{AsyncBorshDeserialize, AsyncReader}; +pub use ser::{ + helpers::{to_vec, to_writer}, + AsyncBorshSerialize, AsyncWriter, +}; diff --git a/borsh/src/tokio/de/hint.rs b/borsh/src/tokio/de/hint.rs new file mode 100644 index 000000000..7af550cef --- /dev/null +++ b/borsh/src/tokio/de/hint.rs @@ -0,0 +1,15 @@ +#[inline] +pub fn cautious(hint: u32) -> usize { + let el_size = core::mem::size_of::() as u32; + core::cmp::max(core::cmp::min(hint, 4096 / el_size), 1) as usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_cautious_u8() { + assert_eq!(cautious::(10), 10); + } +} diff --git a/borsh/src/tokio/de/mod.rs b/borsh/src/tokio/de/mod.rs new file mode 100644 index 000000000..cd8166376 --- /dev/null +++ b/borsh/src/tokio/de/mod.rs @@ -0,0 +1,863 @@ +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::{ + convert::{TryFrom, TryInto}, + hash::{BuildHasher, Hash}, + mem::{forget, size_of}, +}; + +#[cfg(any(test, feature = "bytes"))] +use bytes::{BufMut, BytesMut}; +use tokio::io::{AsyncRead, AsyncReadExt}; + +use crate::maybestd::{ + borrow::{Borrow, Cow, ToOwned}, + boxed::Box, + collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}, + format, + io::{Error, ErrorKind, Read, Result}, + string::{String, ToString}, + vec, + vec::Vec, +}; + +#[cfg(feature = "rc")] +use crate::maybestd::{rc::Rc, sync::Arc}; + +mod hint; + +const ERROR_NOT_ALL_BYTES_READ: &str = "Not all bytes read"; +const ERROR_UNEXPECTED_LENGTH_OF_INPUT: &str = "Unexpected length of input"; +const ERROR_OVERFLOW_ON_MACHINE_WITH_32_BIT_ISIZE: &str = "Overflow on machine with 32 bit isize"; +const ERROR_OVERFLOW_ON_MACHINE_WITH_32_BIT_USIZE: &str = "Overflow on machine with 32 bit usize"; +const ERROR_INVALID_ZERO_VALUE: &str = "Expected a non-zero value"; + +pub trait AsyncReader: AsyncRead + Send + Unpin {} +impl AsyncReader for R {} + +/// A data-structure that can be de-serialized from binary format by NBOR. +#[async_trait::async_trait] +pub trait AsyncBorshDeserialize: Sized { + /// Deserializes this instance from a given slice of bytes. + /// Updates the buffer to point at the remaining bytes. + async fn deserialize(buf: &mut &[u8]) -> Result { + Self::deserialize_reader(&mut *buf).await + } + + async fn deserialize_reader(reader: &mut R) -> Result; + + /// Deserialize this instance from a slice of bytes. + async fn try_from_slice(v: &[u8]) -> Result { + let mut v_mut = v; + let result = Self::deserialize(&mut v_mut).await?; + if !v_mut.is_empty() { + return Err(Error::new(ErrorKind::InvalidData, ERROR_NOT_ALL_BYTES_READ)); + } + Ok(result) + } + + //async fn try_from_reader(reader: &mut R) -> Result { + // let result = Self::deserialize_reader(reader).await?; + // let mut buf = [0u8; 1]; + // match reader.read_exact(&mut buf) { + // Err(f) if f.kind() == ErrorKind::UnexpectedEof => Ok(result), + // _ => Err(Error::new(ErrorKind::InvalidData, ERROR_NOT_ALL_BYTES_READ)), + // } + //} + + #[inline] + #[doc(hidden)] + async fn vec_from_reader( + len: u32, + reader: &mut R, + ) -> Result>> { + let _ = len; + let _ = reader; + Ok(None) + } + + #[inline] + #[doc(hidden)] + async fn array_from_reader( + reader: &mut R, + ) -> Result> { + let _ = reader; + Ok(None) + } +} + +/// Additional methods offered on enums which uses `[derive(BorshDeserialize)]`. +pub trait EnumExt: AsyncBorshDeserialize { + /// Deserialises given variant of an enum from the reader. + /// + /// This may be used to perform validation or filtering based on what + /// variant is being deserialised. + /// + /// ``` + /// use borsh::BorshDeserialize; + /// use borsh::de::EnumExt as _; + /// + /// #[derive(Debug, PartialEq, Eq, BorshDeserialize)] + /// enum MyEnum { + /// Zero, + /// One(u8), + /// Many(Vec) + /// } + /// + /// #[derive(Debug, PartialEq, Eq)] + /// struct OneOrZero(MyEnum); + /// + /// impl borsh::de::BorshDeserialize for OneOrZero { + /// fn deserialize_reader( + /// reader: &mut R, + /// ) -> borsh::maybestd::io::Result { + /// use borsh::de::EnumExt; + /// let tag = u8::deserialize_reader(reader)?; + /// if tag == 2 { + /// Err(borsh::maybestd::io::Error::new( + /// borsh::maybestd::io::ErrorKind::InvalidInput, + /// "MyEnum::Many not allowed here", + /// )) + /// } else { + /// MyEnum::deserialize_variant(reader, tag).map(Self) + /// } + /// } + /// } + /// + /// let data = b"\0"; + /// assert_eq!(MyEnum::Zero, MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert_eq!(MyEnum::Zero, OneOrZero::try_from_slice(&data[..]).unwrap().0); + /// + /// let data = b"\x02\0\0\0\0"; + /// assert_eq!(MyEnum::Many(Vec::new()), MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert!(OneOrZero::try_from_slice(&data[..]).is_err()); + /// ``` + fn deserialize_variant(reader: &mut R, tag: u8) -> Result; +} + +fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error { + if e.kind() == ErrorKind::UnexpectedEof { + Error::new(ErrorKind::InvalidInput, ERROR_UNEXPECTED_LENGTH_OF_INPUT) + } else { + e + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for u8 { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + reader + .read_u8() + .await + .map_err(unexpected_eof_to_unexpected_length_of_input) + } + + #[inline] + #[doc(hidden)] + async fn vec_from_reader( + len: u32, + reader: &mut R, + ) -> Result>> { + let len: usize = len.try_into().map_err(|_| ErrorKind::InvalidInput)?; + // Avoid OOM by limiting the size of allocation. This makes the read + // less efficient (since we need to loop and reallocate) but it protects + // us from someone sending us [0xff, 0xff, 0xff, 0xff] and forcing us to + // allocate 4GiB of memory. + let mut vec = vec![0u8; len.min(1024 * 1024)]; + let mut pos = 0; + while pos < len { + if pos == vec.len() { + vec.resize(vec.len().saturating_mul(2).min(len), 0) + } + // TODO(mina86): Convert this to read_buf once that stabilises. + match reader.read(&mut vec.as_mut_slice()[pos..]).await? { + 0 => { + return Err(Error::new( + ErrorKind::InvalidInput, + ERROR_UNEXPECTED_LENGTH_OF_INPUT, + )) + } + read => { + pos += read; + } + } + } + Ok(Some(vec)) + } + + #[inline] + #[doc(hidden)] + async fn array_from_reader( + reader: &mut R, + ) -> Result> { + let mut arr = [0u8; N]; + reader + .read_exact(&mut arr) + .await + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(Some(arr)) + } +} + +macro_rules! impl_for_integer { + ($type: ident) => { + #[async_trait::async_trait] + impl AsyncBorshDeserialize for $type { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::<$type>()]; + reader + .read_exact(&mut buf) + .await + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + let res = $type::from_le_bytes(buf.try_into().unwrap()); + Ok(res) + } + } + }; +} + +impl_for_integer!(i8); +impl_for_integer!(i16); +impl_for_integer!(i32); +impl_for_integer!(i64); +impl_for_integer!(i128); +impl_for_integer!(u16); +impl_for_integer!(u32); +impl_for_integer!(u64); +impl_for_integer!(u128); + +macro_rules! impl_for_nonzero_integer { + ($type: ty) => { + #[async_trait::async_trait] + impl AsyncBorshDeserialize for $type { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + <$type>::new(AsyncBorshDeserialize::deserialize_reader(reader).await?) + .ok_or_else(|| Error::new(ErrorKind::InvalidData, ERROR_INVALID_ZERO_VALUE)) + } + } + }; +} + +impl_for_nonzero_integer!(core::num::NonZeroI8); +impl_for_nonzero_integer!(core::num::NonZeroI16); +impl_for_nonzero_integer!(core::num::NonZeroI32); +impl_for_nonzero_integer!(core::num::NonZeroI64); +impl_for_nonzero_integer!(core::num::NonZeroI128); +impl_for_nonzero_integer!(core::num::NonZeroU8); +impl_for_nonzero_integer!(core::num::NonZeroU16); +impl_for_nonzero_integer!(core::num::NonZeroU32); +impl_for_nonzero_integer!(core::num::NonZeroU64); +impl_for_nonzero_integer!(core::num::NonZeroU128); +impl_for_nonzero_integer!(core::num::NonZeroUsize); + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for isize { + async fn deserialize_reader(reader: &mut R) -> Result { + let i: i64 = AsyncBorshDeserialize::deserialize_reader(reader).await?; + let i = isize::try_from(i).map_err(|_| { + Error::new( + ErrorKind::InvalidInput, + ERROR_OVERFLOW_ON_MACHINE_WITH_32_BIT_ISIZE, + ) + })?; + Ok(i) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for usize { + async fn deserialize_reader(reader: &mut R) -> Result { + let u: u64 = AsyncBorshDeserialize::deserialize_reader(reader).await?; + let u = usize::try_from(u).map_err(|_| { + Error::new( + ErrorKind::InvalidInput, + ERROR_OVERFLOW_ON_MACHINE_WITH_32_BIT_USIZE, + ) + })?; + Ok(u) + } +} + +// Note NaNs have a portability issue. Specifically, signalling NaNs on MIPS are quiet NaNs on x86, +// and vice-versa. We disallow NaNs to avoid this issue. +macro_rules! impl_for_float { + ($type: ident, $int_type: ident) => { + #[async_trait::async_trait] + impl AsyncBorshDeserialize for $type { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::<$type>()]; + reader + .read_exact(&mut buf) + .await + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + let res = $type::from_bits($int_type::from_le_bytes(buf.try_into().unwrap())); + if res.is_nan() { + return Err(Error::new( + ErrorKind::InvalidInput, + "For portability reasons we do not allow to deserialize NaNs.", + )); + } + Ok(res) + } + } + }; +} + +impl_for_float!(f32, u32); +impl_for_float!(f64, u64); + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for bool { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let b: u8 = AsyncBorshDeserialize::deserialize_reader(reader).await?; + if b == 0 { + Ok(false) + } else if b == 1 { + Ok(true) + } else { + let msg = format!("Invalid bool representation: {}", b); + + Err(Error::new(ErrorKind::InvalidInput, msg)) + } + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for core::ops::Range +where + T: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + Ok(Self { + start: T::deserialize_reader(reader).await?, + end: T::deserialize_reader(reader).await?, + }) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Option +where + T: AsyncBorshDeserialize, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let flag: u8 = AsyncBorshDeserialize::deserialize_reader(reader).await?; + if flag == 0 { + Ok(None) + } else if flag == 1 { + Ok(Some(T::deserialize_reader(reader).await?)) + } else { + let msg = format!( + "Invalid Option representation: {}. The first byte must be 0 or 1", + flag + ); + + Err(Error::new(ErrorKind::InvalidInput, msg)) + } + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for core::result::Result +where + T: AsyncBorshDeserialize, + E: AsyncBorshDeserialize, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let flag: u8 = AsyncBorshDeserialize::deserialize_reader(reader).await?; + if flag == 0 { + Ok(Err(E::deserialize_reader(reader).await?)) + } else if flag == 1 { + Ok(Ok(T::deserialize_reader(reader).await?)) + } else { + let msg = format!( + "Invalid Result representation: {}. The first byte must be 0 or 1", + flag + ); + + Err(Error::new(ErrorKind::InvalidInput, msg)) + } + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for String { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + String::from_utf8(Vec::::deserialize_reader(reader).await?).map_err(|err| { + let msg = err.to_string(); + Error::new(ErrorKind::InvalidData, msg) + }) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Vec +where + T: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader).await?; + if len == 0 { + Ok(Vec::new()) + } else if let Some(vec_bytes) = T::vec_from_reader(len, reader).await? { + Ok(vec_bytes) + } else if size_of::() == 0 { + let mut result = vec![T::deserialize_reader(reader).await?]; + + let p = result.as_mut_ptr(); + unsafe { + forget(result); + let len = len.try_into().map_err(|_| ErrorKind::InvalidInput)?; + let result = Vec::from_raw_parts(p, len, len); + Ok(result) + } + } else { + // TODO(16): return capacity allocation when we can safely do that. + let mut result = Vec::with_capacity(hint::cautious::(len)); + for _ in 0..len { + result.push(T::deserialize_reader(reader).await?); + } + Ok(result) + } + } +} + +#[cfg(any(test, feature = "bytes"))] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for bytes::Bytes { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into()) + } +} + +#[cfg(any(test, feature = "bytes"))] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for bytes::BytesMut { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader).await?; + let mut out = BytesMut::with_capacity(hint::cautious::(len)); + for _ in 0..len { + out.put_u8(u8::deserialize_reader(reader).await?); + } + Ok(out) + } +} + +#[cfg(any(test, feature = "bson"))] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for bson::oid::ObjectId { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 12]; + reader.read_exact(&mut buf).await?; + Ok(bson::oid::ObjectId::from_bytes(buf)) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Cow<'_, T> +where + T: ToOwned + ?Sized, + T::Owned: AsyncBorshDeserialize, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + Ok(Cow::Owned( + AsyncBorshDeserialize::deserialize_reader(reader).await?, + )) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for VecDeque +where + T: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for LinkedList +where + T: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into_iter().collect::>()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for BinaryHeap +where + T: AsyncBorshDeserialize + Ord + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into_iter().collect::>()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for HashSet +where + T: AsyncBorshDeserialize + Eq + Hash + Send, + H: BuildHasher + Default, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into_iter().collect::>()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for HashMap +where + K: AsyncBorshDeserialize + Eq + Hash + Send, + V: AsyncBorshDeserialize + Send, + H: BuildHasher + Default + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader).await?; + // TODO(16): return capacity allocation when we can safely do that. + let mut result = HashMap::with_hasher(H::default()); + for _ in 0..len { + let key = K::deserialize_reader(reader).await?; + let value = V::deserialize_reader(reader).await?; + result.insert(key, value); + } + Ok(result) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for BTreeSet +where + T: AsyncBorshDeserialize + Ord + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader).await?; + Ok(vec.into_iter().collect::>()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for BTreeMap +where + K: AsyncBorshDeserialize + Ord + core::hash::Hash + Send, + V: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader).await?; + let mut result = BTreeMap::new(); + for _ in 0..len { + let key = K::deserialize_reader(reader).await?; + let value = V::deserialize_reader(reader).await?; + result.insert(key, value); + } + Ok(result) + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for std::net::SocketAddr { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let kind = u8::deserialize_reader(reader).await?; + match kind { + 0 => std::net::SocketAddrV4::deserialize_reader(reader) + .await + .map(std::net::SocketAddr::V4), + 1 => std::net::SocketAddrV6::deserialize_reader(reader) + .await + .map(std::net::SocketAddr::V6), + value => Err(Error::new( + ErrorKind::InvalidInput, + format!("Invalid SocketAddr variant: {}", value), + )), + } + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for std::net::SocketAddrV4 { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let ip = std::net::Ipv4Addr::deserialize_reader(reader).await?; + let port = u16::deserialize_reader(reader).await?; + Ok(std::net::SocketAddrV4::new(ip, port)) + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for std::net::SocketAddrV6 { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let ip = std::net::Ipv6Addr::deserialize_reader(reader).await?; + let port = u16::deserialize_reader(reader).await?; + Ok(std::net::SocketAddrV6::new(ip, port, 0, 0)) + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for std::net::Ipv4Addr { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 4]; + reader + .read_exact(&mut buf) + .await + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(std::net::Ipv4Addr::from(buf)) + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for std::net::Ipv6Addr { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 16]; + reader + .read_exact(&mut buf) + .await + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(std::net::Ipv6Addr::from(buf)) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Box +where + U: Into> + Borrow, + T: ToOwned + ?Sized, + T::Owned: AsyncBorshDeserialize, +{ + async fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader).await?.into()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for [T; N] +where + T: AsyncBorshDeserialize + Send, +{ + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + struct ArrayDropGuard<'r, T: AsyncBorshDeserialize, const N: usize, R: AsyncReader> { + buffer: [MaybeUninit; N], + init_count: usize, + reader: &'r mut R, + } + impl<'r, T: AsyncBorshDeserialize, const N: usize, R: AsyncReader> Drop + for ArrayDropGuard<'r, T, N, R> + { + fn drop(&mut self) { + let init_range = &mut self.buffer[..self.init_count]; + // SAFETY: Elements up to self.init_count have been initialized. Assumes this value + // is only incremented in `fill_buffer`, which writes the element before + // increasing the init_count. + unsafe { + core::ptr::drop_in_place(init_range as *mut _ as *mut [T]); + }; + } + } + + impl<'r, T: AsyncBorshDeserialize, const N: usize, R: AsyncReader> ArrayDropGuard<'r, T, N, R> { + unsafe fn transmute_to_array(mut self) -> [T; N] { + debug_assert_eq!(self.init_count, N); + // Set init_count to 0 so that the values do not get dropped twice. + self.init_count = 0; + // SAFETY: This cast is required because `mem::transmute` does not work with + // const generics https://github.com/rust-lang/rust/issues/61956. This + // array is guaranteed to be initialized by this point. + core::ptr::read(&self.buffer as *const _ as *const [T; N]) + } + async fn fill_buffer(&mut self) -> Result<()> { + // TODO: replace with `core::array::try_from_fn` when stabilized to avoid manually + // dropping uninitialized values through the guard drop. + for elem in self.buffer.iter_mut() { + elem.write(T::deserialize_reader(self.reader).await?); + self.init_count += 1; + } + Ok(()) + } + } + + if let Some(arr) = T::array_from_reader(reader).await? { + Ok(arr) + } else { + let mut result = ArrayDropGuard { + buffer: unsafe { MaybeUninit::uninit().assume_init() }, + init_count: 0, + reader, + }; + + result.fill_buffer().await?; + + // SAFETY: The elements up to `i` have been initialized in `fill_buffer`. + Ok(unsafe { result.transmute_to_array() }) + } + } +} + +#[cfg(test)] +#[tokio::test] +async fn array_deserialization_doesnt_leak() { + use core::sync::atomic::{AtomicUsize, Ordering}; + + static DESERIALIZE_COUNT: AtomicUsize = AtomicUsize::new(0); + static DROP_COUNT: AtomicUsize = AtomicUsize::new(0); + + struct MyType(u8); + #[async_trait::async_trait] + impl AsyncBorshDeserialize for MyType { + async fn deserialize_reader(reader: &mut R) -> Result { + let val = u8::deserialize_reader(reader).await?; + let v = DESERIALIZE_COUNT.fetch_add(1, Ordering::SeqCst); + if v >= 7 { + panic!("panic in deserialize"); + } + Ok(MyType(val)) + } + } + impl Drop for MyType { + fn drop(&mut self) { + DROP_COUNT.fetch_add(1, Ordering::SeqCst); + } + } + + assert!( + <[MyType; 5] as AsyncBorshDeserialize>::deserialize(&mut &[0u8; 3][..]) + .await + .is_err() + ); + assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 3); + assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3); + + assert!( + <[MyType; 2] as AsyncBorshDeserialize>::deserialize(&mut &[0u8; 2][..]) + .await + .is_ok() + ); + assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 5); + assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 5); + + /* TODO: Find a way to catch panic from async functions + #[cfg(feature = "std")] + { + // Test that during a panic in deserialize, the values are still dropped. + let result = std::panic::catch_unwind(|| { + <[MyType; 3] as AsyncBorshDeserialize>::deserialize(&mut &[0u8; 3][..]) + .await + .unwrap(); + }); + assert!(result.is_err()); + assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 8); + assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 7); // 5 because 6 panicked and was not init + } */ +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for () { + async fn deserialize_reader(_reader: &mut R) -> Result { + Ok(()) + } +} + +macro_rules! impl_tuple { + ($($name:ident)+) => { + #[async_trait::async_trait] + impl<$($name),+> AsyncBorshDeserialize for ($($name,)+) + where $($name: AsyncBorshDeserialize + Send,)+ + { + #[inline] + async fn deserialize_reader(reader: &mut R) -> Result { + + Ok(($($name::deserialize_reader(reader).await?,)+)) + } + } + }; +} + +impl_tuple!(T0); +impl_tuple!(T0 T1); +impl_tuple!(T0 T1 T2); +impl_tuple!(T0 T1 T2 T3); +impl_tuple!(T0 T1 T2 T3 T4); +impl_tuple!(T0 T1 T2 T3 T4 T5); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18); +impl_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19); + +#[cfg(feature = "rc")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Rc +where + U: Into> + Borrow, + T: ToOwned + ?Sized, + T::Owned: AsyncBorshDeserialize, +{ + async fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader).await?.into()) + } +} + +#[cfg(feature = "rc")] +#[async_trait::async_trait] +impl AsyncBorshDeserialize for Arc +where + U: Into> + Borrow, + T: ToOwned + ?Sized, + T::Owned: AsyncBorshDeserialize, +{ + async fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader).await?.into()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshDeserialize for PhantomData { + async fn deserialize_reader(_: &mut R) -> Result { + Ok(Self::default()) + } +} diff --git a/borsh/src/tokio/ser/helpers.rs b/borsh/src/tokio/ser/helpers.rs new file mode 100644 index 000000000..cbe218ca4 --- /dev/null +++ b/borsh/src/tokio/ser/helpers.rs @@ -0,0 +1,18 @@ +use super::{AsyncBorshSerialize, AsyncWriter}; +use crate::maybestd::{io::Result, vec::Vec}; + +/// Serialize an object into a vector of bytes. +pub async fn to_vec(value: &T) -> Result> +where + T: AsyncBorshSerialize + ?Sized + Sync, +{ + value.try_to_vec().await +} + +/// Serializes an object directly into a `Writer`. +pub async fn to_writer(mut writer: W, value: &T) -> Result<()> +where + T: AsyncBorshSerialize + ?Sized + Sync, +{ + value.serialize(&mut writer).await +} diff --git a/borsh/src/tokio/ser/mod.rs b/borsh/src/tokio/ser/mod.rs new file mode 100644 index 000000000..bdb504a13 --- /dev/null +++ b/borsh/src/tokio/ser/mod.rs @@ -0,0 +1,614 @@ +use core::convert::TryFrom; +use core::hash::BuildHasher; +use core::marker::PhantomData; + +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::maybestd::{ + borrow::{Cow, ToOwned}, + boxed::Box, + collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}, + io::{ErrorKind, Result}, + string::String, + vec::Vec, +}; + +#[cfg(feature = "rc")] +use crate::maybestd::sync::Arc; + +pub(crate) mod helpers; + +const DEFAULT_SERIALIZER_CAPACITY: usize = 1024; + +pub trait AsyncWriter: AsyncWrite + Send + Unpin {} +impl AsyncWriter for W {} + +/// A data-structure that can be serialized into binary format by NBOR. +/// +/// ``` +/// use borsh::BorshSerialize; +/// +/// #[derive(BorshSerialize)] +/// struct MyBorshSerializableStruct { +/// value: String, +/// } +/// +/// let x = MyBorshSerializableStruct { value: "hello".to_owned() }; +/// let mut buffer: Vec = Vec::new(); +/// x.serialize(&mut buffer).unwrap(); +/// let single_serialized_buffer_len = buffer.len(); +/// +/// x.serialize(&mut buffer).unwrap(); +/// assert_eq!(buffer.len(), single_serialized_buffer_len * 2); +/// +/// let mut buffer: Vec = vec![0; 1024 + single_serialized_buffer_len]; +/// let mut buffer_slice_enough_for_the_data = &mut buffer[1024..1024 + single_serialized_buffer_len]; +/// x.serialize(&mut buffer_slice_enough_for_the_data).unwrap(); +/// ``` +#[async_trait::async_trait] +pub trait AsyncBorshSerialize { + async fn serialize(&self, writer: &mut W) -> Result<()>; + + /// Serialize this instance into a vector of bytes. + async fn try_to_vec(&self) -> Result> { + let mut result = Vec::with_capacity(DEFAULT_SERIALIZER_CAPACITY); + self.serialize(&mut result).await?; + Ok(result) + } + + #[inline] + #[doc(hidden)] + fn u8_slice(slice: &[Self]) -> Option<&[u8]> + where + Self: Sized, + { + let _ = slice; + None + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for u8 { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer.write_all(core::slice::from_ref(self)).await + } + + #[inline] + fn u8_slice(slice: &[Self]) -> Option<&[u8]> { + Some(slice) + } +} + +macro_rules! impl_for_integer { + ($type: ident) => { + #[async_trait::async_trait] + impl AsyncBorshSerialize for $type { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + let bytes = self.to_le_bytes(); + writer.write_all(&bytes).await + } + } + }; +} + +impl_for_integer!(i8); +impl_for_integer!(i16); +impl_for_integer!(i32); +impl_for_integer!(i64); +impl_for_integer!(i128); +impl_for_integer!(u16); +impl_for_integer!(u32); +impl_for_integer!(u64); +impl_for_integer!(u128); + +macro_rules! impl_for_nonzero_integer { + ($type: ty) => { + #[async_trait::async_trait] + impl AsyncBorshSerialize for $type { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + AsyncBorshSerialize::serialize(&self.get(), writer).await + } + } + }; +} + +impl_for_nonzero_integer!(core::num::NonZeroI8); +impl_for_nonzero_integer!(core::num::NonZeroI16); +impl_for_nonzero_integer!(core::num::NonZeroI32); +impl_for_nonzero_integer!(core::num::NonZeroI64); +impl_for_nonzero_integer!(core::num::NonZeroI128); +impl_for_nonzero_integer!(core::num::NonZeroU8); +impl_for_nonzero_integer!(core::num::NonZeroU16); +impl_for_nonzero_integer!(core::num::NonZeroU32); +impl_for_nonzero_integer!(core::num::NonZeroU64); +impl_for_nonzero_integer!(core::num::NonZeroU128); +impl_for_nonzero_integer!(core::num::NonZeroUsize); + +#[async_trait::async_trait] +impl AsyncBorshSerialize for isize { + async fn serialize(&self, writer: &mut W) -> Result<()> { + AsyncBorshSerialize::serialize(&(*self as i64), writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for usize { + async fn serialize(&self, writer: &mut W) -> Result<()> { + AsyncBorshSerialize::serialize(&(*self as u64), writer).await + } +} + +// Note NaNs have a portability issue. Specifically, signalling NaNs on MIPS are quiet NaNs on x86, +// and vice-versa. We disallow NaNs to avoid this issue. +macro_rules! impl_for_float { + ($type: ident) => { + #[async_trait::async_trait] + impl AsyncBorshSerialize for $type { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + assert!( + !self.is_nan(), + "For portability reasons we do not allow to serialize NaNs." + ); + writer.write_all(&self.to_bits().to_le_bytes()).await + } + } + }; +} + +impl_for_float!(f32); +impl_for_float!(f64); + +#[async_trait::async_trait] +impl AsyncBorshSerialize for bool { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + (u8::from(*self)).serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for core::ops::Range +where + T: AsyncBorshSerialize + Send + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.start.serialize(writer).await?; + self.end.serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for Option +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + match self { + None => 0u8.serialize(writer).await, + Some(value) => { + 1u8.serialize(writer).await?; + value.serialize(writer).await + } + } + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for core::result::Result +where + T: AsyncBorshSerialize + Sync, + E: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + match self { + Err(e) => { + 0u8.serialize(writer).await?; + e.serialize(writer).await + } + Ok(v) => { + 1u8.serialize(writer).await?; + v.serialize(writer).await + } + } + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for str { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_bytes().serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for String { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_bytes().serialize(writer).await + } +} + +/// Helper method that is used to serialize a slice of data (without the length marker). +#[inline] +async fn serialize_slice( + data: &[T], + writer: &mut W, +) -> Result<()> { + if let Some(u8_slice) = T::u8_slice(data) { + writer.write_all(u8_slice).await?; + } else { + for item in data { + item.serialize(writer).await?; + } + } + Ok(()) +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for [T] +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer + .write_all( + &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(), + ) + .await?; + serialize_slice(self, writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for &T { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + (*self).serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for Cow<'_, T> +where + T: AsyncBorshSerialize + ToOwned + ?Sized + Sync, + ::Owned: Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_ref().serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for Vec +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_slice().serialize(writer).await + } +} + +#[cfg(any(test, feature = "bytes"))] +#[async_trait::async_trait] +impl AsyncBorshSerialize for bytes::Bytes { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_ref().serialize(writer).await + } +} + +#[cfg(any(test, feature = "bytes"))] +#[async_trait::async_trait] +impl AsyncBorshSerialize for bytes::BytesMut { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_ref().serialize(writer).await + } +} + +#[cfg(any(test, feature = "bson"))] +#[async_trait::async_trait] +impl AsyncBorshSerialize for bson::oid::ObjectId { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.bytes().serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for VecDeque +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer + .write_all( + &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(), + ) + .await?; + let slices = self.as_slices(); + serialize_slice(slices.0, writer).await?; + serialize_slice(slices.1, writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for LinkedList +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer + .write_all( + &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(), + ) + .await?; + for item in self { + item.serialize(writer).await?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for BinaryHeap +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + // It could have been just `self.as_slice().serialize(writer)`, but there is no + // `as_slice()` method: + // https://internals.rust-lang.org/t/should-i-add-as-slice-method-to-binaryheap/13816 + writer + .write_all( + &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(), + ) + .await?; + for item in self { + item.serialize(writer).await?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for HashMap +where + K: AsyncBorshSerialize + PartialOrd + Sync, + V: AsyncBorshSerialize + Sync, + H: BuildHasher + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + let mut vec = self.iter().collect::>(); + vec.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()); + u32::try_from(vec.len()) + .map_err(|_| ErrorKind::InvalidInput)? + .serialize(writer) + .await?; + for (key, value) in vec { + key.serialize(writer).await?; + value.serialize(writer).await?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for HashSet +where + T: AsyncBorshSerialize + PartialOrd + Sync, + H: BuildHasher + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + let mut vec = self.iter().collect::>(); + vec.sort_by(|a, b| a.partial_cmp(b).unwrap()); + u32::try_from(vec.len()) + .map_err(|_| ErrorKind::InvalidInput)? + .serialize(writer) + .await?; + for item in vec { + item.serialize(writer).await?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for BTreeMap +where + K: AsyncBorshSerialize + Sync, + V: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + // NOTE: BTreeMap iterates over the entries that are sorted by key, so the serialization + // result will be consistent without a need to sort the entries as we do for HashMap + // serialization. + u32::try_from(self.len()) + .map_err(|_| ErrorKind::InvalidInput)? + .serialize(writer) + .await?; + for (key, value) in self { + key.serialize(writer).await?; + value.serialize(writer).await?; + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for BTreeSet +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + // NOTE: BTreeSet iterates over the items that are sorted, so the serialization result will + // be consistent without a need to sort the entries as we do for HashSet serialization. + u32::try_from(self.len()) + .map_err(|_| ErrorKind::InvalidInput)? + .serialize(writer) + .await?; + for item in self { + item.serialize(writer).await?; + } + Ok(()) + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for std::net::SocketAddr { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + match *self { + std::net::SocketAddr::V4(ref addr) => { + 0u8.serialize(writer).await?; + addr.serialize(writer).await + } + std::net::SocketAddr::V6(ref addr) => { + 1u8.serialize(writer).await?; + addr.serialize(writer).await + } + } + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for std::net::SocketAddrV4 { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.ip().serialize(writer).await?; + self.port().serialize(writer).await + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for std::net::SocketAddrV6 { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.ip().serialize(writer).await?; + self.port().serialize(writer).await + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for std::net::Ipv4Addr { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer.write_all(&self.octets()).await + } +} + +#[cfg(feature = "std")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for std::net::Ipv6Addr { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + writer.write_all(&self.octets()).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for Box { + async fn serialize(&self, writer: &mut W) -> Result<()> { + self.as_ref().serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for [T; N] +where + T: AsyncBorshSerialize + Sync, +{ + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + if N == 0 { + return Ok(()); + } else if let Some(u8_slice) = T::u8_slice(self) { + writer.write_all(u8_slice).await?; + } else { + for el in self.iter() { + el.serialize(writer).await?; + } + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for () { + async fn serialize(&self, _writer: &mut W) -> Result<()> { + Ok(()) + } +} + +macro_rules! impl_tuple { + ($($idx:tt $name:ident)+) => { + #[async_trait::async_trait] + impl<$($name),+> AsyncBorshSerialize for ($($name,)+) + where $($name: AsyncBorshSerialize + Sync + Send,)+ + { + #[inline] + async fn serialize(&self, writer: &mut W) -> Result<()> { + $(self.$idx.serialize(writer).await?;)+ + Ok(()) + } + } + }; +} + +impl_tuple!(0 T0); +impl_tuple!(0 T0 1 T1); +impl_tuple!(0 T0 1 T1 2 T2); +impl_tuple!(0 T0 1 T1 2 T2 3 T3); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18); +impl_tuple!(0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19); + +#[cfg(feature = "rc")] +#[async_trait::async_trait] +impl AsyncBorshSerialize for Arc { + async fn serialize(&self, writer: &mut W) -> Result<()> { + (**self).serialize(writer).await + } +} + +#[async_trait::async_trait] +impl AsyncBorshSerialize for PhantomData { + async fn serialize(&self, _: &mut W) -> Result<()> { + Ok(()) + } +} diff --git a/borsh/tests/tokio_smoke.rs b/borsh/tests/tokio_smoke.rs new file mode 100644 index 000000000..95b94f3b9 --- /dev/null +++ b/borsh/tests/tokio_smoke.rs @@ -0,0 +1,21 @@ +// Smoke tests that ensure that we don't accidentally remove top-level +// re-exports in a minor release. + +use borsh::tokio::{to_vec, to_writer, AsyncBorshDeserialize}; + +#[tokio::test] +async fn test_to_vec() { + let value = 42u8; + let serialized = to_vec(&value).await.unwrap(); + let deserialized = u8::try_from_slice(&serialized).await.unwrap(); + assert_eq!(value, deserialized); +} + +#[tokio::test] +async fn test_to_writer() { + let value = 42u8; + let mut serialized = vec![0; 1]; + to_writer(&mut serialized, &value).await.unwrap(); + let deserialized = u8::try_from_slice(&serialized).await.unwrap(); + assert_eq!(value, deserialized); +} diff --git a/borsh/tests/tokio_test_arrays.rs b/borsh/tests/tokio_test_arrays.rs new file mode 100644 index 000000000..08f70020d --- /dev/null +++ b/borsh/tests/tokio_test_arrays.rs @@ -0,0 +1,72 @@ +#![allow(clippy::float_cmp)] + +use borsh::tokio::{AsyncBorshDeserialize, AsyncBorshSerialize}; +use borsh_derive::{AsyncBorshDeserialize, AsyncBorshSerialize}; + +macro_rules! test_array { + ($v: expr, $t: ty, $len: expr) => { + let buf = $v.try_to_vec().await.unwrap(); + let actual_v: [$t; $len] = AsyncBorshDeserialize::try_from_slice(&buf) + .await + .expect("failed to deserialize"); + assert_eq!($v.len(), actual_v.len()); + #[allow(clippy::reversed_empty_ranges)] + for i in 0..$len { + assert_eq!($v[i], actual_v[i]); + } + }; +} + +macro_rules! test_arrays { + ($test_name: ident, $el: expr, $t: ty) => { + #[tokio::test] + async fn $test_name() { + test_array!([$el; 0], $t, 0); + test_array!([$el; 1], $t, 1); + test_array!([$el; 2], $t, 2); + test_array!([$el; 3], $t, 3); + test_array!([$el; 4], $t, 4); + test_array!([$el; 8], $t, 8); + test_array!([$el; 16], $t, 16); + test_array!([$el; 32], $t, 32); + test_array!([$el; 64], $t, 64); + test_array!([$el; 65], $t, 65); + } + }; +} + +test_arrays!(test_array_u8, 100u8, u8); +test_arrays!(test_array_i8, 100i8, i8); +test_arrays!(test_array_u32, 1000000000u32, u32); +test_arrays!(test_array_u64, 1000000000000000000u64, u64); +test_arrays!( + test_array_u128, + 1000000000000000000000000000000000000u128, + u128 +); +test_arrays!(test_array_f32, 1000000000.0f32, f32); +test_arrays!(test_array_array_u8, [100u8; 32], [u8; 32]); +test_arrays!(test_array_zst, (), ()); + +#[derive(AsyncBorshDeserialize, AsyncBorshSerialize, PartialEq, Debug)] +struct CustomStruct(u8); + +#[tokio::test] +async fn test_custom_struct_array() { + let arr = [CustomStruct(0), CustomStruct(1), CustomStruct(2)]; + let serialized = arr.try_to_vec().await.unwrap(); + let deserialized: [CustomStruct; 3] = AsyncBorshDeserialize::try_from_slice(&serialized) + .await + .unwrap(); + assert_eq!(arr, deserialized); +} + +#[tokio::test] +async fn test_string_array() { + let arr = ["0".to_string(), "1".to_string(), "2".to_string()]; + let serialized = arr.try_to_vec().await.unwrap(); + let deserialized: [String; 3] = AsyncBorshDeserialize::try_from_slice(&serialized) + .await + .unwrap(); + assert_eq!(arr, deserialized); +} diff --git a/borsh/tests/tokio_test_binary_heaps.rs b/borsh/tests/tokio_test_binary_heaps.rs new file mode 100644 index 000000000..5ec6cd418 --- /dev/null +++ b/borsh/tests/tokio_test_binary_heaps.rs @@ -0,0 +1,30 @@ +use borsh::maybestd::collections::BinaryHeap; +use borsh::{BorshDeserialize, BorshSerialize}; + +macro_rules! test_binary_heap { + ($v: expr, $t: ty) => { + let buf = $v.try_to_vec().unwrap(); + let actual_v: BinaryHeap<$t> = + BorshDeserialize::try_from_slice(&buf).expect("failed to deserialize"); + assert_eq!(actual_v.into_vec(), $v.into_vec()); + }; +} + +macro_rules! test_binary_heaps { + ($test_name: ident, $el: expr, $t: ty) => { + #[test] + fn $test_name() { + test_binary_heap!(BinaryHeap::<$t>::new(), $t); + test_binary_heap!(vec![$el].into_iter().collect::>(), $t); + test_binary_heap!(vec![$el; 10].into_iter().collect::>(), $t); + test_binary_heap!(vec![$el; 100].into_iter().collect::>(), $t); + test_binary_heap!(vec![$el; 1000].into_iter().collect::>(), $t); + test_binary_heap!(vec![$el; 10000].into_iter().collect::>(), $t); + } + }; +} + +test_binary_heaps!(test_binary_heap_u8, 100u8, u8); +test_binary_heaps!(test_binary_heap_i8, 100i8, i8); +test_binary_heaps!(test_binary_heap_u32, 1000000000u32, u32); +test_binary_heaps!(test_binary_heap_string, "a".to_string(), String); diff --git a/borsh/tests/tokio_test_bson_object_ids.rs b/borsh/tests/tokio_test_bson_object_ids.rs new file mode 100644 index 000000000..479532332 --- /dev/null +++ b/borsh/tests/tokio_test_bson_object_ids.rs @@ -0,0 +1,19 @@ +#![allow(clippy::float_cmp)] + +use borsh::{BorshDeserialize, BorshSerialize}; +use bson::oid::ObjectId; + +#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug)] +struct StructWithObjectId(i32, ObjectId, u8); + +#[test] +fn test_object_id() { + let obj = StructWithObjectId( + 123, + ObjectId::from_bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), + 33, + ); + let serialized = obj.try_to_vec().unwrap(); + let deserialized: StructWithObjectId = BorshDeserialize::try_from_slice(&serialized).unwrap(); + assert_eq!(obj, deserialized); +} diff --git a/borsh/tests/tokio_test_custom_reader.rs b/borsh/tests/tokio_test_custom_reader.rs new file mode 100644 index 000000000..0a0b2e075 --- /dev/null +++ b/borsh/tests/tokio_test_custom_reader.rs @@ -0,0 +1,140 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +const ERROR_NOT_ALL_BYTES_READ: &str = "Not all bytes read"; +const ERROR_UNEXPECTED_LENGTH_OF_INPUT: &str = "Unexpected length of input"; + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +struct Serializable { + item1: i32, + item2: String, + item3: f64, +} + +#[test] +fn test_custom_reader() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let bytes = s.try_to_vec().unwrap(); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + let de: Serializable = BorshDeserialize::deserialize_reader(&mut reader).unwrap(); + assert_eq!(de.item1, s.item1); + assert_eq!(de.item2, s.item2); + assert_eq!(de.item3, s.item3); +} + +#[test] +fn test_custom_reader_with_insufficient_data() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let mut bytes = s.try_to_vec().unwrap(); + bytes.pop().unwrap(); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + assert_eq!( + ::deserialize_reader(&mut reader) + .unwrap_err() + .to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_custom_reader_with_too_much_data() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let mut bytes = s.try_to_vec().unwrap(); + bytes.push(1); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + assert_eq!( + ::try_from_reader(&mut reader) + .unwrap_err() + .to_string(), + ERROR_NOT_ALL_BYTES_READ + ); +} + +struct CustomReader { + data: Vec, + read_index: usize, +} + +impl borsh::maybestd::io::Read for CustomReader { + fn read(&mut self, buf: &mut [u8]) -> borsh::maybestd::io::Result { + let len = buf.len().min(self.data.len() - self.read_index); + buf[0..len].copy_from_slice(&self.data[self.read_index..self.read_index + len]); + self.read_index += len; + Ok(len) + } +} + +#[test] +fn test_custom_reader_that_doesnt_fill_slices() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let bytes = s.try_to_vec().unwrap(); + let mut reader = CustomReaderThatDoesntFillSlices { + data: bytes, + read_index: 0, + }; + let de: Serializable = BorshDeserialize::deserialize_reader(&mut reader).unwrap(); + assert_eq!(de.item1, s.item1); + assert_eq!(de.item2, s.item2); + assert_eq!(de.item3, s.item3); +} + +struct CustomReaderThatDoesntFillSlices { + data: Vec, + read_index: usize, +} + +impl borsh::maybestd::io::Read for CustomReaderThatDoesntFillSlices { + fn read(&mut self, buf: &mut [u8]) -> borsh::maybestd::io::Result { + let len = buf.len().min(self.data.len() - self.read_index); + let len = if len <= 1 { len } else { len / 2 }; + buf[0..len].copy_from_slice(&self.data[self.read_index..self.read_index + len]); + self.read_index += len; + Ok(len) + } +} + +#[test] +fn test_custom_reader_that_fails_preserves_error_information() { + let mut reader = CustomReaderThatFails; + let err = ::try_from_reader(&mut reader).unwrap_err(); + assert_eq!(err.to_string(), "I don't like to run"); + assert_eq!( + err.kind(), + borsh::maybestd::io::ErrorKind::ConnectionAborted + ); +} + +struct CustomReaderThatFails; + +impl borsh::maybestd::io::Read for CustomReaderThatFails { + fn read(&mut self, _buf: &mut [u8]) -> borsh::maybestd::io::Result { + Err(borsh::maybestd::io::Error::new( + borsh::maybestd::io::ErrorKind::ConnectionAborted, + "I don't like to run", + )) + } +} diff --git a/borsh/tests/tokio_test_de_errors.rs b/borsh/tests/tokio_test_de_errors.rs new file mode 100644 index 000000000..e4baae3b4 --- /dev/null +++ b/borsh/tests/tokio_test_de_errors.rs @@ -0,0 +1,202 @@ +use borsh::BorshDeserialize; + +#[derive(BorshDeserialize, Debug)] +enum A { + X, + Y, +} + +#[derive(BorshDeserialize, Debug)] +struct B { + #[allow(unused)] + x: u64, + #[allow(unused)] + y: u32, +} + +const ERROR_UNEXPECTED_LENGTH_OF_INPUT: &str = "Unexpected length of input"; +const ERROR_INVALID_ZERO_VALUE: &str = "Expected a non-zero value"; + +#[test] +fn test_missing_bytes() { + let bytes = vec![1, 0]; + assert_eq!( + B::try_from_slice(&bytes).unwrap_err().to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_invalid_enum_variant() { + let bytes = vec![123]; + assert_eq!( + A::try_from_slice(&bytes).unwrap_err().to_string(), + "Unexpected variant tag: 123" + ); +} + +#[test] +fn test_extra_bytes() { + let bytes = vec![1, 0, 0, 0, 32, 32]; + assert_eq!( + >::try_from_slice(&bytes).unwrap_err().to_string(), + "Not all bytes read" + ); +} + +#[test] +fn test_invalid_bool() { + for i in 2u8..=255 { + let bytes = [i]; + assert_eq!( + ::try_from_slice(&bytes).unwrap_err().to_string(), + format!("Invalid bool representation: {}", i) + ); + } +} + +#[test] +fn test_invalid_option() { + for i in 2u8..=255 { + let bytes = [i, 32]; + assert_eq!( + >::try_from_slice(&bytes) + .unwrap_err() + .to_string(), + format!( + "Invalid Option representation: {}. The first byte must be 0 or 1", + i + ) + ); + } +} + +#[test] +fn test_invalid_result() { + for i in 2u8..=255 { + let bytes = [i, 0]; + assert_eq!( + >::try_from_slice(&bytes) + .unwrap_err() + .to_string(), + format!( + "Invalid Result representation: {}. The first byte must be 0 or 1", + i + ) + ); + } +} + +#[test] +fn test_invalid_length() { + let bytes = vec![255u8; 4]; + assert_eq!( + >::try_from_slice(&bytes).unwrap_err().to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_invalid_length_string() { + let bytes = vec![255u8; 4]; + assert_eq!( + String::try_from_slice(&bytes).unwrap_err().to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_non_utf_string() { + let bytes = vec![1, 0, 0, 0, 0xC0]; + assert_eq!( + String::try_from_slice(&bytes).unwrap_err().to_string(), + "invalid utf-8 sequence of 1 bytes from index 0" + ); +} + +#[test] +fn test_nan_float() { + let bytes = vec![0, 0, 192, 127]; + assert_eq!( + f32::try_from_slice(&bytes).unwrap_err().to_string(), + "For portability reasons we do not allow to deserialize NaNs." + ); +} + +#[test] +fn test_evil_bytes_vec_with_extra() { + // Should fail to allocate given length + // test takes a really long time if read() is used instead of read_exact() + let bytes = vec![255, 255, 255, 255, 32, 32]; + assert_eq!( + >::try_from_slice(&bytes) + .unwrap_err() + .to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_evil_bytes_string_extra() { + // Might fail if reading too much + let bytes = vec![255, 255, 255, 255, 32, 32]; + assert_eq!( + String::try_from_slice(&bytes).unwrap_err().to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_zero_on_nonzero_integer_u8() { + let bytes = &[0]; + assert_eq!( + std::num::NonZeroU8::try_from_slice(bytes) + .unwrap_err() + .to_string(), + ERROR_INVALID_ZERO_VALUE + ); +} + +#[test] +fn test_zero_on_nonzero_integer_u32() { + let bytes = &[0; 4]; + assert_eq!( + std::num::NonZeroU32::try_from_slice(bytes) + .unwrap_err() + .to_string(), + ERROR_INVALID_ZERO_VALUE + ); +} + +#[test] +fn test_zero_on_nonzero_integer_i64() { + let bytes = &[0; 8]; + assert_eq!( + std::num::NonZeroI64::try_from_slice(bytes) + .unwrap_err() + .to_string(), + ERROR_INVALID_ZERO_VALUE + ); +} + +#[test] +fn test_zero_on_nonzero_integer_usize() { + let bytes = &[0; 8]; + assert_eq!( + std::num::NonZeroUsize::try_from_slice(bytes) + .unwrap_err() + .to_string(), + ERROR_INVALID_ZERO_VALUE + ); +} + +#[test] +fn test_zero_on_nonzero_integer_missing_byte() { + let bytes = &[0; 7]; + assert_eq!( + std::num::NonZeroUsize::try_from_slice(bytes) + .unwrap_err() + .to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} diff --git a/borsh/tests/tokio_test_generic_struct.rs b/borsh/tests/tokio_test_generic_struct.rs new file mode 100644 index 000000000..f48056e1d --- /dev/null +++ b/borsh/tests/tokio_test_generic_struct.rs @@ -0,0 +1,34 @@ +use core::marker::PhantomData; + +use borsh::{BorshDeserialize, BorshSerialize}; + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +struct A { + x: Vec, + y: String, + b: B, + pd: PhantomData, + c: std::result::Result, + d: [u64; 5], +} + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +enum B { + X { f: Vec }, + Y(G), +} + +#[test] +fn test_generic_struct() { + let a = A:: { + x: vec!["foo".to_string(), "bar".to_string()], + pd: Default::default(), + y: "world".to_string(), + b: B::X { f: vec![1, 2] }, + c: Err("error".to_string()), + d: [0, 1, 2, 3, 4], + }; + let data = a.try_to_vec().unwrap(); + let actual_a = A::::try_from_slice(&data).unwrap(); + assert_eq!(a, actual_a); +} diff --git a/borsh/tests/tokio_test_hash_map.rs b/borsh/tests/tokio_test_hash_map.rs new file mode 100644 index 000000000..a2e4f4f78 --- /dev/null +++ b/borsh/tests/tokio_test_hash_map.rs @@ -0,0 +1,42 @@ +#[cfg(feature = "std")] +use core::hash::BuildHasher; +#[cfg(feature = "std")] +use std::collections::hash_map::{DefaultHasher, RandomState}; + +use borsh::maybestd::collections::HashMap; +use borsh::{BorshDeserialize, BorshSerialize}; + +#[test] +fn test_default_hashmap() { + let mut map = HashMap::new(); + map.insert("foo".to_string(), "bar".to_string()); + map.insert("one".to_string(), "two".to_string()); + + let data = map.try_to_vec().unwrap(); + let actual_map = HashMap::::try_from_slice(&data).unwrap(); + assert_eq!(map, actual_map); +} + +#[derive(Default)] +#[cfg(feature = "std")] +struct NewHasher(RandomState); + +#[cfg(feature = "std")] +impl BuildHasher for NewHasher { + type Hasher = DefaultHasher; + fn build_hasher(&self) -> DefaultHasher { + self.0.build_hasher() + } +} + +#[test] +#[cfg(feature = "std")] +fn test_generic_hash_hashmap() { + let mut map = HashMap::with_hasher(NewHasher::default()); + map.insert("foo".to_string(), "bar".to_string()); + map.insert("one".to_string(), "two".to_string()); + + let data = map.try_to_vec().unwrap(); + let actual_map = HashMap::::try_from_slice(&data).unwrap(); + assert_eq!(map, actual_map); +} diff --git a/borsh/tests/tokio_test_macro_namespace_collisions.rs b/borsh/tests/tokio_test_macro_namespace_collisions.rs new file mode 100644 index 000000000..f228ae5b9 --- /dev/null +++ b/borsh/tests/tokio_test_macro_namespace_collisions.rs @@ -0,0 +1,13 @@ +// Borsh macros should not collide with the local modules: +// https://github.com/near/borsh-rs/issues/11 +mod std {} +mod core {} + +#[derive(borsh::BorshSerialize, borsh::BorshDeserialize)] +struct A; + +#[derive(borsh::BorshSerialize, borsh::BorshDeserialize)] +enum B { + C, + D, +} diff --git a/borsh/tests/tokio_test_nonzero_integers.rs b/borsh/tests/tokio_test_nonzero_integers.rs new file mode 100644 index 000000000..a0b0a4692 --- /dev/null +++ b/borsh/tests/tokio_test_nonzero_integers.rs @@ -0,0 +1,32 @@ +use borsh::BorshDeserialize; +use std::num::*; + +#[test] +fn test_nonzero_integer_u8() { + let bytes = &[1]; + assert_eq!(NonZeroU8::try_from_slice(bytes).unwrap().get(), 1); +} + +#[test] +fn test_nonzero_integer_u32() { + let bytes = &[255, 0, 0, 0]; + assert_eq!(NonZeroU32::try_from_slice(bytes).unwrap().get(), 255); +} + +#[test] +fn test_nonzero_integer_usize() { + let bytes = &[1, 1, 0, 0, 0, 0, 0, 0]; + assert_eq!(NonZeroUsize::try_from_slice(bytes).unwrap().get(), 257); +} + +#[test] +fn test_nonzero_integer_i64() { + let bytes = &[255; 8]; + assert_eq!(NonZeroI64::try_from_slice(bytes).unwrap().get(), -1); +} + +#[test] +fn test_nonzero_integer_i16b() { + let bytes = &[0, 0b1000_0000]; + assert_eq!(NonZeroI16::try_from_slice(bytes).unwrap().get(), i16::MIN); +} diff --git a/borsh/tests/tokio_test_primitives.rs b/borsh/tests/tokio_test_primitives.rs new file mode 100644 index 000000000..df3dd2748 --- /dev/null +++ b/borsh/tests/tokio_test_primitives.rs @@ -0,0 +1,22 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +macro_rules! test_primitive { + ($test_name: ident, $v: expr, $t: ty) => { + #[test] + fn $test_name() { + let expected: $t = $v; + let buf = expected.try_to_vec().unwrap(); + let actual = <$t>::try_from_slice(&buf).expect("failed to deserialize"); + assert_eq!(actual, expected); + } + }; +} + +test_primitive!(test_isize_neg, -100isize, isize); +test_primitive!(test_isize_pos, 100isize, isize); +test_primitive!(test_isize_min, isize::min_value(), isize); +test_primitive!(test_isize_max, isize::max_value(), isize); + +test_primitive!(test_usize, 100usize, usize); +test_primitive!(test_usize_min, usize::min_value(), usize); +test_primitive!(test_usize_max, usize::max_value(), usize); diff --git a/borsh/tests/tokio_test_rc.rs b/borsh/tests/tokio_test_rc.rs new file mode 100644 index 000000000..38fc7e304 --- /dev/null +++ b/borsh/tests/tokio_test_rc.rs @@ -0,0 +1,21 @@ +#![cfg(feature = "rc")] + +use borsh::maybestd::rc::Rc; +use borsh::maybestd::sync::Arc; +use borsh::{BorshDeserialize, BorshSerialize}; + +#[test] +fn test_rc_roundtrip() { + let value = Rc::new(8u8); + let serialized = value.try_to_vec().unwrap(); + let deserialized = Rc::::try_from_slice(&serialized).unwrap(); + assert_eq!(value, deserialized); +} + +#[test] +fn test_arc_roundtrip() { + let value = Arc::new(8u8); + let serialized = value.try_to_vec().unwrap(); + let deserialized = Arc::::try_from_slice(&serialized).unwrap(); + assert_eq!(value, deserialized); +} diff --git a/borsh/tests/tokio_test_schema_enums.rs b/borsh/tests/tokio_test_schema_enums.rs new file mode 100644 index 000000000..933d7c80b --- /dev/null +++ b/borsh/tests/tokio_test_schema_enums.rs @@ -0,0 +1,209 @@ +#![allow(dead_code)] // Local structures do not have their fields used. +use borsh::maybestd::collections::HashMap; +use borsh::schema::*; +use borsh::schema_helpers::{try_from_slice_with_schema, try_to_vec_with_schema}; + +macro_rules! map( + () => { HashMap::new() }; + { $($key:expr => $value:expr),+ } => { + { + let mut m = HashMap::new(); + $( + m.insert($key.to_string(), $value); + )+ + m + } + }; +); + +#[test] +pub fn simple_enum() { + #[derive(borsh::BorshSchema)] + enum A { + Bacon, + Eggs, + } + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "ABacon" => Definition::Struct{ fields: Fields::Empty }, + "AEggs" => Definition::Struct{ fields: Fields::Empty }, + "A" => Definition::Enum { variants: vec![("Bacon".to_string(), "ABacon".to_string()), ("Eggs".to_string(), "AEggs".to_string())]} + }, + defs + ); +} + +#[test] +pub fn single_field_enum() { + #[derive(borsh::BorshSchema)] + enum A { + Bacon, + } + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "ABacon" => Definition::Struct {fields: Fields::Empty}, + "A" => Definition::Enum { variants: vec![("Bacon".to_string(), "ABacon".to_string())]} + }, + defs + ); +} + +#[test] +pub fn complex_enum_with_schema() { + #[derive( + borsh::BorshSchema, + Default, + borsh::BorshSerialize, + borsh::BorshDeserialize, + PartialEq, + Debug, + )] + struct Tomatoes; + #[derive( + borsh::BorshSchema, + Default, + borsh::BorshSerialize, + borsh::BorshDeserialize, + PartialEq, + Debug, + )] + struct Cucumber; + #[derive( + borsh::BorshSchema, + Default, + borsh::BorshSerialize, + borsh::BorshDeserialize, + PartialEq, + Debug, + )] + struct Oil; + #[derive( + borsh::BorshSchema, + Default, + borsh::BorshSerialize, + borsh::BorshDeserialize, + PartialEq, + Debug, + )] + struct Wrapper; + #[derive( + borsh::BorshSchema, + Default, + borsh::BorshSerialize, + borsh::BorshDeserialize, + PartialEq, + Debug, + )] + struct Filling; + #[derive( + borsh::BorshSchema, borsh::BorshSerialize, borsh::BorshDeserialize, PartialEq, Debug, + )] + enum A { + Bacon, + Eggs, + Salad(Tomatoes, Cucumber, Oil), + Sausage { wrapper: Wrapper, filling: Filling }, + } + + impl Default for A { + fn default() -> Self { + A::Sausage { + wrapper: Default::default(), + filling: Default::default(), + } + } + } + // First check schema. + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "Cucumber" => Definition::Struct {fields: Fields::Empty}, + "ASalad" => Definition::Struct{ fields: Fields::UnnamedFields(vec!["Tomatoes".to_string(), "Cucumber".to_string(), "Oil".to_string()])}, + "ABacon" => Definition::Struct {fields: Fields::Empty}, + "Oil" => Definition::Struct {fields: Fields::Empty}, + "A" => Definition::Enum{ variants: vec![ + ("Bacon".to_string(), "ABacon".to_string()), + ("Eggs".to_string(), "AEggs".to_string()), + ("Salad".to_string(), "ASalad".to_string()), + ("Sausage".to_string(), "ASausage".to_string())]}, + "Wrapper" => Definition::Struct {fields: Fields::Empty}, + "Tomatoes" => Definition::Struct {fields: Fields::Empty}, + "ASausage" => Definition::Struct { fields: Fields::NamedFields(vec![ + ("wrapper".to_string(), "Wrapper".to_string()), + ("filling".to_string(), "Filling".to_string()) + ])}, + "AEggs" => Definition::Struct {fields: Fields::Empty}, + "Filling" => Definition::Struct {fields: Fields::Empty} + }, + defs + ); + // Then check that we serialize and deserialize with schema. + let obj = A::default(); + let data = try_to_vec_with_schema(&obj).unwrap(); + let obj2: A = try_from_slice_with_schema(&data).unwrap(); + assert_eq!(obj, obj2); +} + +#[test] +pub fn complex_enum_generics() { + #[derive(borsh::BorshSchema)] + struct Tomatoes; + #[derive(borsh::BorshSchema)] + struct Cucumber; + #[derive(borsh::BorshSchema)] + struct Oil; + #[derive(borsh::BorshSchema)] + struct Wrapper; + #[derive(borsh::BorshSchema)] + struct Filling; + #[derive(borsh::BorshSchema)] + enum A { + Bacon, + Eggs, + Salad(Tomatoes, C, Oil), + Sausage { wrapper: W, filling: Filling }, + } + assert_eq!( + "A".to_string(), + >::declaration() + ); + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "Cucumber" => Definition::Struct {fields: Fields::Empty}, + "ASalad" => Definition::Struct{ + fields: Fields::UnnamedFields(vec!["Tomatoes".to_string(), "Cucumber".to_string(), "Oil".to_string()]) + }, + "ABacon" => Definition::Struct {fields: Fields::Empty}, + "Oil" => Definition::Struct {fields: Fields::Empty}, + "A" => Definition::Enum{ + variants: vec![ + ("Bacon".to_string(), "ABacon".to_string()), + ("Eggs".to_string(), "AEggs".to_string()), + ("Salad".to_string(), "ASalad".to_string()), + ("Sausage".to_string(), "ASausage".to_string()) + ] + }, + "Wrapper" => Definition::Struct {fields: Fields::Empty}, + "Tomatoes" => Definition::Struct {fields: Fields::Empty}, + "ASausage" => Definition::Struct { + fields: Fields::NamedFields(vec![ + ("wrapper".to_string(), "Wrapper".to_string()), + ("filling".to_string(), "Filling".to_string()) + ]) + }, + "AEggs" => Definition::Struct {fields: Fields::Empty}, + "Filling" => Definition::Struct {fields: Fields::Empty} + }, + defs + ); +} diff --git a/borsh/tests/tokio_test_schema_nested.rs b/borsh/tests/tokio_test_schema_nested.rs new file mode 100644 index 000000000..ea2f181a6 --- /dev/null +++ b/borsh/tests/tokio_test_schema_nested.rs @@ -0,0 +1,83 @@ +#![allow(dead_code)] // Local structures do not have their fields used. +use borsh::maybestd::collections::HashMap; +use borsh::schema::*; + +macro_rules! map( + () => { HashMap::new() }; + { $($key:expr => $value:expr),+ } => { + { + let mut m = HashMap::new(); + $( + m.insert($key.to_string(), $value); + )+ + m + } + }; +); + +// Checks that recursive definitions work. Also checks that re-instantiations of templated types work. +#[test] +pub fn duplicated_instantiations() { + #[derive(borsh::BorshSchema)] + struct Tomatoes; + #[derive(borsh::BorshSchema)] + struct Cucumber; + #[derive(borsh::BorshSchema)] + struct Oil { + seeds: HashMap, + liquid: Option, + } + #[derive(borsh::BorshSchema)] + struct Wrapper { + foo: Option, + bar: Box>, + } + #[derive(borsh::BorshSchema)] + struct Filling; + #[derive(borsh::BorshSchema)] + enum A { + Bacon, + Eggs, + Salad(Tomatoes, C, Oil), + Sausage { wrapper: W, filling: Filling }, + } + assert_eq!( + "A>".to_string(), + >>::declaration() + ); + let mut defs = Default::default(); + >>::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A>" => Definition::Enum {variants: vec![ + ("Bacon".to_string(), "ABacon>".to_string()), + ("Eggs".to_string(), "AEggs>".to_string()), + ("Salad".to_string(), "ASalad>".to_string()), + ("Sausage".to_string(), "ASausage>".to_string()) + ]}, + "A" => Definition::Enum {variants: vec![ + ("Bacon".to_string(), "ABacon".to_string()), + ("Eggs".to_string(), "AEggs".to_string()), + ("Salad".to_string(), "ASalad".to_string()), + ("Sausage".to_string(), "ASausage".to_string())]}, + "ABacon>" => Definition::Struct {fields: Fields::Empty}, + "ABacon" => Definition::Struct {fields: Fields::Empty}, + "AEggs>" => Definition::Struct {fields: Fields::Empty}, + "AEggs" => Definition::Struct {fields: Fields::Empty}, + "ASalad>" => Definition::Struct {fields: Fields::UnnamedFields(vec!["Tomatoes".to_string(), "Cucumber".to_string(), "Oil".to_string()])}, + "ASalad" => Definition::Struct { fields: Fields::UnnamedFields( vec!["Tomatoes".to_string(), "string".to_string(), "Oil".to_string() ])}, + "ASausage>" => Definition::Struct {fields: Fields::NamedFields(vec![("wrapper".to_string(), "Wrapper".to_string()), ("filling".to_string(), "Filling".to_string())])}, + "ASausage" => Definition::Struct{ fields: Fields::NamedFields(vec![("wrapper".to_string(), "string".to_string()), ("filling".to_string(), "Filling".to_string())])}, + "Cucumber" => Definition::Struct {fields: Fields::Empty}, + "Filling" => Definition::Struct {fields: Fields::Empty}, + "HashMap" => Definition::Sequence { elements: "Tuple".to_string()}, + "Oil" => Definition::Struct { fields: Fields::NamedFields(vec![("seeds".to_string(), "HashMap".to_string()), ("liquid".to_string(), "Option".to_string())])}, + "Option" => Definition::Enum {variants: vec![("None".to_string(), "nil".to_string()), ("Some".to_string(), "string".to_string())]}, + "Option" => Definition::Enum { variants: vec![("None".to_string(), "nil".to_string()), ("Some".to_string(), "u64".to_string())]}, + "Tomatoes" => Definition::Struct {fields: Fields::Empty}, + "Tuple" => Definition::Tuple {elements: vec!["u64".to_string(), "string".to_string()]}, + "Wrapper" => Definition::Struct{ fields: Fields::NamedFields(vec![("foo".to_string(), "Option".to_string()), ("bar".to_string(), "A".to_string())])} + }, + defs + ); +} diff --git a/borsh/tests/tokio_test_schema_primitives.rs b/borsh/tests/tokio_test_schema_primitives.rs new file mode 100644 index 000000000..bacaf1e64 --- /dev/null +++ b/borsh/tests/tokio_test_schema_primitives.rs @@ -0,0 +1,25 @@ +use borsh::schema::*; + +#[test] +fn isize_schema() { + let schema = isize::schema_container(); + assert_eq!( + schema, + BorshSchemaContainer { + declaration: "i64".to_string(), + definitions: Default::default() + } + ) +} + +#[test] +fn usize_schema() { + let schema = usize::schema_container(); + assert_eq!( + schema, + BorshSchemaContainer { + declaration: "u64".to_string(), + definitions: Default::default() + } + ) +} diff --git a/borsh/tests/tokio_test_schema_structs.rs b/borsh/tests/tokio_test_schema_structs.rs new file mode 100644 index 000000000..ea258563c --- /dev/null +++ b/borsh/tests/tokio_test_schema_structs.rs @@ -0,0 +1,155 @@ +use borsh::maybestd::collections::HashMap; +use borsh::schema::*; + +macro_rules! map( + () => { HashMap::new() }; + { $($key:expr => $value:expr),+ } => { + { + let mut m = HashMap::new(); + $( + m.insert($key.to_string(), $value); + )+ + m + } + }; +); + +#[test] +pub fn unit_struct() { + #[derive(borsh::BorshSchema)] + struct A; + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct {fields: Fields::Empty} + }, + defs + ); +} + +#[test] +pub fn simple_struct() { + #[derive(borsh::BorshSchema)] + struct A { + _f1: u64, + _f2: String, + } + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct{ fields: Fields::NamedFields(vec![ + ("_f1".to_string(), "u64".to_string()), + ("_f2".to_string(), "string".to_string()) + ])} + }, + defs + ); +} + +#[test] +pub fn boxed() { + #[derive(borsh::BorshSchema)] + struct A { + _f1: Box, + _f2: Box, + _f3: Box<[u8]>, + } + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "Vec" => Definition::Sequence { elements: "u8".to_string() }, + "A" => Definition::Struct{ fields: Fields::NamedFields(vec![ + ("_f1".to_string(), "u64".to_string()), + ("_f2".to_string(), "string".to_string()), + ("_f3".to_string(), "Vec".to_string()) + ])} + }, + defs + ); +} + +#[test] +pub fn wrapper_struct() { + #[derive(borsh::BorshSchema)] + struct A(T); + assert_eq!("A".to_string(), >::declaration()); + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct {fields: Fields::UnnamedFields(vec!["u64".to_string()])} + }, + defs + ); +} + +#[test] +pub fn tuple_struct() { + #[derive(borsh::BorshSchema)] + struct A(u64, String); + assert_eq!("A".to_string(), A::declaration()); + let mut defs = Default::default(); + A::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct {fields: Fields::UnnamedFields(vec![ + "u64".to_string(), "string".to_string() + ])} + }, + defs + ); +} + +#[test] +pub fn tuple_struct_params() { + #[derive(borsh::BorshSchema)] + struct A(K, V); + assert_eq!( + "A".to_string(), + >::declaration() + ); + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct { fields: Fields::UnnamedFields(vec![ + "u64".to_string(), "string".to_string() + ])} + }, + defs + ); +} + +#[test] +pub fn simple_generics() { + #[derive(borsh::BorshSchema)] + struct A { + _f1: HashMap, + _f2: String, + } + assert_eq!( + "A".to_string(), + >::declaration() + ); + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "A" => Definition::Struct { + fields: Fields::NamedFields(vec![ + ("_f1".to_string(), "HashMap".to_string()), + ("_f2".to_string(), "string".to_string()) + ]) + }, + "HashMap" => Definition::Sequence {elements: "Tuple".to_string()}, + "Tuple" => Definition::Tuple{elements: vec!["u64".to_string(), "string".to_string()]} + }, + defs + ); +} diff --git a/borsh/tests/tokio_test_schema_tuple.rs b/borsh/tests/tokio_test_schema_tuple.rs new file mode 100644 index 000000000..c5e6367a8 --- /dev/null +++ b/borsh/tests/tokio_test_schema_tuple.rs @@ -0,0 +1,28 @@ +use borsh::maybestd::collections::HashMap; +use borsh::schema::*; + +macro_rules! map( + () => { HashMap::new() }; + { $($key:expr => $value:expr),+ } => { + { + let mut m = HashMap::new(); + $( + m.insert($key.to_string(), $value); + )+ + m + } + }; +); + +#[test] +fn test_unary_tuple_schema() { + assert_eq!("Tuple", <(bool,)>::declaration()); + let mut defs = Default::default(); + <(bool,)>::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "Tuple" => Definition::Tuple { elements: vec!["bool".to_string()] } + }, + defs + ); +} diff --git a/borsh/tests/tokio_test_simple_structs.rs b/borsh/tests/tokio_test_simple_structs.rs new file mode 100644 index 000000000..428fbb599 --- /dev/null +++ b/borsh/tests/tokio_test_simple_structs.rs @@ -0,0 +1,198 @@ +use borsh::maybestd::collections::{BTreeMap, BTreeSet, HashMap, HashSet, LinkedList, VecDeque}; +use borsh::{BorshDeserialize, BorshSerialize}; +use bytes::{Bytes, BytesMut}; + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +#[borsh_init(init)] +struct A<'a> { + x: u64, + b: B, + y: f32, + z: String, + t: (String, u64), + m: HashMap, + s: HashSet, + btree_map_string: BTreeMap, + btree_set_u64: BTreeSet, + linked_list_string: LinkedList, + vec_deque_u64: VecDeque, + bytes: Bytes, + bytes_mut: BytesMut, + v: Vec, + w: Box<[u8]>, + box_str: Box, + i: [u8; 32], + u: std::result::Result, + lazy: Option, + c: std::borrow::Cow<'a, str>, + cow_arr: std::borrow::Cow<'a, [std::borrow::Cow<'a, str>]>, + range_u32: std::ops::Range, + #[borsh_skip] + skipped: Option, +} + +impl A<'_> { + pub fn init(&mut self) { + if let Some(v) = self.lazy.as_mut() { + *v *= 10; + } + } +} + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +struct B { + x: u64, + y: i32, + c: C, +} + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +enum C { + C1, + C2(u64), + C3(u64, u64), + C4 { x: u64, y: u64 }, + C5(D), +} + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +struct D { + x: u64, +} + +#[derive(BorshSerialize)] +struct E<'a, 'b> { + a: &'a A<'b>, +} + +#[derive(BorshSerialize)] +struct F1<'a, 'b> { + aa: &'a [&'a A<'b>], +} + +#[derive(BorshDeserialize)] +struct F2<'b> { + aa: Vec>, +} + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Eq, Clone, Copy, Debug)] +enum X { + A, + B = 20, + C, + D, + E = 10, + F, +} + +#[test] +fn test_discriminant_serialization() { + let values = vec![X::A, X::B, X::C, X::D, X::E, X::F]; + for value in values { + assert_eq!(value.try_to_vec().unwrap(), [value as u8]); + } +} + +#[test] +fn test_discriminant_deserialization() { + let values = vec![X::A, X::B, X::C, X::D, X::E, X::F]; + for value in values { + assert_eq!( + ::try_from_slice(&[value as u8]).unwrap(), + value, + ); + } +} + +#[test] +#[should_panic = "Unexpected variant tag: 2"] +fn test_deserialize_invalid_discriminant() { + ::try_from_slice(&[2]).unwrap(); +} + +#[test] +fn test_simple_struct() { + let mut map: HashMap = HashMap::new(); + map.insert("test".into(), "test".into()); + let mut set: HashSet = HashSet::new(); + set.insert(std::u64::MAX); + let cow_arr = [ + std::borrow::Cow::Borrowed("Hello1"), + std::borrow::Cow::Owned("Hello2".to_string()), + ]; + let a = A { + x: 1, + b: B { + x: 2, + y: 3, + c: C::C5(D { x: 1 }), + }, + y: 4.0, + z: "123".to_string(), + t: ("Hello".to_string(), 10), + m: map.clone(), + s: set.clone(), + btree_map_string: map.clone().into_iter().collect(), + btree_set_u64: set.clone().into_iter().collect(), + linked_list_string: vec!["a".to_string(), "b".to_string()].into_iter().collect(), + vec_deque_u64: vec![1, 2, 3].into_iter().collect(), + bytes: vec![5, 4, 3, 2, 1].into(), + bytes_mut: BytesMut::from(&[1, 2, 3, 4, 5][..]), + v: vec!["qwe".to_string(), "zxc".to_string()], + w: vec![0].into_boxed_slice(), + box_str: Box::from("asd"), + i: [4u8; 32], + u: Ok("Hello".to_string()), + lazy: Some(5), + c: std::borrow::Cow::Borrowed("Hello"), + cow_arr: std::borrow::Cow::Borrowed(&cow_arr), + range_u32: 12..71, + skipped: Some(6), + }; + let encoded_a = a.try_to_vec().unwrap(); + let e = E { a: &a }; + let encoded_ref_a = e.try_to_vec().unwrap(); + assert_eq!(encoded_ref_a, encoded_a); + + let decoded_a = A::try_from_slice(&encoded_a).unwrap(); + let expected_a = A { + x: 1, + b: B { + x: 2, + y: 3, + c: C::C5(D { x: 1 }), + }, + y: 4.0, + z: a.z.clone(), + t: ("Hello".to_string(), 10), + m: map.clone(), + s: set.clone(), + btree_map_string: map.clone().into_iter().collect(), + btree_set_u64: set.clone().into_iter().collect(), + linked_list_string: vec!["a".to_string(), "b".to_string()].into_iter().collect(), + vec_deque_u64: vec![1, 2, 3].into_iter().collect(), + bytes: vec![5, 4, 3, 2, 1].into(), + bytes_mut: BytesMut::from(&[1, 2, 3, 4, 5][..]), + v: a.v.clone(), + w: a.w.clone(), + box_str: Box::from("asd"), + i: a.i, + u: Ok("Hello".to_string()), + lazy: Some(50), + c: std::borrow::Cow::Owned("Hello".to_string()), + cow_arr: std::borrow::Cow::Owned(vec![ + std::borrow::Cow::Borrowed("Hello1"), + std::borrow::Cow::Owned("Hello2".to_string()), + ]), + range_u32: 12..71, + skipped: None, + }; + + assert_eq!(expected_a, decoded_a); + + let f1 = F1 { aa: &[&a, &a] }; + let encoded_f1 = f1.try_to_vec().unwrap(); + let decoded_f2 = F2::try_from_slice(&encoded_f1).unwrap(); + assert_eq!(decoded_f2.aa.len(), 2); + assert!(decoded_f2.aa.iter().all(|f2_a| f2_a == &expected_a)); +} diff --git a/borsh/tests/tokio_test_strings.rs b/borsh/tests/tokio_test_strings.rs new file mode 100644 index 000000000..3d802ad9a --- /dev/null +++ b/borsh/tests/tokio_test_strings.rs @@ -0,0 +1,22 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +macro_rules! test_string { + ($test_name: ident, $str: expr) => { + #[test] + fn $test_name() { + let s = $str.to_string(); + let buf = s.try_to_vec().unwrap(); + let actual_s = ::try_from_slice(&buf).expect("failed to deserialize a string"); + assert_eq!(actual_s, s); + } + }; +} + +test_string!(test_empty_string, ""); +test_string!(test_a, "a"); +test_string!(test_hello_world, "hello world"); +test_string!(test_x_1024, "x".repeat(1024)); +test_string!(test_x_4096, "x".repeat(4096)); +test_string!(test_x_65535, "x".repeat(65535)); +test_string!(test_hello_1000, "hello world!".repeat(1000)); +test_string!(test_non_ascii, "💩"); diff --git a/borsh/tests/tokio_test_tuple.rs b/borsh/tests/tokio_test_tuple.rs new file mode 100644 index 000000000..c629dcfc2 --- /dev/null +++ b/borsh/tests/tokio_test_tuple.rs @@ -0,0 +1,9 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +#[test] +fn test_unary_tuple() { + let expected = (true,); + let buf = expected.try_to_vec().unwrap(); + let actual = <(bool,)>::try_from_slice(&buf).expect("failed to deserialize"); + assert_eq!(actual, expected); +} diff --git a/borsh/tests/tokio_test_vecs.rs b/borsh/tests/tokio_test_vecs.rs new file mode 100644 index 000000000..ceb01f783 --- /dev/null +++ b/borsh/tests/tokio_test_vecs.rs @@ -0,0 +1,32 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +macro_rules! test_vec { + ($v: expr, $t: ty) => { + let buf = $v.try_to_vec().unwrap(); + let actual_v: Vec<$t> = + BorshDeserialize::try_from_slice(&buf).expect("failed to deserialize"); + assert_eq!(actual_v, $v); + }; +} + +macro_rules! test_vecs { + ($test_name: ident, $el: expr, $t: ty) => { + #[test] + fn $test_name() { + test_vec!(Vec::<$t>::new(), $t); + test_vec!(vec![$el], $t); + test_vec!(vec![$el; 10], $t); + test_vec!(vec![$el; 100], $t); + test_vec!(vec![$el; 1000], $t); + test_vec!(vec![$el; 10000], $t); + } + }; +} + +test_vecs!(test_vec_u8, 100u8, u8); +test_vecs!(test_vec_i8, 100i8, i8); +test_vecs!(test_vec_u32, 1000000000u32, u32); +test_vecs!(test_vec_f32, 1000000000.0f32, f32); +test_vecs!(test_vec_string, "a".to_string(), String); +test_vecs!(test_vec_vec_u8, vec![100u8; 10], Vec); +test_vecs!(test_vec_vec_u32, vec![100u32; 10], Vec); diff --git a/borsh/tests/tokio_test_zero_size.rs b/borsh/tests/tokio_test_zero_size.rs new file mode 100644 index 000000000..b47943d40 --- /dev/null +++ b/borsh/tests/tokio_test_zero_size.rs @@ -0,0 +1,11 @@ +use borsh::BorshDeserialize; + +#[derive(BorshDeserialize, PartialEq, Debug)] +struct A; + +#[test] +fn test_deserialize_vector_to_many_zero_size_struct() { + let v = [0u8, 0u8, 0u8, 64u8]; + let a = Vec::::try_from_slice(&v).unwrap(); + assert_eq!(A {}, a[usize::pow(2, 30) - 1]) +}