Skip to content

Commit

Permalink
fix: resolve PR concerns
Browse files Browse the repository at this point in the history
  • Loading branch information
KolbyML committed Dec 9, 2024
1 parent 49088d6 commit 6d1d562
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 82 deletions.
129 changes: 66 additions & 63 deletions trin-execution/src/trie_walker/filter.rs
Original file line number Diff line number Diff line change
@@ -1,120 +1,123 @@
use alloy::primitives::{B256, U256};
use alloy::primitives::U256;
use rand::{thread_rng, Rng};

use crate::utils::partial_nibble_path_to_right_padded_b256;
use crate::utils::nibbles_to_right_padded_b256;

#[derive(Debug, Clone)]
pub struct Filter {
start_prefix: U256,
end_prefix: U256,
start: U256,
end: U256,
}

impl Filter {
pub fn new_random_filter(slice_count: u16) -> Self {
/// Create a new filter that includes the whole trie
/// Slice index must be less than slice count or it will panic
pub fn new(slice_index: u16, slice_count: u16) -> Self {
assert!(
slice_index < slice_count,
"slice_index must be less than slice_count"
);

// if slice_count is 0 or 1, we want to include the whole trie
if slice_count == 0 || slice_count == 1 {
return Self {
start_prefix: U256::ZERO,
end_prefix: U256::MAX,
start: U256::ZERO,
end: U256::MAX,
};
}

let slice_size = U256::MAX / U256::from(slice_count);
let random_slice_index = thread_rng().gen_range(0..slice_count);
let slice_size = U256::MAX / U256::from(slice_count) + U256::from(1);

let start_prefix = U256::from(random_slice_index) * slice_size;
let end_prefix = if random_slice_index == slice_count - 1 {
let start = U256::from(slice_index) * slice_size;
let end = if slice_index == slice_count - 1 {
U256::MAX
} else {
start_prefix + slice_size - U256::from(1)
start + slice_size - U256::from(1)
};

Self {
start_prefix,
end_prefix,
}
Self { start, end }
}

fn partial_prefix(prefix: U256, shift_amount: usize) -> B256 {
prefix
.arithmetic_shr(shift_amount)
.wrapping_shl(shift_amount)
.into()
pub fn random(slice_count: u16) -> Self {
Self::new(thread_rng().gen_range(0..slice_count), slice_count)
}

pub fn is_included(&self, path: &[u8]) -> bool {
/// Check if a path is included in the filter
pub fn contains(&self, path: &[u8]) -> bool {
// we need to use partial prefixes to not artificially exclude paths that are not exactly
// the same length as the filter
let shift_amount = 256 - path.len() * 4;
let partial_start_prefix = Filter::partial_prefix(self.start_prefix, shift_amount);
let partial_end_prefix = Filter::partial_prefix(self.end_prefix, shift_amount);
let path = partial_nibble_path_to_right_padded_b256(path);
let partial_start_prefix = (self.start >> shift_amount) << shift_amount;
let partial_end_prefix = (self.end >> shift_amount) << shift_amount;
let path = nibbles_to_right_padded_b256(path);

(partial_start_prefix..=partial_end_prefix).contains(&path)
(partial_start_prefix..=partial_end_prefix).contains(&path.into())
}
}

#[cfg(test)]
mod tests {
use alloy::hex::FromHex;

use super::*;

#[test]
fn test_new_random_filter() {
let filter = Filter::new_random_filter(0);
assert_eq!(filter.start_prefix, U256::ZERO);
assert_eq!(filter.end_prefix, U256::MAX);
let filter = Filter::random(0);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);

let filter = Filter::new_random_filter(1);
assert_eq!(filter.start_prefix, U256::ZERO);
assert_eq!(filter.end_prefix, U256::MAX);
let filter = Filter::random(1);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);
}

#[test]
fn test_is_included() {
let filter = Filter {
start_prefix: partial_nibble_path_to_right_padded_b256(&[0x1, 0x5, 0x5]).into(),
end_prefix: partial_nibble_path_to_right_padded_b256(&[0x3]).into(),
start: nibbles_to_right_padded_b256(&[0x1, 0x5, 0x5]).into(),
end: nibbles_to_right_padded_b256(&[0x3]).into(),
};

assert!(!filter.is_included(&[0x0]));
assert!(filter.is_included(&[0x1]));
assert!(filter.is_included(&[0x1, 0x5]));
assert!(filter.is_included(&[0x1, 0x5, 0x5]));
assert!(!filter.is_included(&[0x1, 0x5, 0x4]));
assert!(filter.is_included(&[0x2]));
assert!(filter.is_included(&[0x3]));
assert!(!filter.is_included(&[0x4]));
assert!(!filter.contains(&[0x0]));
assert!(filter.contains(&[0x1]));
assert!(filter.contains(&[0x1, 0x5]));
assert!(filter.contains(&[0x1, 0x5, 0x5]));
assert!(!filter.contains(&[0x1, 0x5, 0x4]));
assert!(filter.contains(&[0x2]));
assert!(filter.contains(&[0x3]));
assert!(!filter.contains(&[0x3, 0x0, 0x1]));
assert!(!filter.contains(&[0x4]));
}

#[test]
fn test_partial_prefix() {
let prefix: &[u8] = &[0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8];
let prefix: U256 = partial_nibble_path_to_right_padded_b256(prefix).into();
fn test_new() {
let filter = Filter::new(0, 1);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX);

assert_eq!(
Filter::partial_prefix(prefix, 256),
B256::from_hex("0x0000000000000000000000000000000000000000000000000000000000000000")
.unwrap()
);
let filter = Filter::new(0, 2);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX / U256::from(2));

assert_eq!(
Filter::partial_prefix(prefix, 256 - 4),
B256::from_hex("0x1000000000000000000000000000000000000000000000000000000000000000")
.unwrap()
);
let filter = Filter::new(1, 2);
assert_eq!(filter.start, U256::MAX / U256::from(2) + U256::from(1));
assert_eq!(filter.end, U256::MAX);

let filter = Filter::new(0, 3);
assert_eq!(filter.start, U256::ZERO);
assert_eq!(filter.end, U256::MAX / U256::from(3));

let filter = Filter::new(1, 3);
assert_eq!(filter.start, U256::MAX / U256::from(3) + U256::from(1));
assert_eq!(
Filter::partial_prefix(prefix, 256 - 4 * 4),
B256::from_hex("0x1234000000000000000000000000000000000000000000000000000000000000")
.unwrap()
filter.end,
U256::MAX / U256::from(3) * U256::from(2) + U256::from(1)
);

let filter = Filter::new(2, 3);
assert_eq!(
Filter::partial_prefix(prefix, 256 - 5 * 4),
B256::from_hex("0x1234500000000000000000000000000000000000000000000000000000000000")
.unwrap()
filter.start,
U256::MAX / U256::from(3) * U256::from(2) + U256::from(2)
);
assert_eq!(filter.end, U256::MAX);
}
}
2 changes: 1 addition & 1 deletion trin-execution/src/trie_walker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<DB: TrieWalkerDb> TrieWalker<DB> {
) -> anyhow::Result<()> {
// If we have a filter, we only want to include nodes that are in the filter
if let Some(filter) = &self.filter {
if !filter.is_included(&path) {
if !filter.contains(&path) {
return Ok(());
}
}
Expand Down
31 changes: 13 additions & 18 deletions trin-execution/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
use alloy::primitives::{keccak256, Address, B256};

fn compress_nibbles(nibbles: &[u8]) -> Vec<u8> {
let mut compressed_nibbles = vec![];
for i in 0..nibbles.len() {
if i % 2 == 0 {
compressed_nibbles.push(nibbles[i] << 4);
} else {
compressed_nibbles[i / 2] |= nibbles[i];
}
}
compressed_nibbles
}

pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 {
if key_path.len() != 64 {
panic!(
Expand All @@ -20,11 +8,19 @@ pub fn full_nibble_path_to_address_hash(key_path: &[u8]) -> B256 {
);
}

B256::from_slice(&compress_nibbles(key_path))
nibbles_to_right_padded_b256(key_path)
}

pub fn partial_nibble_path_to_right_padded_b256(partial_nibble_path: &[u8]) -> B256 {
B256::right_padding_from(&compress_nibbles(partial_nibble_path))
pub fn nibbles_to_right_padded_b256(nibbles: &[u8]) -> B256 {
let mut result = B256::ZERO;
for (i, nibble) in nibbles.iter().enumerate() {
if i % 2 == 0 {
result[i / 2] |= nibble << 4;
} else {
result[i / 2] |= nibble;
};
}
result
}

pub fn address_to_nibble_path(address: Address) -> Vec<u8> {
Expand All @@ -41,8 +37,7 @@ mod tests {
use revm_primitives::{keccak256, Address, B256};

use crate::utils::{
address_to_nibble_path, full_nibble_path_to_address_hash,
partial_nibble_path_to_right_padded_b256,
address_to_nibble_path, full_nibble_path_to_address_hash, nibbles_to_right_padded_b256,
};

#[test]
Expand All @@ -68,7 +63,7 @@ mod tests {
#[test]
fn test_partial_nibble_path_to_right_padded_b256() {
let partial_nibble_path = vec![0xf, 0xf, 0x0, 0x1, 0x0, 0x2, 0x0, 0x3];
let partial_path = partial_nibble_path_to_right_padded_b256(&partial_nibble_path);
let partial_path = nibbles_to_right_padded_b256(&partial_nibble_path);
assert_eq!(
partial_path,
B256::from_hex("0xff01020300000000000000000000000000000000000000000000000000000000")
Expand Down

0 comments on commit 6d1d562

Please sign in to comment.