diff --git a/mpt_trie/src/debug_tools/common.rs b/mpt_trie/src/debug_tools/common.rs index 5eba83624..cb38d5c47 100644 --- a/mpt_trie/src/debug_tools/common.rs +++ b/mpt_trie/src/debug_tools/common.rs @@ -1,66 +1,9 @@ //! Common utilities for the debugging tools. -use std::fmt::{self, Display}; - use crate::{ - nibbles::{Nibble, Nibbles}, + nibbles::Nibbles, partial_trie::{Node, PartialTrie}, - utils::TrieNodeType, }; -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub(super) enum PathSegment { - Empty, - Hash, - Branch(Nibble), - Extension(Nibbles), - Leaf(Nibbles), -} - -impl Display for PathSegment { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - PathSegment::Empty => write!(f, "Empty"), - PathSegment::Hash => write!(f, "Hash"), - PathSegment::Branch(nib) => write!(f, "Branch({})", nib), - PathSegment::Extension(nibs) => write!(f, "Extension({})", nibs), - PathSegment::Leaf(nibs) => write!(f, "Leaf({})", nibs), - } - } -} - -impl PathSegment { - pub(super) fn node_type(&self) -> TrieNodeType { - match self { - PathSegment::Empty => TrieNodeType::Empty, - PathSegment::Hash => TrieNodeType::Hash, - PathSegment::Branch(_) => TrieNodeType::Branch, - PathSegment::Extension(_) => TrieNodeType::Extension, - PathSegment::Leaf(_) => TrieNodeType::Leaf, - } - } - - pub(super) fn get_key_piece_from_seg_if_present(&self) -> Option { - match self { - PathSegment::Empty | PathSegment::Hash => None, - PathSegment::Branch(nib) => Some(Nibbles::from_nibble(*nib)), - PathSegment::Extension(nibs) | PathSegment::Leaf(nibs) => Some(*nibs), - } - } -} - -pub(super) fn get_segment_from_node_and_key_piece( - n: &Node, - k_piece: &Nibbles, -) -> PathSegment { - match TrieNodeType::from(n) { - TrieNodeType::Empty => PathSegment::Empty, - TrieNodeType::Hash => PathSegment::Hash, - TrieNodeType::Branch => PathSegment::Branch(k_piece.get_nibble(0)), - TrieNodeType::Extension => PathSegment::Extension(*k_piece), - TrieNodeType::Leaf => PathSegment::Leaf(*k_piece), - } -} - /// Get the key piece from the given node if applicable. /// /// Note that there is no specific [`Nibble`] associated with a branch like @@ -87,43 +30,3 @@ pub(super) fn get_key_piece_from_node(n: &Node) -> Nibbles { Node::Extension { nibbles, child: _ } | Node::Leaf { nibbles, value: _ } => *nibbles, } } - -#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] -/// A vector of path segments representing a path in the trie. -pub struct NodePath(pub(super) Vec); - -impl NodePath { - pub(super) fn dup_and_append(&self, seg: PathSegment) -> Self { - let mut duped_vec = self.0.clone(); - duped_vec.push(seg); - - Self(duped_vec) - } - - pub(super) fn append(&mut self, seg: PathSegment) { - self.0.push(seg); - } - - fn write_elem(f: &mut fmt::Formatter<'_>, seg: &PathSegment) -> fmt::Result { - write!(f, "{}", seg) - } -} - -impl Display for NodePath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let num_elems = self.0.len(); - - // For everything but the last elem. - for seg in self.0.iter().take(num_elems.saturating_sub(1)) { - Self::write_elem(f, seg)?; - write!(f, " --> ")?; - } - - // Avoid the extra `-->` for the last elem. - if let Some(seg) = self.0.last() { - Self::write_elem(f, seg)?; - } - - Ok(()) - } -} diff --git a/mpt_trie/src/debug_tools/diff.rs b/mpt_trie/src/debug_tools/diff.rs index 3ed857492..d6271cfb1 100644 --- a/mpt_trie/src/debug_tools/diff.rs +++ b/mpt_trie/src/debug_tools/diff.rs @@ -30,7 +30,8 @@ use std::{fmt::Display, ops::Deref}; use ethereum_types::H256; -use super::common::{get_key_piece_from_node, get_segment_from_node_and_key_piece, NodePath}; +use super::common::get_key_piece_from_node; +use crate::utils::{get_segment_from_node_and_key_piece, TriePath}; use crate::{ nibbles::Nibbles, partial_trie::{HashedPartialTrie, Node, PartialTrie}, @@ -83,7 +84,7 @@ pub struct DiffPoint { /// The depth of the point in both tries. pub depth: usize, /// The path of the point in both tries. - pub path: NodePath, + pub path: TriePath, /// The node key in both tries. pub key: Nibbles, /// The node info in the first trie. @@ -97,7 +98,7 @@ impl DiffPoint { child_a: &HashedPartialTrie, child_b: &HashedPartialTrie, parent_k: Nibbles, - path: NodePath, + path: TriePath, ) -> Self { let a_key = parent_k.merge_nibbles(&get_key_piece_from_node(child_a)); let b_key = parent_k.merge_nibbles(&get_key_piece_from_node(child_b)); @@ -222,7 +223,7 @@ impl DepthNodeDiffState { parent_k: &Nibbles, child_a: &HashedPartialTrie, child_b: &HashedPartialTrie, - path: NodePath, + path: TriePath, ) { if field .as_ref() @@ -242,7 +243,7 @@ struct DepthDiffPerCallState<'a> { curr_depth: usize, // Horribly inefficient, but these are debug tools, so I think we get a pass. - curr_path: NodePath, + curr_path: TriePath, } impl<'a> DepthDiffPerCallState<'a> { @@ -259,7 +260,7 @@ impl<'a> DepthDiffPerCallState<'a> { b, curr_key, curr_depth, - curr_path: NodePath::default(), + curr_path: TriePath::default(), } } @@ -411,7 +412,7 @@ fn get_value_from_node(n: &Node) -> Option<&Vec> { #[cfg(test)] mod tests { - use super::{create_diff_between_tries, DiffPoint, NodeInfo, NodePath}; + use super::{create_diff_between_tries, DiffPoint, NodeInfo, TriePath}; use crate::{ nibbles::Nibbles, partial_trie::{HashedPartialTrie, PartialTrie}, @@ -447,7 +448,7 @@ mod tests { let expected = DiffPoint { depth: 0, - path: NodePath(vec![]), + path: TriePath(vec![]), key: Nibbles::default(), a_info: expected_a, b_info: expected_b, diff --git a/mpt_trie/src/debug_tools/query.rs b/mpt_trie/src/debug_tools/query.rs index 8ce9ddfad..dd7d3b654 100644 --- a/mpt_trie/src/debug_tools/query.rs +++ b/mpt_trie/src/debug_tools/query.rs @@ -5,13 +5,11 @@ use std::fmt::{self, Display}; use ethereum_types::H256; -use super::common::{ - get_key_piece_from_node_pulling_from_key_for_branches, get_segment_from_node_and_key_piece, - NodePath, PathSegment, -}; +use super::common::get_key_piece_from_node_pulling_from_key_for_branches; use crate::{ nibbles::Nibbles, partial_trie::{Node, PartialTrie, WrappedNode}, + utils::{get_segment_from_node_and_key_piece, TriePath, TrieSegment}, }; /// Params controlling how much information is reported in the query output. @@ -159,7 +157,9 @@ fn count_non_empty_branch_children_from_mask(mask: u16) -> usize { /// of the path used for searching for a key in the trie. pub struct DebugQueryOutput { k: Nibbles, - node_path: NodePath, + + /// The nodes hit during the query. + pub node_path: TriePath, extra_node_info: Vec>, node_found: bool, params: DebugQueryParams, @@ -199,7 +199,7 @@ impl DebugQueryOutput { fn new(k: Nibbles, params: DebugQueryParams) -> Self { Self { k, - node_path: NodePath::default(), + node_path: TriePath::default(), extra_node_info: Vec::default(), node_found: false, params, @@ -219,7 +219,7 @@ impl DebugQueryOutput { // TODO: Make the output easier to read... fn fmt_node_based_on_debug_params( f: &mut fmt::Formatter<'_>, - seg: &PathSegment, + seg: &TrieSegment, extra_seg_info: &Option, params: &DebugQueryParams, ) -> fmt::Result { diff --git a/mpt_trie/src/lib.rs b/mpt_trie/src/lib.rs index c359e0430..85988e4c4 100644 --- a/mpt_trie/src/lib.rs +++ b/mpt_trie/src/lib.rs @@ -17,10 +17,11 @@ pub mod nibbles; pub mod partial_trie; +pub mod special_query; mod trie_hashing; pub mod trie_ops; pub mod trie_subsets; -mod utils; +pub mod utils; #[cfg(feature = "trie_debug")] pub mod debug_tools; diff --git a/mpt_trie/src/nibbles.rs b/mpt_trie/src/nibbles.rs index 94eb8cafc..f5d63a043 100644 --- a/mpt_trie/src/nibbles.rs +++ b/mpt_trie/src/nibbles.rs @@ -363,11 +363,11 @@ impl Nibbles { /// Appends `Nibbles` to the front. /// /// # Panics - /// Panics if appending the `Nibble` causes an overflow (total nibbles > + /// Panics if appending the `Nibbles` causes an overflow (total nibbles > /// 64). pub fn push_nibbles_front(&mut self, n: &Self) { let new_count = self.count + n.count; - assert!(new_count <= 64); + self.nibbles_append_safety_asserts(new_count); let shift_amt = 4 * self.count; @@ -375,6 +375,21 @@ impl Nibbles { self.packed |= n.packed << shift_amt; } + /// Appends `Nibbles` to the back. + /// + /// # Panics + /// Panics if appending the `Nibbles` causes an overflow (total nibbles > + /// 64). + pub fn push_nibbles_back(&mut self, n: &Self) { + let new_count = self.count + n.count; + self.nibbles_append_safety_asserts(new_count); + + let shift_amt = 4 * n.count; + + self.count = new_count; + self.packed = (self.packed << shift_amt) | n.packed; + } + /// Gets the nibbles at the range specified, where `0` is the next nibble. /// /// # Panics @@ -765,6 +780,10 @@ impl Nibbles { assert!(n < 16); } + fn nibbles_append_safety_asserts(&self, new_count: usize) { + assert!(new_count <= 64); + } + // TODO: REMOVE BEFORE NEXT CRATE VERSION! THIS IS A TEMP HACK! /// Converts to u256 returning an error if not possible. pub fn try_into_u256(&self) -> Result { @@ -787,6 +806,9 @@ mod tests { use super::{Nibble, Nibbles, ToNibbles}; use crate::nibbles::FromHexPrefixError; + const LONG_ZERO_NIBS_STR_LEN_63: &str = + "0x000000000000000000000000000000000000000000000000000000000000000"; + #[test] fn get_nibble_works() { let n = Nibbles::from(0x1234); @@ -898,6 +920,69 @@ mod tests { assert_eq!(res, expected_resulting_nibbles); } + #[test] + fn push_nibble_front_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1, |n| n.push_nibble_front(0x1)); + test_and_assert_nib_push_func(0x1, 0x21, |n| n.push_nibble_front(0x2)); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x1000000000000000000000000000000000000000000000000000000000000000") + .unwrap(), + |n| n.push_nibble_front(0x1), + ); + } + + #[test] + fn push_nibble_back_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1, |n| n.push_nibble_back(0x1)); + test_and_assert_nib_push_func(0x1, 0x12, |n| n.push_nibble_back(0x2)); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(), + |n| n.push_nibble_back(0x1), + ); + } + + #[test] + fn push_nibbles_front_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1234, |n| { + n.push_nibbles_front(&0x1234.into()) + }); + test_and_assert_nib_push_func(0x1234, 0x5671234, |n| n.push_nibbles_front(&0x567.into())); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x1000000000000000000000000000000000000000000000000000000000000000") + .unwrap(), + |n| n.push_nibbles_front(&0x1.into()), + ); + } + + #[test] + fn push_nibbles_back_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1234, |n| { + n.push_nibbles_back(&0x1234.into()) + }); + test_and_assert_nib_push_func(0x1234, 0x1234567, |n| n.push_nibbles_back(&0x567.into())); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(), + |n| n.push_nibbles_back(&0x1.into()), + ); + } + + fn test_and_assert_nib_push_func, E: Into>( + starting_nibs: S, + expected: E, + f: F, + ) { + let mut nibs = starting_nibs.into(); + (f)(&mut nibs); + + assert_eq!(nibs, expected.into()); + } + #[test] fn get_next_nibbles_works() { let n: Nibbles = 0x1234.into(); @@ -1179,7 +1264,7 @@ mod tests { fn nibbles_from_h256_works() { assert_eq!( format!("{:x}", Nibbles::from_h256_be(H256::from_low_u64_be(0))), - "0x0000000000000000000000000000000000000000000000000000000000000000" + "0x0000000000000000000000000000000000000000000000000000000000000000", ); assert_eq!( format!("{:x}", Nibbles::from_h256_be(H256::from_low_u64_be(2048))), diff --git a/mpt_trie/src/special_query.rs b/mpt_trie/src/special_query.rs new file mode 100644 index 000000000..503331aa0 --- /dev/null +++ b/mpt_trie/src/special_query.rs @@ -0,0 +1,156 @@ +//! Specialized queries that users of the library may need that require +//! knowledge of the private internal trie state. + +use crate::{ + nibbles::Nibbles, + partial_trie::{Node, PartialTrie, WrappedNode}, + utils::TrieSegment, +}; + +/// An iterator for a trie query. Note that this iterator is lazy. +#[derive(Debug)] +pub struct TriePathIter { + /// The next node in the trie to query with the remaining key. + curr_node: WrappedNode, + + /// The remaining part of the key as we traverse down the trie. + curr_key: Nibbles, + + // Although wrapping `curr_node` in an option might be more "Rust like", the logic is a lot + // cleaner with a bool. + terminated: bool, +} + +impl Iterator for TriePathIter { + type Item = TrieSegment; + + fn next(&mut self) -> Option { + if self.terminated { + return None; + } + + match self.curr_node.as_ref() { + Node::Empty => { + self.terminated = true; + Some(TrieSegment::Empty) + } + Node::Hash(_) => { + self.terminated = true; + Some(TrieSegment::Hash) + } + Node::Branch { children, .. } => { + // Our query key has ended. Stop here. + if self.curr_key.is_empty() { + self.terminated = true; + return None; + } + + let nib = self.curr_key.pop_next_nibble_front(); + self.curr_node = children[nib as usize].clone(); + + Some(TrieSegment::Branch(nib)) + } + Node::Extension { nibbles, child } => { + match self + .curr_key + .nibbles_are_identical_up_to_smallest_count(nibbles) + { + false => { + // Only a partial match. Stop. + self.terminated = true; + None + } + true => { + pop_nibbles_clamped(&mut self.curr_key, nibbles.count); + let res = Some(TrieSegment::Extension(*nibbles)); + self.curr_node = child.clone(); + + res + } + } + } + Node::Leaf { nibbles, .. } => { + self.terminated = true; + + match self.curr_key == *nibbles { + false => None, + true => Some(TrieSegment::Leaf(*nibbles)), + } + } + } + } +} + +/// Attempts to pop `n` nibbles from the given [`Nibbles`] and "clamp" the +/// nibbles popped by not popping more nibbles than there are. +fn pop_nibbles_clamped(nibbles: &mut Nibbles, n: usize) -> Nibbles { + let n_nibs_to_pop = nibbles.count.min(n); + nibbles.pop_nibbles_front(n_nibs_to_pop) +} + +/// Returns all nodes in the trie that are traversed given a query (key). +/// +/// Note that if the key does not match the entire key of a node (eg. the +/// remaining key is `0x34` but the next key is a leaf with the key `0x3456`), +/// then the leaf will not appear in the query output. +pub fn path_for_query(trie: &Node, k: K) -> TriePathIter +where + K: Into, +{ + TriePathIter { + curr_node: trie.clone().into(), + curr_key: k.into(), + terminated: false, + } +} + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use super::path_for_query; + use crate::{nibbles::Nibbles, testing_utils::handmade_trie_1, utils::TrieSegment}; + + #[test] + fn query_iter_works() { + let (trie, ks) = handmade_trie_1(); + + // ks --> vec![0x1234, 0x1324, 0x132400005_u64, 0x2001, 0x2002]; + let res = vec![ + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(2), + TrieSegment::Leaf(0x34.into()), + ], + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(3), + TrieSegment::Extension(0x24.into()), + ], + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(3), + TrieSegment::Extension(0x24.into()), + TrieSegment::Branch(0), + TrieSegment::Leaf(Nibbles::from_str("0x0005").unwrap()), + ], + vec![ + TrieSegment::Branch(2), + TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()), + TrieSegment::Branch(0x1), + TrieSegment::Leaf(Nibbles::default()), + ], + vec![ + TrieSegment::Branch(2), + TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()), + TrieSegment::Branch(0x2), + TrieSegment::Leaf(Nibbles::default()), + ], + ]; + + for (q, expected) in ks.into_iter().zip(res.into_iter()) { + let res: Vec<_> = path_for_query(&trie.node, q).collect(); + assert_eq!(res, expected) + } + } +} diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs index 5a540d899..974b92afc 100644 --- a/mpt_trie/src/trie_subsets.rs +++ b/mpt_trie/src/trie_subsets.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use ethereum_types::H256; +use log::trace; use thiserror::Error; use crate::{ @@ -206,6 +207,7 @@ impl TrackedNodeInfo { } } +// TODO: Make this interface also work with &[ ... ]... /// Create a [`PartialTrie`] subset from a base trie given an iterator of keys /// of nodes that may or may not exist in the trie. All nodes traversed by the /// keys will not be hashed out in the trie subset. If the key does not exist in @@ -221,6 +223,7 @@ where create_trie_subset_intern(&mut tracked_trie, keys_involved.into_iter()) } +// TODO: Make this interface also work with &[ ... ]... /// Create [`PartialTrie`] subsets from a given base `PartialTrie` given a /// iterator of keys per subset needed. See [`create_trie_subset`] for more /// info. @@ -259,42 +262,68 @@ where Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie)) } +/// For a given key, mark every node that we encounter that is part of the key. +/// Note that this means non-existent keys passed into this function will mark +/// nodes to not be hashed that are part of the given key. For example: +/// - Relevant nodes in trie: [B(0x), B(0x1), L(0x123)] +/// - For the key `0x1`, the marked nodes would be [B(0x), B(0x1)]. +/// - For the key `0x12`, the marked nodes still would be [B(0x), B(0x1)]. +/// - For the key `0x123`, the marked nodes would be [B(0x), B(0x1), L(0x123)]. fn mark_nodes_that_are_needed( trie: &mut TrackedNode, curr_nibbles: &mut Nibbles, ) -> SubsetTrieResult<()> { - trie.info.touched = true; + trace!( + "Sub-trie marking at {:x}, (type: {})", + curr_nibbles, + TrieNodeType::from(trie.info.underlying_node.deref()) + ); match &mut trie.node { - TrackedNodeIntern::Empty => Ok(()), + TrackedNodeIntern::Empty => { + trie.info.touched = true; + } TrackedNodeIntern::Hash => match curr_nibbles.is_empty() { - false => Err(SubsetTrieError::UnexpectedKey( - *curr_nibbles, - format!("{:?}", trie), - )), - true => Ok(()), + false => { + return Err(SubsetTrieError::UnexpectedKey( + *curr_nibbles, + format!("{:?}", trie), + )) + } + true => { + trie.info.touched = true; + } }, // Note: If we end up supporting non-fixed sized keys, then we need to also check value. TrackedNodeIntern::Branch(children) => { + trie.info.touched = true; + // Check against branch value. if curr_nibbles.is_empty() { return Ok(()); } let nib = curr_nibbles.pop_next_nibble_front(); - mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles) + return mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles); } TrackedNodeIntern::Extension(child) => { let nibbles = trie.info.get_nibbles_expected(); let r = curr_nibbles.pop_nibbles_front(nibbles.count); - match r.nibbles_are_identical_up_to_smallest_count(nibbles) { - false => Ok(()), - true => mark_nodes_that_are_needed(child, curr_nibbles), + if r.nibbles_are_identical_up_to_smallest_count(nibbles) { + trie.info.touched = true; + return mark_nodes_that_are_needed(child, curr_nibbles); + } + } + TrackedNodeIntern::Leaf => { + let (k, _) = trie.info.get_leaf_nibbles_and_value_expected(); + if k == curr_nibbles { + trie.info.touched = true; } } - TrackedNodeIntern::Leaf => Ok(()), } + + Ok(()) } fn create_partial_trie_subset_from_tracked_trie( @@ -333,10 +362,14 @@ fn create_partial_trie_subset_from_tracked_trie( fn reset_tracked_trie_state(tracked_node: &mut TrackedNode) { match tracked_node.node { - TrackedNodeIntern::Branch(ref mut children) => { - children.iter_mut().for_each(|c| c.info.reset()) + TrackedNodeIntern::Branch(ref mut children) => children.iter_mut().for_each(|c| { + c.info.reset(); + reset_tracked_trie_state(c); + }), + TrackedNodeIntern::Extension(ref mut child) => { + child.info.reset(); + reset_tracked_trie_state(child); } - TrackedNodeIntern::Extension(ref mut child) => child.info.reset(), TrackedNodeIntern::Empty | TrackedNodeIntern::Hash | TrackedNodeIntern::Leaf => { tracked_node.info.reset() } @@ -345,17 +378,20 @@ fn reset_tracked_trie_state(tracked_node: &mut TrackedNode) { #[cfg(test)] mod tests { - use std::{collections::HashSet, iter::once}; + use std::{ + collections::{HashMap, HashSet}, + iter::once, + }; use ethereum_types::H256; use super::{create_trie_subset, create_trie_subsets}; use crate::{ nibbles::Nibbles, - partial_trie::{HashedPartialTrie, Node, PartialTrie}, + partial_trie::{Node, PartialTrie}, testing_utils::{ - create_trie_with_large_entry_nodes, generate_n_random_fixed_trie_value_entries, - handmade_trie_1, TrieType, + common_setup, create_trie_with_large_entry_nodes, + generate_n_random_fixed_trie_value_entries, handmade_trie_1, TrieType, }, trie_ops::ValOrHash, utils::TrieNodeType, @@ -386,36 +422,56 @@ mod tests { } } + fn get_all_nodes_in_trie(trie: &TrieType) -> Vec { + get_nodes_in_trie_intern(trie, false) + } + fn get_all_non_empty_and_hash_nodes_in_trie(trie: &TrieType) -> Vec { + get_nodes_in_trie_intern(trie, true) + } + + fn get_nodes_in_trie_intern( + trie: &TrieType, + return_on_empty_or_hash: bool, + ) -> Vec { let mut nodes = Vec::new(); - get_all_non_empty_and_non_hash_nodes_in_trie_intern(trie, Nibbles::default(), &mut nodes); + get_nodes_in_trie_intern_rec( + trie, + Nibbles::default(), + &mut nodes, + return_on_empty_or_hash, + ); nodes } - fn get_all_non_empty_and_non_hash_nodes_in_trie_intern( + fn get_nodes_in_trie_intern_rec( trie: &TrieType, mut curr_nibbles: Nibbles, nodes: &mut Vec, + return_on_empty_or_hash: bool, ) { match &trie.node { - Node::Empty | Node::Hash(_) => return, + Node::Empty | Node::Hash(_) => match return_on_empty_or_hash { + false => (), + true => return, + }, Node::Branch { children, .. } => { for (i, c) in children.iter().enumerate() { - get_all_non_empty_and_non_hash_nodes_in_trie_intern( + get_nodes_in_trie_intern_rec( c, curr_nibbles.merge_nibble(i as u8), nodes, + return_on_empty_or_hash, ) } } - Node::Extension { nibbles, child } => { - get_all_non_empty_and_non_hash_nodes_in_trie_intern( - child, - curr_nibbles.merge_nibbles(nibbles), - nodes, - ) - } + Node::Extension { nibbles, child } => get_nodes_in_trie_intern_rec( + child, + curr_nibbles.merge_nibbles(nibbles), + nodes, + return_on_empty_or_hash, + ), Node::Leaf { nibbles, .. } => curr_nibbles = curr_nibbles.merge_nibbles(nibbles), }; @@ -430,6 +486,8 @@ mod tests { #[test] fn empty_trie_does_not_return_err_on_query() { + common_setup(); + let trie = TrieType::default(); let nibbles: Nibbles = 0x1234.into(); let res = create_trie_subset(&trie, once(nibbles)); @@ -439,6 +497,8 @@ mod tests { #[test] fn non_existent_key_does_not_return_err() { + common_setup(); + let mut trie = TrieType::default(); trie.insert(0x1234, vec![0, 1, 2]); let res = create_trie_subset(&trie, once(0x5678)); @@ -448,7 +508,9 @@ mod tests { #[test] fn encountering_a_hash_node_returns_err() { - let trie = HashedPartialTrie::new(Node::Hash(H256::zero())); + common_setup(); + + let trie = TrieType::new(Node::Hash(H256::zero())); let res = create_trie_subset(&trie, once(0x1234)); assert!(res.is_err()) @@ -456,6 +518,8 @@ mod tests { #[test] fn single_node_trie_is_queryable() { + common_setup(); + let mut trie = TrieType::default(); trie.insert(0x1234, vec![0, 1, 2]); let trie_subset = create_trie_subset(&trie, once(0x1234)).unwrap(); @@ -465,6 +529,8 @@ mod tests { #[test] fn multi_node_trie_returns_proper_subset() { + common_setup(); + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x56, 0x12345_u64]); let trie_subset = create_trie_subset(&trie, vec![0x1234, 0x56]).unwrap(); @@ -477,6 +543,8 @@ mod tests { #[test] fn intermediate_nodes_are_included_in_subset() { + common_setup(); + let (trie, ks_nibbles) = handmade_trie_1(); let trie_subset_all = create_trie_subset(&trie, ks_nibbles.iter().cloned()).unwrap(); @@ -573,8 +641,55 @@ mod tests { ))); } + fn assert_nodes_are_leaf_nodes, I: IntoIterator>( + trie: &TrieType, + keys: I, + ) { + assert_keys_point_to_nodes_of_type( + trie, + keys.into_iter().map(|k| (k.into(), TrieNodeType::Leaf)), + ) + } + + fn assert_nodes_are_hash_nodes, I: IntoIterator>( + trie: &TrieType, + keys: I, + ) { + assert_keys_point_to_nodes_of_type( + trie, + keys.into_iter().map(|k| (k.into(), TrieNodeType::Hash)), + ) + } + + fn assert_keys_point_to_nodes_of_type( + trie: &TrieType, + keys: impl Iterator, + ) { + let nodes = get_all_nodes_in_trie(trie); + let keys_to_node_types: HashMap<_, _> = + HashMap::from_iter(nodes.into_iter().map(|n| (n.nibbles.reverse(), n.n_type))); + + for (k, expected_n_type) in keys { + let actual_n_type_opt = keys_to_node_types.get(&k); + + match actual_n_type_opt { + Some(actual_n_type) => { + if *actual_n_type != expected_n_type { + panic!("Expected trie node at {:x} to be a {} node but it wasn't! (found a {} node instead)", k, expected_n_type, actual_n_type) + } + } + None => panic!( + "Expected a {} node at {:x} but no node was found!", + expected_n_type, k + ), + } + } + } + #[test] fn all_leafs_of_keys_to_create_subset_are_included_in_subset_for_giant_trie() { + common_setup(); + let (_, trie_subsets, keys_of_subsets) = create_massive_trie_and_subsets(9009); for (sub_trie, ks_used) in trie_subsets.into_iter().zip(keys_of_subsets.into_iter()) { @@ -585,17 +700,38 @@ mod tests { #[test] fn hash_of_single_leaf_trie_partial_trie_matches_original_trie() { - let mut trie = TrieType::default(); - trie.insert(0x1234, vec![0]); + let trie = create_trie_with_large_entry_nodes(&[0x0]); let base_hash = trie.hash(); - let partial_trie = create_trie_subset(&trie, vec![0x1234]).unwrap(); + let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap(); assert_eq!(base_hash, partial_trie.hash()); } + #[test] + fn sub_trie_that_includes_branch_but_not_children_hashes_out_children() { + common_setup(); + + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x12345, 0x12346, 0x1234f]); + let partial_trie = create_trie_subset(&trie, [0x1234f]).unwrap(); + + assert_nodes_are_hash_nodes(&partial_trie, [0x12345, 0x12346]); + } + + #[test] + fn sub_trie_for_non_existent_key_that_hits_branch_leaf_hashes_out_leaf() { + common_setup(); + + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x1234589, 0x12346]); + let partial_trie = create_trie_subset(&trie, [0x1234567]).unwrap(); + + // Note that `0x1234589` gets hashed at the branch slot at `0x12345`. + assert_nodes_are_hash_nodes(&partial_trie, [0x12345, 0x12346]); + } + #[test] fn hash_of_branch_partial_tries_matches_original_trie() { + common_setup(); let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x56, 0x12345]); let base_hash: H256 = trie.hash(); @@ -613,6 +749,8 @@ mod tests { #[test] fn hash_of_giant_random_partial_tries_matches_original_trie() { + common_setup(); + let (base_trie, trie_subsets, _) = create_massive_trie_and_subsets(9010); let base_hash = base_trie.hash(); @@ -621,6 +759,37 @@ mod tests { .all(|p_tree| p_tree.hash() == base_hash)) } + #[test] + fn giant_random_partial_tries_hashes_leaves_correctly() { + common_setup(); + + let (base_trie, trie_subsets, leaf_keys_per_trie) = create_massive_trie_and_subsets(9011); + let all_keys: Vec = base_trie.keys().collect(); + + for (partial_trie, leaf_trie_keys) in + trie_subsets.into_iter().zip(leaf_keys_per_trie.into_iter()) + { + let leaf_keys_lookup: HashSet = leaf_trie_keys.iter().cloned().collect(); + let keys_of_hash_nodes = all_keys + .iter() + .filter(|k| !leaf_keys_lookup.contains(k)) + .cloned(); + + assert_nodes_are_leaf_nodes(&partial_trie, leaf_trie_keys); + + // We have no idea were the paths to the hashed out nodes will start in the + // trie, so the best we can do is to check that they don't exist (if we traverse + // over a `Hash` node, we return `None`.) + assert_all_keys_do_not_exist(&partial_trie, keys_of_hash_nodes); + } + } + + fn assert_all_keys_do_not_exist(trie: &TrieType, ks: impl Iterator) { + for k in ks { + assert!(trie.get(k).is_none()); + } + } + fn create_massive_trie_and_subsets(seed: u64) -> (TrieType, Vec, Vec>) { let trie_size = MASSIVE_TEST_NUM_SUB_TRIES * MASSIVE_TEST_NUM_SUB_TRIE_SIZE; diff --git a/mpt_trie/src/utils.rs b/mpt_trie/src/utils.rs index 9b87d8606..59400ee82 100644 --- a/mpt_trie/src/utils.rs +++ b/mpt_trie/src/utils.rs @@ -1,17 +1,36 @@ -use std::{fmt::Display, ops::BitAnd, sync::Arc}; +//! Various types and logic that don't fit well into any other module. + +use std::{ + borrow::Borrow, + fmt::{self, Display}, + ops::BitAnd, + sync::Arc, +}; use ethereum_types::{H256, U512}; use num_traits::PrimInt; -use crate::partial_trie::{Node, PartialTrie}; +use crate::{ + nibbles::{Nibble, Nibbles}, + partial_trie::{Node, PartialTrie}, +}; #[derive(Clone, Debug, Eq, Hash, PartialEq)] /// Simplified trie node type to make logging cleaner. -pub(crate) enum TrieNodeType { +pub enum TrieNodeType { + /// Empty node. Empty, + + /// Hash node. Hash, + + /// Branch node. Branch, + + /// Extension node. Extension, + + /// Leaf node. Leaf, } @@ -58,3 +77,190 @@ pub(crate) fn create_mask_of_1s(amt: usize) -> U512 { pub(crate) fn bytes_to_h256(b: &[u8; 32]) -> H256 { keccak_hash::H256::from_slice(b) } + +/// Minimal key information of "segments" (nodes) used to construct trie +/// "traces" of a trie query. Unlike [`TrieNodeType`], this type also contains +/// the key piece of the node if applicable (eg. [`Node::Empty`] & +/// [`Node::Hash`] do not have associated key pieces). +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum TrieSegment { + /// Empty node. + Empty, + + /// Hash node. + Hash, + + /// Branch node along with the nibble of the child taken. + Branch(Nibble), + + /// Extension node along with the key piece of the node. + Extension(Nibbles), + + /// Leaf node along wth the key piece of the node. + Leaf(Nibbles), +} + +/// Trait for a type that can be converted into a trie key ([`Nibbles`]). +pub trait IntoTrieKey { + /// Reconstruct the key of the type. + fn into_key(self) -> Nibbles; +} + +impl, T: Iterator> IntoTrieKey for T { + fn into_key(self) -> Nibbles { + let mut key = Nibbles::default(); + + for seg in self { + match seg.borrow() { + TrieSegment::Empty | TrieSegment::Hash => (), + TrieSegment::Branch(nib) => key.push_nibble_back(*nib), + TrieSegment::Extension(nibs) | TrieSegment::Leaf(nibs) => { + key.push_nibbles_back(nibs) + } + } + } + + key + } +} + +/// A vector of path segments representing a path in the trie. +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct TriePath(pub Vec); + +impl Display for TriePath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let num_elems = self.0.len(); + + // For everything but the last elem. + for seg in self.0.iter().take(num_elems.saturating_sub(1)) { + Self::write_elem(f, seg)?; + write!(f, " --> ")?; + } + + // Avoid the extra `-->` for the last elem. + if let Some(seg) = self.0.last() { + Self::write_elem(f, seg)?; + } + + Ok(()) + } +} + +impl IntoIterator for TriePath { + type Item = TrieSegment; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl From> for TriePath { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl FromIterator for TriePath { + fn from_iter>(iter: T) -> Self { + Self(Vec::from_iter(iter)) + } +} + +impl TriePath { + /// Get an iterator of the individual path segments in the [`TriePath`]. + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub(crate) fn dup_and_append(&self, seg: TrieSegment) -> Self { + let mut duped_vec = self.0.clone(); + duped_vec.push(seg); + + Self(duped_vec) + } + + pub(crate) fn append(&mut self, seg: TrieSegment) { + self.0.push(seg); + } + + fn write_elem(f: &mut fmt::Formatter<'_>, seg: &TrieSegment) -> fmt::Result { + write!(f, "{}", seg) + } +} + +impl Display for TrieSegment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TrieSegment::Empty => write!(f, "Empty"), + TrieSegment::Hash => write!(f, "Hash"), + TrieSegment::Branch(nib) => write!(f, "Branch({})", nib), + TrieSegment::Extension(nibs) => write!(f, "Extension({})", nibs), + TrieSegment::Leaf(nibs) => write!(f, "Leaf({})", nibs), + } + } +} + +impl TrieSegment { + /// Get the node type of the [`TrieSegment`]. + pub fn node_type(&self) -> TrieNodeType { + match self { + TrieSegment::Empty => TrieNodeType::Empty, + TrieSegment::Hash => TrieNodeType::Hash, + TrieSegment::Branch(_) => TrieNodeType::Branch, + TrieSegment::Extension(_) => TrieNodeType::Extension, + TrieSegment::Leaf(_) => TrieNodeType::Leaf, + } + } + + /// Extracts the key piece used by the segment (if applicable). + pub fn get_key_piece_from_seg_if_present(&self) -> Option { + match self { + TrieSegment::Empty | TrieSegment::Hash => None, + TrieSegment::Branch(nib) => Some(Nibbles::from_nibble(*nib)), + TrieSegment::Extension(nibs) | TrieSegment::Leaf(nibs) => Some(*nibs), + } + } +} + +/// Creates a [`TrieSegment`] given a node and a key we are querying. +/// +/// This function is intended to be used during a trie query as we are +/// traversing down a trie. Depending on the current node, we pop off nibbles +/// and use these to create `TrieSegment`s. +pub fn get_segment_from_node_and_key_piece( + n: &Node, + k_piece: &Nibbles, +) -> TrieSegment { + match TrieNodeType::from(n) { + TrieNodeType::Empty => TrieSegment::Empty, + TrieNodeType::Hash => TrieSegment::Hash, + TrieNodeType::Branch => TrieSegment::Branch(k_piece.get_nibble(0)), + TrieNodeType::Extension => TrieSegment::Extension(*k_piece), + TrieNodeType::Leaf => TrieSegment::Leaf(*k_piece), + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::{IntoTrieKey, TriePath, TrieSegment}; + use crate::nibbles::Nibbles; + + #[test] + fn path_from_query_works() { + let query_path: TriePath = vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(2), + TrieSegment::Extension(0x34.into()), + TrieSegment::Leaf(0x567.into()), + ] + .into(); + + let reconstructed_key = query_path.iter().into_key(); + assert_eq!(reconstructed_key, Nibbles::from_str("0x1234567").unwrap()); + } +}