diff --git a/ssz-rs/Cargo.toml b/ssz-rs/Cargo.toml index 36f11cea..bb454018 100644 --- a/ssz-rs/Cargo.toml +++ b/ssz-rs/Cargo.toml @@ -13,6 +13,7 @@ exclude = ["tests/data"] [features] default = ["serde", "std"] +hashtree = ["dep:hashtree"] std = ["bitvec/default", "sha2/default", "alloy-primitives/default"] sha2-asm = ["sha2/asm"] serde = ["dep:serde", "alloy-primitives/serde"] @@ -20,6 +21,7 @@ serde = ["dep:serde", "alloy-primitives/serde"] [dependencies] bitvec = { version = "1.0.0", default-features = false, features = ["alloc"] } rayon = "1.10" +hashtree = { version = "0.2.0", optional = true, package = "hashtree-rs" } ssz_rs_derive = { path = "../ssz-rs-derive", version = "0.9.0" } sha2 = { version = "0.9.8", default-features = false } serde = { version = "1.0", default-features = false, features = [ diff --git a/ssz-rs/src/merkleization/hasher.rs b/ssz-rs/src/merkleization/hasher.rs index 2698ef2c..fc1d9d62 100644 --- a/ssz-rs/src/merkleization/hasher.rs +++ b/ssz-rs/src/merkleization/hasher.rs @@ -1,8 +1,41 @@ +#[cfg(feature = "hashtree")] +use std::sync::Once; + use super::BYTES_PER_CHUNK; +#[cfg(not(feature = "hashtree"))] use ::sha2::{Digest, Sha256}; +#[cfg(feature = "hashtree")] +static INIT: Once = Once::new(); + +#[inline] +#[cfg(feature = "hashtree")] +fn hash_chunks_hashtree(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] { + // Initialize the hashtree library (once) + INIT.call_once(|| { + hashtree::init(); + }); + + let mut out = [0u8; BYTES_PER_CHUNK]; + + let mut chunks = [0u8; 2 * BYTES_PER_CHUNK]; + + chunks[..BYTES_PER_CHUNK].copy_from_slice(left.as_ref()); + chunks[BYTES_PER_CHUNK..].copy_from_slice(right.as_ref()); + + // NOTE: hashtree "chunks" are 64 bytes long, not 32. That's why we + // specify "1" as the chunk count. + hashtree::hash(&mut out, &chunks, 1); + + out +} + +// Hashes two chunks together using sha256 +// Defaults to sha256 +// sha256-asm can be enabled with the "sha2-asm" feature flag #[inline] +#[cfg(not(feature = "hashtree"))] fn hash_chunks_sha256(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] { let mut hasher = Sha256::new(); hasher.update(left.as_ref()); @@ -13,11 +46,16 @@ fn hash_chunks_sha256(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; B /// Function that hashes 2 [BYTES_PER_CHUNK] (32) len byte slices together. Depending on the feature /// flags, this will either use: /// - sha256 (default) -/// - TODO: sha256 with assembly support (with the "sha2-asm" feature flag) -/// - TODO: hashtree (with the "hashtree" feature flag) +/// - sha256 with assembly support (with the "sha2-asm" feature flag) +/// - hashtree (with the "hashtree" feature flag) #[inline] pub fn hash_chunks(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] { debug_assert!(left.as_ref().len() == BYTES_PER_CHUNK); debug_assert!(right.as_ref().len() == BYTES_PER_CHUNK); - hash_chunks_sha256(left, right) + + #[cfg(feature = "hashtree")] + return hash_chunks_hashtree(left, right); + + #[cfg(not(feature = "hashtree"))] + return hash_chunks_sha256(left, right); }