Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: trait StateTrie #542

Merged
merged 13 commits into from
Aug 28, 2024
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions trace_decoder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,17 @@ strum = { version = "0.26.3", features = ["derive"] }
thiserror = { workspace = true }
u4 = { workspace = true }
winnow = { workspace = true }
zk_evm_common = {workspace = true}
zk_evm_common = { workspace = true }

[dev-dependencies]
alloy = { workspace = true }
criterion = { workspace = true }
plonky2_maybe_rayon = { workspace = true }
pretty_env_logger = { workspace = true }
serde_json = { workspace = true }
prover = { workspace = true }
serde_path_to_error = { workspace = true }
plonky2_maybe_rayon = { workspace = true }
alloy = { workspace = true }
rstest = "0.21.0"

serde_json = { workspace = true }
serde_path_to_error = { workspace = true }

[[bench]]
name = "block_processing"
Expand Down
128 changes: 51 additions & 77 deletions trace_decoder/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use mpt_trie::{
nibbles::Nibbles,
partial_trie::{HashedPartialTrie, PartialTrie as _},
special_query::path_for_query,
trie_ops::TrieOpError,
utils::{IntoTrieKey as _, TriePath},
};

Expand All @@ -25,14 +26,14 @@ use crate::{
NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateWrite, TxnMetaState,
},
typed_mpt::{ReceiptTrie, StateTrie, StorageTrie, TransactionTrie, TrieKey},
OtherBlockData, PartialTriePreImages,
OtherBlockData, PartialTriePreImages, TryIntoExt as TryIntoBounds,
};

