From 6bcf06b2e8af11b14d25238ce88111c216ffb974 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski <atanmarko@users.noreply.github.com> Date: Thu, 22 Aug 2024 12:14:04 +0200 Subject: [PATCH 1/3] feat: retrieve prover input per block (#499) * feat: retrieve prover input per block * fix: cleanup * fix: into implementation * fix: nitpick * fix: review * fix: review and cleanup --- Cargo.lock | 1 + zero_bin/leader/src/client.rs | 61 ++++++++---- zero_bin/leader/src/stdio.rs | 20 ++-- zero_bin/prover/src/lib.rs | 165 +++++++++++++++---------------- zero_bin/rpc/Cargo.toml | 1 + zero_bin/rpc/src/jerigon.rs | 2 +- zero_bin/rpc/src/lib.rs | 87 +++++++--------- zero_bin/rpc/src/main.rs | 41 +++++--- zero_bin/rpc/src/native/mod.rs | 9 +- zero_bin/rpc/src/native/state.rs | 35 ++++--- 10 files changed, 231 insertions(+), 191 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33b8be650..a75cf45ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4167,6 +4167,7 @@ dependencies = [ "evm_arithmetization", "futures", "hex", + "itertools 0.13.0", "lru", "mpt_trie", "primitive-types 0.12.2", diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs index 555bc74aa..8fbcf1bd8 100644 --- a/zero_bin/leader/src/client.rs +++ b/zero_bin/leader/src/client.rs @@ -1,6 +1,8 @@ use std::io::Write; use std::path::PathBuf; +use std::sync::Arc; +use alloy::rpc::types::{BlockId, BlockNumberOrTag, BlockTransactionsKind}; use alloy::transports::http::reqwest::Url; use anyhow::Result; use paladin::runtime::Runtime; @@ -34,31 +36,52 @@ pub(crate) async fn client_main( block_interval: BlockInterval, mut params: ProofParams, ) -> Result<()> { - let cached_provider = rpc::provider::CachedProvider::new(build_http_retry_provider( - rpc_params.rpc_url.clone(), - rpc_params.backoff, - rpc_params.max_retries, + use futures::{FutureExt, StreamExt}; + + let cached_provider = Arc::new(rpc::provider::CachedProvider::new( + build_http_retry_provider( + rpc_params.rpc_url.clone(), + rpc_params.backoff, + rpc_params.max_retries, + ), )); - let prover_input = rpc::prover_input( - &cached_provider, - block_interval, - params.checkpoint_block_number.into(), - rpc_params.rpc_type, - ) - .await?; + // Grab interval checkpoint block state trie + let checkpoint_state_trie_root = cached_provider + .get_block( + params.checkpoint_block_number.into(), + BlockTransactionsKind::Hashes, + ) + .await? + .header + .state_root; + + let mut block_prover_inputs = Vec::new(); + let mut block_interval = block_interval.into_bounded_stream()?; + while let Some(block_num) = block_interval.next().await { + let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); + // Get future of prover input for particular block. + let block_prover_input = rpc::block_prover_input( + cached_provider.clone(), + block_id, + checkpoint_state_trie_root, + rpc_params.rpc_type, + ) + .boxed(); + block_prover_inputs.push(block_prover_input); + } // If `keep_intermediate_proofs` is not set we only keep the last block // proof from the interval. It contains all the necessary information to // verify the whole sequence. - let proved_blocks = prover_input - .prove( - &runtime, - params.previous_proof.take(), - params.save_inputs_on_error, - params.proof_output_dir.clone(), - ) - .await; + let proved_blocks = prover::prove( + block_prover_inputs, + &runtime, + params.previous_proof.take(), + params.save_inputs_on_error, + params.proof_output_dir.clone(), + ) + .await; runtime.close().await?; let proved_blocks = proved_blocks?; diff --git a/zero_bin/leader/src/stdio.rs b/zero_bin/leader/src/stdio.rs index 76bcd089b..403ea2a6a 100644 --- a/zero_bin/leader/src/stdio.rs +++ b/zero_bin/leader/src/stdio.rs @@ -3,7 +3,7 @@ use std::io::{Read, Write}; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::ProverInput; +use prover::{BlockProverInput, BlockProverInputFuture}; use tracing::info; /// The main function for the stdio mode. @@ -16,13 +16,19 @@ pub(crate) async fn stdio_main( std::io::stdin().read_to_string(&mut buffer)?; let des = &mut serde_json::Deserializer::from_str(&buffer); - let prover_input = ProverInput { - blocks: serde_path_to_error::deserialize(des)?, - }; + let block_prover_inputs = serde_path_to_error::deserialize::<_, Vec<BlockProverInput>>(des)? + .into_iter() + .map(Into::into) + .collect::<Vec<BlockProverInputFuture>>(); - let proved_blocks = prover_input - .prove(&runtime, previous, save_inputs_on_error, None) - .await; + let proved_blocks = prover::prove( + block_prover_inputs, + &runtime, + previous, + save_inputs_on_error, + None, + ) + .await; runtime.close().await?; let proved_blocks = proved_blocks?; diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index a43c74104..a30a4d3f3 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -18,7 +18,20 @@ use trace_decoder::{BlockTrace, OtherBlockData}; use tracing::info; use zero_bin_common::fs::generate_block_proof_file_name; -#[derive(Debug, Deserialize, Serialize)] +pub type BlockProverInputFuture = std::pin::Pin< + Box<dyn Future<Output = std::result::Result<BlockProverInput, anyhow::Error>> + Send>, +>; + +impl From<BlockProverInput> for BlockProverInputFuture { + fn from(item: BlockProverInput) -> Self { + async fn _from(item: BlockProverInput) -> Result<BlockProverInput, anyhow::Error> { + Ok(item) + } + Box::pin(_from(item)) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct BlockProverInput { pub block_trace: BlockTrace, pub other_data: OtherBlockData, @@ -113,91 +126,77 @@ impl BlockProverInput { } } -#[derive(Debug, Deserialize, Serialize)] -pub struct ProverInput { - pub blocks: Vec<BlockProverInput>, +/// Prove all the blocks in the input. +/// Return the list of block numbers that are proved and if the proof data +/// is not saved to disk, return the generated block proofs as well. +pub async fn prove( + block_prover_inputs: Vec<BlockProverInputFuture>, + runtime: &Runtime, + previous_proof: Option<GeneratedBlockProof>, + save_inputs_on_error: bool, + proof_output_dir: Option<PathBuf>, +) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> { + let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> = + previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>); + + let mut results = FuturesOrdered::new(); + for block_prover_input in block_prover_inputs { + let (tx, rx) = oneshot::channel::<GeneratedBlockProof>(); + let proof_output_dir = proof_output_dir.clone(); + let previos_block_proof = prev.take(); + let fut = async move { + // Get the prover input data from the external source (e.g. Erigon node). + let block = block_prover_input.await?; + let block_number = block.get_block_number(); + info!("Proving block {block_number}"); + + // Prove the block + let block_proof = block + .prove(runtime, previos_block_proof, save_inputs_on_error) + .then(move |proof| async move { + let proof = proof?; + let block_number = proof.b_height; + + // Write latest generated proof to disk if proof_output_dir is provided + // or alternatively return proof as function result. + let return_proof: Option<GeneratedBlockProof> = + if let Some(output_dir) = proof_output_dir { + write_proof_to_dir(output_dir, &proof).await?; + None + } else { + Some(proof.clone()) + }; + + if tx.send(proof).is_err() { + anyhow::bail!("Failed to send proof"); + } + + Ok((block_number, return_proof)) + }) + .await?; + + Ok(block_proof) + } + .boxed(); + prev = Some(Box::pin(rx.map_err(anyhow::Error::new))); + results.push_back(fut); + } + + results.try_collect().await } -impl ProverInput { - /// Prove all the blocks in the input. - /// Return the list of block numbers that are proved and if the proof data - /// is not saved to disk, return the generated block proofs as well. - pub async fn prove( - self, - runtime: &Runtime, - previous_proof: Option<GeneratedBlockProof>, - save_inputs_on_error: bool, - proof_output_dir: Option<PathBuf>, - ) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> { - let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> = - previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>); - - let results: FuturesOrdered<_> = self - .blocks - .into_iter() - .map(|block| { - let block_number = block.get_block_number(); - info!("Proving block {block_number}"); - - let (tx, rx) = oneshot::channel::<GeneratedBlockProof>(); - - // Prove the block - let proof_output_dir = proof_output_dir.clone(); - let fut = block - .prove(runtime, prev.take(), save_inputs_on_error) - .then(move |proof| async move { - let proof = proof?; - let block_number = proof.b_height; - - // Write latest generated proof to disk if proof_output_dir is provided - let return_proof: Option<GeneratedBlockProof> = - if proof_output_dir.is_some() { - ProverInput::write_proof(proof_output_dir, &proof).await?; - None - } else { - Some(proof.clone()) - }; - - if tx.send(proof).is_err() { - anyhow::bail!("Failed to send proof"); - } - - Ok((block_number, return_proof)) - }) - .boxed(); - - prev = Some(Box::pin(rx.map_err(anyhow::Error::new))); - - fut - }) - .collect(); +/// Write the proof to the `output_dir` directory. +async fn write_proof_to_dir(output_dir: PathBuf, proof: &GeneratedBlockProof) -> Result<()> { + let proof_serialized = serde_json::to_vec(proof)?; + let block_proof_file_path = + generate_block_proof_file_name(&output_dir.to_str(), proof.b_height); - results.try_collect().await + if let Some(parent) = block_proof_file_path.parent() { + tokio::fs::create_dir_all(parent).await?; } - /// Write the proof to the disk (if `output_dir` is provided) or stdout. - pub(crate) async fn write_proof( - output_dir: Option<PathBuf>, - proof: &GeneratedBlockProof, - ) -> Result<()> { - let proof_serialized = serde_json::to_vec(proof)?; - let block_proof_file_path = - output_dir.map(|path| generate_block_proof_file_name(&path.to_str(), proof.b_height)); - match block_proof_file_path { - Some(p) => { - if let Some(parent) = p.parent() { - tokio::fs::create_dir_all(parent).await?; - } - - let mut f = tokio::fs::File::create(p).await?; - f.write_all(&proof_serialized) - .await - .context("Failed to write proof to disk") - } - None => tokio::io::stdout() - .write_all(&proof_serialized) - .await - .context("Failed to write proof to stdout"), - } - } + let mut f = tokio::fs::File::create(block_proof_file_path).await?; + f.write_all(&proof_serialized) + .await + .context("Failed to write proof to disk") } diff --git a/zero_bin/rpc/Cargo.toml b/zero_bin/rpc/Cargo.toml index 14f447cef..cbd2df11d 100644 --- a/zero_bin/rpc/Cargo.toml +++ b/zero_bin/rpc/Cargo.toml @@ -26,6 +26,7 @@ tower = { workspace = true, features = ["retry"] } trace_decoder = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } +itertools = {workspace = true} # Local dependencies compat = { workspace = true } diff --git a/zero_bin/rpc/src/jerigon.rs b/zero_bin/rpc/src/jerigon.rs index 470b2dffb..891421971 100644 --- a/zero_bin/rpc/src/jerigon.rs +++ b/zero_bin/rpc/src/jerigon.rs @@ -19,7 +19,7 @@ pub struct ZeroTxResult { } pub async fn block_prover_input<ProviderT, TransportT>( - cached_provider: &CachedProvider<ProviderT, TransportT>, + cached_provider: std::sync::Arc<CachedProvider<ProviderT, TransportT>>, target_block_id: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result<BlockProverInput> diff --git a/zero_bin/rpc/src/lib.rs b/zero_bin/rpc/src/lib.rs index 345cf8c96..cc6ddf2f1 100644 --- a/zero_bin/rpc/src/lib.rs +++ b/zero_bin/rpc/src/lib.rs @@ -1,7 +1,9 @@ +use std::sync::Arc; + use alloy::{ primitives::B256, providers::Provider, - rpc::types::eth::{BlockId, BlockNumberOrTag, BlockTransactionsKind, Withdrawal}, + rpc::types::eth::{BlockId, BlockTransactionsKind, Withdrawal}, transports::Transport, }; use anyhow::Context as _; @@ -9,9 +11,8 @@ use clap::ValueEnum; use compat::Compat; use evm_arithmetization::proof::{BlockHashes, BlockMetadata}; use futures::{StreamExt as _, TryStreamExt as _}; -use prover::ProverInput; +use prover::BlockProverInput; use trace_decoder::{BlockLevelData, OtherBlockData}; -use zero_bin_common::block_interval::BlockInterval; pub mod jerigon; pub mod native; @@ -23,56 +24,36 @@ use crate::provider::CachedProvider; const PREVIOUS_HASHES_COUNT: usize = 256; /// The RPC type. -#[derive(ValueEnum, Clone, Debug)] +#[derive(ValueEnum, Clone, Debug, Copy)] pub enum RpcType { Jerigon, Native, } -/// Obtain the prover input for a given block interval -pub async fn prover_input<ProviderT, TransportT>( - cached_provider: &CachedProvider<ProviderT, TransportT>, - block_interval: BlockInterval, - checkpoint_block_id: BlockId, +/// Obtain the prover input for one block +pub async fn block_prover_input<ProviderT, TransportT>( + cached_provider: Arc<CachedProvider<ProviderT, TransportT>>, + block_id: BlockId, + checkpoint_state_trie_root: B256, rpc_type: RpcType, -) -> anyhow::Result<ProverInput> +) -> Result<BlockProverInput, anyhow::Error> where ProviderT: Provider<TransportT>, TransportT: Transport + Clone, { - // Grab interval checkpoint block state trie - let checkpoint_state_trie_root = cached_provider - .get_block(checkpoint_block_id, BlockTransactionsKind::Hashes) - .await? - .header - .state_root; - - let mut block_proofs = Vec::new(); - let mut block_interval = block_interval.into_bounded_stream()?; - - while let Some(block_num) = block_interval.next().await { - let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); - let block_prover_input = match rpc_type { - RpcType::Jerigon => { - jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root) - .await? - } - RpcType::Native => { - native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root) - .await? - } - }; - - block_proofs.push(block_prover_input); + match rpc_type { + RpcType::Jerigon => { + jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await + } + RpcType::Native => { + native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await + } } - Ok(ProverInput { - blocks: block_proofs, - }) } /// Fetches other block data async fn fetch_other_block_data<ProviderT, TransportT>( - cached_provider: &CachedProvider<ProviderT, TransportT>, + cached_provider: Arc<CachedProvider<ProviderT, TransportT>>, target_block_id: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result<OtherBlockData> @@ -80,6 +61,7 @@ where ProviderT: Provider<TransportT>, TransportT: Transport + Clone, { + use itertools::Itertools; let target_block = cached_provider .get_block(target_block_id, BlockTransactionsKind::Hashes) .await?; @@ -102,28 +84,33 @@ where }) .take(PREVIOUS_HASHES_COUNT + 1) .filter(|i| *i >= 0) + .chunks(2) + .into_iter() + .map(|mut chunk| { + // We convert to tuple of (current block, optional previous block) + let first = chunk + .next() + .expect("must be valid according to itertools::Iterator::chunks definition"); + let second = chunk.next(); + (first, second) + }) .collect::<Vec<_>>(); + let concurrency = previous_block_numbers.len(); let collected_hashes = futures::stream::iter( previous_block_numbers - .chunks(2) // we get hash for previous and current block with one request - .map(|block_numbers| { + .into_iter() // we get hash for previous and current block with one request + .map(|(current_block_number, previous_block_number)| { let cached_provider = &cached_provider; - let block_num = &block_numbers[0]; - let previos_block_num = if block_numbers.len() > 1 { - Some(block_numbers[1]) - } else { - // For genesis block - None - }; + let block_num = current_block_number; async move { let block = cached_provider - .get_block((*block_num as u64).into(), BlockTransactionsKind::Hashes) + .get_block((block_num as u64).into(), BlockTransactionsKind::Hashes) .await .context("couldn't get block")?; anyhow::Ok([ - (block.header.hash, Some(*block_num)), - (Some(block.header.parent_hash), previos_block_num), + (block.header.hash, Some(block_num)), + (Some(block.header.parent_hash), previous_block_number), ]) } }), diff --git a/zero_bin/rpc/src/main.rs b/zero_bin/rpc/src/main.rs index 444e89e3b..3c72ac902 100644 --- a/zero_bin/rpc/src/main.rs +++ b/zero_bin/rpc/src/main.rs @@ -1,7 +1,10 @@ -use std::{env, io}; +use std::env; +use std::sync::Arc; use alloy::rpc::types::eth::BlockId; +use alloy::rpc::types::{BlockNumberOrTag, BlockTransactionsKind}; use clap::{Parser, ValueHint}; +use futures::StreamExt; use rpc::provider::CachedProvider; use rpc::{retry::build_http_retry_provider, RpcType}; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -55,22 +58,36 @@ impl Cli { checkpoint_block_number.unwrap_or((start_block - 1).into()); let block_interval = BlockInterval::Range(start_block..end_block + 1); - let cached_provider = CachedProvider::new(build_http_retry_provider( + let cached_provider = Arc::new(CachedProvider::new(build_http_retry_provider( rpc_url.clone(), backoff, max_retries, - )); + ))); - // Retrieve prover input from the Erigon node - let prover_input = rpc::prover_input( - &cached_provider, - block_interval, - checkpoint_block_number, - rpc_type, - ) - .await?; + // Grab interval checkpoint block state trie + let checkpoint_state_trie_root = cached_provider + .get_block(checkpoint_block_number, BlockTransactionsKind::Hashes) + .await? + .header + .state_root; - serde_json::to_writer_pretty(io::stdout(), &prover_input.blocks)?; + let mut block_prover_inputs = Vec::new(); + let mut block_interval = block_interval.clone().into_bounded_stream()?; + while let Some(block_num) = block_interval.next().await { + let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); + // Get the prover input for particular block. + let result = rpc::block_prover_input( + cached_provider.clone(), + block_id, + checkpoint_state_trie_root, + rpc_type, + ) + .await?; + + block_prover_inputs.push(result); + } + + serde_json::to_writer_pretty(std::io::stdout(), &block_prover_inputs)?; } } Ok(()) diff --git a/zero_bin/rpc/src/native/mod.rs b/zero_bin/rpc/src/native/mod.rs index 892a799d6..1f61d7b26 100644 --- a/zero_bin/rpc/src/native/mod.rs +++ b/zero_bin/rpc/src/native/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use alloy::{ primitives::B256, @@ -19,7 +20,7 @@ type CodeDb = HashMap<__compat_primitive_types::H256, Vec<u8>>; /// Fetches the prover input for the given BlockId. pub async fn block_prover_input<ProviderT, TransportT>( - provider: &CachedProvider<ProviderT, TransportT>, + provider: Arc<CachedProvider<ProviderT, TransportT>>, block_number: BlockId, checkpoint_state_trie_root: B256, ) -> anyhow::Result<BlockProverInput> @@ -28,8 +29,8 @@ where TransportT: Transport + Clone, { let (block_trace, other_data) = try_join!( - process_block_trace(provider, block_number), - crate::fetch_other_block_data(provider, block_number, checkpoint_state_trie_root,) + process_block_trace(provider.clone(), block_number), + crate::fetch_other_block_data(provider.clone(), block_number, checkpoint_state_trie_root,) )?; Ok(BlockProverInput { @@ -40,7 +41,7 @@ where /// Processes the block with the given block number and returns the block trace. async fn process_block_trace<ProviderT, TransportT>( - cached_provider: &CachedProvider<ProviderT, TransportT>, + cached_provider: Arc<CachedProvider<ProviderT, TransportT>>, block_number: BlockId, ) -> anyhow::Result<BlockTrace> where diff --git a/zero_bin/rpc/src/native/state.rs b/zero_bin/rpc/src/native/state.rs index 5fd9b539c..b5017b394 100644 --- a/zero_bin/rpc/src/native/state.rs +++ b/zero_bin/rpc/src/native/state.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use alloy::{ primitives::{keccak256, Address, StorageKey, B256, U256}, @@ -20,7 +21,7 @@ use crate::Compat; /// Processes the state witness for the given block. pub async fn process_state_witness<ProviderT, TransportT>( - cached_provider: &CachedProvider<ProviderT, TransportT>, + cached_provider: Arc<CachedProvider<ProviderT, TransportT>>, block: Block, txn_infos: &[TxnInfo], ) -> anyhow::Result<BlockTraceTriePreImages> @@ -115,7 +116,7 @@ fn insert_beacon_roots_update( async fn generate_state_witness<ProviderT, TransportT>( prev_state_root: B256, accounts_state: HashMap<Address, HashSet<StorageKey>>, - cached_provider: &CachedProvider<ProviderT, TransportT>, + cached_provider: Arc<CachedProvider<ProviderT, TransportT>>, block_number: u64, ) -> anyhow::Result<( PartialTrieBuilder<HashedPartialTrie>, @@ -164,7 +165,7 @@ where /// Fetches the proof data for the given accounts and associated storage keys. async fn fetch_proof_data<ProviderT, TransportT>( accounts_state: HashMap<Address, HashSet<StorageKey>>, - provider: &CachedProvider<ProviderT, TransportT>, + provider: Arc<CachedProvider<ProviderT, TransportT>>, block_number: u64, ) -> anyhow::Result<( Vec<(Address, EIP1186AccountProofResponse)>, @@ -177,20 +178,23 @@ where let account_proofs_fut = accounts_state .clone() .into_iter() - .map(|(address, keys)| async move { - let proof = provider - .as_provider() - .get_proof(address, keys.into_iter().collect()) - .block_id((block_number - 1).into()) - .await - .context("Failed to get proof for account")?; - anyhow::Result::Ok((address, proof)) + .map(|(address, keys)| { + let provider = provider.clone(); + async move { + let proof = provider + .as_provider() + .get_proof(address, keys.into_iter().collect()) + .block_id((block_number - 1).into()) + .await + .context("Failed to get proof for account")?; + anyhow::Result::Ok((address, proof)) + } }) .collect::<Vec<_>>(); - let next_account_proofs_fut = accounts_state - .into_iter() - .map(|(address, keys)| async move { + let next_account_proofs_fut = accounts_state.into_iter().map(|(address, keys)| { + let provider = provider.clone(); + async move { let proof = provider .as_provider() .get_proof(address, keys.into_iter().collect()) @@ -198,7 +202,8 @@ where .await .context("Failed to get proof for account")?; anyhow::Result::Ok((address, proof)) - }); + } + }); try_join( try_join_all(account_proofs_fut), From b2006bf2c9ff7b5fe10cbb170f3126093041e1ae Mon Sep 17 00:00:00 2001 From: 0xaatif <169152398+0xaatif@users.noreply.github.com> Date: Thu, 22 Aug 2024 12:06:46 +0100 Subject: [PATCH 2/3] refactor: Hash2Code (#522) * mark: 0xaatif/refactor-hash2code * refactor: Hash2Code * refactor: insert on new * refactor: StateWrite != StateWrite::default * nomerge: assert code hash * fix: clippy * fix: contract_code_accessed always contains empty vec * refactor: Hash2Code does not always contain empty vec * Revert "nomerge: assert code hash" This reverts commit 0b8f4592489754d3cf324e52d1af884e3e7fe11b. --- trace_decoder/src/decoding.rs | 48 ++--- trace_decoder/src/lib.rs | 37 ++-- trace_decoder/src/processed_block_trace.rs | 223 ++++++++++----------- 3 files changed, 141 insertions(+), 167 deletions(-) diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs index 1a6d2b725..aa755fadf 100644 --- a/trace_decoder/src/decoding.rs +++ b/trace_decoder/src/decoding.rs @@ -20,7 +20,7 @@ use mpt_trie::{ use crate::{ hash, processed_block_trace::{ - NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateTrieWrites, TxnMetaState, + NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateWrite, TxnMetaState, }, typed_mpt::{ReceiptTrie, StateTrie, StorageTrie, TransactionTrie, TrieKey}, OtherBlockData, PartialTriePreImages, @@ -201,15 +201,12 @@ fn update_txn_and_receipt_tries( meta: &TxnMetaState, txn_idx: usize, ) -> anyhow::Result<()> { - if meta.is_dummy() { - // This is a dummy payload, that does not mutate these tries. - return Ok(()); - } - - trie_state.txn.insert(txn_idx, meta.txn_bytes())?; - trie_state - .receipt - .insert(txn_idx, meta.receipt_node_bytes.clone())?; + if let Some(bytes) = &meta.txn_bytes { + trie_state.txn.insert(txn_idx, bytes.clone())?; + trie_state + .receipt + .insert(txn_idx, meta.receipt_node_bytes.clone())?; + } // else it's just a dummy Ok(()) } @@ -219,11 +216,11 @@ fn update_txn_and_receipt_tries( fn init_any_needed_empty_storage_tries<'a>( storage_tries: &mut HashMap<H256, StorageTrie>, accounts_with_storage: impl Iterator<Item = &'a H256>, - state_accounts_with_no_accesses_but_storage_tries: &'a HashMap<H256, H256>, + accts_with_unaccessed_storage: &HashMap<H256, H256>, ) { for h_addr in accounts_with_storage { if !storage_tries.contains_key(h_addr) { - let trie = state_accounts_with_no_accesses_but_storage_tries + let trie = accts_with_unaccessed_storage .get(h_addr) .map(|s_root| { let mut it = StorageTrie::default(); @@ -519,9 +516,7 @@ fn process_txn_info( .storage_accesses .iter() .map(|(k, _)| k), - &txn_info - .nodes_used_by_txn - .state_accounts_with_no_accesses_but_storage_tries, + &txn_info.nodes_used_by_txn.accts_with_unaccessed_storage, ); // For each non-dummy txn, we increment `txn_number_after` by 1, and // update `gas_used_after` accordingly. @@ -577,7 +572,11 @@ fn process_txn_info( receipts_root: curr_block_tries.receipt.root(), }, checkpoint_state_trie_root: extra_data.checkpoint_state_trie_root, - contract_code: txn_info.contract_code_accessed, + contract_code: txn_info + .contract_code_accessed + .into_iter() + .map(|code| (hash(&code), code)) + .collect(), block_metadata: other_data.b_data.b_meta.clone(), block_hashes: other_data.b_data.b_hashes.clone(), global_exit_roots: vec![], @@ -591,7 +590,7 @@ fn process_txn_info( Ok(gen_inputs) } -impl StateTrieWrites { +impl StateWrite { fn apply_writes_to_state_node( &self, state_node: &mut AccountRlp, @@ -678,21 +677,6 @@ fn create_trie_subset_wrapped( .context(format!("missing keys when creating {}", trie_type)) } -impl TxnMetaState { - /// Outputs a boolean indicating whether this `TxnMetaState` - /// represents a dummy payload or an actual transaction. - const fn is_dummy(&self) -> bool { - self.txn_bytes.is_none() - } - - fn txn_bytes(&self) -> Vec<u8> { - match self.txn_bytes.as_ref() { - Some(v) => v.clone(), - None => Vec::default(), - } - } -} - fn eth_to_gwei(eth: U256) -> U256 { // 1 ether = 10^9 gwei. eth * U256::from(10).pow(9.into()) diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index a59243a24..a71cd38ee 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -258,15 +258,6 @@ pub enum ContractCodeUsage { Write(#[serde(with = "crate::hex")] Vec<u8>), } -impl ContractCodeUsage { - fn get_code_hash(&self) -> H256 { - match self { - ContractCodeUsage::Read(hash) => *hash, - ContractCodeUsage::Write(bytes) => hash(bytes), - } - } -} - /// Other data that is needed for proof gen. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct OtherBlockData { @@ -397,15 +388,17 @@ pub fn entrypoint( .map(|(addr, data)| (addr.into_hash_left_padded(), data)) .collect::<Vec<_>>(); - let code_db = { - let mut code_db = code_db.unwrap_or_default(); - if let Some(code_mappings) = pre_images.extra_code_hash_mappings { - code_db.extend(code_mappings); - } - code_db - }; - - let mut code_hash_resolver = Hash2Code::new(code_db); + // Note we discard any user-provided hashes. + let mut hash2code = code_db + .unwrap_or_default() + .into_values() + .chain( + pre_images + .extra_code_hash_mappings + .unwrap_or_default() + .into_values(), + ) + .collect::<Hash2Code>(); let last_tx_idx = txn_info.len().saturating_sub(1); @@ -430,7 +423,7 @@ pub fn entrypoint( &pre_images.tries, &all_accounts_in_pre_images, &extra_state_accesses, - &mut code_hash_resolver, + &mut hash2code, ) }) .collect::<Result<Vec<_>, _>>()?; @@ -457,8 +450,6 @@ struct PartialTriePreImages { /// Like `#[serde(with = "hex")`, but tolerates and emits leading `0x` prefixes mod hex { - use std::{borrow::Cow, fmt}; - use serde::{de::Error as _, Deserialize as _, Deserializer, Serializer}; pub fn serialize<S: Serializer, T>(data: T, serializer: S) -> Result<S::Ok, S::Error> @@ -472,9 +463,9 @@ mod hex { pub fn deserialize<'de, D: Deserializer<'de>, T>(deserializer: D) -> Result<T, D::Error> where T: hex::FromHex, - T::Error: fmt::Display, + T::Error: std::fmt::Display, { - let s = Cow::<str>::deserialize(deserializer)?; + let s = String::deserialize(deserializer)?; match s.strip_prefix("0x") { Some(rest) => T::from_hex(rest), None => T::from_hex(&*s), diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs index dac816530..5dcd9f109 100644 --- a/trace_decoder/src/processed_block_trace.rs +++ b/trace_decoder/src/processed_block_trace.rs @@ -1,16 +1,13 @@ -use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; -use std::fmt::Debug; -use std::iter::once; -use anyhow::bail; +use anyhow::{bail, Context as _}; use ethereum_types::{Address, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; -use zk_evm_common::{EMPTY_CODE_HASH, EMPTY_TRIE_HASH}; +use zk_evm_common::EMPTY_TRIE_HASH; -use crate::hash; use crate::typed_mpt::TrieKey; use crate::PartialTriePreImages; +use crate::{hash, TxnTrace}; use crate::{ContractCodeUsage, TxnInfo}; const FIRST_PRECOMPILE_ADDRESS: U256 = U256([1, 0, 0, 0]); @@ -32,7 +29,7 @@ pub(crate) struct ProcessedBlockTracePreImages { #[derive(Debug, Default)] pub(crate) struct ProcessedTxnInfo { pub nodes_used_by_txn: NodesUsedByTxn, - pub contract_code_accessed: HashMap<H256, Vec<u8>>, + pub contract_code_accessed: HashSet<Vec<u8>>, pub meta: TxnMetaState, } @@ -41,22 +38,34 @@ pub(crate) struct ProcessedTxnInfo { /// If there are any txns that create contracts, then they will also /// get added here as we process the deltas. pub(crate) struct Hash2Code { + /// Key must always be [`hash`] of value. inner: HashMap<H256, Vec<u8>>, } impl Hash2Code { - pub fn new(inner: HashMap<H256, Vec<u8>>) -> Self { - Self { inner } + pub fn new() -> Self { + Self { + inner: HashMap::new(), + } } - fn resolve(&mut self, c_hash: &H256) -> anyhow::Result<Vec<u8>> { - match self.inner.get(c_hash) { + fn get(&mut self, hash: H256) -> anyhow::Result<Vec<u8>> { + match self.inner.get(&hash) { Some(code) => Ok(code.clone()), - None => bail!("no code for hash {}", c_hash), + None => bail!("no code for hash {}", hash), } } + fn insert(&mut self, code: Vec<u8>) { + self.inner.insert(hash(&code), code); + } +} - fn insert_code(&mut self, c_hash: H256, code: Vec<u8>) { - self.inner.insert(c_hash, code); +impl FromIterator<Vec<u8>> for Hash2Code { + fn from_iter<II: IntoIterator<Item = Vec<u8>>>(iter: II) -> Self { + let mut this = Self::new(); + for code in iter { + this.insert(code) + } + this } } @@ -69,57 +78,59 @@ impl TxnInfo { hash2code: &mut Hash2Code, ) -> anyhow::Result<ProcessedTxnInfo> { let mut nodes_used_by_txn = NodesUsedByTxn::default(); - let mut contract_code_accessed = create_empty_code_access_map(); - - for (addr, trace) in self.traces { + let mut contract_code_accessed = HashSet::from([vec![]]); // we always "access" empty code + + for ( + addr, + TxnTrace { + balance, + nonce, + storage_read, + storage_written, + code_usage, + self_destructed, + }, + ) in self.traces + { let hashed_addr = hash(addr.as_bytes()); - let storage_writes = trace.storage_written.unwrap_or_default(); - - let storage_read_keys = trace - .storage_read - .into_iter() - .flat_map(|reads| reads.into_iter()); - - let storage_write_keys = storage_writes.keys(); - let storage_access_keys = storage_read_keys.chain(storage_write_keys.copied()); - + // record storage changes + let storage_written = storage_written.unwrap_or_default(); nodes_used_by_txn.storage_accesses.push(( hashed_addr, - storage_access_keys + storage_read + .into_iter() + .flatten() + .chain(storage_written.keys().copied()) .map(|H256(bytes)| TrieKey::from_hash(hash(bytes))) .collect(), )); + nodes_used_by_txn.storage_writes.push(( + hashed_addr, + storage_written + .iter() + .map(|(k, v)| (TrieKey::from_hash(*k), rlp::encode(v).to_vec())) + .collect(), + )); - let storage_trie_change = !storage_writes.is_empty(); - let code_change = trace.code_usage.is_some(); - let state_write_occurred = trace.balance.is_some() - || trace.nonce.is_some() - || storage_trie_change - || code_change; - - if state_write_occurred { - let state_trie_writes = StateTrieWrites { - balance: trace.balance, - nonce: trace.nonce, - storage_trie_change, - code_hash: trace.code_usage.as_ref().map(|usage| usage.get_code_hash()), - }; - + // record state changes + let state_write = StateWrite { + balance, + nonce, + storage_trie_change: !storage_written.is_empty(), + code_hash: code_usage.as_ref().map(|it| match it { + ContractCodeUsage::Read(hash) => *hash, + ContractCodeUsage::Write(bytes) => hash(bytes), + }), + }; + + if state_write != StateWrite::default() { + // a write occurred nodes_used_by_txn .state_writes - .push((hashed_addr, state_trie_writes)) + .push((hashed_addr, state_write)) } - let storage_writes_vec = storage_writes - .into_iter() - .map(|(k, v)| (TrieKey::from_hash(k), rlp::encode(&v).to_vec())) - .collect(); - - nodes_used_by_txn - .storage_writes - .push((hashed_addr, storage_writes_vec)); - let is_precompile = (FIRST_PRECOMPILE_ADDRESS..LAST_PRECOMPILE_ADDRESS) .contains(&U256::from_big_endian(&addr.0)); @@ -136,23 +147,18 @@ impl TxnInfo { nodes_used_by_txn.state_accesses.push(hashed_addr); } - if let Some(c_usage) = trace.code_usage { - match c_usage { - ContractCodeUsage::Read(c_hash) => { - if let Entry::Vacant(vacant) = contract_code_accessed.entry(c_hash) { - vacant.insert(hash2code.resolve(&c_hash)?); - } - } - ContractCodeUsage::Write(c_bytes) => { - let c_hash = hash(&c_bytes); - - contract_code_accessed.insert(c_hash, c_bytes.clone()); - hash2code.insert_code(c_hash, c_bytes); - } + match code_usage { + Some(ContractCodeUsage::Read(hash)) => { + contract_code_accessed.insert(hash2code.get(hash)?); + } + Some(ContractCodeUsage::Write(code)) => { + contract_code_accessed.insert(code.clone()); + hash2code.insert(code); } + None => {} } - if trace.self_destructed.unwrap_or_default() { + if self_destructed.unwrap_or_default() { nodes_used_by_txn.self_destructed_accounts.push(hashed_addr); } } @@ -161,78 +167,64 @@ impl TxnInfo { nodes_used_by_txn.state_accesses.push(hashed_addr); } - let accounts_with_storage_accesses: HashSet<_> = HashSet::from_iter( - nodes_used_by_txn - .storage_accesses - .iter() - .filter(|(_, slots)| !slots.is_empty()) - .map(|(addr, _)| *addr), - ); - - let all_accounts_with_non_empty_storage = all_accounts_in_pre_image + let accounts_with_storage_accesses = nodes_used_by_txn + .storage_accesses .iter() - .filter(|(_, data)| data.storage_root != EMPTY_TRIE_HASH); - - let accounts_with_storage_but_no_storage_accesses = all_accounts_with_non_empty_storage - .filter(|&(addr, _data)| !accounts_with_storage_accesses.contains(addr)) - .map(|(addr, data)| (*addr, data.storage_root)); - - nodes_used_by_txn - .state_accounts_with_no_accesses_but_storage_tries - .extend(accounts_with_storage_but_no_storage_accesses); - - let txn_bytes = match self.meta.byte_code.is_empty() { - false => Some(self.meta.byte_code), - true => None, - }; + .filter(|(_, slots)| !slots.is_empty()) + .map(|(addr, _)| *addr) + .collect::<HashSet<_>>(); - let receipt_node_bytes = - process_rlped_receipt_node_bytes(self.meta.new_receipt_trie_node_byte); - - let new_meta_state = TxnMetaState { - txn_bytes, - receipt_node_bytes, - gas_used: self.meta.gas_used, - }; + for (addr, state) in all_accounts_in_pre_image { + if state.storage_root != EMPTY_TRIE_HASH + && !accounts_with_storage_accesses.contains(addr) + { + nodes_used_by_txn + .accts_with_unaccessed_storage + .insert(*addr, state.storage_root); + } + } Ok(ProcessedTxnInfo { nodes_used_by_txn, contract_code_accessed, - meta: new_meta_state, + meta: TxnMetaState { + txn_bytes: match self.meta.byte_code.is_empty() { + false => Some(self.meta.byte_code), + true => None, + }, + receipt_node_bytes: check_receipt_bytes(self.meta.new_receipt_trie_node_byte)?, + gas_used: self.meta.gas_used, + }, }) } } -fn process_rlped_receipt_node_bytes(raw_bytes: Vec<u8>) -> Vec<u8> { - match rlp::decode::<LegacyReceiptRlp>(&raw_bytes) { - Ok(_) => raw_bytes, +fn check_receipt_bytes(bytes: Vec<u8>) -> anyhow::Result<Vec<u8>> { + match rlp::decode::<LegacyReceiptRlp>(&bytes) { + Ok(_) => Ok(bytes), Err(_) => { - // Must be non-legacy. - rlp::decode::<Vec<u8>>(&raw_bytes).unwrap() + rlp::decode(&bytes).context("couldn't decode receipt as a legacy receipt or raw bytes") } } } -fn create_empty_code_access_map() -> HashMap<H256, Vec<u8>> { - HashMap::from_iter(once((EMPTY_CODE_HASH, Vec::new()))) -} - /// Note that "*_accesses" includes writes. #[derive(Debug, Default)] pub(crate) struct NodesUsedByTxn { pub state_accesses: Vec<H256>, - pub state_writes: Vec<(H256, StateTrieWrites)>, + pub state_writes: Vec<(H256, StateWrite)>, // Note: All entries in `storage_writes` also appear in `storage_accesses`. pub storage_accesses: Vec<(H256, Vec<TrieKey>)>, #[allow(clippy::type_complexity)] pub storage_writes: Vec<(H256, Vec<(TrieKey, Vec<u8>)>)>, - pub state_accounts_with_no_accesses_but_storage_tries: HashMap<H256, H256>, + /// Hashed address -> storage root. + pub accts_with_unaccessed_storage: HashMap<H256, H256>, pub self_destructed_accounts: Vec<H256>, } -#[derive(Debug)] -pub(crate) struct StateTrieWrites { +#[derive(Debug, Default, PartialEq)] +pub(crate) struct StateWrite { pub balance: Option<U256>, pub nonce: Option<U256>, pub storage_trie_change: bool, @@ -241,7 +233,14 @@ pub(crate) struct StateTrieWrites { #[derive(Debug, Default)] pub(crate) struct TxnMetaState { + /// [`None`] if this is a dummy transaction inserted for padding. pub txn_bytes: Option<Vec<u8>>, pub receipt_node_bytes: Vec<u8>, pub gas_used: u64, } + +impl TxnMetaState { + pub fn is_dummy(&self) -> bool { + self.txn_bytes.is_none() + } +} From c7a16419d0faee01cd38ec3e8fc9b3310e94f13c Mon Sep 17 00:00:00 2001 From: BGluth <gluthb@gmail.com> Date: Thu, 22 Aug 2024 05:07:30 -0600 Subject: [PATCH 3/3] Made sub-trie errors better (#520) * Made sub-trie errors better - Now shows the path in the trie where we encountered the `hash` node. * Requested PR changes for #520 --- mpt_trie/src/debug_tools/query.rs | 6 +++--- mpt_trie/src/trie_subsets.rs | 34 +++++++++++++++++++------------ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/mpt_trie/src/debug_tools/query.rs b/mpt_trie/src/debug_tools/query.rs index 0fb6ade96..dcfff397d 100644 --- a/mpt_trie/src/debug_tools/query.rs +++ b/mpt_trie/src/debug_tools/query.rs @@ -66,19 +66,19 @@ pub struct DebugQueryParamsBuilder { impl DebugQueryParamsBuilder { /// Defaults to `true`. - pub const fn print_key_pieces(mut self, enabled: bool) -> Self { + pub const fn include_key_pieces(mut self, enabled: bool) -> Self { self.params.include_key_piece_per_node = enabled; self } /// Defaults to `true`. - pub const fn print_node_type(mut self, enabled: bool) -> Self { + pub const fn include_node_type(mut self, enabled: bool) -> Self { self.params.include_node_type = enabled; self } /// Defaults to `false`. - pub const fn print_node_specific_values(mut self, enabled: bool) -> Self { + pub const fn include_node_specific_values(mut self, enabled: bool) -> Self { self.params.include_node_specific_values = enabled; self } diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs index 13e9d0d9f..eadb6b3e2 100644 --- a/mpt_trie/src/trie_subsets.rs +++ b/mpt_trie/src/trie_subsets.rs @@ -12,6 +12,7 @@ use log::trace; use thiserror::Error; use crate::{ + debug_tools::query::{get_path_from_query, DebugQueryOutput, DebugQueryParamsBuilder}, nibbles::Nibbles, partial_trie::{Node, PartialTrie, WrappedNode}, trie_hashing::EncodedNode, @@ -21,13 +22,10 @@ use crate::{ /// The output type of trie_subset operations. pub type SubsetTrieResult<T> = Result<T, SubsetTrieError>; -/// Errors that may occur when creating a subset [`PartialTrie`]. +/// We encountered a `hash` node when marking nodes during sub-trie creation. #[derive(Clone, Debug, Error, Hash)] -pub enum SubsetTrieError { - #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")] - /// The key does not exist in the trie. - UnexpectedKey(Nibbles, String), -} +#[error("Encountered a hash node when marking nodes to not hash when traversing a key to not hash!\nPath: {0}")] +pub struct SubsetTrieError(DebugQueryOutput); #[derive(Debug)] enum TrackedNodeIntern<N: PartialTrie> { @@ -256,8 +254,17 @@ where N: PartialTrie, K: Into<Nibbles>, { - for k in keys_involved { - mark_nodes_that_are_needed(tracked_trie, &mut k.into())?; + for mut k in keys_involved.map(|k| k.into()) { + mark_nodes_that_are_needed(tracked_trie, &mut k).map_err(|_| { + // We need to unwind back to this callsite in order to produce the actual error. + let query = DebugQueryParamsBuilder::default() + .include_node_specific_values(true) + .build(k); + + let res = get_path_from_query(&tracked_trie.info.underlying_node, query); + + SubsetTrieError(res) + })?; } Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie)) @@ -270,10 +277,14 @@ where /// - For the key `0x1`, the marked nodes would be [B(0x), B(0x1)]. /// - For the key `0x12`, the marked nodes still would be [B(0x), B(0x1)]. /// - For the key `0x123`, the marked nodes would be [B(0x), B(0x1), L(0x123)]. +/// +/// Also note that we can't construct the error until we back out of this +/// recursive function. We need to know the full key that hit the hash +/// node, and that's only available at the initial call site. fn mark_nodes_that_are_needed<N: PartialTrie>( trie: &mut TrackedNode<N>, curr_nibbles: &mut Nibbles, -) -> SubsetTrieResult<()> { +) -> Result<(), ()> { trace!( "Sub-trie marking at {:x}, (type: {})", curr_nibbles, @@ -286,10 +297,7 @@ fn mark_nodes_that_are_needed<N: PartialTrie>( } TrackedNodeIntern::Hash => match curr_nibbles.is_empty() { false => { - return Err(SubsetTrieError::UnexpectedKey( - *curr_nibbles, - format!("{:?}", trie), - )); + return Err(()); } true => { trie.info.touched = true;