Skip to content

Commit

Permalink
improv serde impl
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy committed Dec 21, 2023
1 parent 082f556 commit 56cae56
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 12 deletions.
20 changes: 17 additions & 3 deletions starknet-core/src/serde/num_hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,29 @@ pub mod u64 {
where
S: Serializer,
{
serializer.serialize_str(&format!("{value:#x}"))
if serializer.is_human_readable() {
serializer.serialize_str(&format!("{value:#x}"))
} else {
serializer.serialize_bytes(&value.to_be_bytes())
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(NumHexVisitor)
if deserializer.is_human_readable() {
deserializer.deserialize_any(NumHexVisitor)
} else {
deserializer.deserialize_bytes(NumHexVisitor)
}
}

impl<'de> Visitor<'de> for NumHexVisitor {
type Value = u64;

fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result {
write!(formatter, "string")
write!(formatter, "string, or an array of u8")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
Expand All @@ -33,5 +41,11 @@ pub mod u64 {
u64::from_str_radix(v.trim_start_matches("0x"), 16)
.map_err(|err| serde::de::Error::custom(format!("invalid u64 hex string: {err}")))
}

fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
<[u8; std::mem::size_of::<u64>()]>::try_from(v)
.map(u64::from_be_bytes)
.map_err(serde::de::Error::custom)
}
}
}
19 changes: 16 additions & 3 deletions starknet-core/src/serde/unsigned_field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ impl SerializeAs<FieldElement> for UfeHex {
where
S: Serializer,
{
serializer.serialize_str(&format!("{value:#x}"))
if serializer.is_human_readable() {
serializer.serialize_str(&format!("{value:#x}"))
} else {
serializer.serialize_bytes(&value.to_bytes_be())
}
}
}

Expand All @@ -32,15 +36,19 @@ impl<'de> DeserializeAs<'de, FieldElement> for UfeHex {
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(UfeHexVisitor)
if deserializer.is_human_readable() {
deserializer.deserialize_any(UfeHexVisitor)
} else {
deserializer.deserialize_bytes(UfeHexVisitor)
}
}
}

impl<'de> Visitor<'de> for UfeHexVisitor {
type Value = FieldElement;

fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result {
write!(formatter, "string")
write!(formatter, "a hex string, or an array of u8")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
Expand All @@ -50,6 +58,11 @@ impl<'de> Visitor<'de> for UfeHexVisitor {
FieldElement::from_hex_be(v)
.map_err(|err| DeError::custom(format!("invalid hex string: {err}")))
}

fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
let buf = <[u8; 32]>::try_from(v).map_err(serde::de::Error::custom)?;
FieldElement::from_bytes_be(&buf).map_err(serde::de::Error::custom)
}
}

impl SerializeAs<Option<FieldElement>> for UfeHexOption {
Expand Down
20 changes: 17 additions & 3 deletions starknet-core/src/types/hash_256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ impl Serialize for Hash256 {
where
S: serde::Serializer,
{
serializer.serialize_str(&format!("0x{}", hex::encode(self.inner)))
if serializer.is_human_readable() {
serializer.serialize_str(&format!("0x{}", hex::encode(self.inner)))
} else {
serializer.serialize_bytes(self.as_bytes())
}
}
}

Expand All @@ -87,15 +91,19 @@ impl<'de> Deserialize<'de> for Hash256 {
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(Hash256Visitor)
if deserializer.is_human_readable() {
deserializer.deserialize_any(Hash256Visitor)
} else {
deserializer.deserialize_bytes(Hash256Visitor)
}
}
}

impl<'de> Visitor<'de> for Hash256Visitor {
type Value = Hash256;

fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result {
write!(formatter, "string")
write!(formatter, "string, or an array of u8")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
Expand All @@ -105,6 +113,12 @@ impl<'de> Visitor<'de> for Hash256Visitor {
v.parse()
.map_err(|err| serde::de::Error::custom(format!("{}", err)))
}

fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
<[u8; HASH_256_BYTE_COUNT]>::try_from(v)
.map(Hash256::from_bytes)
.map_err(serde::de::Error::custom)
}
}

impl FromStr for Hash256 {
Expand Down
19 changes: 16 additions & 3 deletions starknet-ff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,11 @@ mod serde_field_element {
where
S: serde::Serializer,
{
serializer.serialize_str(&ToString::to_string(&self))
if serializer.is_human_readable() {
serializer.serialize_str(&ToString::to_string(&self))
} else {
serializer.serialize_bytes(&self.to_bytes_be())
}
}
}

Expand All @@ -579,15 +583,19 @@ mod serde_field_element {
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(FieldElementVisitor)
if deserializer.is_human_readable() {
deserializer.deserialize_any(FieldElementVisitor)
} else {
deserializer.deserialize_bytes(FieldElementVisitor)
}
}
}

impl<'de> Visitor<'de> for FieldElementVisitor {
type Value = FieldElement;

fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
write!(formatter, "string")
write!(formatter, "string, or an array of u8")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
Expand All @@ -596,6 +604,11 @@ mod serde_field_element {
{
FieldElement::from_str(v).map_err(serde::de::Error::custom)
}

fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
let buf = <[u8; U256_BYTE_COUNT]>::try_from(v).map_err(serde::de::Error::custom)?;
FieldElement::from_bytes_be(&buf).map_err(serde::de::Error::custom)
}
}
}

Expand Down

0 comments on commit 56cae56

Please sign in to comment.