/// The current state of all tries as we process txn deltas. These are mutated
/// after every txn we process in the trace.
#[derive(Clone, Debug, Default)]
struct PartialTrieState {
state: StateTrie,
struct PartialTrieState<StateTrieT> {
state: StateTrieT,
storage: HashMap<H256, StorageTrie>,
txn: TransactionTrie,
receipt: ReceiptTrie,
Expand Down Expand Up @@ -113,7 +114,7 @@ pub fn into_txn_proof_gen_ir(
/// need to update the storage of the beacon block root contract.
// See <https://eips.ethereum.org/EIPS/eip-4788>.
fn update_beacon_block_root_contract_storage(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
delta_out: &mut TrieDeltaApplicationOutput,
nodes_used: &mut NodesUsedByTxn,
block_data: &BlockMetadata,
Expand Down Expand Up @@ -207,7 +208,7 @@ fn update_beacon_block_root_contract_storage(
}

fn update_txn_and_receipt_tries(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
meta: &TxnMetaState,
txn_idx: usize,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -246,20 +247,19 @@ fn init_any_needed_empty_storage_tries<'a>(
}

fn create_minimal_partial_tries_needed_by_txn(
curr_block_tries: &PartialTrieState,
curr_block_tries: &PartialTrieState<impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>>,
nodes_used_by_txn: &NodesUsedByTxn,
txn_range: Range<usize>,
delta_application_out: TrieDeltaApplicationOutput,
) -> anyhow::Result<TrieInputs> {
let state_trie = create_minimal_state_partial_trie(
&curr_block_tries.state,
nodes_used_by_txn.state_accesses.iter().map(hash),
delta_application_out
.additional_state_trie_paths_to_not_hash
.into_iter(),
)?
.as_hashed_partial_trie()
.clone();
let mut state_trie = curr_block_tries.state.clone();
state_trie.trim_to(
nodes_used_by_txn
.state_accesses
.iter()
.map(|it| TrieKey::from_address(*it))
.chain(delta_application_out.additional_state_trie_paths_to_not_hash),
)?;

let txn_keys = txn_range.map(TrieKey::from_txn_ix);

Expand All @@ -282,15 +282,15 @@ fn create_minimal_partial_tries_needed_by_txn(
)?;

Ok(TrieInputs {
state_trie,
state_trie: state_trie.try_into()?,
transactions_trie,
receipts_trie,
storage_tries,
})
}

fn apply_deltas_to_trie_state(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
deltas: &NodesUsedByTxn,
meta: &[TxnMetaState],
) -> anyhow::Result<TrieDeltaApplicationOutput> {
Expand Down Expand Up @@ -360,12 +360,7 @@ fn apply_deltas_to_trie_state(

if !receipt.status {
// The transaction failed, hence any created account should be removed.
if let Some(remaining_account_key) =
delete_node_and_report_remaining_key_if_branch_collapsed(
trie_state.state.as_mut_hashed_partial_trie_unchecked(),
&TrieKey::from_hash(hash(addr)),
)?
{
if let Some(remaining_account_key) = trie_state.state.reporting_remove(*addr)? {
out.additional_state_trie_paths_to_not_hash
.push(remaining_account_key);
trie_state.storage.remove(&hash(addr));
Expand All @@ -379,12 +374,7 @@ fn apply_deltas_to_trie_state(
for addr in deltas.self_destructed_accounts.iter() {
trie_state.storage.remove(&hash(addr));

if let Some(remaining_account_key) =
delete_node_and_report_remaining_key_if_branch_collapsed(
trie_state.state.as_mut_hashed_partial_trie_unchecked(),
&TrieKey::from_hash(hash(addr)),
)?
{
if let Some(remaining_account_key) = trie_state.state.reporting_remove(*addr)? {
out.additional_state_trie_paths_to_not_hash
.push(remaining_account_key);
}
Expand All @@ -400,13 +390,14 @@ fn get_trie_trace(trie: &HashedPartialTrie, k: &Nibbles) -> TriePath {
/// If a branch collapse occurred after a delete, then we must ensure that
/// the other single child that remains also is not hashed when passed into
/// plonky2. Returns the key to the remaining child if a collapse occurred.
fn delete_node_and_report_remaining_key_if_branch_collapsed(
pub fn delete_node_and_report_remaining_key_if_branch_collapsed(
trie: &mut HashedPartialTrie,
delete_k: &TrieKey,
) -> anyhow::Result<Option<TrieKey>> {
let old_trace = get_trie_trace(trie, &delete_k.into_nibbles());
trie.delete(delete_k.into_nibbles())?;
let new_trace = get_trie_trace(trie, &delete_k.into_nibbles());
key: &TrieKey,
) -> Result<Option<TrieKey>, TrieOpError> {
let key = key.into_nibbles();
let old_trace = get_trie_trace(trie, &key);
trie.delete(key)?;
let new_trace = get_trie_trace(trie, &key);
Ok(
node_deletion_resulted_in_a_branch_collapse(&old_trace, &new_trace)
.map(TrieKey::from_nibbles),
Expand Down Expand Up @@ -441,7 +432,9 @@ fn node_deletion_resulted_in_a_branch_collapse(
/// The withdrawals are always in the final ir payload.
fn add_withdrawals_to_txns(
txn_ir: &mut [GenerationInputs],
final_trie_state: &mut PartialTrieState,
final_trie_state: &mut PartialTrieState<
impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>,
>,
mut withdrawals: Vec<(Address, U256)>,
) -> anyhow::Result<()> {
// Scale withdrawals amounts.
Expand All @@ -460,25 +453,22 @@ fn add_withdrawals_to_txns(
.expect("We cannot have an empty list of payloads.");

if last_inputs.signed_txns.is_empty() {
// This is a dummy payload, hence it does not contain yet
// state accesses to the withdrawal addresses.
let withdrawal_addrs = withdrawals_with_hashed_addrs_iter().map(|(_, h_addr, _)| h_addr);

let additional_paths = if last_inputs.txn_number_before == 0.into() {
// We need to include the beacon roots contract as this payload is at the
// start of the block execution.
vec![TrieKey::from_hash(BEACON_ROOTS_CONTRACT_ADDRESS_HASHED)]
} else {
vec![]
};

last_inputs.tries.state_trie = create_minimal_state_partial_trie(
&final_trie_state.state,
withdrawal_addrs,
additional_paths,
)?
.as_hashed_partial_trie()
.clone();
let mut state_trie = final_trie_state.state.clone();
state_trie.trim_to(
// This is a dummy payload, hence it does not contain yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// This is a dummy payload, hence it does not contain yet
// This is a dummy payload, hence it does not yet contain

// state accesses to the withdrawal addresses.
withdrawals
.iter()
.map(|(addr, _)| *addr)
.chain(match last_inputs.txn_number_before == 0.into() {
// We need to include the beacon roots contract as this payload is at the
// start of the block execution.
true => Some(BEACON_ROOTS_CONTRACT_ADDRESS),
false => None,
})
.map(TrieKey::from_address),
)?;
last_inputs.tries.state_trie = state_trie.try_into()?;
}

update_trie_state_from_withdrawals(
Expand All @@ -487,7 +477,7 @@ fn add_withdrawals_to_txns(
)?;

last_inputs.withdrawals = withdrawals;
last_inputs.trie_roots_after.state_root = final_trie_state.state.root();
last_inputs.trie_roots_after.state_root = final_trie_state.state.clone().try_into()?.hash();

Ok(())
}
Expand All @@ -496,7 +486,7 @@ fn add_withdrawals_to_txns(
/// our local trie state.
fn update_trie_state_from_withdrawals<'a>(
withdrawals: impl IntoIterator<Item = (Address, H256, U256)> + 'a,
state: &mut StateTrie,
state: &mut impl StateTrie,
) -> anyhow::Result<()> {
for (addr, h_addr, amt) in withdrawals {
let mut acc_data = state.get_by_address(addr).context(format!(
Expand All @@ -520,7 +510,9 @@ fn process_txn_info(
txn_range: Range<usize>,
is_initial_payload: bool,
txn_info: ProcessedTxnInfo,
curr_block_tries: &mut PartialTrieState,
curr_block_tries: &mut PartialTrieState<
impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>,
>,
extra_data: &mut ExtraBlockData,
other_data: &OtherBlockData,
) -> anyhow::Result<GenerationInputs> {
Expand Down Expand Up @@ -595,7 +587,7 @@ fn process_txn_info(
* for more info). */
tries,
trie_roots_after: TrieRoots {
state_root: curr_block_tries.state.root(),
state_root: curr_block_tries.state.clone().try_into()?.hash(),
transactions_root: curr_block_tries.txn.root(),
receipts_root: curr_block_tries.receipt.root(),
},
Expand Down Expand Up @@ -645,22 +637,6 @@ impl StateWrite {
}
}

fn create_minimal_state_partial_trie(
state_trie: &StateTrie,
state_accesses: impl IntoIterator<Item = H256>,
additional_state_trie_paths_to_not_hash: impl IntoIterator<Item = TrieKey>,
) -> anyhow::Result<StateTrie> {
create_trie_subset_wrapped(
state_trie.as_hashed_partial_trie(),
state_accesses
.into_iter()
.map(TrieKey::from_hash)
.chain(additional_state_trie_paths_to_not_hash),
TrieType::State,
)
.map(StateTrie::from_hashed_partial_trie_unchecked)
}

// TODO!!!: We really need to be appending the empty storage tries to the base
// trie somewhere else! This is a big hack!
fn create_minimal_storage_partial_tries<'a>(
Expand Down Expand Up @@ -714,11 +690,9 @@ fn eth_to_gwei(eth: U256) -> U256 {
const ZERO_STORAGE_SLOT_VAL_RLPED: [u8; 1] = [128];

/// Aid for error context.
/// Covers all Ethereum trie types (see <https://ethereum.github.io/yellowpaper/paper.pdf> for details).
#[derive(Debug, strum::Display)]
#[allow(missing_docs)]
enum TrieType {
State,
Storage,
Receipt,
Txn,
Expand Down
35 changes: 22 additions & 13 deletions trace_decoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ use keccak_hash::H256;
use mpt_trie::partial_trie::{HashedPartialTrie, OnOrphanedHashNode};
use processed_block_trace::ProcessedTxnInfo;
use serde::{Deserialize, Serialize};
use typed_mpt::{StateTrie, StorageTrie, TrieKey};
use typed_mpt::{StateMpt, StateTrie as _, StorageTrie, TrieKey};

/// Core payload needed to generate proof for a block.
/// Additional data retrievable from the blockchain node (using standard ETH RPC
Expand Down Expand Up @@ -311,7 +311,7 @@ pub fn entrypoint(
}) => ProcessedBlockTracePreImages {
tries: PartialTriePreImages {
state: state.items().try_fold(
StateTrie::new(OnOrphanedHashNode::Reject),
StateMpt::new(OnOrphanedHashNode::Reject),
|mut acc, (nibbles, hash_or_val)| {
let path = TrieKey::from_nibbles(nibbles);
match hash_or_val {
Expand Down Expand Up @@ -367,10 +367,7 @@ pub fn entrypoint(
ProcessedBlockTracePreImages {
tries: PartialTriePreImages {
state,
storage: storage
.into_iter()
.map(|(path, trie)| (path.into_hash_left_padded(), trie))
.collect(),
storage: storage.into_iter().collect(),
},
extra_code_hash_mappings: match code.is_empty() {
true => None,
Expand All @@ -384,12 +381,7 @@ pub fn entrypoint(
}
};

let all_accounts_in_pre_images = pre_images
.tries
.state
.iter()
.map(|(addr, data)| (addr.into_hash_left_padded(), data))
.collect::<Vec<_>>();
let all_accounts_in_pre_images = pre_images.tries.state.iter().collect::<Vec<_>>();

// Note we discard any user-provided hashes.
let mut hash2code = code_db
Expand Down Expand Up @@ -449,7 +441,7 @@ pub fn entrypoint(

#[derive(Debug, Default)]
struct PartialTriePreImages {
pub state: StateTrie,
pub state: StateMpt,
pub storage: HashMap<H256, StorageTrie>,
}

Expand Down Expand Up @@ -479,6 +471,23 @@ mod hex {
}
}

trait TryIntoExt<T> {
type Error: std::error::Error + Send + Sync + 'static;
fn try_into(self) -> Result<T, Self::Error>;
}

impl<ThisT, T, E> TryIntoExt<T> for ThisT
where
ThisT: TryInto<T, Error = E>,
E: std::error::Error + Send + Sync + 'static,
{
type Error = ThisT::Error;

fn try_into(self) -> Result<T, Self::Error> {
TryInto::try_into(self)
}
}

#[cfg(test)]
#[derive(serde::Deserialize)]
struct Case {
Expand Down
2 changes: 1 addition & 1 deletion trace_decoder/src/processed_block_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp};
use itertools::Itertools;
use zk_evm_common::EMPTY_TRIE_HASH;

use crate::typed_mpt::TrieKey;
use crate::typed_mpt::{StateTrie as _, TrieKey};
use crate::PartialTriePreImages;
use crate::{hash, TxnTrace};
use crate::{ContractCodeUsage, TxnInfo};
Expand Down
Loading
Loading