diff --git a/trin-execution/src/trie_walker/filter.rs b/trin-execution/src/trie_walker/filter.rs index 40b897f94..717a0aa89 100644 --- a/trin-execution/src/trie_walker/filter.rs +++ b/trin-execution/src/trie_walker/filter.rs @@ -1,120 +1,123 @@ -use alloy::primitives::{B256, U256}; +use alloy::primitives::U256; use rand::{thread_rng, Rng}; -use crate::utils::partial_nibble_path_to_right_padded_b256; +use crate::utils::nibbles_to_right_padded_b256; #[derive(Debug, Clone)] pub struct Filter { - start_prefix: U256, - end_prefix: U256, + start: U256, + end: U256, } impl Filter { - pub fn new_random_filter(slice_count: u16) -> Self { + /// Create a new filter that includes the whole trie + /// Slice index must be less than slice count or it will panic + pub fn new(slice_index: u16, slice_count: u16) -> Self { + assert!( + slice_index < slice_count, + "slice_index must be less than slice_count" + ); + // if slice_count is 0 or 1, we want to include the whole trie if slice_count == 0 || slice_count == 1 { return Self { - start_prefix: U256::ZERO, - end_prefix: U256::MAX, + start: U256::ZERO, + end: U256::MAX, }; } - let slice_size = U256::MAX / U256::from(slice_count); - let random_slice_index = thread_rng().gen_range(0..slice_count); + let slice_size = U256::MAX / U256::from(slice_count) + U256::from(1); - let start_prefix = U256::from(random_slice_index) * slice_size; - let end_prefix = if random_slice_index == slice_count - 1 { + let start = U256::from(slice_index) * slice_size; + let end = if slice_index == slice_count - 1 { U256::MAX } else { - start_prefix + slice_size - U256::from(1) + start + slice_size - U256::from(1) }; - Self { - start_prefix, - end_prefix, - } + Self { start, end } } - fn partial_prefix(prefix: U256, shift_amount: usize) -> B256 { - prefix - .arithmetic_shr(shift_amount) - .wrapping_shl(shift_amount) - .into() + pub fn random(slice_count: u16) -> Self { + Self::new(thread_rng().gen_range(0..slice_count), slice_count) } - pub fn is_included(&self, path: &[u8]) -> bool { + /// Check if a path is included in the filter + pub fn contains(&self, path: &[u8]) -> bool { // we need to use partial prefixes to not artificially exclude paths that are not exactly // the same length as the filter let shift_amount = 256 - path.len() * 4; - let partial_start_prefix = Filter::partial_prefix(self.start_prefix, shift_amount); - let partial_end_prefix = Filter::partial_prefix(self.end_prefix, shift_amount); - let path = partial_nibble_path_to_right_padded_b256(path); + let partial_start_prefix = (self.start >> shift_amount) << shift_amount; + let partial_end_prefix = (self.end >> shift_amount) << shift_amount; + let path = nibbles_to_right_padded_b256(path); - (partial_start_prefix..=partial_end_prefix).contains(&path) + (partial_start_prefix..=partial_end_prefix).contains(&path.into()) } } #[cfg(test)] mod tests { - use alloy::hex::FromHex; - use super::*; #[test] fn test_new_random_filter() { - let filter = Filter::new_random_filter(0); - assert_eq!(filter.start_prefix, U256::ZERO); - assert_eq!(filter.end_prefix, U256::MAX); + let filter = Filter::random(0); + assert_eq!(filter.start, U256::ZERO); + assert_eq!(filter.end, U256::MAX); - let filter = Filter::new_random_filter(1); - assert_eq!(filter.start_prefix, U256::ZERO); - assert_eq!(filter.end_prefix, U256::MAX); + let filter = Filter::random(1); + assert_eq!(filter.start, U256::ZERO); + assert_eq!(filter.end, U256::MAX); } #[test] fn test_is_included() { let filter = Filter { - start_prefix: partial_nibble_path_to_right_padded_b256(&[0x1, 0x5, 0x5]).into(), - end_prefix: partial_nibble_path_to_right_padded_b256(&[0x3]).into(), + start: nibbles_to_right_padded_b256(&[0x1, 0x5, 0x5]).into(), + end: nibbles_to_right_padded_b256(&[0x3]).into(), }; - assert!(!filter.is_included(&[0x0])); - assert!(filter.is_included(&[0x1])); - assert!(filter.is_included(&[0x1, 0x5])); - assert!(filter.is_included(&[0x1, 0x5, 0x5])); - assert!(!filter.is_included(&[0x1, 0x5, 0x4])); - assert!(filter.is_included(&[0x2])); - assert!(filter.is_included(&[0x3])); - assert!(!filter.is_included(&[0x4])); + assert!(!filter.contains(&[0x0])); + assert!(filter.contains(&[0x1])); + assert!(filter.contains(&[0x1, 0x5])); + assert!(filter.contains(&[0x1, 0x5, 0x5])); + assert!(!filter.contains(&[0x1, 0x5, 0x4])); + assert!(filter.contains(&[0x2])); + assert!(filter.contains(&[0x3])); + assert!(!filter.contains(&[0x3, 0x0, 0x1])); + assert!(!filter.contains(&[0x4])); } #[test] - fn test_partial_prefix() { - let prefix: &[u8] = &[0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]; - let prefix: U256 = partial_nibble_path_to_right_padded_b256(prefix).into(); + fn test_new() { + let filter = Filter::new(0, 1); + assert_eq!(filter.start, U256::ZERO); + assert_eq!(filter.end, U256::MAX); - assert_eq!( - Filter::partial_prefix(prefix, 256), - B256::from_hex("0x0000000000000000000000000000000000000000000000000000000000000000") - .unwrap() - ); + let filter = Filter::new(0, 2); + assert_eq!(filter.start, U256::ZERO); + assert_eq!(filter.end, U256::MAX / U256::from(2)); - assert_eq!( - Filter::partial_prefix(prefix, 256 - 4), - B256::from_hex("0x1000000000000000000000000000000000000000000000000000000000000000") - .unwrap() - ); + let filter = Filter::new(1, 2); + assert_eq!(filter.start, U256::MAX / U256::from(2) + U256::from(1)); + assert_eq!(filter.end, U256::MAX); + + let filter = Filter::new(0, 3); + assert_eq!(filter.start, U256::ZERO); + assert_eq!(filter.end, U256::MAX / U256::from(3)); + let filter = Filter::new(1, 3); + assert_eq!(filter.start, U256::MAX / U256::from(3) + U256::from(1)); assert_eq!( - Filter::partial_prefix(prefix, 256 - 4 * 4), - B256::from_hex("0x1234000000000000000000000000000000000000000000000000000000000000") - .unwrap() + filter.end, + U256::MAX / U256::from(3) * U256::from(2) + U256::from(1) ); + let filter = Filter::new(2, 3); assert_eq!( - Filter::partial_prefix(prefix, 256 - 5 * 4), - B256::from_hex("0x1234500000000000000000000000000000000000000000000000000000000000") - .unwrap() + filter.start, + U256::MAX / U256::from(3) * U256::from(2) + U256::from(2) ); + assert_eq!(filter.end, U256::MAX); } } diff --git a/trin-execution/src/trie_walker/mod.rs b/trin-execution/src/trie_walker/mod.rs index 84a2979ad..654762f0d 100644 --- a/trin-execution/src/trie_walker/mod.rs +++ b/trin-execution/src/trie_walker/mod.rs @@ -84,7 +84,7 @@ impl TrieWalker { ) -> anyhow::Result<()> { // If we have a filter, we only want to include nodes that are in the filter if let Some(filter) = &self.filter { - if !filter.is_included(&path) { + if !filter.contains(&path) { return Ok(()); } } diff --git a/trin-execution/src/utils.rs b/trin-execution/src/utils.rs index 8012afc5d..332e4b9c6 100644 --- a/trin-execution/src/utils.rs +++ b/trin-execution/src/utils.rs @@ -1,17 +1,5 @@ use alloy::primitives::{keccak256, Address, B256}; -fn compress_nibbles(nibbles: &[u8]) -> Vec { - let mut compressed_nibbles = vec![]; - for i in 0..nibbles.len() { - if i % 2 == 0 { - compressed_nibbles.push(nibbles[i] << 4); - } else { - compressed_nibbles[i / 2] |= nibbles[i]; - } - } - compressed_nibbles -} - pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 { if key_path.len() != 64 { panic!( @@ -20,11 +8,19 @@ pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 { ); } - B256::from_slice(&compress_nibbles(key_path)) + nibbles_to_right_padded_b256(key_path) } -pub fn partial_nibble_path_to_right_padded_b256(partial_nibble_path: &[u8]) -> B256 { - B256::right_padding_from(&compress_nibbles(partial_nibble_path)) +pub fn nibbles_to_right_padded_b256(nibbles: &[u8]) -> B256 { + let mut result = B256::ZERO; + for (i, nibble) in nibbles.iter().enumerate() { + if i % 2 == 0 { + result[i / 2] |= nibble << 4; + } else { + result[i / 2] |= nibble; + }; + } + result } pub fn address_to_nibble_path(address: Address) -> Vec { @@ -41,8 +37,7 @@ mod tests { use revm_primitives::{keccak256, Address, B256}; use crate::utils::{ - address_to_nibble_path, full_nibble_path_to_address_hash, - partial_nibble_path_to_right_padded_b256, + address_to_nibble_path, full_nibble_path_to_address_hash, nibbles_to_right_padded_b256, }; #[test] @@ -68,7 +63,7 @@ mod tests { #[test] fn test_partial_nibble_path_to_right_padded_b256() { let partial_nibble_path = vec![0xf, 0xf, 0x0, 0x1, 0x0, 0x2, 0x0, 0x3]; - let partial_path = partial_nibble_path_to_right_padded_b256(&partial_nibble_path); + let partial_path = nibbles_to_right_padded_b256(&partial_nibble_path); assert_eq!( partial_path, B256::from_hex("0xff01020300000000000000000000000000000000000000000000000000000000")