Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hashtree crate with feature, add Criterion benchmarks #161

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[toolchain]
channel = "1.70.0"
channel = "1.74.0"
7 changes: 7 additions & 0 deletions ssz-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ exclude = ["tests/data"]
default = ["serde", "std"]
std = ["bitvec/default", "sha2/default", "alloy-primitives/default"]
sha2-asm = ["sha2/asm"]
hashtree = ["dep:hashtree"]
serde = ["dep:serde", "alloy-primitives/serde"]

[dependencies]
bitvec = { version = "1.0.0", default-features = false, features = ["alloc"] }
ssz_rs_derive = { path = "../ssz-rs-derive", version = "0.9.0" }
sha2 = { version = "0.9.8", default-features = false }
hashtree = { version = "0.2.0", optional = true, package = "hashtree-rs" }
serde = { version = "1.0", default-features = false, features = [
"alloc",
"derive",
Expand All @@ -32,6 +34,11 @@ snap = "1.0"
project-root = "0.2.2"
serde_json = "1.0.81"
hex = "0.4.3"
criterion = { version = "0.5", features = ["html_reports"] }

[build-dependencies]
sha2 = "0.9.8"

[[bench]]
name = "merkleization"
harness = false
46 changes: 46 additions & 0 deletions ssz-rs/benches/merkleization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use criterion::{criterion_group, criterion_main, Criterion};
use ssz_rs::{PathElement, Prove};

fn generate_id(name: &str) -> String {
let tag = if cfg!(feature = "hashtree") {
"hashtree"
} else if cfg!(feature = "sha2-asm") {
"sha256-asm"
} else {
"sha256"
};

format!("{}_{}", name, tag)
}

fn bench_merkleization(c: &mut Criterion) {
use ssz_rs::{HashTreeRoot, List};

let inner: Vec<List<u8, 1073741824>> = vec![
vec![0u8, 1u8, 2u8].try_into().unwrap(),
vec![3u8, 4u8, 5u8].try_into().unwrap(),
vec![6u8, 7u8, 8u8].try_into().unwrap(),
vec![9u8, 10u8, 11u8].try_into().unwrap(),
];

// Emulate a transactions tree
let outer: List<List<u8, 1073741824>, 1048576> = List::try_from(inner).unwrap();

c.bench_function(&generate_id("hash_tree_root"), |b| {
b.iter(|| {
let _ = outer.hash_tree_root().unwrap();
})
});

// let root = outer.hash_tree_root().unwrap();
let index = PathElement::from(1);
c.bench_function(&generate_id("generate_proof"), |b| {
b.iter(|| {
let (_proof, _witness) = outer.prove(&[index.clone()]).unwrap();
})
});
}

criterion_group!(benches, bench_merkleization,);

criterion_main!(benches);
58 changes: 58 additions & 0 deletions ssz-rs/src/merkleization/hasher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#[cfg(feature = "hashtree")]
use std::sync::Once;

use super::BYTES_PER_CHUNK;

#[cfg(not(feature = "hashtree"))]
use ::sha2::{Digest, Sha256};

#[cfg(feature = "hashtree")]
static INIT: Once = Once::new();

#[inline]
#[cfg(feature = "hashtree")]
fn hash_chunks_hashtree(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] {
// Initialize the hashtree library (once)
INIT.call_once(|| {
hashtree::init();
});

let mut out = [0u8; BYTES_PER_CHUNK];

let mut chunks = [0u8; 2 * BYTES_PER_CHUNK];

chunks[..BYTES_PER_CHUNK].copy_from_slice(left.as_ref());
chunks[BYTES_PER_CHUNK..].copy_from_slice(right.as_ref());

// NOTE: hashtree "chunks" are 64 bytes long, not 32. That's why we
// specify "1" as the chunk count.
hashtree::hash(&mut out, &chunks, 1);

out
}

#[inline]
#[cfg(not(feature = "hashtree"))]
fn hash_chunks_sha256(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] {
let mut hasher = Sha256::new();
hasher.update(left.as_ref());
hasher.update(right.as_ref());
hasher.finalize_reset().into()
}

/// Function that hashes 2 [BYTES_PER_CHUNK] (32) len byte slices together. Depending on the feature
/// flags, this will either use:
/// - sha256 (default)
/// - sha256 with assembly support (with the "sha2-asm" feature flag)
/// - hashtree (with the "hashtree" feature flag)
#[inline]
pub fn hash_chunks(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> [u8; BYTES_PER_CHUNK] {
debug_assert!(left.as_ref().len() == BYTES_PER_CHUNK);
debug_assert!(right.as_ref().len() == BYTES_PER_CHUNK);

#[cfg(feature = "hashtree")]
return hash_chunks_hashtree(left, right);

#[cfg(not(feature = "hashtree"))]
return hash_chunks_sha256(left, right);
}
46 changes: 15 additions & 31 deletions ssz-rs/src/merkleization/merkleize.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
//! Support for computing Merkle trees.
use crate::{
lib::*,
merkleization::{MerkleizationError as Error, Node, BYTES_PER_CHUNK},
merkleization::{hasher::hash_chunks, MerkleizationError as Error, Node, BYTES_PER_CHUNK},
ser::Serialize,
GeneralizedIndex,
};
#[cfg(feature = "serde")]
use alloy_primitives::hex::FromHex;
use sha2::{Digest, Sha256};

// The generalized index for the root of the "decorated" type in any Merkleized type that supports
// decoration.
Expand Down Expand Up @@ -52,10 +51,8 @@ where
Ok(buffer)
}

fn hash_nodes(hasher: &mut Sha256, a: impl AsRef<[u8]>, b: impl AsRef<[u8]>, out: &mut [u8]) {
hasher.update(a);
hasher.update(b);
out.copy_from_slice(&hasher.finalize_reset());
fn hash_nodes(a: impl AsRef<[u8]>, b: impl AsRef<[u8]>, out: &mut [u8]) {
out.copy_from_slice(&hash_chunks(a, b));
}

const MAX_MERKLE_TREE_DEPTH: usize = 64;
Expand Down Expand Up @@ -108,13 +105,12 @@ fn merkleize_chunks_with_virtual_padding(chunks: &[u8], leaf_count: usize) -> Re
let depth = height - 1;
// SAFETY: index is safe while depth == leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH;
// qed
return Ok(CONTEXT[depth as usize].try_into().expect("can produce a single root chunk"))
return Ok(CONTEXT[depth as usize].try_into().expect("can produce a single root chunk"));
}

let mut layer = chunks.to_vec();
// SAFETY: checked subtraction is unnecessary, as we return early when chunk_count == 0; qed
let mut last_index = chunk_count - 1;
let mut hasher = Sha256::new();
// for each layer of the tree, starting from the bottom and walking up to the root:
for k in (1..height).rev() {
// for each pair of nodes in this layer:
Expand Down Expand Up @@ -171,13 +167,11 @@ fn merkleize_chunks_with_virtual_padding(chunks: &[u8], leaf_count: usize) -> Re
// NOTE: nodes share memory here and so we can't use the `hash_nodes` utility
// as the disjunct nature is reflect in that functions type signature
// so instead we will just replicate here.
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
left.copy_from_slice(&hash_chunks(&left, right));
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
hash_nodes(left, right, &mut parent[..BYTES_PER_CHUNK]);
}
}
last_index /= 2;
Expand All @@ -198,7 +192,7 @@ pub fn merkleize(chunks: &[u8], limit: Option<usize>) -> Result<Node, Error> {
let mut leaf_count = chunk_count.next_power_of_two();
if let Some(limit) = limit {
if limit < chunk_count {
return Err(Error::InputExceedsLimit(limit))
return Err(Error::InputExceedsLimit(limit));
}
leaf_count = limit.next_power_of_two();
}
Expand All @@ -208,9 +202,8 @@ pub fn merkleize(chunks: &[u8], limit: Option<usize>) -> Result<Node, Error> {
fn mix_in_decoration(root: Node, decoration: usize) -> Node {
let decoration_data = decoration.hash_tree_root().expect("can merkleize usize");

let mut hasher = Sha256::new();
let mut output = vec![0u8; BYTES_PER_CHUNK];
hash_nodes(&mut hasher, root, decoration_data, &mut output);
hash_nodes(root, decoration_data, &mut output);
output.as_slice().try_into().expect("can extract root")
}

Expand Down Expand Up @@ -238,17 +231,13 @@ pub(crate) fn elements_to_chunks<'a, T: HashTreeRoot + 'a>(
pub struct Tree(Vec<u8>);

impl Tree {
pub fn mix_in_decoration(
&mut self,
decoration: usize,
hasher: &mut Sha256,
) -> Result<(), Error> {
pub fn mix_in_decoration(&mut self, decoration: usize) -> Result<(), Error> {
let target_node = &mut self[DECORATION_GENERALIZED_INDEX];
let decoration_node = decoration.hash_tree_root()?;
target_node.copy_from_slice(decoration_node.as_ref());
hasher.update(&self[INNER_ROOT_GENERALIZED_INDEX]);
hasher.update(&self[DECORATION_GENERALIZED_INDEX]);
self[1].copy_from_slice(&hasher.finalize_reset());
let out =
hash_chunks(&self[INNER_ROOT_GENERALIZED_INDEX], &self[DECORATION_GENERALIZED_INDEX]);
self[1].copy_from_slice(&out);
Ok(())
}

Expand Down Expand Up @@ -287,11 +276,7 @@ impl std::fmt::Debug for Tree {
// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0`
// Invariant: `leaf_count.next_power_of_two() == leaf_count`
// NOTE: naive implementation, can make much more efficient
pub fn compute_merkle_tree(
hasher: &mut Sha256,
chunks: &[u8],
leaf_count: usize,
) -> Result<Tree, Error> {
pub fn compute_merkle_tree(chunks: &[u8], leaf_count: usize) -> Result<Tree, Error> {
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
debug_assert!(leaf_count.next_power_of_two() == leaf_count);

Expand Down Expand Up @@ -320,7 +305,7 @@ pub fn compute_merkle_tree(
// NOTE: children.len() == 2 * BYTES_PER_CHUNK
let (parent, children) = focus.split_at_mut(children_index);
let (left, right) = children.split_at(BYTES_PER_CHUNK);
hash_nodes(hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
hash_nodes(left, right, &mut parent[..BYTES_PER_CHUNK]);
}
Ok(Tree(buffer))
}
Expand All @@ -332,8 +317,7 @@ mod tests {

// Return the root of the Merklization of a binary tree formed from `chunks`.
fn merkleize_chunks(chunks: &[u8], leaf_count: usize) -> Result<Node, Error> {
let mut hasher = Sha256::new();
let tree = compute_merkle_tree(&mut hasher, chunks, leaf_count)?;
let tree = compute_merkle_tree(chunks, leaf_count)?;
let root_index = default_generalized_index();
Ok(tree[root_index].try_into().expect("can produce a single root chunk"))
}
Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod generalized_index;
mod hasher;
mod merkleize;
pub mod multiproofs;
mod node;
Expand Down
28 changes: 11 additions & 17 deletions ssz-rs/src/merkleization/multiproofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::{
lib::*,
merkleization::{
generalized_index::{get_bit, get_path_length, parent, sibling},
hasher::hash_chunks,
GeneralizedIndex, MerkleizationError as Error, Node,
},
};
use sha2::{Digest, Sha256};

fn get_branch_indices(tree_index: GeneralizedIndex) -> Vec<GeneralizedIndex> {
let mut focus = sibling(tree_index);
Expand Down Expand Up @@ -52,20 +52,15 @@ pub fn calculate_merkle_root(
) -> Result<Node, Error> {
let path_length = get_path_length(index)?;
if path_length != proof.len() {
return Err(Error::InvalidProof)
return Err(Error::InvalidProof);
}
let mut result = leaf;

let mut hasher = Sha256::new();
for (i, next) in proof.iter().enumerate() {
if get_bit(index, i) {
hasher.update(next);
hasher.update(result);
} else {
hasher.update(result);
hasher.update(next);
}
result.copy_from_slice(&hasher.finalize_reset());
let out =
if get_bit(index, i) { hash_chunks(next, result) } else { hash_chunks(result, next) };

result.copy_from_slice(&out);
}
Ok(result)
}
Expand All @@ -89,11 +84,11 @@ pub fn calculate_multi_merkle_root(
indices: &[GeneralizedIndex],
) -> Result<Node, Error> {
if leaves.len() != indices.len() {
return Err(Error::InvalidProof)
return Err(Error::InvalidProof);
}
let helper_indices = get_helper_indices(indices);
if proof.len() != helper_indices.len() {
return Err(Error::InvalidProof)
return Err(Error::InvalidProof);
}

let mut objects = HashMap::new();
Expand All @@ -107,7 +102,6 @@ pub fn calculate_multi_merkle_root(
let mut keys = objects.keys().cloned().collect::<Vec<_>>();
keys.sort_by(|a, b| b.cmp(a));

let mut hasher = Sha256::new();
let mut pos = 0;
while pos < keys.len() {
let key = keys.get(pos).unwrap();
Expand All @@ -121,11 +115,11 @@ pub fn calculate_multi_merkle_root(
let left_index = sibling(right_index);
let left_input = objects.get(&left_index).expect("contains index");
let right_input = objects.get(&right_index).expect("contains index");
hasher.update(left_input);
hasher.update(right_input);

let out = hash_chunks(left_input, right_input);

let parent = objects.entry(parent_index).or_default();
parent.copy_from_slice(&hasher.finalize_reset());
parent.copy_from_slice(&out);
keys.push(parent_index);
}
pos += 1;
Expand Down
Loading
Loading