diff --git a/Cargo.lock b/Cargo.lock index 706845c8e..18e708b0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4180,6 +4180,7 @@ dependencies = [ "evm_arithmetization", "futures", "hex", + "itertools 0.13.0", "lru", "mpt_trie", "primitive-types 0.12.2", diff --git a/mpt_trie/src/debug_tools/query.rs b/mpt_trie/src/debug_tools/query.rs index 0fb6ade96..dcfff397d 100644 --- a/mpt_trie/src/debug_tools/query.rs +++ b/mpt_trie/src/debug_tools/query.rs @@ -66,19 +66,19 @@ pub struct DebugQueryParamsBuilder { impl DebugQueryParamsBuilder { /// Defaults to `true`. - pub const fn print_key_pieces(mut self, enabled: bool) -> Self { + pub const fn include_key_pieces(mut self, enabled: bool) -> Self { self.params.include_key_piece_per_node = enabled; self } /// Defaults to `true`. - pub const fn print_node_type(mut self, enabled: bool) -> Self { + pub const fn include_node_type(mut self, enabled: bool) -> Self { self.params.include_node_type = enabled; self } /// Defaults to `false`. - pub const fn print_node_specific_values(mut self, enabled: bool) -> Self { + pub const fn include_node_specific_values(mut self, enabled: bool) -> Self { self.params.include_node_specific_values = enabled; self } diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs index 13e9d0d9f..eadb6b3e2 100644 --- a/mpt_trie/src/trie_subsets.rs +++ b/mpt_trie/src/trie_subsets.rs @@ -12,6 +12,7 @@ use log::trace; use thiserror::Error; use crate::{ + debug_tools::query::{get_path_from_query, DebugQueryOutput, DebugQueryParamsBuilder}, nibbles::Nibbles, partial_trie::{Node, PartialTrie, WrappedNode}, trie_hashing::EncodedNode, @@ -21,13 +22,10 @@ use crate::{ /// The output type of trie_subset operations. pub type SubsetTrieResult = Result; -/// Errors that may occur when creating a subset [`PartialTrie`]. +/// We encountered a `hash` node when marking nodes during sub-trie creation. #[derive(Clone, Debug, Error, Hash)] -pub enum SubsetTrieError { - #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")] - /// The key does not exist in the trie. - UnexpectedKey(Nibbles, String), -} +#[error("Encountered a hash node when marking nodes to not hash when traversing a key to not hash!\nPath: {0}")] +pub struct SubsetTrieError(DebugQueryOutput); #[derive(Debug)] enum TrackedNodeIntern { @@ -256,8 +254,17 @@ where N: PartialTrie, K: Into, { - for k in keys_involved { - mark_nodes_that_are_needed(tracked_trie, &mut k.into())?; + for mut k in keys_involved.map(|k| k.into()) { + mark_nodes_that_are_needed(tracked_trie, &mut k).map_err(|_| { + // We need to unwind back to this callsite in order to produce the actual error. + let query = DebugQueryParamsBuilder::default() + .include_node_specific_values(true) + .build(k); + + let res = get_path_from_query(&tracked_trie.info.underlying_node, query); + + SubsetTrieError(res) + })?; } Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie)) @@ -270,10 +277,14 @@ where /// - For the key `0x1`, the marked nodes would be [B(0x), B(0x1)]. /// - For the key `0x12`, the marked nodes still would be [B(0x), B(0x1)]. /// - For the key `0x123`, the marked nodes would be [B(0x), B(0x1), L(0x123)]. +/// +/// Also note that we can't construct the error until we back out of this +/// recursive function. We need to know the full key that hit the hash +/// node, and that's only available at the initial call site. fn mark_nodes_that_are_needed( trie: &mut TrackedNode, curr_nibbles: &mut Nibbles, -) -> SubsetTrieResult<()> { +) -> Result<(), ()> { trace!( "Sub-trie marking at {:x}, (type: {})", curr_nibbles, @@ -286,10 +297,7 @@ fn mark_nodes_that_are_needed( } TrackedNodeIntern::Hash => match curr_nibbles.is_empty() { false => { - return Err(SubsetTrieError::UnexpectedKey( - *curr_nibbles, - format!("{:?}", trie), - )); + return Err(()); } true => { trie.info.touched = true; diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs index d9b750cc1..ce669a581 100644 --- a/trace_decoder/src/decoding.rs +++ b/trace_decoder/src/decoding.rs @@ -20,7 +20,7 @@ use mpt_trie::{ use crate::{ hash, processed_block_trace::{ - NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateTrieWrites, TxnMetaState, + NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateWrite, TxnMetaState, }, typed_mpt::{ReceiptTrie, StateTrie, StorageTrie, TransactionTrie, TrieKey}, OtherBlockData, PartialTriePreImages, @@ -202,15 +202,12 @@ fn update_txn_and_receipt_tries( meta: &TxnMetaState, txn_idx: usize, ) -> anyhow::Result<()> { - if meta.is_dummy() { - // This is a dummy payload, that does not mutate these tries. - return Ok(()); - } - - trie_state.txn.insert(txn_idx, meta.txn_bytes())?; - trie_state - .receipt - .insert(txn_idx, meta.receipt_node_bytes.clone())?; + if let Some(bytes) = &meta.txn_bytes { + trie_state.txn.insert(txn_idx, bytes.clone())?; + trie_state + .receipt + .insert(txn_idx, meta.receipt_node_bytes.clone())?; + } // else it's just a dummy Ok(()) } @@ -220,11 +217,11 @@ fn update_txn_and_receipt_tries( fn init_any_needed_empty_storage_tries<'a>( storage_tries: &mut HashMap, accounts_with_storage: impl Iterator, - state_accounts_with_no_accesses_but_storage_tries: &'a HashMap, + accts_with_unaccessed_storage: &HashMap, ) { for h_addr in accounts_with_storage { if !storage_tries.contains_key(h_addr) { - let trie = state_accounts_with_no_accesses_but_storage_tries + let trie = accts_with_unaccessed_storage .get(h_addr) .map(|s_root| { let mut it = StorageTrie::default(); @@ -537,9 +534,7 @@ fn process_txn_info( init_any_needed_empty_storage_tries( &mut curr_block_tries.storage, txn_info.nodes_used_by_txn.storage_accesses.keys(), - &txn_info - .nodes_used_by_txn - .state_accounts_with_no_accesses_but_storage_tries, + &txn_info.nodes_used_by_txn.accts_with_unaccessed_storage, ); // For each non-dummy txn, we increment `txn_number_after` and @@ -594,8 +589,7 @@ fn process_txn_info( signed_txns: txn_info .meta .iter() - .filter(|t| t.txn_bytes.is_some()) - .map(|tx| tx.txn_bytes()) + .filter_map(|t| t.txn_bytes.clone()) .collect::>(), withdrawals: Vec::default(), /* Only ever set in a dummy txn at the end of * the block (see `[add_withdrawals_to_txns]` @@ -607,7 +601,11 @@ fn process_txn_info( receipts_root: curr_block_tries.receipt.root(), }, checkpoint_state_trie_root: extra_data.checkpoint_state_trie_root, - contract_code: txn_info.contract_code_accessed, + contract_code: txn_info + .contract_code_accessed + .into_iter() + .map(|code| (hash(&code), code)) + .collect(), block_metadata: other_data.b_data.b_meta.clone(), block_hashes: other_data.b_data.b_hashes.clone(), global_exit_roots: vec![], @@ -621,7 +619,7 @@ fn process_txn_info( Ok(gen_inputs) } -impl StateTrieWrites { +impl StateWrite { fn apply_writes_to_state_node( &self, state_node: &mut AccountRlp, @@ -708,21 +706,6 @@ fn create_trie_subset_wrapped( .context(format!("missing keys when creating {}", trie_type)) } -impl TxnMetaState { - /// Outputs a boolean indicating whether this `TxnMetaState` - /// represents a dummy payload or an actual transaction. - const fn is_dummy(&self) -> bool { - self.txn_bytes.is_none() - } - - fn txn_bytes(&self) -> Vec { - match self.txn_bytes.as_ref() { - Some(v) => v.clone(), - None => Vec::default(), - } - } -} - fn eth_to_gwei(eth: U256) -> U256 { // 1 ether = 10^9 gwei. eth * U256::from(10).pow(9.into()) diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 2bc34cdf4..652d25ba4 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -258,15 +258,6 @@ pub enum ContractCodeUsage { Write(#[serde(with = "crate::hex")] Vec), } -impl ContractCodeUsage { - fn get_code_hash(&self) -> H256 { - match self { - ContractCodeUsage::Read(hash) => *hash, - ContractCodeUsage::Write(bytes) => hash(bytes), - } - } -} - /// Other data that is needed for proof gen. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct OtherBlockData { @@ -398,15 +389,17 @@ pub fn entrypoint( .map(|(addr, data)| (addr.into_hash_left_padded(), data)) .collect::>(); - let code_db = { - let mut code_db = code_db.unwrap_or_default(); - if let Some(code_mappings) = pre_images.extra_code_hash_mappings { - code_db.extend(code_mappings); - } - code_db - }; - - let mut code_hash_resolver = Hash2Code::new(code_db); + // Note we discard any user-provided hashes. + let mut hash2code = code_db + .unwrap_or_default() + .into_values() + .chain( + pre_images + .extra_code_hash_mappings + .unwrap_or_default() + .into_values(), + ) + .collect::(); let last_tx_idx = txn_info.len().saturating_sub(1) / batch_size; @@ -432,7 +425,7 @@ pub fn entrypoint( &pre_images.tries, &all_accounts_in_pre_images, &extra_state_accesses, - &mut code_hash_resolver, + &mut hash2code, ) }) .collect::, _>>()?; @@ -460,8 +453,6 @@ struct PartialTriePreImages { /// Like `#[serde(with = "hex")`, but tolerates and emits leading `0x` prefixes mod hex { - use std::{borrow::Cow, fmt}; - use serde::{de::Error as _, Deserialize as _, Deserializer, Serializer}; pub fn serialize(data: T, serializer: S) -> Result @@ -475,9 +466,9 @@ mod hex { pub fn deserialize<'de, D: Deserializer<'de>, T>(deserializer: D) -> Result where T: hex::FromHex, - T::Error: fmt::Display, + T::Error: std::fmt::Display, { - let s = Cow::::deserialize(deserializer)?; + let s = String::deserialize(deserializer)?; match s.strip_prefix("0x") { Some(rest) => T::from_hex(rest), None => T::from_hex(&*s), diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs index c27d89305..6472b18b0 100644 --- a/trace_decoder/src/processed_block_trace.rs +++ b/trace_decoder/src/processed_block_trace.rs @@ -1,17 +1,14 @@ -use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap, HashSet}; -use std::fmt::Debug; -use std::iter::once; -use anyhow::bail; +use anyhow::{bail, Context as _}; use ethereum_types::{Address, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; use itertools::Itertools; -use zk_evm_common::{EMPTY_CODE_HASH, EMPTY_TRIE_HASH}; +use zk_evm_common::EMPTY_TRIE_HASH; -use crate::hash; use crate::typed_mpt::TrieKey; use crate::PartialTriePreImages; +use crate::{hash, TxnTrace}; use crate::{ContractCodeUsage, TxnInfo}; const FIRST_PRECOMPILE_ADDRESS: U256 = U256([1, 0, 0, 0]); @@ -33,7 +30,7 @@ pub(crate) struct ProcessedBlockTracePreImages { #[derive(Debug, Default)] pub(crate) struct ProcessedTxnInfo { pub nodes_used_by_txn: NodesUsedByTxn, - pub contract_code_accessed: HashMap>, + pub contract_code_accessed: HashSet>, pub meta: Vec, } @@ -42,22 +39,34 @@ pub(crate) struct ProcessedTxnInfo { /// If there are any txns that create contracts, then they will also /// get added here as we process the deltas. pub(crate) struct Hash2Code { + /// Key must always be [`hash`] of value. inner: HashMap>, } impl Hash2Code { - pub fn new(inner: HashMap>) -> Self { - Self { inner } + pub fn new() -> Self { + Self { + inner: HashMap::new(), + } } - fn resolve(&mut self, c_hash: &H256) -> anyhow::Result> { - match self.inner.get(c_hash) { + fn get(&mut self, hash: H256) -> anyhow::Result> { + match self.inner.get(&hash) { Some(code) => Ok(code.clone()), - None => bail!("no code for hash {}", c_hash), + None => bail!("no code for hash {}", hash), } } + fn insert(&mut self, code: Vec) { + self.inner.insert(hash(&code), code); + } +} - fn insert_code(&mut self, c_hash: H256, code: Vec) { - self.inner.insert(c_hash, code); +impl FromIterator> for Hash2Code { + fn from_iter>>(iter: II) -> Self { + let mut this = Self::new(); + for code in iter { + this.insert(code) + } + this } } @@ -70,7 +79,7 @@ impl TxnInfo { hash2code: &mut Hash2Code, ) -> anyhow::Result { let mut nodes_used_by_txn = NodesUsedByTxn::default(); - let mut contract_code_accessed = create_empty_code_access_map(); + let mut contract_code_accessed = HashSet::from([vec![]]); // we always "access" empty code let mut meta = Vec::with_capacity(tx_infos.len()); let all_accounts: BTreeSet = @@ -79,19 +88,30 @@ impl TxnInfo { for txn in tx_infos.iter() { let mut created_accounts = BTreeSet::new(); - for (addr, trace) in txn.traces.iter() { + for ( + addr, + TxnTrace { + balance, + nonce, + storage_read, + storage_written, + code_usage, + self_destructed, + }, + ) in txn.traces.iter() + { let hashed_addr = hash(addr.as_bytes()); - let storage_writes = trace.storage_written.clone().unwrap_or_default(); + // record storage changes + let storage_written = storage_written.clone().unwrap_or_default(); - let storage_read_keys = trace - .storage_read + let storage_read_keys = storage_read .clone() .into_iter() .flat_map(|reads| reads.into_iter()); - let storage_write_keys = storage_writes.keys(); - let storage_access_keys = storage_read_keys.chain(storage_write_keys.copied()); + let storage_written_keys = storage_written.keys(); + let storage_access_keys = storage_read_keys.chain(storage_written_keys.copied()); if let Some(storage) = nodes_used_by_txn.storage_accesses.get_mut(&hashed_addr) { storage.extend( @@ -108,14 +128,20 @@ impl TxnInfo { ); }; - let storage_trie_change = !storage_writes.is_empty(); - let code_change = trace.code_usage.is_some(); - let state_write_occurred = trace.balance.is_some() - || trace.nonce.is_some() - || storage_trie_change - || code_change; + // record state changes + let state_write = StateWrite { + balance: *balance, + nonce: *nonce, + storage_trie_change: !storage_written.is_empty(), + code_hash: code_usage.as_ref().map(|it| match it { + ContractCodeUsage::Read(hash) => *hash, + ContractCodeUsage::Write(bytes) => hash(bytes), + }), + }; + + if state_write != StateWrite::default() { + // a write occurred - if state_write_occurred { // Account creations are flagged to handle reverts. if !all_accounts.contains(&hashed_addr) { created_accounts.insert(hashed_addr); @@ -129,38 +155,31 @@ impl TxnInfo { .self_destructed_accounts .remove(&hashed_addr); - if let Some(state_trie_writes) = + if let Some(existing_state_write) = nodes_used_by_txn.state_writes.get_mut(&hashed_addr) { // The entry already exists, so we update only the relevant fields. - if trace.balance.is_some() { - state_trie_writes.balance = trace.balance; + if state_write.balance.is_some() { + existing_state_write.balance = state_write.balance; } - if trace.nonce.is_some() { - state_trie_writes.nonce = trace.nonce; + if state_write.nonce.is_some() { + existing_state_write.nonce = state_write.nonce; } - if storage_trie_change { - state_trie_writes.storage_trie_change = storage_trie_change; + if state_write.storage_trie_change { + existing_state_write.storage_trie_change = + state_write.storage_trie_change; } - if code_change { - state_trie_writes.code_hash = - trace.code_usage.as_ref().map(|usage| usage.get_code_hash()); + if state_write.code_hash.is_some() { + existing_state_write.code_hash = state_write.code_hash; } } else { - let state_trie_writes = StateTrieWrites { - balance: trace.balance, - nonce: trace.nonce, - storage_trie_change, - code_hash: trace.code_usage.as_ref().map(|usage| usage.get_code_hash()), - }; - nodes_used_by_txn .state_writes - .insert(hashed_addr, state_trie_writes); + .insert(hashed_addr, state_write); } } - for (k, v) in storage_writes.into_iter() { + for (k, v) in storage_written.into_iter() { if let Some(storage) = nodes_used_by_txn.storage_writes.get_mut(&hashed_addr) { storage.insert(TrieKey::from_hash(k), rlp::encode(&v).to_vec()); } else { @@ -187,23 +206,18 @@ impl TxnInfo { nodes_used_by_txn.state_accesses.insert(hashed_addr); } - if let Some(c_usage) = &trace.code_usage { - match c_usage { - ContractCodeUsage::Read(c_hash) => { - if let Entry::Vacant(vacant) = contract_code_accessed.entry(*c_hash) { - vacant.insert(hash2code.resolve(c_hash)?); - } - } - ContractCodeUsage::Write(c_bytes) => { - let c_hash = hash(c_bytes); - - contract_code_accessed.insert(c_hash, c_bytes.clone()); - hash2code.insert_code(c_hash, c_bytes.clone()); - } + match code_usage { + Some(ContractCodeUsage::Read(hash)) => { + contract_code_accessed.insert(hash2code.get(*hash)?); + } + Some(ContractCodeUsage::Write(code)) => { + contract_code_accessed.insert(code.clone()); + hash2code.insert(code.to_vec()); } + None => {} } - if trace.self_destructed.unwrap_or_default() { + if self_destructed.unwrap_or_default() { nodes_used_by_txn .self_destructed_accounts .insert(hashed_addr); @@ -214,13 +228,12 @@ impl TxnInfo { nodes_used_by_txn.state_accesses.insert(hashed_addr); } - let accounts_with_storage_accesses: HashSet<_> = HashSet::from_iter( - nodes_used_by_txn - .storage_accesses - .iter() - .filter(|(_, slots)| !slots.is_empty()) - .map(|(addr, _)| *addr), - ); + let accounts_with_storage_accesses = nodes_used_by_txn + .storage_accesses + .iter() + .filter(|(_, slots)| !slots.is_empty()) + .map(|(addr, _)| *addr) + .collect::>(); let all_accounts_with_non_empty_storage = all_accounts_in_pre_image .iter() @@ -231,20 +244,17 @@ impl TxnInfo { .map(|(addr, data)| (*addr, data.storage_root)); nodes_used_by_txn - .state_accounts_with_no_accesses_but_storage_tries + .accts_with_unaccessed_storage .extend(accounts_with_storage_but_no_storage_accesses); - let txn_bytes = match txn.meta.byte_code.is_empty() { - false => Some(txn.meta.byte_code.clone()), - true => None, - }; - - let receipt_node_bytes = - process_rlped_receipt_node_bytes(txn.meta.new_receipt_trie_node_byte.clone()); - meta.push(TxnMetaState { - txn_bytes, - receipt_node_bytes, + txn_bytes: match txn.meta.byte_code.is_empty() { + false => Some(txn.meta.byte_code.clone()), + true => None, + }, + receipt_node_bytes: check_receipt_bytes( + txn.meta.new_receipt_trie_node_byte.clone(), + )?, gas_used: txn.meta.gas_used, created_accounts, }); @@ -258,35 +268,31 @@ impl TxnInfo { } } -fn process_rlped_receipt_node_bytes(raw_bytes: Vec) -> Vec { - match rlp::decode::(&raw_bytes) { - Ok(_) => raw_bytes, +fn check_receipt_bytes(bytes: Vec) -> anyhow::Result> { + match rlp::decode::(&bytes) { + Ok(_) => Ok(bytes), Err(_) => { - // Must be non-legacy. - rlp::decode::>(&raw_bytes).unwrap() + rlp::decode(&bytes).context("couldn't decode receipt as a legacy receipt or raw bytes") } } } -fn create_empty_code_access_map() -> HashMap> { - HashMap::from_iter(once((EMPTY_CODE_HASH, Vec::new()))) -} - /// Note that "*_accesses" includes writes. #[derive(Debug, Default)] pub(crate) struct NodesUsedByTxn { pub state_accesses: HashSet, - pub state_writes: HashMap, + pub state_writes: HashMap, // Note: All entries in `storage_writes` also appear in `storage_accesses`. pub storage_accesses: HashMap>, pub storage_writes: HashMap>>, - pub state_accounts_with_no_accesses_but_storage_tries: HashMap, + /// Hashed address -> storage root. + pub accts_with_unaccessed_storage: HashMap, pub self_destructed_accounts: HashSet, } -#[derive(Debug)] -pub(crate) struct StateTrieWrites { +#[derive(Debug, Default, PartialEq)] +pub(crate) struct StateWrite { pub balance: Option, pub nonce: Option, pub storage_trie_change: bool, @@ -295,6 +301,7 @@ pub(crate) struct StateTrieWrites { #[derive(Debug, Default)] pub(crate) struct TxnMetaState { + /// [`None`] if this is a dummy transaction inserted for padding. pub txn_bytes: Option>, pub receipt_node_bytes: Vec, pub gas_used: u64, diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs index 9510b1d75..ecf8a969c 100644 --- a/zero_bin/leader/src/client.rs +++ b/zero_bin/leader/src/client.rs @@ -1,6 +1,8 @@ use std::io::Write; use std::path::PathBuf; +use std::sync::Arc; +use alloy::rpc::types::{BlockId, BlockNumberOrTag, BlockTransactionsKind}; use alloy::transports::http::reqwest::Url; use anyhow::Result; use paladin::runtime::Runtime; @@ -35,31 +37,52 @@ pub(crate) async fn client_main( block_interval: BlockInterval, mut params: ProofParams, ) -> Result<()> { - let cached_provider = rpc::provider::CachedProvider::new(build_http_retry_provider( - rpc_params.rpc_url.clone(), - rpc_params.backoff, - rpc_params.max_retries, + use futures::{FutureExt, StreamExt}; + + let cached_provider = Arc::new(rpc::provider::CachedProvider::new( + build_http_retry_provider( + rpc_params.rpc_url.clone(), + rpc_params.backoff, + rpc_params.max_retries, + ), )); - let prover_input = rpc::prover_input( - &cached_provider, - block_interval, - params.checkpoint_block_number.into(), - rpc_params.rpc_type, - ) - .await?; + // Grab interval checkpoint block state trie + let checkpoint_state_trie_root = cached_provider + .get_block( + params.checkpoint_block_number.into(), + BlockTransactionsKind::Hashes, + ) + .await? + .header + .state_root; + + let mut block_prover_inputs = Vec::new(); + let mut block_interval = block_interval.into_bounded_stream()?; + while let Some(block_num) = block_interval.next().await { + let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); + // Get future of prover input for particular block. + let block_prover_input = rpc::block_prover_input( + cached_provider.clone(), + block_id, + checkpoint_state_trie_root, + rpc_params.rpc_type, + ) + .boxed(); + block_prover_inputs.push(block_prover_input); + } // If `keep_intermediate_proofs` is not set we only keep the last block // proof from the interval. It contains all the necessary information to // verify the whole sequence. - let proved_blocks = prover_input - .prove( - &runtime, - params.previous_proof.take(), - params.prover_config, - params.proof_output_dir.clone(), - ) - .await; + let proved_blocks = prover::prove( + block_prover_inputs, + &runtime, + params.previous_proof.take(), + params.prover_config, + params.proof_output_dir.clone(), + ) + .await; runtime.close().await?; let proved_blocks = proved_blocks?; diff --git a/zero_bin/leader/src/stdio.rs b/zero_bin/leader/src/stdio.rs index 0acaf88ad..88dd20aac 100644 --- a/zero_bin/leader/src/stdio.rs +++ b/zero_bin/leader/src/stdio.rs @@ -3,7 +3,7 @@ use std::io::{Read, Write}; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::{ProverConfig, ProverInput}; +use prover::{BlockProverInput, BlockProverInputFuture, ProverConfig}; use tracing::info; /// The main function for the stdio mode. @@ -16,13 +16,13 @@ pub(crate) async fn stdio_main( std::io::stdin().read_to_string(&mut buffer)?; let des = &mut serde_json::Deserializer::from_str(&buffer); - let prover_input = ProverInput { - blocks: serde_path_to_error::deserialize(des)?, - }; + let block_prover_inputs = serde_path_to_error::deserialize::<_, Vec>(des)? + .into_iter() + .map(Into::into) + .collect::>(); - let proved_blocks = prover_input - .prove(&runtime, previous, prover_config, None) - .await; + let proved_blocks = + prover::prove(block_prover_inputs, &runtime, previous, prover_config, None).await; runtime.close().await?; let proved_blocks = proved_blocks?; diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index d5d62f430..117bbffd9 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -24,7 +24,20 @@ pub struct ProverConfig { pub test_only: bool, } -#[derive(Debug, Deserialize, Serialize)] +pub type BlockProverInputFuture = std::pin::Pin< + Box> + Send>, +>; + +impl From for BlockProverInputFuture { + fn from(item: BlockProverInput) -> Self { + async fn _from(item: BlockProverInput) -> Result { + Ok(item) + } + Box::pin(_from(item)) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct BlockProverInput { pub block_trace: BlockTrace, pub other_data: OtherBlockData, @@ -185,105 +198,104 @@ impl BlockProverInput { } } -#[derive(Debug, Deserialize, Serialize)] -pub struct ProverInput { - pub blocks: Vec, +/// Prove all the blocks in the input, or simulate their execution depending on +/// the selected prover configuration. Return the list of block numbers that are +/// proved and if the proof data is not saved to disk, return the generated +/// block proofs as well. +pub async fn prove( + block_prover_inputs: Vec, + runtime: &Runtime, + previous_proof: Option, + prover_config: ProverConfig, + proof_output_dir: Option, +) -> Result)>> { + let mut prev: Option>> = + previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>); + + let mut results = FuturesOrdered::new(); + for block_prover_input in block_prover_inputs { + let (tx, rx) = oneshot::channel::(); + let proof_output_dir = proof_output_dir.clone(); + let previous_block_proof = prev.take(); + let fut = async move { + // Get the prover input data from the external source (e.g. Erigon node). + let block = block_prover_input.await?; + let block_number = block.get_block_number(); + info!("Proving block {block_number}"); + + // Prove the block + let block_proof = if prover_config.test_only { + block + .prove_test(runtime, previous_block_proof, prover_config) + .then(move |proof| async move { + let proof = proof?; + let block_number = proof.b_height; + + // Write latest generated proof to disk if proof_output_dir is provided + // or alternatively return proof as function result. + let return_proof: Option = + if let Some(output_dir) = proof_output_dir { + write_proof_to_dir(output_dir, &proof).await?; + None + } else { + Some(proof.clone()) + }; + + if tx.send(proof).is_err() { + anyhow::bail!("Failed to send proof"); + } + + Ok((block_number, return_proof)) + }) + .await? + } else { + block + .prove(runtime, previous_block_proof, prover_config) + .then(move |proof| async move { + let proof = proof?; + let block_number = proof.b_height; + + // Write latest generated proof to disk if proof_output_dir is provided + // or alternatively return proof as function result. + let return_proof: Option = + if let Some(output_dir) = proof_output_dir { + write_proof_to_dir(output_dir, &proof).await?; + None + } else { + Some(proof.clone()) + }; + + if tx.send(proof).is_err() { + anyhow::bail!("Failed to send proof"); + } + + Ok((block_number, return_proof)) + }) + .await? + }; + + Ok(block_proof) + } + .boxed(); + prev = Some(Box::pin(rx.map_err(anyhow::Error::new))); + results.push_back(fut); + } + + results.try_collect().await } -impl ProverInput { - /// Prove all the blocks in the input. - /// Return the list of block numbers that are proved and if the proof data - /// is not saved to disk, return the generated block proofs as well. - pub async fn prove( - self, - runtime: &Runtime, - previous_proof: Option, - prover_config: ProverConfig, - proof_output_dir: Option, - ) -> Result)>> { - let mut prev: Option>> = - previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>); - - let results: FuturesOrdered<_> = self - .blocks - .into_iter() - .map(|block| { - let (tx, rx) = oneshot::channel::(); - - // Prove the block - let proof_output_dir = proof_output_dir.clone(); - let fut = if prover_config.test_only { - block - .prove_test(runtime, prev.take(), prover_config) - .then(move |proof| async move { - let proof = proof?; - let block_number = proof.b_height; - - if tx.send(proof).is_err() { - anyhow::bail!("Failed to send proof"); - } - - // We ignore the returned dummy proof in test-only mode. - Ok((block_number, None)) - }) - .boxed() - } else { - block - .prove(runtime, prev.take(), prover_config) - .then(move |proof| async move { - let proof = proof?; - let block_number = proof.b_height; - - // Write latest generated proof to disk if proof_output_dir is provided. - let return_proof: Option = - if proof_output_dir.is_some() { - ProverInput::write_proof(proof_output_dir, &proof).await?; - None - } else { - Some(proof.clone()) - }; - - if tx.send(proof).is_err() { - anyhow::bail!("Failed to send proof"); - } - - Ok((block_number, return_proof)) - }) - .boxed() - }; - - prev = Some(Box::pin(rx.map_err(anyhow::Error::new))); - - fut - }) - .collect(); +/// Write the proof to the `output_dir` directory. +async fn write_proof_to_dir(output_dir: PathBuf, proof: &GeneratedBlockProof) -> Result<()> { + let proof_serialized = serde_json::to_vec(proof)?; + let block_proof_file_path = + generate_block_proof_file_name(&output_dir.to_str(), proof.b_height); - results.try_collect().await + if let Some(parent) = block_proof_file_path.parent() { + tokio::fs::create_dir_all(parent).await?; } - /// Write the proof to the disk (if `output_dir` is provided) or stdout. - pub(crate) async fn write_proof( - output_dir: Option, - proof: &GeneratedBlockProof, - ) -> Result<()> { - let proof_serialized = serde_json::to_vec(proof)?; - let block_proof_file_path = - output_dir.map(|path| generate_block_proof_file_name(&path.to_str(), proof.b_height)); - match block_proof_file_path { - Some(p) => { - if let Some(parent) = p.parent() { - tokio::fs::create_dir_all(parent).await?; - } - - let mut f = tokio::fs::File::create(p).await?; - f.write_all(&proof_serialized) - .await - .context("Failed to write proof to disk") - } - None => tokio::io::stdout() - .write_all(&proof_serialized) - .await - .context("Failed to write proof to stdout"), - } - } + let mut f = tokio::fs::File::create(block_proof_file_path).await?; + f.write_all(&proof_serialized) + .await + .context("Failed to write proof to disk") } diff --git a/zero_bin/rpc/Cargo.toml b/zero_bin/rpc/Cargo.toml index 14f447cef..cbd2df11d 100644 --- a/zero_bin/rpc/Cargo.toml +++ b/zero_bin/rpc/Cargo.toml @@ -26,6 +26,7 @@ tower = { workspace = true, features = ["retry"] } trace_decoder = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } +itertools = {workspace = true} # Local dependencies compat = { workspace = true } diff --git a/zero_bin/rpc/src/jerigon.rs b/zero_bin/rpc/src/jerigon.rs index 470b2dffb..891421971 100644 --- a/zero_bin/rpc/src/jerigon.rs +++ b/zero_bin/rpc/src/jerigon.rs @@ -19,7 +19,7 @@ pub struct ZeroTxResult { } pub async fn block_prover_input( - cached_provider: &CachedProvider, + cached_provider: std::sync::Arc>, target_block_id: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result diff --git a/zero_bin/rpc/src/lib.rs b/zero_bin/rpc/src/lib.rs index 345cf8c96..cc6ddf2f1 100644 --- a/zero_bin/rpc/src/lib.rs +++ b/zero_bin/rpc/src/lib.rs @@ -1,7 +1,9 @@ +use std::sync::Arc; + use alloy::{ primitives::B256, providers::Provider, - rpc::types::eth::{BlockId, BlockNumberOrTag, BlockTransactionsKind, Withdrawal}, + rpc::types::eth::{BlockId, BlockTransactionsKind, Withdrawal}, transports::Transport, }; use anyhow::Context as _; @@ -9,9 +11,8 @@ use clap::ValueEnum; use compat::Compat; use evm_arithmetization::proof::{BlockHashes, BlockMetadata}; use futures::{StreamExt as _, TryStreamExt as _}; -use prover::ProverInput; +use prover::BlockProverInput; use trace_decoder::{BlockLevelData, OtherBlockData}; -use zero_bin_common::block_interval::BlockInterval; pub mod jerigon; pub mod native; @@ -23,56 +24,36 @@ use crate::provider::CachedProvider; const PREVIOUS_HASHES_COUNT: usize = 256; /// The RPC type. -#[derive(ValueEnum, Clone, Debug)] +#[derive(ValueEnum, Clone, Debug, Copy)] pub enum RpcType { Jerigon, Native, } -/// Obtain the prover input for a given block interval -pub async fn prover_input( - cached_provider: &CachedProvider, - block_interval: BlockInterval, - checkpoint_block_id: BlockId, +/// Obtain the prover input for one block +pub async fn block_prover_input( + cached_provider: Arc>, + block_id: BlockId, + checkpoint_state_trie_root: B256, rpc_type: RpcType, -) -> anyhow::Result +) -> Result where ProviderT: Provider, TransportT: Transport + Clone, { - // Grab interval checkpoint block state trie - let checkpoint_state_trie_root = cached_provider - .get_block(checkpoint_block_id, BlockTransactionsKind::Hashes) - .await? - .header - .state_root; - - let mut block_proofs = Vec::new(); - let mut block_interval = block_interval.into_bounded_stream()?; - - while let Some(block_num) = block_interval.next().await { - let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); - let block_prover_input = match rpc_type { - RpcType::Jerigon => { - jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root) - .await? - } - RpcType::Native => { - native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root) - .await? - } - }; - - block_proofs.push(block_prover_input); + match rpc_type { + RpcType::Jerigon => { + jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await + } + RpcType::Native => { + native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await + } } - Ok(ProverInput { - blocks: block_proofs, - }) } /// Fetches other block data async fn fetch_other_block_data( - cached_provider: &CachedProvider, + cached_provider: Arc>, target_block_id: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result @@ -80,6 +61,7 @@ where ProviderT: Provider, TransportT: Transport + Clone, { + use itertools::Itertools; let target_block = cached_provider .get_block(target_block_id, BlockTransactionsKind::Hashes) .await?; @@ -102,28 +84,33 @@ where }) .take(PREVIOUS_HASHES_COUNT + 1) .filter(|i| *i >= 0) + .chunks(2) + .into_iter() + .map(|mut chunk| { + // We convert to tuple of (current block, optional previous block) + let first = chunk + .next() + .expect("must be valid according to itertools::Iterator::chunks definition"); + let second = chunk.next(); + (first, second) + }) .collect::>(); + let concurrency = previous_block_numbers.len(); let collected_hashes = futures::stream::iter( previous_block_numbers - .chunks(2) // we get hash for previous and current block with one request - .map(|block_numbers| { + .into_iter() // we get hash for previous and current block with one request + .map(|(current_block_number, previous_block_number)| { let cached_provider = &cached_provider; - let block_num = &block_numbers[0]; - let previos_block_num = if block_numbers.len() > 1 { - Some(block_numbers[1]) - } else { - // For genesis block - None - }; + let block_num = current_block_number; async move { let block = cached_provider - .get_block((*block_num as u64).into(), BlockTransactionsKind::Hashes) + .get_block((block_num as u64).into(), BlockTransactionsKind::Hashes) .await .context("couldn't get block")?; anyhow::Ok([ - (block.header.hash, Some(*block_num)), - (Some(block.header.parent_hash), previos_block_num), + (block.header.hash, Some(block_num)), + (Some(block.header.parent_hash), previous_block_number), ]) } }), diff --git a/zero_bin/rpc/src/main.rs b/zero_bin/rpc/src/main.rs index 444e89e3b..3c72ac902 100644 --- a/zero_bin/rpc/src/main.rs +++ b/zero_bin/rpc/src/main.rs @@ -1,7 +1,10 @@ -use std::{env, io}; +use std::env; +use std::sync::Arc; use alloy::rpc::types::eth::BlockId; +use alloy::rpc::types::{BlockNumberOrTag, BlockTransactionsKind}; use clap::{Parser, ValueHint}; +use futures::StreamExt; use rpc::provider::CachedProvider; use rpc::{retry::build_http_retry_provider, RpcType}; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -55,22 +58,36 @@ impl Cli { checkpoint_block_number.unwrap_or((start_block - 1).into()); let block_interval = BlockInterval::Range(start_block..end_block + 1); - let cached_provider = CachedProvider::new(build_http_retry_provider( + let cached_provider = Arc::new(CachedProvider::new(build_http_retry_provider( rpc_url.clone(), backoff, max_retries, - )); + ))); - // Retrieve prover input from the Erigon node - let prover_input = rpc::prover_input( - &cached_provider, - block_interval, - checkpoint_block_number, - rpc_type, - ) - .await?; + // Grab interval checkpoint block state trie + let checkpoint_state_trie_root = cached_provider + .get_block(checkpoint_block_number, BlockTransactionsKind::Hashes) + .await? + .header + .state_root; - serde_json::to_writer_pretty(io::stdout(), &prover_input.blocks)?; + let mut block_prover_inputs = Vec::new(); + let mut block_interval = block_interval.clone().into_bounded_stream()?; + while let Some(block_num) = block_interval.next().await { + let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); + // Get the prover input for particular block. + let result = rpc::block_prover_input( + cached_provider.clone(), + block_id, + checkpoint_state_trie_root, + rpc_type, + ) + .await?; + + block_prover_inputs.push(result); + } + + serde_json::to_writer_pretty(std::io::stdout(), &block_prover_inputs)?; } } Ok(()) diff --git a/zero_bin/rpc/src/native/mod.rs b/zero_bin/rpc/src/native/mod.rs index 892a799d6..1f61d7b26 100644 --- a/zero_bin/rpc/src/native/mod.rs +++ b/zero_bin/rpc/src/native/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use alloy::{ primitives::B256, @@ -19,7 +20,7 @@ type CodeDb = HashMap<__compat_primitive_types::H256, Vec>; /// Fetches the prover input for the given BlockId. pub async fn block_prover_input( - provider: &CachedProvider, + provider: Arc>, block_number: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result @@ -28,8 +29,8 @@ where TransportT: Transport + Clone, { let (block_trace, other_data) = try_join!( - process_block_trace(provider, block_number), - crate::fetch_other_block_data(provider, block_number, checkpoint_state_trie_root,) + process_block_trace(provider.clone(), block_number), + crate::fetch_other_block_data(provider.clone(), block_number, checkpoint_state_trie_root,) )?; Ok(BlockProverInput { @@ -40,7 +41,7 @@ where /// Processes the block with the given block number and returns the block trace. async fn process_block_trace( - cached_provider: &CachedProvider, + cached_provider: Arc>, block_number: BlockId, ) -> anyhow::Result where diff --git a/zero_bin/rpc/src/native/state.rs b/zero_bin/rpc/src/native/state.rs index 5fd9b539c..b5017b394 100644 --- a/zero_bin/rpc/src/native/state.rs +++ b/zero_bin/rpc/src/native/state.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use alloy::{ primitives::{keccak256, Address, StorageKey, B256, U256}, @@ -20,7 +21,7 @@ use crate::Compat; /// Processes the state witness for the given block. pub async fn process_state_witness( - cached_provider: &CachedProvider, + cached_provider: Arc>, block: Block, txn_infos: &[TxnInfo], ) -> anyhow::Result @@ -115,7 +116,7 @@ fn insert_beacon_roots_update( async fn generate_state_witness( prev_state_root: B256, accounts_state: HashMap>, - cached_provider: &CachedProvider, + cached_provider: Arc>, block_number: u64, ) -> anyhow::Result<( PartialTrieBuilder, @@ -164,7 +165,7 @@ where /// Fetches the proof data for the given accounts and associated storage keys. async fn fetch_proof_data( accounts_state: HashMap>, - provider: &CachedProvider, + provider: Arc>, block_number: u64, ) -> anyhow::Result<( Vec<(Address, EIP1186AccountProofResponse)>, @@ -177,20 +178,23 @@ where let account_proofs_fut = accounts_state .clone() .into_iter() - .map(|(address, keys)| async move { - let proof = provider - .as_provider() - .get_proof(address, keys.into_iter().collect()) - .block_id((block_number - 1).into()) - .await - .context("Failed to get proof for account")?; - anyhow::Result::Ok((address, proof)) + .map(|(address, keys)| { + let provider = provider.clone(); + async move { + let proof = provider + .as_provider() + .get_proof(address, keys.into_iter().collect()) + .block_id((block_number - 1).into()) + .await + .context("Failed to get proof for account")?; + anyhow::Result::Ok((address, proof)) + } }) .collect::>(); - let next_account_proofs_fut = accounts_state - .into_iter() - .map(|(address, keys)| async move { + let next_account_proofs_fut = accounts_state.into_iter().map(|(address, keys)| { + let provider = provider.clone(); + async move { let proof = provider .as_provider() .get_proof(address, keys.into_iter().collect()) @@ -198,7 +202,8 @@ where .await .context("Failed to get proof for account")?; anyhow::Result::Ok((address, proof)) - }); + } + }); try_join( try_join_all(account_proofs_fut),