From 6bcf06b2e8af11b14d25238ce88111c216ffb974 Mon Sep 17 00:00:00 2001
From: Marko Atanasievski <atanmarko@users.noreply.github.com>
Date: Thu, 22 Aug 2024 12:14:04 +0200
Subject: [PATCH 1/3] feat: retrieve prover input per block (#499)

* feat: retrieve prover input per block

* fix: cleanup

* fix: into implementation

* fix: nitpick

* fix: review

* fix: review and cleanup
---
 Cargo.lock                       |   1 +
 zero_bin/leader/src/client.rs    |  61 ++++++++----
 zero_bin/leader/src/stdio.rs     |  20 ++--
 zero_bin/prover/src/lib.rs       | 165 +++++++++++++++----------------
 zero_bin/rpc/Cargo.toml          |   1 +
 zero_bin/rpc/src/jerigon.rs      |   2 +-
 zero_bin/rpc/src/lib.rs          |  87 +++++++---------
 zero_bin/rpc/src/main.rs         |  41 +++++---
 zero_bin/rpc/src/native/mod.rs   |   9 +-
 zero_bin/rpc/src/native/state.rs |  35 ++++---
 10 files changed, 231 insertions(+), 191 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 33b8be650..a75cf45ad 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4167,6 +4167,7 @@ dependencies = [
  "evm_arithmetization",
  "futures",
  "hex",
+ "itertools 0.13.0",
  "lru",
  "mpt_trie",
  "primitive-types 0.12.2",
diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs
index 555bc74aa..8fbcf1bd8 100644
--- a/zero_bin/leader/src/client.rs
+++ b/zero_bin/leader/src/client.rs
@@ -1,6 +1,8 @@
 use std::io::Write;
 use std::path::PathBuf;
+use std::sync::Arc;
 
+use alloy::rpc::types::{BlockId, BlockNumberOrTag, BlockTransactionsKind};
 use alloy::transports::http::reqwest::Url;
 use anyhow::Result;
 use paladin::runtime::Runtime;
@@ -34,31 +36,52 @@ pub(crate) async fn client_main(
     block_interval: BlockInterval,
     mut params: ProofParams,
 ) -> Result<()> {
-    let cached_provider = rpc::provider::CachedProvider::new(build_http_retry_provider(
-        rpc_params.rpc_url.clone(),
-        rpc_params.backoff,
-        rpc_params.max_retries,
+    use futures::{FutureExt, StreamExt};
+
+    let cached_provider = Arc::new(rpc::provider::CachedProvider::new(
+        build_http_retry_provider(
+            rpc_params.rpc_url.clone(),
+            rpc_params.backoff,
+            rpc_params.max_retries,
+        ),
     ));
 
-    let prover_input = rpc::prover_input(
-        &cached_provider,
-        block_interval,
-        params.checkpoint_block_number.into(),
-        rpc_params.rpc_type,
-    )
-    .await?;
+    // Grab interval checkpoint block state trie
+    let checkpoint_state_trie_root = cached_provider
+        .get_block(
+            params.checkpoint_block_number.into(),
+            BlockTransactionsKind::Hashes,
+        )
+        .await?
+        .header
+        .state_root;
+
+    let mut block_prover_inputs = Vec::new();
+    let mut block_interval = block_interval.into_bounded_stream()?;
+    while let Some(block_num) = block_interval.next().await {
+        let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num));
+        // Get future of prover input for particular block.
+        let block_prover_input = rpc::block_prover_input(
+            cached_provider.clone(),
+            block_id,
+            checkpoint_state_trie_root,
+            rpc_params.rpc_type,
+        )
+        .boxed();
+        block_prover_inputs.push(block_prover_input);
+    }
 
     // If `keep_intermediate_proofs` is not set we only keep the last block
     // proof from the interval. It contains all the necessary information to
     // verify the whole sequence.
-    let proved_blocks = prover_input
-        .prove(
-            &runtime,
-            params.previous_proof.take(),
-            params.save_inputs_on_error,
-            params.proof_output_dir.clone(),
-        )
-        .await;
+    let proved_blocks = prover::prove(
+        block_prover_inputs,
+        &runtime,
+        params.previous_proof.take(),
+        params.save_inputs_on_error,
+        params.proof_output_dir.clone(),
+    )
+    .await;
     runtime.close().await?;
     let proved_blocks = proved_blocks?;
 
diff --git a/zero_bin/leader/src/stdio.rs b/zero_bin/leader/src/stdio.rs
index 76bcd089b..403ea2a6a 100644
--- a/zero_bin/leader/src/stdio.rs
+++ b/zero_bin/leader/src/stdio.rs
@@ -3,7 +3,7 @@ use std::io::{Read, Write};
 use anyhow::Result;
 use paladin::runtime::Runtime;
 use proof_gen::proof_types::GeneratedBlockProof;
-use prover::ProverInput;
+use prover::{BlockProverInput, BlockProverInputFuture};
 use tracing::info;
 
 /// The main function for the stdio mode.
@@ -16,13 +16,19 @@ pub(crate) async fn stdio_main(
     std::io::stdin().read_to_string(&mut buffer)?;
 
     let des = &mut serde_json::Deserializer::from_str(&buffer);
-    let prover_input = ProverInput {
-        blocks: serde_path_to_error::deserialize(des)?,
-    };
+    let block_prover_inputs = serde_path_to_error::deserialize::<_, Vec<BlockProverInput>>(des)?
+        .into_iter()
+        .map(Into::into)
+        .collect::<Vec<BlockProverInputFuture>>();
 
-    let proved_blocks = prover_input
-        .prove(&runtime, previous, save_inputs_on_error, None)
-        .await;
+    let proved_blocks = prover::prove(
+        block_prover_inputs,
+        &runtime,
+        previous,
+        save_inputs_on_error,
+        None,
+    )
+    .await;
     runtime.close().await?;
     let proved_blocks = proved_blocks?;
 
diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs
index a43c74104..a30a4d3f3 100644
--- a/zero_bin/prover/src/lib.rs
+++ b/zero_bin/prover/src/lib.rs
@@ -18,7 +18,20 @@ use trace_decoder::{BlockTrace, OtherBlockData};
 use tracing::info;
 use zero_bin_common::fs::generate_block_proof_file_name;
 
-#[derive(Debug, Deserialize, Serialize)]
+pub type BlockProverInputFuture = std::pin::Pin<
+    Box<dyn Future<Output = std::result::Result<BlockProverInput, anyhow::Error>> + Send>,
+>;
+
+impl From<BlockProverInput> for BlockProverInputFuture {
+    fn from(item: BlockProverInput) -> Self {
+        async fn _from(item: BlockProverInput) -> Result<BlockProverInput, anyhow::Error> {
+            Ok(item)
+        }
+        Box::pin(_from(item))
+    }
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
 pub struct BlockProverInput {
     pub block_trace: BlockTrace,
     pub other_data: OtherBlockData,
@@ -113,91 +126,77 @@ impl BlockProverInput {
     }
 }
 
-#[derive(Debug, Deserialize, Serialize)]
-pub struct ProverInput {
-    pub blocks: Vec<BlockProverInput>,
+/// Prove all the blocks in the input.
+/// Return the list of block numbers that are proved and if the proof data
+/// is not saved to disk, return the generated block proofs as well.
+pub async fn prove(
+    block_prover_inputs: Vec<BlockProverInputFuture>,
+    runtime: &Runtime,
+    previous_proof: Option<GeneratedBlockProof>,
+    save_inputs_on_error: bool,
+    proof_output_dir: Option<PathBuf>,
+) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> {
+    let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> =
+        previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>);
+
+    let mut results = FuturesOrdered::new();
+    for block_prover_input in block_prover_inputs {
+        let (tx, rx) = oneshot::channel::<GeneratedBlockProof>();
+        let proof_output_dir = proof_output_dir.clone();
+        let previos_block_proof = prev.take();
+        let fut = async move {
+            // Get the prover input data from the external source (e.g. Erigon node).
+            let block = block_prover_input.await?;
+            let block_number = block.get_block_number();
+            info!("Proving block {block_number}");
+
+            // Prove the block
+            let block_proof = block
+                .prove(runtime, previos_block_proof, save_inputs_on_error)
+                .then(move |proof| async move {
+                    let proof = proof?;
+                    let block_number = proof.b_height;
+
+                    // Write latest generated proof to disk if proof_output_dir is provided
+                    // or alternatively return proof as function result.
+                    let return_proof: Option<GeneratedBlockProof> =
+                        if let Some(output_dir) = proof_output_dir {
+                            write_proof_to_dir(output_dir, &proof).await?;
+                            None
+                        } else {
+                            Some(proof.clone())
+                        };
+
+                    if tx.send(proof).is_err() {
+                        anyhow::bail!("Failed to send proof");
+                    }
+
+                    Ok((block_number, return_proof))
+                })
+                .await?;
+
+            Ok(block_proof)
+        }
+        .boxed();
+        prev = Some(Box::pin(rx.map_err(anyhow::Error::new)));
+        results.push_back(fut);
+    }
+
+    results.try_collect().await
 }
 
-impl ProverInput {
-    /// Prove all the blocks in the input.
-    /// Return the list of block numbers that are proved and if the proof data
-    /// is not saved to disk, return the generated block proofs as well.
-    pub async fn prove(
-        self,
-        runtime: &Runtime,
-        previous_proof: Option<GeneratedBlockProof>,
-        save_inputs_on_error: bool,
-        proof_output_dir: Option<PathBuf>,
-    ) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> {
-        let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> =
-            previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>);
-
-        let results: FuturesOrdered<_> = self
-            .blocks
-            .into_iter()
-            .map(|block| {
-                let block_number = block.get_block_number();
-                info!("Proving block {block_number}");
-
-                let (tx, rx) = oneshot::channel::<GeneratedBlockProof>();
-
-                // Prove the block
-                let proof_output_dir = proof_output_dir.clone();
-                let fut = block
-                    .prove(runtime, prev.take(), save_inputs_on_error)
-                    .then(move |proof| async move {
-                        let proof = proof?;
-                        let block_number = proof.b_height;
-
-                        // Write latest generated proof to disk if proof_output_dir is provided
-                        let return_proof: Option<GeneratedBlockProof> =
-                            if proof_output_dir.is_some() {
-                                ProverInput::write_proof(proof_output_dir, &proof).await?;
-                                None
-                            } else {
-                                Some(proof.clone())
-                            };
-
-                        if tx.send(proof).is_err() {
-                            anyhow::bail!("Failed to send proof");
-                        }
-
-                        Ok((block_number, return_proof))
-                    })
-                    .boxed();
-
-                prev = Some(Box::pin(rx.map_err(anyhow::Error::new)));
-
-                fut
-            })
-            .collect();
+/// Write the proof to the `output_dir` directory.
+async fn write_proof_to_dir(output_dir: PathBuf, proof: &GeneratedBlockProof) -> Result<()> {
+    let proof_serialized = serde_json::to_vec(proof)?;
+    let block_proof_file_path =
+        generate_block_proof_file_name(&output_dir.to_str(), proof.b_height);
 
-        results.try_collect().await
+    if let Some(parent) = block_proof_file_path.parent() {
+        tokio::fs::create_dir_all(parent).await?;
     }
 
-    /// Write the proof to the disk (if `output_dir` is provided) or stdout.
-    pub(crate) async fn write_proof(
-        output_dir: Option<PathBuf>,
-        proof: &GeneratedBlockProof,
-    ) -> Result<()> {
-        let proof_serialized = serde_json::to_vec(proof)?;
-        let block_proof_file_path =
-            output_dir.map(|path| generate_block_proof_file_name(&path.to_str(), proof.b_height));
-        match block_proof_file_path {
-            Some(p) => {
-                if let Some(parent) = p.parent() {
-                    tokio::fs::create_dir_all(parent).await?;
-                }
-
-                let mut f = tokio::fs::File::create(p).await?;
-                f.write_all(&proof_serialized)
-                    .await
-                    .context("Failed to write proof to disk")
-            }
-            None => tokio::io::stdout()
-                .write_all(&proof_serialized)
-                .await
-                .context("Failed to write proof to stdout"),
-        }
-    }
+    let mut f = tokio::fs::File::create(block_proof_file_path).await?;
+    f.write_all(&proof_serialized)
+        .await
+        .context("Failed to write proof to disk")
 }
diff --git a/zero_bin/rpc/Cargo.toml b/zero_bin/rpc/Cargo.toml
index 14f447cef..cbd2df11d 100644
--- a/zero_bin/rpc/Cargo.toml
+++ b/zero_bin/rpc/Cargo.toml
@@ -26,6 +26,7 @@ tower = { workspace = true, features = ["retry"] }
 trace_decoder = { workspace = true }
 tracing-subscriber = { workspace = true }
 url = { workspace = true }
+itertools = {workspace = true}
 
 # Local dependencies
 compat = { workspace = true }
diff --git a/zero_bin/rpc/src/jerigon.rs b/zero_bin/rpc/src/jerigon.rs
index 470b2dffb..891421971 100644
--- a/zero_bin/rpc/src/jerigon.rs
+++ b/zero_bin/rpc/src/jerigon.rs
@@ -19,7 +19,7 @@ pub struct ZeroTxResult {
 }
 
 pub async fn block_prover_input<ProviderT, TransportT>(
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
+    cached_provider: std::sync::Arc<CachedProvider<ProviderT, TransportT>>,
     target_block_id: BlockId,
     checkpoint_state_trie_root: B256,
 ) -> anyhow::Result<BlockProverInput>
diff --git a/zero_bin/rpc/src/lib.rs b/zero_bin/rpc/src/lib.rs
index 345cf8c96..cc6ddf2f1 100644
--- a/zero_bin/rpc/src/lib.rs
+++ b/zero_bin/rpc/src/lib.rs
@@ -1,7 +1,9 @@
+use std::sync::Arc;
+
 use alloy::{
     primitives::B256,
     providers::Provider,
-    rpc::types::eth::{BlockId, BlockNumberOrTag, BlockTransactionsKind, Withdrawal},
+    rpc::types::eth::{BlockId, BlockTransactionsKind, Withdrawal},
     transports::Transport,
 };
 use anyhow::Context as _;
@@ -9,9 +11,8 @@ use clap::ValueEnum;
 use compat::Compat;
 use evm_arithmetization::proof::{BlockHashes, BlockMetadata};
 use futures::{StreamExt as _, TryStreamExt as _};
-use prover::ProverInput;
+use prover::BlockProverInput;
 use trace_decoder::{BlockLevelData, OtherBlockData};
-use zero_bin_common::block_interval::BlockInterval;
 
 pub mod jerigon;
 pub mod native;
@@ -23,56 +24,36 @@ use crate::provider::CachedProvider;
 const PREVIOUS_HASHES_COUNT: usize = 256;
 
 /// The RPC type.
-#[derive(ValueEnum, Clone, Debug)]
+#[derive(ValueEnum, Clone, Debug, Copy)]
 pub enum RpcType {
     Jerigon,
     Native,
 }
 
-/// Obtain the prover input for a given block interval
-pub async fn prover_input<ProviderT, TransportT>(
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
-    block_interval: BlockInterval,
-    checkpoint_block_id: BlockId,
+/// Obtain the prover input for one block
+pub async fn block_prover_input<ProviderT, TransportT>(
+    cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
+    block_id: BlockId,
+    checkpoint_state_trie_root: B256,
     rpc_type: RpcType,
-) -> anyhow::Result<ProverInput>
+) -> Result<BlockProverInput, anyhow::Error>
 where
     ProviderT: Provider<TransportT>,
     TransportT: Transport + Clone,
 {
-    // Grab interval checkpoint block state trie
-    let checkpoint_state_trie_root = cached_provider
-        .get_block(checkpoint_block_id, BlockTransactionsKind::Hashes)
-        .await?
-        .header
-        .state_root;
-
-    let mut block_proofs = Vec::new();
-    let mut block_interval = block_interval.into_bounded_stream()?;
-
-    while let Some(block_num) = block_interval.next().await {
-        let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num));
-        let block_prover_input = match rpc_type {
-            RpcType::Jerigon => {
-                jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root)
-                    .await?
-            }
-            RpcType::Native => {
-                native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root)
-                    .await?
-            }
-        };
-
-        block_proofs.push(block_prover_input);
+    match rpc_type {
+        RpcType::Jerigon => {
+            jerigon::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await
+        }
+        RpcType::Native => {
+            native::block_prover_input(cached_provider, block_id, checkpoint_state_trie_root).await
+        }
     }
-    Ok(ProverInput {
-        blocks: block_proofs,
-    })
 }
 
 /// Fetches other block data
 async fn fetch_other_block_data<ProviderT, TransportT>(
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
+    cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
     target_block_id: BlockId,
     checkpoint_state_trie_root: B256,
 ) -> anyhow::Result<OtherBlockData>
@@ -80,6 +61,7 @@ where
     ProviderT: Provider<TransportT>,
     TransportT: Transport + Clone,
 {
+    use itertools::Itertools;
     let target_block = cached_provider
         .get_block(target_block_id, BlockTransactionsKind::Hashes)
         .await?;
@@ -102,28 +84,33 @@ where
         })
         .take(PREVIOUS_HASHES_COUNT + 1)
         .filter(|i| *i >= 0)
+        .chunks(2)
+        .into_iter()
+        .map(|mut chunk| {
+            // We convert to tuple of (current block, optional previous block)
+            let first = chunk
+                .next()
+                .expect("must be valid according to itertools::Iterator::chunks definition");
+            let second = chunk.next();
+            (first, second)
+        })
         .collect::<Vec<_>>();
+
     let concurrency = previous_block_numbers.len();
     let collected_hashes = futures::stream::iter(
         previous_block_numbers
-            .chunks(2) // we get hash for previous and current block with one request
-            .map(|block_numbers| {
+            .into_iter() // we get hash for previous and current block with one request
+            .map(|(current_block_number, previous_block_number)| {
                 let cached_provider = &cached_provider;
-                let block_num = &block_numbers[0];
-                let previos_block_num = if block_numbers.len() > 1 {
-                    Some(block_numbers[1])
-                } else {
-                    // For genesis block
-                    None
-                };
+                let block_num = current_block_number;
                 async move {
                     let block = cached_provider
-                        .get_block((*block_num as u64).into(), BlockTransactionsKind::Hashes)
+                        .get_block((block_num as u64).into(), BlockTransactionsKind::Hashes)
                         .await
                         .context("couldn't get block")?;
                     anyhow::Ok([
-                        (block.header.hash, Some(*block_num)),
-                        (Some(block.header.parent_hash), previos_block_num),
+                        (block.header.hash, Some(block_num)),
+                        (Some(block.header.parent_hash), previous_block_number),
                     ])
                 }
             }),
diff --git a/zero_bin/rpc/src/main.rs b/zero_bin/rpc/src/main.rs
index 444e89e3b..3c72ac902 100644
--- a/zero_bin/rpc/src/main.rs
+++ b/zero_bin/rpc/src/main.rs
@@ -1,7 +1,10 @@
-use std::{env, io};
+use std::env;
+use std::sync::Arc;
 
 use alloy::rpc::types::eth::BlockId;
+use alloy::rpc::types::{BlockNumberOrTag, BlockTransactionsKind};
 use clap::{Parser, ValueHint};
+use futures::StreamExt;
 use rpc::provider::CachedProvider;
 use rpc::{retry::build_http_retry_provider, RpcType};
 use tracing_subscriber::{prelude::*, EnvFilter};
@@ -55,22 +58,36 @@ impl Cli {
                     checkpoint_block_number.unwrap_or((start_block - 1).into());
                 let block_interval = BlockInterval::Range(start_block..end_block + 1);
 
-                let cached_provider = CachedProvider::new(build_http_retry_provider(
+                let cached_provider = Arc::new(CachedProvider::new(build_http_retry_provider(
                     rpc_url.clone(),
                     backoff,
                     max_retries,
-                ));
+                )));
 
-                // Retrieve prover input from the Erigon node
-                let prover_input = rpc::prover_input(
-                    &cached_provider,
-                    block_interval,
-                    checkpoint_block_number,
-                    rpc_type,
-                )
-                .await?;
+                // Grab interval checkpoint block state trie
+                let checkpoint_state_trie_root = cached_provider
+                    .get_block(checkpoint_block_number, BlockTransactionsKind::Hashes)
+                    .await?
+                    .header
+                    .state_root;
 
-                serde_json::to_writer_pretty(io::stdout(), &prover_input.blocks)?;
+                let mut block_prover_inputs = Vec::new();
+                let mut block_interval = block_interval.clone().into_bounded_stream()?;
+                while let Some(block_num) = block_interval.next().await {
+                    let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num));
+                    // Get the prover input for particular block.
+                    let result = rpc::block_prover_input(
+                        cached_provider.clone(),
+                        block_id,
+                        checkpoint_state_trie_root,
+                        rpc_type,
+                    )
+                    .await?;
+
+                    block_prover_inputs.push(result);
+                }
+
+                serde_json::to_writer_pretty(std::io::stdout(), &block_prover_inputs)?;
             }
         }
         Ok(())
diff --git a/zero_bin/rpc/src/native/mod.rs b/zero_bin/rpc/src/native/mod.rs
index 892a799d6..1f61d7b26 100644
--- a/zero_bin/rpc/src/native/mod.rs
+++ b/zero_bin/rpc/src/native/mod.rs
@@ -1,4 +1,5 @@
 use std::collections::HashMap;
+use std::sync::Arc;
 
 use alloy::{
     primitives::B256,
@@ -19,7 +20,7 @@ type CodeDb = HashMap<__compat_primitive_types::H256, Vec<u8>>;
 
 /// Fetches the prover input for the given BlockId.
 pub async fn block_prover_input<ProviderT, TransportT>(
-    provider: &CachedProvider<ProviderT, TransportT>,
+    provider: Arc<CachedProvider<ProviderT, TransportT>>,
     block_number: BlockId,
     checkpoint_state_trie_root: B256,
 ) -> anyhow::Result<BlockProverInput>
@@ -28,8 +29,8 @@ where
     TransportT: Transport + Clone,
 {
     let (block_trace, other_data) = try_join!(
-        process_block_trace(provider, block_number),
-        crate::fetch_other_block_data(provider, block_number, checkpoint_state_trie_root,)
+        process_block_trace(provider.clone(), block_number),
+        crate::fetch_other_block_data(provider.clone(), block_number, checkpoint_state_trie_root,)
     )?;
 
     Ok(BlockProverInput {
@@ -40,7 +41,7 @@ where
 
 /// Processes the block with the given block number and returns the block trace.
 async fn process_block_trace<ProviderT, TransportT>(
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
+    cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
     block_number: BlockId,
 ) -> anyhow::Result<BlockTrace>
 where
diff --git a/zero_bin/rpc/src/native/state.rs b/zero_bin/rpc/src/native/state.rs
index 5fd9b539c..b5017b394 100644
--- a/zero_bin/rpc/src/native/state.rs
+++ b/zero_bin/rpc/src/native/state.rs
@@ -1,4 +1,5 @@
 use std::collections::{HashMap, HashSet};
+use std::sync::Arc;
 
 use alloy::{
     primitives::{keccak256, Address, StorageKey, B256, U256},
@@ -20,7 +21,7 @@ use crate::Compat;
 
 /// Processes the state witness for the given block.
 pub async fn process_state_witness<ProviderT, TransportT>(
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
+    cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
     block: Block,
     txn_infos: &[TxnInfo],
 ) -> anyhow::Result<BlockTraceTriePreImages>
@@ -115,7 +116,7 @@ fn insert_beacon_roots_update(
 async fn generate_state_witness<ProviderT, TransportT>(
     prev_state_root: B256,
     accounts_state: HashMap<Address, HashSet<StorageKey>>,
-    cached_provider: &CachedProvider<ProviderT, TransportT>,
+    cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
     block_number: u64,
 ) -> anyhow::Result<(
     PartialTrieBuilder<HashedPartialTrie>,
@@ -164,7 +165,7 @@ where
 /// Fetches the proof data for the given accounts and associated storage keys.
 async fn fetch_proof_data<ProviderT, TransportT>(
     accounts_state: HashMap<Address, HashSet<StorageKey>>,
-    provider: &CachedProvider<ProviderT, TransportT>,
+    provider: Arc<CachedProvider<ProviderT, TransportT>>,
     block_number: u64,
 ) -> anyhow::Result<(
     Vec<(Address, EIP1186AccountProofResponse)>,
@@ -177,20 +178,23 @@ where
     let account_proofs_fut = accounts_state
         .clone()
         .into_iter()
-        .map(|(address, keys)| async move {
-            let proof = provider
-                .as_provider()
-                .get_proof(address, keys.into_iter().collect())
-                .block_id((block_number - 1).into())
-                .await
-                .context("Failed to get proof for account")?;
-            anyhow::Result::Ok((address, proof))
+        .map(|(address, keys)| {
+            let provider = provider.clone();
+            async move {
+                let proof = provider
+                    .as_provider()
+                    .get_proof(address, keys.into_iter().collect())
+                    .block_id((block_number - 1).into())
+                    .await
+                    .context("Failed to get proof for account")?;
+                anyhow::Result::Ok((address, proof))
+            }
         })
         .collect::<Vec<_>>();
 
-    let next_account_proofs_fut = accounts_state
-        .into_iter()
-        .map(|(address, keys)| async move {
+    let next_account_proofs_fut = accounts_state.into_iter().map(|(address, keys)| {
+        let provider = provider.clone();
+        async move {
             let proof = provider
                 .as_provider()
                 .get_proof(address, keys.into_iter().collect())
@@ -198,7 +202,8 @@ where
                 .await
                 .context("Failed to get proof for account")?;
             anyhow::Result::Ok((address, proof))
-        });
+        }
+    });
 
     try_join(
         try_join_all(account_proofs_fut),

From b2006bf2c9ff7b5fe10cbb170f3126093041e1ae Mon Sep 17 00:00:00 2001
From: 0xaatif <169152398+0xaatif@users.noreply.github.com>
Date: Thu, 22 Aug 2024 12:06:46 +0100
Subject: [PATCH 2/3] refactor: Hash2Code (#522)

* mark: 0xaatif/refactor-hash2code

* refactor: Hash2Code

* refactor: insert on new

* refactor: StateWrite != StateWrite::default

* nomerge: assert code hash

* fix: clippy

* fix: contract_code_accessed always contains empty vec

* refactor: Hash2Code does not always contain empty vec

* Revert "nomerge: assert code hash"

This reverts commit 0b8f4592489754d3cf324e52d1af884e3e7fe11b.
---
 trace_decoder/src/decoding.rs              |  48 ++---
 trace_decoder/src/lib.rs                   |  37 ++--
 trace_decoder/src/processed_block_trace.rs | 223 ++++++++++-----------
 3 files changed, 141 insertions(+), 167 deletions(-)

diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs
index 1a6d2b725..aa755fadf 100644
--- a/trace_decoder/src/decoding.rs
+++ b/trace_decoder/src/decoding.rs
@@ -20,7 +20,7 @@ use mpt_trie::{
 use crate::{
     hash,
     processed_block_trace::{
-        NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateTrieWrites, TxnMetaState,
+        NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateWrite, TxnMetaState,
     },
     typed_mpt::{ReceiptTrie, StateTrie, StorageTrie, TransactionTrie, TrieKey},
     OtherBlockData, PartialTriePreImages,
@@ -201,15 +201,12 @@ fn update_txn_and_receipt_tries(
     meta: &TxnMetaState,
     txn_idx: usize,
 ) -> anyhow::Result<()> {
-    if meta.is_dummy() {
-        // This is a dummy payload, that does not mutate these tries.
-        return Ok(());
-    }
-
-    trie_state.txn.insert(txn_idx, meta.txn_bytes())?;
-    trie_state
-        .receipt
-        .insert(txn_idx, meta.receipt_node_bytes.clone())?;
+    if let Some(bytes) = &meta.txn_bytes {
+        trie_state.txn.insert(txn_idx, bytes.clone())?;
+        trie_state
+            .receipt
+            .insert(txn_idx, meta.receipt_node_bytes.clone())?;
+    } // else it's just a dummy
     Ok(())
 }
 
@@ -219,11 +216,11 @@ fn update_txn_and_receipt_tries(
 fn init_any_needed_empty_storage_tries<'a>(
     storage_tries: &mut HashMap<H256, StorageTrie>,
     accounts_with_storage: impl Iterator<Item = &'a H256>,
-    state_accounts_with_no_accesses_but_storage_tries: &'a HashMap<H256, H256>,
+    accts_with_unaccessed_storage: &HashMap<H256, H256>,
 ) {
     for h_addr in accounts_with_storage {
         if !storage_tries.contains_key(h_addr) {
-            let trie = state_accounts_with_no_accesses_but_storage_tries
+            let trie = accts_with_unaccessed_storage
                 .get(h_addr)
                 .map(|s_root| {
                     let mut it = StorageTrie::default();
@@ -519,9 +516,7 @@ fn process_txn_info(
             .storage_accesses
             .iter()
             .map(|(k, _)| k),
-        &txn_info
-            .nodes_used_by_txn
-            .state_accounts_with_no_accesses_but_storage_tries,
+        &txn_info.nodes_used_by_txn.accts_with_unaccessed_storage,
     );
     // For each non-dummy txn, we increment `txn_number_after` by 1, and
     // update `gas_used_after` accordingly.
@@ -577,7 +572,11 @@ fn process_txn_info(
             receipts_root: curr_block_tries.receipt.root(),
         },
         checkpoint_state_trie_root: extra_data.checkpoint_state_trie_root,
-        contract_code: txn_info.contract_code_accessed,
+        contract_code: txn_info
+            .contract_code_accessed
+            .into_iter()
+            .map(|code| (hash(&code), code))
+            .collect(),
         block_metadata: other_data.b_data.b_meta.clone(),
         block_hashes: other_data.b_data.b_hashes.clone(),
         global_exit_roots: vec![],
@@ -591,7 +590,7 @@ fn process_txn_info(
     Ok(gen_inputs)
 }
 
-impl StateTrieWrites {
+impl StateWrite {
     fn apply_writes_to_state_node(
         &self,
         state_node: &mut AccountRlp,
@@ -678,21 +677,6 @@ fn create_trie_subset_wrapped(
     .context(format!("missing keys when creating {}", trie_type))
 }
 
-impl TxnMetaState {
-    /// Outputs a boolean indicating whether this `TxnMetaState`
-    /// represents a dummy payload or an actual transaction.
-    const fn is_dummy(&self) -> bool {
-        self.txn_bytes.is_none()
-    }
-
-    fn txn_bytes(&self) -> Vec<u8> {
-        match self.txn_bytes.as_ref() {
-            Some(v) => v.clone(),
-            None => Vec::default(),
-        }
-    }
-}
-
 fn eth_to_gwei(eth: U256) -> U256 {
     // 1 ether = 10^9 gwei.
     eth * U256::from(10).pow(9.into())
diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs
index a59243a24..a71cd38ee 100644
--- a/trace_decoder/src/lib.rs
+++ b/trace_decoder/src/lib.rs
@@ -258,15 +258,6 @@ pub enum ContractCodeUsage {
     Write(#[serde(with = "crate::hex")] Vec<u8>),
 }
 
-impl ContractCodeUsage {
-    fn get_code_hash(&self) -> H256 {
-        match self {
-            ContractCodeUsage::Read(hash) => *hash,
-            ContractCodeUsage::Write(bytes) => hash(bytes),
-        }
-    }
-}
-
 /// Other data that is needed for proof gen.
 #[derive(Clone, Debug, Deserialize, Serialize)]
 pub struct OtherBlockData {
@@ -397,15 +388,17 @@ pub fn entrypoint(
         .map(|(addr, data)| (addr.into_hash_left_padded(), data))
         .collect::<Vec<_>>();
 
-    let code_db = {
-        let mut code_db = code_db.unwrap_or_default();
-        if let Some(code_mappings) = pre_images.extra_code_hash_mappings {
-            code_db.extend(code_mappings);
-        }
-        code_db
-    };
-
-    let mut code_hash_resolver = Hash2Code::new(code_db);
+    // Note we discard any user-provided hashes.
+    let mut hash2code = code_db
+        .unwrap_or_default()
+        .into_values()
+        .chain(
+            pre_images
+                .extra_code_hash_mappings
+                .unwrap_or_default()
+                .into_values(),
+        )
+        .collect::<Hash2Code>();
 
     let last_tx_idx = txn_info.len().saturating_sub(1);
 
@@ -430,7 +423,7 @@ pub fn entrypoint(
                 &pre_images.tries,
                 &all_accounts_in_pre_images,
                 &extra_state_accesses,
-                &mut code_hash_resolver,
+                &mut hash2code,
             )
         })
         .collect::<Result<Vec<_>, _>>()?;
@@ -457,8 +450,6 @@ struct PartialTriePreImages {
 
 /// Like `#[serde(with = "hex")`, but tolerates and emits leading `0x` prefixes
 mod hex {
-    use std::{borrow::Cow, fmt};
-
     use serde::{de::Error as _, Deserialize as _, Deserializer, Serializer};
 
     pub fn serialize<S: Serializer, T>(data: T, serializer: S) -> Result<S::Ok, S::Error>
@@ -472,9 +463,9 @@ mod hex {
     pub fn deserialize<'de, D: Deserializer<'de>, T>(deserializer: D) -> Result<T, D::Error>
     where
         T: hex::FromHex,
-        T::Error: fmt::Display,
+        T::Error: std::fmt::Display,
     {
-        let s = Cow::<str>::deserialize(deserializer)?;
+        let s = String::deserialize(deserializer)?;
         match s.strip_prefix("0x") {
             Some(rest) => T::from_hex(rest),
             None => T::from_hex(&*s),
diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs
index dac816530..5dcd9f109 100644
--- a/trace_decoder/src/processed_block_trace.rs
+++ b/trace_decoder/src/processed_block_trace.rs
@@ -1,16 +1,13 @@
-use std::collections::hash_map::Entry;
 use std::collections::{HashMap, HashSet};
-use std::fmt::Debug;
-use std::iter::once;
 
-use anyhow::bail;
+use anyhow::{bail, Context as _};
 use ethereum_types::{Address, H256, U256};
 use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp};
-use zk_evm_common::{EMPTY_CODE_HASH, EMPTY_TRIE_HASH};
+use zk_evm_common::EMPTY_TRIE_HASH;
 
-use crate::hash;
 use crate::typed_mpt::TrieKey;
 use crate::PartialTriePreImages;
+use crate::{hash, TxnTrace};
 use crate::{ContractCodeUsage, TxnInfo};
 
 const FIRST_PRECOMPILE_ADDRESS: U256 = U256([1, 0, 0, 0]);
@@ -32,7 +29,7 @@ pub(crate) struct ProcessedBlockTracePreImages {
 #[derive(Debug, Default)]
 pub(crate) struct ProcessedTxnInfo {
     pub nodes_used_by_txn: NodesUsedByTxn,
-    pub contract_code_accessed: HashMap<H256, Vec<u8>>,
+    pub contract_code_accessed: HashSet<Vec<u8>>,
     pub meta: TxnMetaState,
 }
 
@@ -41,22 +38,34 @@ pub(crate) struct ProcessedTxnInfo {
 /// If there are any txns that create contracts, then they will also
 /// get added here as we process the deltas.
 pub(crate) struct Hash2Code {
+    /// Key must always be [`hash`] of value.
     inner: HashMap<H256, Vec<u8>>,
 }
 
 impl Hash2Code {
-    pub fn new(inner: HashMap<H256, Vec<u8>>) -> Self {
-        Self { inner }
+    pub fn new() -> Self {
+        Self {
+            inner: HashMap::new(),
+        }
     }
-    fn resolve(&mut self, c_hash: &H256) -> anyhow::Result<Vec<u8>> {
-        match self.inner.get(c_hash) {
+    fn get(&mut self, hash: H256) -> anyhow::Result<Vec<u8>> {
+        match self.inner.get(&hash) {
             Some(code) => Ok(code.clone()),
-            None => bail!("no code for hash {}", c_hash),
+            None => bail!("no code for hash {}", hash),
         }
     }
+    fn insert(&mut self, code: Vec<u8>) {
+        self.inner.insert(hash(&code), code);
+    }
+}
 
-    fn insert_code(&mut self, c_hash: H256, code: Vec<u8>) {
-        self.inner.insert(c_hash, code);
+impl FromIterator<Vec<u8>> for Hash2Code {
+    fn from_iter<II: IntoIterator<Item = Vec<u8>>>(iter: II) -> Self {
+        let mut this = Self::new();
+        for code in iter {
+            this.insert(code)
+        }
+        this
     }
 }
 
@@ -69,57 +78,59 @@ impl TxnInfo {
         hash2code: &mut Hash2Code,
     ) -> anyhow::Result<ProcessedTxnInfo> {
         let mut nodes_used_by_txn = NodesUsedByTxn::default();
-        let mut contract_code_accessed = create_empty_code_access_map();
-
-        for (addr, trace) in self.traces {
+        let mut contract_code_accessed = HashSet::from([vec![]]); // we always "access" empty code
+
+        for (
+            addr,
+            TxnTrace {
+                balance,
+                nonce,
+                storage_read,
+                storage_written,
+                code_usage,
+                self_destructed,
+            },
+        ) in self.traces
+        {
             let hashed_addr = hash(addr.as_bytes());
 
-            let storage_writes = trace.storage_written.unwrap_or_default();
-
-            let storage_read_keys = trace
-                .storage_read
-                .into_iter()
-                .flat_map(|reads| reads.into_iter());
-
-            let storage_write_keys = storage_writes.keys();
-            let storage_access_keys = storage_read_keys.chain(storage_write_keys.copied());
-
+            // record storage changes
+            let storage_written = storage_written.unwrap_or_default();
             nodes_used_by_txn.storage_accesses.push((
                 hashed_addr,
-                storage_access_keys
+                storage_read
+                    .into_iter()
+                    .flatten()
+                    .chain(storage_written.keys().copied())
                     .map(|H256(bytes)| TrieKey::from_hash(hash(bytes)))
                     .collect(),
             ));
+            nodes_used_by_txn.storage_writes.push((
+                hashed_addr,
+                storage_written
+                    .iter()
+                    .map(|(k, v)| (TrieKey::from_hash(*k), rlp::encode(v).to_vec()))
+                    .collect(),
+            ));
 
-            let storage_trie_change = !storage_writes.is_empty();
-            let code_change = trace.code_usage.is_some();
-            let state_write_occurred = trace.balance.is_some()
-                || trace.nonce.is_some()
-                || storage_trie_change
-                || code_change;
-
-            if state_write_occurred {
-                let state_trie_writes = StateTrieWrites {
-                    balance: trace.balance,
-                    nonce: trace.nonce,
-                    storage_trie_change,
-                    code_hash: trace.code_usage.as_ref().map(|usage| usage.get_code_hash()),
-                };
-
+            // record state changes
+            let state_write = StateWrite {
+                balance,
+                nonce,
+                storage_trie_change: !storage_written.is_empty(),
+                code_hash: code_usage.as_ref().map(|it| match it {
+                    ContractCodeUsage::Read(hash) => *hash,
+                    ContractCodeUsage::Write(bytes) => hash(bytes),
+                }),
+            };
+
+            if state_write != StateWrite::default() {
+                // a write occurred
                 nodes_used_by_txn
                     .state_writes
-                    .push((hashed_addr, state_trie_writes))
+                    .push((hashed_addr, state_write))
             }
 
-            let storage_writes_vec = storage_writes
-                .into_iter()
-                .map(|(k, v)| (TrieKey::from_hash(k), rlp::encode(&v).to_vec()))
-                .collect();
-
-            nodes_used_by_txn
-                .storage_writes
-                .push((hashed_addr, storage_writes_vec));
-
             let is_precompile = (FIRST_PRECOMPILE_ADDRESS..LAST_PRECOMPILE_ADDRESS)
                 .contains(&U256::from_big_endian(&addr.0));
 
@@ -136,23 +147,18 @@ impl TxnInfo {
                 nodes_used_by_txn.state_accesses.push(hashed_addr);
             }
 
-            if let Some(c_usage) = trace.code_usage {
-                match c_usage {
-                    ContractCodeUsage::Read(c_hash) => {
-                        if let Entry::Vacant(vacant) = contract_code_accessed.entry(c_hash) {
-                            vacant.insert(hash2code.resolve(&c_hash)?);
-                        }
-                    }
-                    ContractCodeUsage::Write(c_bytes) => {
-                        let c_hash = hash(&c_bytes);
-
-                        contract_code_accessed.insert(c_hash, c_bytes.clone());
-                        hash2code.insert_code(c_hash, c_bytes);
-                    }
+            match code_usage {
+                Some(ContractCodeUsage::Read(hash)) => {
+                    contract_code_accessed.insert(hash2code.get(hash)?);
+                }
+                Some(ContractCodeUsage::Write(code)) => {
+                    contract_code_accessed.insert(code.clone());
+                    hash2code.insert(code);
                 }
+                None => {}
             }
 
-            if trace.self_destructed.unwrap_or_default() {
+            if self_destructed.unwrap_or_default() {
                 nodes_used_by_txn.self_destructed_accounts.push(hashed_addr);
             }
         }
@@ -161,78 +167,64 @@ impl TxnInfo {
             nodes_used_by_txn.state_accesses.push(hashed_addr);
         }
 
-        let accounts_with_storage_accesses: HashSet<_> = HashSet::from_iter(
-            nodes_used_by_txn
-                .storage_accesses
-                .iter()
-                .filter(|(_, slots)| !slots.is_empty())
-                .map(|(addr, _)| *addr),
-        );
-
-        let all_accounts_with_non_empty_storage = all_accounts_in_pre_image
+        let accounts_with_storage_accesses = nodes_used_by_txn
+            .storage_accesses
             .iter()
-            .filter(|(_, data)| data.storage_root != EMPTY_TRIE_HASH);
-
-        let accounts_with_storage_but_no_storage_accesses = all_accounts_with_non_empty_storage
-            .filter(|&(addr, _data)| !accounts_with_storage_accesses.contains(addr))
-            .map(|(addr, data)| (*addr, data.storage_root));
-
-        nodes_used_by_txn
-            .state_accounts_with_no_accesses_but_storage_tries
-            .extend(accounts_with_storage_but_no_storage_accesses);
-
-        let txn_bytes = match self.meta.byte_code.is_empty() {
-            false => Some(self.meta.byte_code),
-            true => None,
-        };
+            .filter(|(_, slots)| !slots.is_empty())
+            .map(|(addr, _)| *addr)
+            .collect::<HashSet<_>>();
 
-        let receipt_node_bytes =
-            process_rlped_receipt_node_bytes(self.meta.new_receipt_trie_node_byte);
-
-        let new_meta_state = TxnMetaState {
-            txn_bytes,
-            receipt_node_bytes,
-            gas_used: self.meta.gas_used,
-        };
+        for (addr, state) in all_accounts_in_pre_image {
+            if state.storage_root != EMPTY_TRIE_HASH
+                && !accounts_with_storage_accesses.contains(addr)
+            {
+                nodes_used_by_txn
+                    .accts_with_unaccessed_storage
+                    .insert(*addr, state.storage_root);
+            }
+        }
 
         Ok(ProcessedTxnInfo {
             nodes_used_by_txn,
             contract_code_accessed,
-            meta: new_meta_state,
+            meta: TxnMetaState {
+                txn_bytes: match self.meta.byte_code.is_empty() {
+                    false => Some(self.meta.byte_code),
+                    true => None,
+                },
+                receipt_node_bytes: check_receipt_bytes(self.meta.new_receipt_trie_node_byte)?,
+                gas_used: self.meta.gas_used,
+            },
         })
     }
 }
 
-fn process_rlped_receipt_node_bytes(raw_bytes: Vec<u8>) -> Vec<u8> {
-    match rlp::decode::<LegacyReceiptRlp>(&raw_bytes) {
-        Ok(_) => raw_bytes,
+fn check_receipt_bytes(bytes: Vec<u8>) -> anyhow::Result<Vec<u8>> {
+    match rlp::decode::<LegacyReceiptRlp>(&bytes) {
+        Ok(_) => Ok(bytes),
         Err(_) => {
-            // Must be non-legacy.
-            rlp::decode::<Vec<u8>>(&raw_bytes).unwrap()
+            rlp::decode(&bytes).context("couldn't decode receipt as a legacy receipt or raw bytes")
         }
     }
 }
 
-fn create_empty_code_access_map() -> HashMap<H256, Vec<u8>> {
-    HashMap::from_iter(once((EMPTY_CODE_HASH, Vec::new())))
-}
-
 /// Note that "*_accesses" includes writes.
 #[derive(Debug, Default)]
 pub(crate) struct NodesUsedByTxn {
     pub state_accesses: Vec<H256>,
-    pub state_writes: Vec<(H256, StateTrieWrites)>,
+    pub state_writes: Vec<(H256, StateWrite)>,
 
     // Note: All entries in `storage_writes` also appear in `storage_accesses`.
     pub storage_accesses: Vec<(H256, Vec<TrieKey>)>,
     #[allow(clippy::type_complexity)]
     pub storage_writes: Vec<(H256, Vec<(TrieKey, Vec<u8>)>)>,
-    pub state_accounts_with_no_accesses_but_storage_tries: HashMap<H256, H256>,
+    /// Hashed address -> storage root.
+    pub accts_with_unaccessed_storage: HashMap<H256, H256>,
     pub self_destructed_accounts: Vec<H256>,
 }
 
-#[derive(Debug)]
-pub(crate) struct StateTrieWrites {
+#[derive(Debug, Default, PartialEq)]
+pub(crate) struct StateWrite {
     pub balance: Option<U256>,
     pub nonce: Option<U256>,
     pub storage_trie_change: bool,
@@ -241,7 +233,14 @@ pub(crate) struct StateTrieWrites {
 
 #[derive(Debug, Default)]
 pub(crate) struct TxnMetaState {
+    /// [`None`] if this is a dummy transaction inserted for padding.
     pub txn_bytes: Option<Vec<u8>>,
     pub receipt_node_bytes: Vec<u8>,
     pub gas_used: u64,
 }
+
+impl TxnMetaState {
+    pub fn is_dummy(&self) -> bool {
+        self.txn_bytes.is_none()
+    }
+}

From c7a16419d0faee01cd38ec3e8fc9b3310e94f13c Mon Sep 17 00:00:00 2001
From: BGluth <gluthb@gmail.com>
Date: Thu, 22 Aug 2024 05:07:30 -0600
Subject: [PATCH 3/3] Made sub-trie errors better (#520)

* Made sub-trie errors better

- Now shows the path in the trie where we encountered the `hash` node.

* Requested PR changes for #520
---
 mpt_trie/src/debug_tools/query.rs |  6 +++---
 mpt_trie/src/trie_subsets.rs      | 34 +++++++++++++++++++------------
 2 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/mpt_trie/src/debug_tools/query.rs b/mpt_trie/src/debug_tools/query.rs
index 0fb6ade96..dcfff397d 100644
--- a/mpt_trie/src/debug_tools/query.rs
+++ b/mpt_trie/src/debug_tools/query.rs
@@ -66,19 +66,19 @@ pub struct DebugQueryParamsBuilder {
 
 impl DebugQueryParamsBuilder {
     /// Defaults to `true`.
-    pub const fn print_key_pieces(mut self, enabled: bool) -> Self {
+    pub const fn include_key_pieces(mut self, enabled: bool) -> Self {
         self.params.include_key_piece_per_node = enabled;
         self
     }
 
     /// Defaults to `true`.
-    pub const fn print_node_type(mut self, enabled: bool) -> Self {
+    pub const fn include_node_type(mut self, enabled: bool) -> Self {
         self.params.include_node_type = enabled;
         self
     }
 
     /// Defaults to `false`.
-    pub const fn print_node_specific_values(mut self, enabled: bool) -> Self {
+    pub const fn include_node_specific_values(mut self, enabled: bool) -> Self {
         self.params.include_node_specific_values = enabled;
         self
     }
diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs
index 13e9d0d9f..eadb6b3e2 100644
--- a/mpt_trie/src/trie_subsets.rs
+++ b/mpt_trie/src/trie_subsets.rs
@@ -12,6 +12,7 @@ use log::trace;
 use thiserror::Error;
 
 use crate::{
+    debug_tools::query::{get_path_from_query, DebugQueryOutput, DebugQueryParamsBuilder},
     nibbles::Nibbles,
     partial_trie::{Node, PartialTrie, WrappedNode},
     trie_hashing::EncodedNode,
@@ -21,13 +22,10 @@ use crate::{
 /// The output type of trie_subset operations.
 pub type SubsetTrieResult<T> = Result<T, SubsetTrieError>;
 
-/// Errors that may occur when creating a subset [`PartialTrie`].
+/// We encountered a `hash` node when marking nodes during sub-trie creation.
 #[derive(Clone, Debug, Error, Hash)]
-pub enum SubsetTrieError {
-    #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")]
-    /// The key does not exist in the trie.
-    UnexpectedKey(Nibbles, String),
-}
+#[error("Encountered a hash node when marking nodes to not hash when traversing a key to not hash!\nPath: {0}")]
+pub struct SubsetTrieError(DebugQueryOutput);
 
 #[derive(Debug)]
 enum TrackedNodeIntern<N: PartialTrie> {
@@ -256,8 +254,17 @@ where
     N: PartialTrie,
     K: Into<Nibbles>,
 {
-    for k in keys_involved {
-        mark_nodes_that_are_needed(tracked_trie, &mut k.into())?;
+    for mut k in keys_involved.map(|k| k.into()) {
+        mark_nodes_that_are_needed(tracked_trie, &mut k).map_err(|_| {
+            // We need to unwind back to this callsite in order to produce the actual error.
+            let query = DebugQueryParamsBuilder::default()
+                .include_node_specific_values(true)
+                .build(k);
+
+            let res = get_path_from_query(&tracked_trie.info.underlying_node, query);
+
+            SubsetTrieError(res)
+        })?;
     }
 
     Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie))
@@ -270,10 +277,14 @@ where
 /// - For the key `0x1`, the marked nodes would be [B(0x), B(0x1)].
 /// - For the key `0x12`, the marked nodes still would be [B(0x), B(0x1)].
 /// - For the key `0x123`, the marked nodes would be [B(0x), B(0x1), L(0x123)].
+///
+/// Also note that we can't construct the error until we back out of this
+/// recursive function. We need to know the full key that hit the hash
+/// node, and that's only available at the initial call site.
 fn mark_nodes_that_are_needed<N: PartialTrie>(
     trie: &mut TrackedNode<N>,
     curr_nibbles: &mut Nibbles,
-) -> SubsetTrieResult<()> {
+) -> Result<(), ()> {
     trace!(
         "Sub-trie marking at {:x}, (type: {})",
         curr_nibbles,
@@ -286,10 +297,7 @@ fn mark_nodes_that_are_needed<N: PartialTrie>(
         }
         TrackedNodeIntern::Hash => match curr_nibbles.is_empty() {
             false => {
-                return Err(SubsetTrieError::UnexpectedKey(
-                    *curr_nibbles,
-                    format!("{:?}", trie),
-                ));
+                return Err(());
             }
             true => {
                 trie.info.touched = true;