Skip to content

Commit

Permalink
perf: process 4 subtrees in parallel instead of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
estensen committed Dec 5, 2024
1 parent 843d070 commit 74acc3d
Showing 1 changed file with 145 additions and 21 deletions.
166 changes: 145 additions & 21 deletions ssz-rs/src/merkleization/merkleize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ pub fn compute_merkle_tree(chunks: &[u8], leaf_count: usize) -> Result<Tree, Err
// Copy input chunks to leaf positions
buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks);

if leaf_count >= 8 {
if leaf_count >= 16 {
compute_merkle_tree_parallel(&mut buffer, node_count);
} else {
compute_merkle_tree_serial(&mut buffer, node_count);
Expand All @@ -323,39 +323,115 @@ pub fn compute_merkle_tree(chunks: &[u8], leaf_count: usize) -> Result<Tree, Err

// Split for left subtree (nodes 1,3,4) and right subtree (nodes 2,5,6)
fn compute_merkle_tree_parallel(buffer: &mut [u8], node_count: usize) {
let (left_nodes, right_nodes): (Vec<_>, Vec<_>) = (1..node_count).partition(|&i| {
let level = (i + 1).ilog2() as usize;
let pos_in_level = i - ((1 << level) - 1);
pos_in_level < (1 << (level - 1))
let nodes = split_merkle_tree_nodes(node_count);

// Create buffers for each section
let mut left_left_buf = vec![0u8; nodes.left_left.len() * BYTES_PER_CHUNK];
let mut left_right_buf = vec![0u8; nodes.left_right.len() * BYTES_PER_CHUNK];
let mut right_left_buf = vec![0u8; nodes.right_left.len() * BYTES_PER_CHUNK];
let mut right_right_buf = vec![0u8; nodes.right_right.len() * BYTES_PER_CHUNK];

// Copy data to section buffers
copy_nodes_to_buffer(buffer, &nodes.left_left, &mut left_left_buf);
copy_nodes_to_buffer(buffer, &nodes.left_right, &mut left_right_buf);
copy_nodes_to_buffer(buffer, &nodes.right_left, &mut right_left_buf);
copy_nodes_to_buffer(buffer, &nodes.right_right, &mut right_right_buf);

// Process all sections in parallel
rayon::scope(|s| {
s.spawn(|_| process_subtree(&mut left_left_buf, nodes.left_left.len()));
s.spawn(|_| process_subtree(&mut left_right_buf, nodes.left_right.len()));
s.spawn(|_| process_subtree(&mut right_left_buf, nodes.right_left.len()));
s.spawn(|_| process_subtree(&mut right_right_buf, nodes.right_right.len()));
});

// Create buffers for left and right subtrees
let mut left_buf = vec![0u8; left_nodes.len() * BYTES_PER_CHUNK];
let mut right_buf = vec![0u8; right_nodes.len() * BYTES_PER_CHUNK];

// Copy relevant chunks to subtree buffers
copy_nodes_to_buffer(buffer, &left_nodes, &mut left_buf);
copy_nodes_to_buffer(buffer, &right_nodes, &mut right_buf);
// Copy results back
copy_buffer_to_nodes(&left_left_buf, &nodes.left_left, buffer);
copy_buffer_to_nodes(&left_right_buf, &nodes.left_right, buffer);
copy_buffer_to_nodes(&right_left_buf, &nodes.right_left, buffer);
copy_buffer_to_nodes(&right_right_buf, &nodes.right_right, buffer);

// Process left and right subtrees in parallel
rayon::join(
|| process_subtree(&mut left_buf, left_nodes.len()),
|| process_subtree(&mut right_buf, right_nodes.len()),
// Compute node 1 from 3 and 4
let node1_hash = hash_chunks(
&buffer[3 * BYTES_PER_CHUNK..4 * BYTES_PER_CHUNK],
&buffer[4 * BYTES_PER_CHUNK..5 * BYTES_PER_CHUNK],
);
buffer[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK].copy_from_slice(&node1_hash);

// Copy results back to main buffer
copy_buffer_to_nodes(&left_buf, &left_nodes, buffer);
copy_buffer_to_nodes(&right_buf, &right_nodes, buffer);
// Compute node 2 from 5 and 6
let node2_hash = hash_chunks(
&buffer[5 * BYTES_PER_CHUNK..6 * BYTES_PER_CHUNK],
&buffer[6 * BYTES_PER_CHUNK..7 * BYTES_PER_CHUNK],
);
buffer[2 * BYTES_PER_CHUNK..3 * BYTES_PER_CHUNK].copy_from_slice(&node2_hash);

// Compute root hash
// Compute the root from 1 and 2
let root_hash = hash_chunks(
&buffer[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK],
&buffer[2 * BYTES_PER_CHUNK..3 * BYTES_PER_CHUNK],
);
// Store root hash at the root position
buffer[..BYTES_PER_CHUNK].copy_from_slice(&root_hash);
}

#[derive(Debug, PartialEq)]
struct SubtreeNodes {
left_left: Vec<usize>,
left_right: Vec<usize>,
right_left: Vec<usize>,
right_right: Vec<usize>,
}

// Split merkle tree into 4 subtrees
fn split_merkle_tree_nodes(node_count: usize) -> SubtreeNodes {
let mut left_left = Vec::new();
let mut left_right = Vec::new();
let mut right_left = Vec::new();
let mut right_right = Vec::new();

// Skip root node (index 0)
for i in 1..node_count {
// Current level in tree (0-based)
let level = (i + 1).ilog2() as usize;
// Position within level
let pos_in_level = i - ((1 << level) - 1);

// Handle level 1 nodes (1,2) - skip them as they're not part of subtrees
if level == 1 {
continue;
}

// For level 2 nodes (3,4,5,6), they become the roots of our subtrees
if level == 2 {
match pos_in_level {
0 => left_left.push(i), // Node 3
1 => left_right.push(i), // Node 4
2 => right_left.push(i), // Node 5
3 => right_right.push(i), // Node 6
_ => unreachable!(),
}
continue;
}

// For deeper levels, assign based on their parent at level 2
let ancestor_at_level_2 = {
let steps_up = level - 2;
let parent_pos = pos_in_level >> steps_up;
parent_pos + 3 // +3 because level 2 starts at index 3
};

// Assign to appropriate subtree based on level 2 ancestor
match ancestor_at_level_2 {
3 => left_left.push(i),
4 => left_right.push(i),
5 => right_left.push(i),
6 => right_right.push(i),
_ => unreachable!(),
}
}

SubtreeNodes { left_left, left_right, right_left, right_right }
}

// Compute the Merkle tree serially.
fn compute_merkle_tree_serial(buffer: &mut [u8], node_count: usize) {
for i in (1..node_count).rev().step_by(2) {
Expand Down Expand Up @@ -590,6 +666,54 @@ mod tests {
);
}

#[test]
fn test_split_merkle_tree_nodes_tree() {
// 0
// / \
// 1 2
// / \ / \
// 3 4 5 6
// / \ / \ / \ / \
// 7 8 9 10 11 12 13 14
let nodes = split_merkle_tree_nodes(15);

assert_eq!(nodes.left_left, vec![3, 7, 8]);
assert_eq!(nodes.left_right, vec![4, 9, 10]);
assert_eq!(nodes.right_left, vec![5, 11, 12]);
assert_eq!(nodes.right_right, vec![6, 13, 14]);
}

#[test]
fn test_split_merkle_tree_nodes_16_leaves() {
// 0
// / \
// 1 2
// / \ / \
// 3 4 5 6
// / \ / \ / \ / \
// 7 8 9 10 11 12 13 14
// /\ /\ /\ /\ /\ /\ /\ /\
//15 16... ...29 30

let nodes = split_merkle_tree_nodes(31);

assert_eq!(nodes.left_left, vec![3, 7, 8, 15, 16, 17, 18]);
assert_eq!(nodes.left_right, vec![4, 9, 10, 19, 20, 21, 22]);
assert_eq!(nodes.right_left, vec![5, 11, 12, 23, 24, 25, 26]);
assert_eq!(nodes.right_right, vec![6, 13, 14, 27, 28, 29, 30]);
}

#[test]
fn test_split_nodes_empty_tree() {
// Test edge case with just root node
let nodes = split_merkle_tree_nodes(1);

assert_eq!(nodes.left_left, Vec::<usize>::new());
assert_eq!(nodes.left_right, Vec::<usize>::new());
assert_eq!(nodes.right_left, Vec::<usize>::new());
assert_eq!(nodes.right_right, Vec::<usize>::new());
}

#[test]
fn test_process_subtree_with_size_zero() {
let mut buffer = vec![0u8; 0];
Expand Down

0 comments on commit 74acc3d

Please sign in to comment.