From 9ea4aba996f3dd4c859f4001f62e67178577301f Mon Sep 17 00:00:00 2001 From: Jonathan LEI Date: Wed, 1 Nov 2023 05:28:54 +0000 Subject: [PATCH] perf: drop useless allocation in deserialization --- starknet-core/src/serde/byte_array.rs | 29 ++++++-- starknet-core/src/serde/num_hex.rs | 28 +++++-- .../src/serde/unsigned_field_element.rs | 74 +++++++++++++++---- starknet-core/src/types/eth_address.rs | 25 +++++-- starknet-ff/src/lib.rs | 30 +++++++- 5 files changed, 148 insertions(+), 38 deletions(-) diff --git a/starknet-core/src/serde/byte_array.rs b/starknet-core/src/serde/byte_array.rs index 79d9411e..d5ddd470 100644 --- a/starknet-core/src/serde/byte_array.rs +++ b/starknet-core/src/serde/byte_array.rs @@ -1,8 +1,10 @@ pub mod base64 { - use alloc::{format, string::String, vec::Vec}; + use alloc::{fmt::Formatter, format, vec::Vec}; use base64::{engine::general_purpose::STANDARD, Engine}; - use serde::{Deserialize, Deserializer, Serializer}; + use serde::{de::Visitor, Deserializer, Serializer}; + + struct Base64Visitor; pub fn serialize(value: T, serializer: S) -> Result where @@ -16,12 +18,23 @@ pub mod base64 { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - match STANDARD.decode(value) { - Ok(value) => Ok(value), - Err(err) => Err(serde::de::Error::custom(format!( - "invalid base64 string: {err}" - ))), + deserializer.deserialize_any(Base64Visitor) + } + + impl<'de> Visitor<'de> for Base64Visitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + STANDARD + .decode(v) + .map_err(|err| serde::de::Error::custom(format!("invalid base64 string: {err}"))) } } } diff --git a/starknet-core/src/serde/num_hex.rs b/starknet-core/src/serde/num_hex.rs index 9aaf9076..985b9053 100644 --- a/starknet-core/src/serde/num_hex.rs +++ b/starknet-core/src/serde/num_hex.rs @@ -1,7 +1,9 @@ pub mod u64 { - use alloc::{format, string::String}; + use alloc::{fmt::Formatter, format}; - use serde::{Deserialize, Deserializer, Serializer}; + use serde::{de::Visitor, Deserializer, Serializer}; + + struct NumHexVisitor; pub fn serialize(value: &u64, serializer: S) -> Result where @@ -14,12 +16,22 @@ pub mod u64 { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - match u64::from_str_radix(value.trim_start_matches("0x"), 16) { - Ok(value) => Ok(value), - Err(err) => Err(serde::de::Error::custom(format!( - "invalid u64 hex string: {err}" - ))), + deserializer.deserialize_any(NumHexVisitor) + } + + impl<'de> Visitor<'de> for NumHexVisitor { + type Value = u64; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + u64::from_str_radix(v.trim_start_matches("0x"), 16) + .map_err(|err| serde::de::Error::custom(format!("invalid u64 hex string: {err}"))) } } } diff --git a/starknet-core/src/serde/unsigned_field_element.rs b/starknet-core/src/serde/unsigned_field_element.rs index a130366c..1a4d8a53 100644 --- a/starknet-core/src/serde/unsigned_field_element.rs +++ b/starknet-core/src/serde/unsigned_field_element.rs @@ -1,6 +1,9 @@ -use alloc::{format, string::String}; +use alloc::{fmt::Formatter, format}; -use serde::{de::Error as DeError, Deserialize, Deserializer, Serializer}; +use serde::{ + de::{Error as DeError, Visitor}, + Deserializer, Serializer, +}; use serde_with::{DeserializeAs, SerializeAs}; use crate::types::FieldElement; @@ -11,6 +14,10 @@ pub struct UfeHexOption; pub struct UfePendingBlockHash; +struct UfeHexVisitor; +struct UfeHexOptionVisitor; +struct UfePendingBlockHashVisitor; + impl SerializeAs for UfeHex { fn serialize_as(value: &FieldElement, serializer: S) -> Result where @@ -25,11 +32,23 @@ impl<'de> DeserializeAs<'de, FieldElement> for UfeHex { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - match FieldElement::from_hex_be(&value) { - Ok(value) => Ok(value), - Err(err) => Err(DeError::custom(format!("invalid hex string: {err}"))), - } + deserializer.deserialize_any(UfeHexVisitor) + } +} + +impl<'de> Visitor<'de> for UfeHexVisitor { + type Value = FieldElement; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: DeError, + { + FieldElement::from_hex_be(v) + .map_err(|err| DeError::custom(format!("invalid hex string: {err}"))) } } @@ -50,10 +69,24 @@ impl<'de> DeserializeAs<'de, Option> for UfeHexOption { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - match value.as_str() { + deserializer.deserialize_any(UfeHexOptionVisitor) + } +} + +impl<'de> Visitor<'de> for UfeHexOptionVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: DeError, + { + match v { "" => Ok(None), - _ => match FieldElement::from_hex_be(&value) { + _ => match FieldElement::from_hex_be(v) { Ok(value) => Ok(Some(value)), Err(err) => Err(DeError::custom(format!("invalid hex string: {err}"))), }, @@ -79,11 +112,25 @@ impl<'de> DeserializeAs<'de, Option> for UfePendingBlockHash { where D: Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - if value.is_empty() || value == "pending" || value == "None" { + deserializer.deserialize_any(UfePendingBlockHashVisitor) + } +} + +impl<'de> Visitor<'de> for UfePendingBlockHashVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: DeError, + { + if v.is_empty() || v == "pending" || v == "None" { Ok(None) } else { - match FieldElement::from_hex_be(&value) { + match FieldElement::from_hex_be(v) { Ok(value) => Ok(Some(value)), Err(err) => Err(DeError::custom(format!("invalid hex string: {err}"))), } @@ -95,6 +142,7 @@ impl<'de> DeserializeAs<'de, Option> for UfePendingBlockHash { mod tests { use super::*; + use serde::Deserialize; use serde_with::serde_as; #[serde_as] diff --git a/starknet-core/src/types/eth_address.rs b/starknet-core/src/types/eth_address.rs index 53d38ad4..21d483cd 100644 --- a/starknet-core/src/types/eth_address.rs +++ b/starknet-core/src/types/eth_address.rs @@ -1,7 +1,7 @@ -use alloc::{format, string::String}; +use alloc::{fmt::Formatter, format}; use core::str::FromStr; -use serde::{Deserialize, Serialize}; +use serde::{de::Visitor, Deserialize, Serialize}; use starknet_ff::FieldElement; // 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF @@ -17,6 +17,8 @@ pub struct EthAddress { inner: [u8; 20], } +struct EthAddressVisitor; + mod errors { use core::fmt::{Display, Formatter, Result}; @@ -84,9 +86,22 @@ impl<'de> Deserialize<'de> for EthAddress { where D: serde::Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - value - .parse() + deserializer.deserialize_any(EthAddressVisitor) + } +} + +impl<'de> Visitor<'de> for EthAddressVisitor { + type Value = EthAddress; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse() .map_err(|err| serde::de::Error::custom(format!("{}", err))) } } diff --git a/starknet-ff/src/lib.rs b/starknet-ff/src/lib.rs index 4a0ff4f4..d3032cf3 100644 --- a/starknet-ff/src/lib.rs +++ b/starknet-ff/src/lib.rs @@ -552,10 +552,18 @@ impl fmt::UpperHex for FieldElement { #[cfg(feature = "serde")] mod serde_field_element { + #[cfg(feature = "std")] + use core::fmt::{Formatter, Result as FmtResult}; + use super::*; #[cfg(not(feature = "std"))] - use alloc::string::{String, ToString}; - use serde::{Deserialize, Serialize}; + use alloc::{ + fmt::{Formatter, Result as FmtResult}, + string::ToString, + }; + use serde::{de::Visitor, Deserialize, Serialize}; + + struct FieldElementVisitor; impl Serialize for FieldElement { fn serialize(&self, serializer: S) -> Result @@ -571,8 +579,22 @@ mod serde_field_element { where D: serde::Deserializer<'de>, { - let value = String::deserialize(deserializer)?; - Self::from_str(&value).map_err(serde::de::Error::custom) + deserializer.deserialize_any(FieldElementVisitor) + } + } + + impl<'de> Visitor<'de> for FieldElementVisitor { + type Value = FieldElement; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + FieldElement::from_str(v).map_err(serde::de::Error::custom) } } }