diff --git a/starknet-core/src/types/hash_256.rs b/starknet-core/src/types/hash_256.rs new file mode 100644 index 00000000..a75c9896 --- /dev/null +++ b/starknet-core/src/types/hash_256.rs @@ -0,0 +1,177 @@ +use alloc::{ + fmt::{Debug, Display, Formatter, Result as FmtResult}, + format, + str::FromStr, +}; + +use serde::{de::Visitor, Deserialize, Serialize}; +use starknet_ff::FieldElement; + +const HASH_256_BYTE_COUNT: usize = 32; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct Hash256 { + inner: [u8; HASH_256_BYTE_COUNT], +} + +struct Hash256Visitor; + +mod errors { + use core::fmt::{Display, Formatter, Result}; + + #[derive(Debug)] + pub enum FromHexError { + UnexpectedLength, + InvalidHexString, + } + + #[derive(Debug)] + pub struct ToFieldElementError; + + #[cfg(feature = "std")] + impl std::error::Error for FromHexError {} + + #[cfg(feature = "std")] + impl std::error::Error for ToFieldElementError {} + + impl Display for FromHexError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Self::UnexpectedLength => { + write!(f, "unexpected length for 256-bit hash") + } + Self::InvalidHexString => { + write!(f, "invalid hex string") + } + } + } + } + + impl Display for ToFieldElementError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "hash value out of range for FieldElement") + } + } +} +pub use errors::{FromHexError, ToFieldElementError}; + +impl Hash256 { + pub fn from_bytes(bytes: [u8; HASH_256_BYTE_COUNT]) -> Self { + Self { inner: bytes } + } + + pub fn from_hex(hex: &str) -> Result { + hex.parse() + } + + pub fn from_felt(felt: &FieldElement) -> Self { + felt.into() + } + + pub fn as_bytes(&self) -> &[u8; HASH_256_BYTE_COUNT] { + &self.inner + } +} + +impl Serialize for Hash256 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("0x{}", hex::encode(self.inner))) + } +} + +impl<'de> Deserialize<'de> for Hash256 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(Hash256Visitor) + } +} + +impl<'de> Visitor<'de> for Hash256Visitor { + type Value = Hash256; + + fn expecting(&self, formatter: &mut Formatter) -> alloc::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse() + .map_err(|err| serde::de::Error::custom(format!("{}", err))) + } +} + +impl FromStr for Hash256 { + type Err = FromHexError; + + fn from_str(s: &str) -> Result { + let value = s.trim_start_matches("0x"); + + let hex_chars_len = value.len(); + let expected_hex_length = HASH_256_BYTE_COUNT * 2; + + let parsed_bytes: [u8; HASH_256_BYTE_COUNT] = if hex_chars_len == expected_hex_length { + let mut buffer = [0u8; HASH_256_BYTE_COUNT]; + hex::decode_to_slice(value, &mut buffer).map_err(|_| FromHexError::InvalidHexString)?; + buffer + } else if hex_chars_len < expected_hex_length { + let mut padded_hex = str::repeat("0", expected_hex_length - hex_chars_len); + padded_hex.push_str(value); + + let mut buffer = [0u8; HASH_256_BYTE_COUNT]; + hex::decode_to_slice(&padded_hex, &mut buffer) + .map_err(|_| FromHexError::InvalidHexString)?; + buffer + } else { + return Err(FromHexError::UnexpectedLength); + }; + + Ok(Self::from_bytes(parsed_bytes)) + } +} + +impl Debug for Hash256 { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "0x{}", hex::encode(self.inner)) + } +} + +impl Display for Hash256 { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "0x{}", hex::encode(self.inner)) + } +} + +impl From for Hash256 { + fn from(value: FieldElement) -> Self { + (&value).into() + } +} + +impl From<&FieldElement> for Hash256 { + fn from(value: &FieldElement) -> Self { + Self::from_bytes(value.to_bytes_be()) + } +} + +impl TryFrom for FieldElement { + type Error = ToFieldElementError; + + fn try_from(value: Hash256) -> Result { + (&value).try_into() + } +} + +impl TryFrom<&Hash256> for FieldElement { + type Error = ToFieldElementError; + + fn try_from(value: &Hash256) -> Result { + FieldElement::from_bytes_be(&value.inner).map_err(|_| ToFieldElementError) + } +} diff --git a/starknet-core/src/types/mod.rs b/starknet-core/src/types/mod.rs index 609a8288..18ff37aa 100644 --- a/starknet-core/src/types/mod.rs +++ b/starknet-core/src/types/mod.rs @@ -39,6 +39,9 @@ pub use codegen::{ pub mod eth_address; pub use eth_address::EthAddress; +pub mod hash_256; +pub use hash_256::Hash256; + mod execution_result; pub use execution_result::ExecutionResult; @@ -583,10 +586,8 @@ mod tests { let msg_to_l2 = l1_handler_tx.parse_msg_to_l2().unwrap(); - let expected_hash: [u8; 32] = - hex::decode("c51a543ef9563ad2545342b390b67edfcddf9886aa36846cf70382362fc5fab3") - .unwrap() - .try_into() + let expected_hash = + Hash256::from_hex("c51a543ef9563ad2545342b390b67edfcddf9886aa36846cf70382362fc5fab3") .unwrap(); assert_eq!(msg_to_l2.hash(), expected_hash); diff --git a/starknet-core/src/types/msg.rs b/starknet-core/src/types/msg.rs index da381986..da5ad5e4 100644 --- a/starknet-core/src/types/msg.rs +++ b/starknet-core/src/types/msg.rs @@ -3,7 +3,7 @@ use alloc::vec::Vec; use sha3::{Digest, Keccak256}; use starknet_ff::FieldElement; -use super::EthAddress; +use super::{EthAddress, Hash256}; #[derive(Debug, Clone)] pub struct MsgToL2 { @@ -18,7 +18,7 @@ impl MsgToL2 { /// Calculates the message hash based on the algorithm documented here: /// /// https://docs.starknet.io/documentation/architecture_and_concepts/L1-L2_Communication/messaging-mechanism/ - pub fn hash(&self) -> [u8; 32] { + pub fn hash(&self) -> Hash256 { let mut hasher = Keccak256::new(); // FromAddress @@ -47,7 +47,7 @@ impl MsgToL2 { let hash = hasher.finalize(); // Because we know hash is always 32 bytes - unsafe { *(hash[..].as_ptr() as *const [u8; 32]) } + Hash256::from_bytes(unsafe { *(hash[..].as_ptr() as *const [u8; 32]) }) } } @@ -82,10 +82,8 @@ mod tests { nonce: 775628, }; - let expected_hash: [u8; 32] = - hex::decode("c51a543ef9563ad2545342b390b67edfcddf9886aa36846cf70382362fc5fab3") - .unwrap() - .try_into() + let expected_hash = + Hash256::from_hex("c51a543ef9563ad2545342b390b67edfcddf9886aa36846cf70382362fc5fab3") .unwrap(); assert_eq!(msg.hash(), expected_hash);