diff --git a/ssz-rs/src/merkleization/merkleize.rs b/ssz-rs/src/merkleization/merkleize.rs index 6a27e422..7efb8899 100644 --- a/ssz-rs/src/merkleization/merkleize.rs +++ b/ssz-rs/src/merkleization/merkleize.rs @@ -300,7 +300,7 @@ pub fn compute_merkle_tree(chunks: &[u8], leaf_count: usize) -> Result= 8 { + if leaf_count >= 16 { compute_merkle_tree_parallel(&mut buffer, node_count); } else { compute_merkle_tree_serial(&mut buffer, node_count); @@ -323,39 +323,115 @@ pub fn compute_merkle_tree(chunks: &[u8], leaf_count: usize) -> Result, Vec<_>) = (1..node_count).partition(|&i| { - let level = (i + 1).ilog2() as usize; - let pos_in_level = i - ((1 << level) - 1); - pos_in_level < (1 << (level - 1)) + let nodes = split_merkle_tree_nodes(node_count); + + // Create buffers for each section + let mut left_left_buf = vec![0u8; nodes.left_left.len() * BYTES_PER_CHUNK]; + let mut left_right_buf = vec![0u8; nodes.left_right.len() * BYTES_PER_CHUNK]; + let mut right_left_buf = vec![0u8; nodes.right_left.len() * BYTES_PER_CHUNK]; + let mut right_right_buf = vec![0u8; nodes.right_right.len() * BYTES_PER_CHUNK]; + + // Copy data to section buffers + copy_nodes_to_buffer(buffer, &nodes.left_left, &mut left_left_buf); + copy_nodes_to_buffer(buffer, &nodes.left_right, &mut left_right_buf); + copy_nodes_to_buffer(buffer, &nodes.right_left, &mut right_left_buf); + copy_nodes_to_buffer(buffer, &nodes.right_right, &mut right_right_buf); + + // Process all sections in parallel + rayon::scope(|s| { + s.spawn(|_| process_subtree(&mut left_left_buf, nodes.left_left.len())); + s.spawn(|_| process_subtree(&mut left_right_buf, nodes.left_right.len())); + s.spawn(|_| process_subtree(&mut right_left_buf, nodes.right_left.len())); + s.spawn(|_| process_subtree(&mut right_right_buf, nodes.right_right.len())); }); - // Create buffers for left and right subtrees - let mut left_buf = vec![0u8; left_nodes.len() * BYTES_PER_CHUNK]; - let mut right_buf = vec![0u8; right_nodes.len() * BYTES_PER_CHUNK]; - - // Copy relevant chunks to subtree buffers - copy_nodes_to_buffer(buffer, &left_nodes, &mut left_buf); - copy_nodes_to_buffer(buffer, &right_nodes, &mut right_buf); + // Copy results back + copy_buffer_to_nodes(&left_left_buf, &nodes.left_left, buffer); + copy_buffer_to_nodes(&left_right_buf, &nodes.left_right, buffer); + copy_buffer_to_nodes(&right_left_buf, &nodes.right_left, buffer); + copy_buffer_to_nodes(&right_right_buf, &nodes.right_right, buffer); - // Process left and right subtrees in parallel - rayon::join( - || process_subtree(&mut left_buf, left_nodes.len()), - || process_subtree(&mut right_buf, right_nodes.len()), + // Compute node 1 from 3 and 4 + let node1_hash = hash_chunks( + &buffer[3 * BYTES_PER_CHUNK..4 * BYTES_PER_CHUNK], + &buffer[4 * BYTES_PER_CHUNK..5 * BYTES_PER_CHUNK], ); + buffer[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK].copy_from_slice(&node1_hash); - // Copy results back to main buffer - copy_buffer_to_nodes(&left_buf, &left_nodes, buffer); - copy_buffer_to_nodes(&right_buf, &right_nodes, buffer); + // Compute node 2 from 5 and 6 + let node2_hash = hash_chunks( + &buffer[5 * BYTES_PER_CHUNK..6 * BYTES_PER_CHUNK], + &buffer[6 * BYTES_PER_CHUNK..7 * BYTES_PER_CHUNK], + ); + buffer[2 * BYTES_PER_CHUNK..3 * BYTES_PER_CHUNK].copy_from_slice(&node2_hash); - // Compute root hash + // Compute the root from 1 and 2 let root_hash = hash_chunks( &buffer[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK], &buffer[2 * BYTES_PER_CHUNK..3 * BYTES_PER_CHUNK], ); - // Store root hash at the root position buffer[..BYTES_PER_CHUNK].copy_from_slice(&root_hash); } +#[derive(Debug, PartialEq)] +struct SubtreeNodes { + left_left: Vec, + left_right: Vec, + right_left: Vec, + right_right: Vec, +} + +// Split merkle tree into 4 subtrees +fn split_merkle_tree_nodes(node_count: usize) -> SubtreeNodes { + let mut left_left = Vec::new(); + let mut left_right = Vec::new(); + let mut right_left = Vec::new(); + let mut right_right = Vec::new(); + + // Skip root node (index 0) + for i in 1..node_count { + // Current level in tree (0-based) + let level = (i + 1).ilog2() as usize; + // Position within level + let pos_in_level = i - ((1 << level) - 1); + + // Handle level 1 nodes (1,2) - skip them as they're not part of subtrees + if level == 1 { + continue; + } + + // For level 2 nodes (3,4,5,6), they become the roots of our subtrees + if level == 2 { + match pos_in_level { + 0 => left_left.push(i), // Node 3 + 1 => left_right.push(i), // Node 4 + 2 => right_left.push(i), // Node 5 + 3 => right_right.push(i), // Node 6 + _ => unreachable!(), + } + continue; + } + + // For deeper levels, assign based on their parent at level 2 + let ancestor_at_level_2 = { + let steps_up = level - 2; + let parent_pos = pos_in_level >> steps_up; + parent_pos + 3 // +3 because level 2 starts at index 3 + }; + + // Assign to appropriate subtree based on level 2 ancestor + match ancestor_at_level_2 { + 3 => left_left.push(i), + 4 => left_right.push(i), + 5 => right_left.push(i), + 6 => right_right.push(i), + _ => unreachable!(), + } + } + + SubtreeNodes { left_left, left_right, right_left, right_right } +} + // Compute the Merkle tree serially. fn compute_merkle_tree_serial(buffer: &mut [u8], node_count: usize) { for i in (1..node_count).rev().step_by(2) { @@ -590,6 +666,54 @@ mod tests { ); } + #[test] + fn test_split_merkle_tree_nodes_tree() { + // 0 + // / \ + // 1 2 + // / \ / \ + // 3 4 5 6 + // / \ / \ / \ / \ + // 7 8 9 10 11 12 13 14 + let nodes = split_merkle_tree_nodes(15); + + assert_eq!(nodes.left_left, vec![3, 7, 8]); + assert_eq!(nodes.left_right, vec![4, 9, 10]); + assert_eq!(nodes.right_left, vec![5, 11, 12]); + assert_eq!(nodes.right_right, vec![6, 13, 14]); + } + + #[test] + fn test_split_merkle_tree_nodes_16_leaves() { + // 0 + // / \ + // 1 2 + // / \ / \ + // 3 4 5 6 + // / \ / \ / \ / \ + // 7 8 9 10 11 12 13 14 + // /\ /\ /\ /\ /\ /\ /\ /\ + //15 16... ...29 30 + + let nodes = split_merkle_tree_nodes(31); + + assert_eq!(nodes.left_left, vec![3, 7, 8, 15, 16, 17, 18]); + assert_eq!(nodes.left_right, vec![4, 9, 10, 19, 20, 21, 22]); + assert_eq!(nodes.right_left, vec![5, 11, 12, 23, 24, 25, 26]); + assert_eq!(nodes.right_right, vec![6, 13, 14, 27, 28, 29, 30]); + } + + #[test] + fn test_split_nodes_empty_tree() { + // Test edge case with just root node + let nodes = split_merkle_tree_nodes(1); + + assert_eq!(nodes.left_left, Vec::::new()); + assert_eq!(nodes.left_right, Vec::::new()); + assert_eq!(nodes.right_left, Vec::::new()); + assert_eq!(nodes.right_right, Vec::::new()); + } + #[test] fn test_process_subtree_with_size_zero() { let mut buffer = vec![0u8; 0];