From dcdaceb627d004f9743ebe04f6ee8a627da7bb7c Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:28:53 -0400 Subject: [PATCH] continuations: have segment iterator return a `Result` (#509) * Do not panic during segment generation * Tweak * Apply comments --- Cargo.lock | 1 + evm_arithmetization/Cargo.toml | 1 + .../src/fixed_recursive_verifier.rs | 5 +- evm_arithmetization/src/lib.rs | 7 +-- evm_arithmetization/src/prover.rs | 51 +++++++++++++----- zero_bin/common/src/prover_state/mod.rs | 7 ++- zero_bin/ops/src/lib.rs | 53 ++++++++++--------- zero_bin/prover/src/lib.rs | 12 +++-- 8 files changed, 87 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ea8c012e..706845c8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2029,6 +2029,7 @@ dependencies = [ "sha2", "starky", "static_assertions", + "thiserror", "tiny-keccak", "zk_evm_proc_macro", ] diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 8a912e8db..d69f2a140 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -41,6 +41,7 @@ serde = { workspace = true, features = ["derive"] } sha2 = { workspace = true } static_assertions = { workspace = true } hashbrown = { workspace = true } +thiserror = { workspace = true } tiny-keccak = { workspace = true } serde_json = { workspace = true } serde-big-array = { workspace = true } diff --git a/evm_arithmetization/src/fixed_recursive_verifier.rs b/evm_arithmetization/src/fixed_recursive_verifier.rs index aa7bf4d9d..5a318b0fb 100644 --- a/evm_arithmetization/src/fixed_recursive_verifier.rs +++ b/evm_arithmetization/src/fixed_recursive_verifier.rs @@ -1730,12 +1730,13 @@ where let mut proofs = vec![]; - for mut next_data in segment_iterator { + for segment_run in segment_iterator { + let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; let proof = self.prove_segment( all_stark, config, generation_inputs.trim(), - &mut next_data.1, + &mut next_data, timing, abort_signal.clone(), )?; diff --git a/evm_arithmetization/src/lib.rs b/evm_arithmetization/src/lib.rs index 82719648c..5af08fbfe 100644 --- a/evm_arithmetization/src/lib.rs +++ b/evm_arithmetization/src/lib.rs @@ -224,8 +224,9 @@ pub type BlockHeight = u64; pub use all_stark::AllStark; pub use fixed_recursive_verifier::AllRecursiveCircuits; pub use generation::GenerationInputs; -use prover::GenerationSegmentData; +use prover::{GenerationSegmentData, SegmentError}; pub use starky::config::StarkConfig; -/// All data needed to prove all transaction segments. -pub type AllData = (TrimmedGenerationInputs, GenerationSegmentData); +/// Returned type from a `SegmentDataIterator`, needed to prove all segments in +/// a transaction batch. +pub type AllData = Result<(TrimmedGenerationInputs, GenerationSegmentData), SegmentError>; diff --git a/evm_arithmetization/src/prover.rs b/evm_arithmetization/src/prover.rs index 97af417c9..746e1926e 100644 --- a/evm_arithmetization/src/prover.rs +++ b/evm_arithmetization/src/prover.rs @@ -30,6 +30,7 @@ use crate::get_challenges::observe_public_values; use crate::proof::{AllProof, MemCap, PublicValues, DEFAULT_CAP_LEN}; use crate::witness::memory::MemoryState; use crate::witness::state::RegistersState; +use crate::AllData; /// Structure holding the data needed to initialize a segment. #[derive(Clone, Default, Debug, Serialize, Deserialize)] @@ -526,6 +527,12 @@ pub struct SegmentDataIterator { partial_next_data: Option, } +pub type SegmentRunResult = Option)>>; + +#[derive(thiserror::Error, Debug, Serialize, Deserialize)] +#[error("{}", .0)] +pub struct SegmentError(pub String); + impl SegmentDataIterator { pub fn new(inputs: &GenerationInputs, max_cpu_len_log: Option) -> Self { debug_inputs(inputs); @@ -548,12 +555,12 @@ impl SegmentDataIterator { fn generate_next_segment( &mut self, partial_segment_data: Option, - ) -> Option<(GenerationSegmentData, Option)> { + ) -> Result { // Get the (partial) current segment data, if it is provided. Otherwise, // initialize it. let mut segment_data = if let Some(partial) = partial_segment_data { if partial.registers_after.program_counter == KERNEL.global_labels["halt"] { - return None; + return Ok(None); } self.interpreter .get_mut_generation_state() @@ -579,7 +586,7 @@ impl SegmentDataIterator { )); segment_data.registers_after = updated_registers; - Some((segment_data, partial_segment_data)) + Ok(Some(Box::new((segment_data, partial_segment_data)))) } else { let inputs = &self.interpreter.get_generation_state().inputs; let block = inputs.block_metadata.block_number; @@ -592,27 +599,38 @@ impl SegmentDataIterator { inputs.txn_number_before + inputs.txn_hashes.len() ), }; - panic!( + let s = format!( "Segment generation {:?} for block {:?} ({}) failed with error {:?}", segment_index, block, txn_range, run.unwrap_err() ); + Err(SegmentError(s)) } } } impl Iterator for SegmentDataIterator { - type Item = (TrimmedGenerationInputs, GenerationSegmentData); + type Item = AllData; fn next(&mut self) -> Option { - if let Some((data, next_data)) = self.generate_next_segment(self.partial_next_data.clone()) - { - self.partial_next_data = next_data; - Some((self.interpreter.generation_state.inputs.clone(), data)) + let run = self.generate_next_segment(self.partial_next_data.clone()); + + if let Ok(segment_run) = run { + match segment_run { + // The run was valid, but didn't not consume the payload fully. + Some(boxed) => { + let (data, next_data) = *boxed; + self.partial_next_data = next_data; + Some(Ok((self.interpreter.generation_state.inputs.clone(), data))) + } + // The payload was fully consumed. + None => None, + } } else { - None + // The run encountered some error. + Some(Err(run.unwrap_err())) } } } @@ -654,11 +672,12 @@ pub mod testing { F: RichField + Extendable, C: GenericConfig, { - let data_iterator = SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)); + let segment_data_iterator = SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)); let inputs = inputs.trim(); let mut proofs = vec![]; - for (_, mut next_data) in data_iterator { + for segment_run in segment_data_iterator { + let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; let proof = prove( all_stark, config, @@ -676,11 +695,15 @@ pub mod testing { pub fn simulate_execution_all_segments( inputs: GenerationInputs, max_cpu_len_log: usize, - ) -> anyhow::Result<()> + ) -> Result<()> where F: RichField, { - let _ = SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)).collect::>(); + for segment in SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)) { + if let Err(e) = segment { + return Err(anyhow::format_err!(e)); + } + } Ok(()) } diff --git a/zero_bin/common/src/prover_state/mod.rs b/zero_bin/common/src/prover_state/mod.rs index cdbe6a4b2..638ca20bb 100644 --- a/zero_bin/common/src/prover_state/mod.rs +++ b/zero_bin/common/src/prover_state/mod.rs @@ -19,7 +19,7 @@ use evm_arithmetization::{ generation::TrimmedGenerationInputs, proof::AllProof, prover::{prove, GenerationSegmentData}, - AllData, AllStark, StarkConfig, + AllStark, StarkConfig, }; use plonky2::{ field::goldilocks_field::GoldilocksField, plonk::config::PoseidonGoldilocksConfig, @@ -255,7 +255,10 @@ impl ProverStateManager { /// - If the persistence strategy is [`CircuitPersistence::Disk`] with /// [`TableLoadStrategy::OnDemand`], the table circuits are loaded as /// needed. - pub fn generate_segment_proof(&self, input: AllData) -> anyhow::Result { + pub fn generate_segment_proof( + &self, + input: (TrimmedGenerationInputs, GenerationSegmentData), + ) -> anyhow::Result { let (generation_inputs, mut segment_data) = input; match self.persistence { diff --git a/zero_bin/ops/src/lib.rs b/zero_bin/ops/src/lib.rs index 979508527..286ae4bca 100644 --- a/zero_bin/ops/src/lib.rs +++ b/zero_bin/ops/src/lib.rs @@ -3,7 +3,7 @@ use std::time::Instant; #[cfg(not(feature = "test_only"))] use evm_arithmetization::generation::TrimmedGenerationInputs; -use evm_arithmetization::{proof::PublicValues, AllData}; +use evm_arithmetization::proof::PublicValues; #[cfg(feature = "test_only")] use evm_arithmetization::{prover::testing::simulate_execution_all_segments, GenerationInputs}; use paladin::{ @@ -26,25 +26,6 @@ use zero_bin_common::{debug_utils::save_inputs_to_disk, prover_state::p_state}; registry!(); -#[cfg(feature = "test_only")] -#[derive(Deserialize, Serialize, RemoteExecute)] -pub struct BatchTestOnly { - pub save_inputs_on_error: bool, -} - -#[cfg(feature = "test_only")] -impl Operation for BatchTestOnly { - type Input = (GenerationInputs, usize); - type Output = (); - - fn execute(&self, inputs: Self::Input) -> Result { - simulate_execution_all_segments::(inputs.0, inputs.1) - .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?; - - Ok(()) - } -} - #[derive(Deserialize, Serialize, RemoteExecute)] pub struct SegmentProof { pub save_inputs_on_error: bool, @@ -52,10 +33,13 @@ pub struct SegmentProof { #[cfg(not(feature = "test_only"))] impl Operation for SegmentProof { - type Input = AllData; + type Input = evm_arithmetization::AllData; type Output = proof_gen::proof_types::SegmentAggregatableProof; fn execute(&self, all_data: Self::Input) -> Result { + let all_data = + all_data.map_err(|err| FatalError::from_str(&err.0, FatalStrategy::Terminate))?; + let input = all_data.0.clone(); let segment_index = all_data.1.segment_index(); let _span = SegmentProofSpan::new(&input, all_data.1.segment_index()); @@ -65,7 +49,7 @@ impl Operation for SegmentProof { .map_err(|err| { if let Err(write_err) = save_inputs_to_disk( format!( - "b{}_txns_{}-{}-({})_input.json", + "b{}_txns_{}..{}-({})_input.json", input.block_metadata.block_number, input.txn_number_before, input.txn_number_before + input.txn_hashes.len(), @@ -90,10 +74,31 @@ impl Operation for SegmentProof { #[cfg(feature = "test_only")] impl Operation for SegmentProof { - type Input = AllData; + type Input = (GenerationInputs, usize); type Output = (); - fn execute(&self, _all_data: Self::Input) -> Result { + fn execute(&self, inputs: Self::Input) -> Result { + if self.save_inputs_on_error { + simulate_execution_all_segments::(inputs.0.clone(), inputs.1).map_err(|err| { + if let Err(write_err) = save_inputs_to_disk( + format!( + "b{}_txns_{}..{}_input.json", + inputs.0.block_metadata.block_number, + inputs.0.txn_number_before, + inputs.0.txn_number_before + inputs.0.signed_txns.len(), + ), + inputs.0, + ) { + error!("Failed to save txn proof input to disk: {:?}", write_err); + } + + FatalError::from_anyhow(err, FatalStrategy::Terminate) + })? + } else { + simulate_execution_all_segments::(inputs.0, inputs.1) + .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?; + } + Ok(()) } } diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index e5f664afe..f96039edf 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -132,7 +132,7 @@ impl BlockProverInput { ) -> Result { use std::iter::repeat; - use futures::StreamExt; + use futures::future; use paladin::directive::{Directive, IndexedStream}; let ProverConfig { @@ -147,7 +147,7 @@ impl BlockProverInput { let block_generation_inputs = trace_decoder::entrypoint(self.block_trace, self.other_data, batch_size)?; - let batch_ops = ops::BatchTestOnly { + let batch_ops = ops::SegmentProof { save_inputs_on_error, }; @@ -160,9 +160,11 @@ impl BlockProverInput { &batch_ops, ); - let result = simulation.run(runtime).await?; - - result.collect::>().await; + simulation + .run(runtime) + .await? + .try_for_each(|_| future::ok(())) + .await?; info!("Successfully generated witness for block {block_number}.");