diff --git a/Cargo.lock b/Cargo.lock index 8666a4693..cca61bf21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6874,6 +6874,7 @@ dependencies = [ "lazy_static", "parking_lot 0.11.2", "prometheus_exporter", + "rand", "rayon", "reqwest", "revm", diff --git a/portal-bridge/src/bridge/state.rs b/portal-bridge/src/bridge/state.rs index acad77145..900fefddd 100644 --- a/portal-bridge/src/bridge/state.rs +++ b/portal-bridge/src/bridge/state.rs @@ -360,7 +360,7 @@ impl StateBridge { let root_hash = evm_db.trie.lock().root_hash()?; let mut content_idx = 0; - let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone())?; + let state_walker = TrieWalker::new(root_hash, evm_db.trie.lock().db.clone(), None)?; for account_proof in state_walker { // gossip the account self.gossip_account(&account_proof, block_hash, content_idx) @@ -426,7 +426,7 @@ impl StateBridge { let account_db = AccountDB::new(address_hash, evm_db.db.clone()); let trie = EthTrie::from(Arc::new(account_db), account.storage_root)?.db; - let storage_walker = TrieWalker::new(account.storage_root, trie)?; + let storage_walker = TrieWalker::new(account.storage_root, trie, None)?; for storage_proof in storage_walker { self.gossip_storage( &account_proof, diff --git a/trin-execution/Cargo.toml b/trin-execution/Cargo.toml index 1989dd052..f24c3570b 100644 --- a/trin-execution/Cargo.toml +++ b/trin-execution/Cargo.toml @@ -27,6 +27,7 @@ jsonrpsee = { workspace = true, features = ["async-client", "client", "macros", lazy_static.workspace = true parking_lot.workspace = true prometheus_exporter.workspace = true +rand.workspace = true rayon = "1.10.0" reqwest = { workspace = true, features = ["stream"] } revm.workspace = true diff --git a/trin-execution/src/trie_walker/filter.rs b/trin-execution/src/trie_walker/filter.rs new file mode 100644 index 000000000..073a04ca0 --- /dev/null +++ b/trin-execution/src/trie_walker/filter.rs @@ -0,0 +1,72 @@ +use alloy::primitives::{B256, U256}; +use rand::{thread_rng, Rng}; + +use crate::utils::partial_nibble_path_to_right_padded_b256; + +#[derive(Debug, Clone)] +pub struct Filter { + start_prefix: B256, + end_prefix: B256, +} + +impl Filter { + pub fn new_random_filter(slice_count: u16) -> Self { + // if slice_count is 0 or 1, we want to include the whole trie + if slice_count == 0 || slice_count == 1 { + return Self { + start_prefix: B256::ZERO, + end_prefix: B256::from(U256::MAX), + }; + } + + let slice_size = U256::MAX / U256::from(slice_count); + let random_slice_index = thread_rng().gen_range(0..slice_count); + + let start_prefix = U256::from(random_slice_index) * slice_size; + let end_prefix = if random_slice_index == slice_count - 1 { + U256::MAX + } else { + start_prefix + slice_size - U256::from(1) + }; + + Self { + start_prefix: B256::from(start_prefix), + end_prefix: B256::from(end_prefix), + } + } + + pub fn is_included(&self, path: &[u8]) -> bool { + let path = partial_nibble_path_to_right_padded_b256(path); + (self.start_prefix..=self.end_prefix).contains(&path) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_random_filter() { + let filter = Filter::new_random_filter(0); + assert_eq!(filter.start_prefix, B256::ZERO); + assert_eq!(filter.end_prefix, B256::from(U256::MAX)); + + let filter = Filter::new_random_filter(1); + assert_eq!(filter.start_prefix, B256::ZERO); + assert_eq!(filter.end_prefix, B256::from(U256::MAX)); + } + + #[test] + fn test_is_included() { + let filter = Filter { + start_prefix: partial_nibble_path_to_right_padded_b256(&[0x1]), + end_prefix: partial_nibble_path_to_right_padded_b256(&[0x3]), + }; + + assert!(!filter.is_included(&[0x00])); + assert!(filter.is_included(&[0x01])); + assert!(filter.is_included(&[0x02])); + assert!(filter.is_included(&[0x03])); + assert!(!filter.is_included(&[0x04])); + } +} diff --git a/trin-execution/src/trie_walker/mod.rs b/trin-execution/src/trie_walker/mod.rs index e8e0dad57..84a2979ad 100644 --- a/trin-execution/src/trie_walker/mod.rs +++ b/trin-execution/src/trie_walker/mod.rs @@ -1,4 +1,5 @@ pub mod db; +pub mod filter; use std::sync::Arc; @@ -6,6 +7,7 @@ use alloy::primitives::{Bytes, B256}; use anyhow::{anyhow, Ok}; use db::TrieWalkerDb; use eth_trie::{decode_node, node::Node}; +use filter::Filter; use crate::types::trie_proof::TrieProof; @@ -21,10 +23,13 @@ pub struct TrieWalker { is_partial_trie: bool, trie: Arc, stack: Vec, + + /// You can filter what slice of the trie you want to walk + filter: Option, } impl TrieWalker { - pub fn new(root_hash: B256, trie: Arc) -> anyhow::Result { + pub fn new(root_hash: B256, trie: Arc, filter: Option) -> anyhow::Result { let root_node_trie = match trie.get(root_hash.as_slice())? { Some(root_node_trie) => root_node_trie, None => return Err(anyhow!("Root node not found in the database")), @@ -38,6 +43,7 @@ impl TrieWalker { is_partial_trie: false, trie, stack: vec![root_proof], + filter, }) } @@ -52,6 +58,7 @@ impl TrieWalker { is_partial_trie: true, trie: Arc::new(trie), stack: vec![], + filter: None, }); } }; @@ -65,6 +72,7 @@ impl TrieWalker { is_partial_trie: true, trie: Arc::new(trie), stack: vec![root_proof], + filter: None, }) } @@ -74,6 +82,13 @@ impl TrieWalker { partial_proof: Vec, path: Vec, ) -> anyhow::Result<()> { + // If we have a filter, we only want to include nodes that are in the filter + if let Some(filter) = &self.filter { + if !filter.is_included(&path) { + return Ok(()); + } + } + // We only need to process hash nodes, because if the node isn't a hash node then none of // its children is if let Node::Hash(hash) = node { @@ -191,7 +206,7 @@ mod tests { } let root_hash = trie.root_hash().unwrap(); - let walker = TrieWalker::new(root_hash, trie.db.clone()).unwrap(); + let walker = TrieWalker::new(root_hash, trie.db.clone(), None).unwrap(); let mut count = 0; let mut leaf_count = 0; for proof in walker { diff --git a/trin-execution/src/utils.rs b/trin-execution/src/utils.rs index 91966abde..8012afc5d 100644 --- a/trin-execution/src/utils.rs +++ b/trin-execution/src/utils.rs @@ -1,5 +1,17 @@ use alloy::primitives::{keccak256, Address, B256}; +fn compress_nibbles(nibbles: &[u8]) -> Vec { + let mut compressed_nibbles = vec![]; + for i in 0..nibbles.len() { + if i % 2 == 0 { + compressed_nibbles.push(nibbles[i] << 4); + } else { + compressed_nibbles[i / 2] |= nibbles[i]; + } + } + compressed_nibbles +} + pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 { if key_path.len() != 64 { panic!( @@ -8,15 +20,11 @@ pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 { ); } - let mut raw_address_hash = vec![]; - for i in 0..key_path.len() { - if i % 2 == 0 { - raw_address_hash.push(key_path[i] << 4); - } else { - raw_address_hash[i / 2] |= key_path[i]; - } - } - B256::from_slice(&raw_address_hash) + B256::from_slice(&compress_nibbles(key_path)) +} + +pub fn partial_nibble_path_to_right_padded_b256(partial_nibble_path: &[u8]) -> B256 { + B256::right_padding_from(&compress_nibbles(partial_nibble_path)) } pub fn address_to_nibble_path(address: Address) -> Vec { @@ -28,10 +36,14 @@ pub fn address_to_nibble_path(address: Address) -> Vec { #[cfg(test)] mod tests { + use alloy::hex::FromHex; use eth_trie::nibbles::Nibbles as EthNibbles; - use revm_primitives::{keccak256, Address}; + use revm_primitives::{keccak256, Address, B256}; - use crate::utils::{address_to_nibble_path, full_nibble_path_to_address_hash}; + use crate::utils::{ + address_to_nibble_path, full_nibble_path_to_address_hash, + partial_nibble_path_to_right_padded_b256, + }; #[test] fn test_eth_trie_and_ethportalapi_nibbles() { @@ -52,4 +64,15 @@ mod tests { let generated_address_hash = full_nibble_path_to_address_hash(&path); assert_eq!(address_hash, generated_address_hash); } + + #[test] + fn test_partial_nibble_path_to_right_padded_b256() { + let partial_nibble_path = vec![0xf, 0xf, 0x0, 0x1, 0x0, 0x2, 0x0, 0x3]; + let partial_path = partial_nibble_path_to_right_padded_b256(&partial_nibble_path); + assert_eq!( + partial_path, + B256::from_hex("0xff01020300000000000000000000000000000000000000000000000000000000") + .unwrap() + ); + } }