Skip to content

Commit

Permalink
continuations: have segment iterator return a Result (#509)
Browse files Browse the repository at this point in the history
* Do not panic during segment generation

* Tweak

* Apply comments
  • Loading branch information
Nashtare authored Aug 21, 2024
1 parent 5cb6b5a commit dcdaceb
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 50 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions evm_arithmetization/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ serde = { workspace = true, features = ["derive"] }
sha2 = { workspace = true }
static_assertions = { workspace = true }
hashbrown = { workspace = true }
thiserror = { workspace = true }
tiny-keccak = { workspace = true }
serde_json = { workspace = true }
serde-big-array = { workspace = true }
Expand Down
5 changes: 3 additions & 2 deletions evm_arithmetization/src/fixed_recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,13 @@ where

let mut proofs = vec![];

for mut next_data in segment_iterator {
for segment_run in segment_iterator {
let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?;
let proof = self.prove_segment(
all_stark,
config,
generation_inputs.trim(),
&mut next_data.1,
&mut next_data,
timing,
abort_signal.clone(),
)?;
Expand Down
7 changes: 4 additions & 3 deletions evm_arithmetization/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ pub type BlockHeight = u64;
pub use all_stark::AllStark;
pub use fixed_recursive_verifier::AllRecursiveCircuits;
pub use generation::GenerationInputs;
use prover::GenerationSegmentData;
use prover::{GenerationSegmentData, SegmentError};
pub use starky::config::StarkConfig;

/// All data needed to prove all transaction segments.
pub type AllData = (TrimmedGenerationInputs, GenerationSegmentData);
/// Returned type from a `SegmentDataIterator`, needed to prove all segments in
/// a transaction batch.
pub type AllData = Result<(TrimmedGenerationInputs, GenerationSegmentData), SegmentError>;
51 changes: 37 additions & 14 deletions evm_arithmetization/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::get_challenges::observe_public_values;
use crate::proof::{AllProof, MemCap, PublicValues, DEFAULT_CAP_LEN};
use crate::witness::memory::MemoryState;
use crate::witness::state::RegistersState;
use crate::AllData;

/// Structure holding the data needed to initialize a segment.
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -526,6 +527,12 @@ pub struct SegmentDataIterator<F: RichField> {
partial_next_data: Option<GenerationSegmentData>,
}

pub type SegmentRunResult = Option<Box<(GenerationSegmentData, Option<GenerationSegmentData>)>>;

#[derive(thiserror::Error, Debug, Serialize, Deserialize)]
#[error("{}", .0)]
pub struct SegmentError(pub String);

impl<F: RichField> SegmentDataIterator<F> {
pub fn new(inputs: &GenerationInputs, max_cpu_len_log: Option<usize>) -> Self {
debug_inputs(inputs);
Expand All @@ -548,12 +555,12 @@ impl<F: RichField> SegmentDataIterator<F> {
fn generate_next_segment(
&mut self,
partial_segment_data: Option<GenerationSegmentData>,
) -> Option<(GenerationSegmentData, Option<GenerationSegmentData>)> {
) -> Result<SegmentRunResult, SegmentError> {
// Get the (partial) current segment data, if it is provided. Otherwise,
// initialize it.
let mut segment_data = if let Some(partial) = partial_segment_data {
if partial.registers_after.program_counter == KERNEL.global_labels["halt"] {
return None;
return Ok(None);
}
self.interpreter
.get_mut_generation_state()
Expand All @@ -579,7 +586,7 @@ impl<F: RichField> SegmentDataIterator<F> {
));

segment_data.registers_after = updated_registers;
Some((segment_data, partial_segment_data))
Ok(Some(Box::new((segment_data, partial_segment_data))))
} else {
let inputs = &self.interpreter.get_generation_state().inputs;
let block = inputs.block_metadata.block_number;
Expand All @@ -592,27 +599,38 @@ impl<F: RichField> SegmentDataIterator<F> {
inputs.txn_number_before + inputs.txn_hashes.len()
),
};
panic!(
let s = format!(
"Segment generation {:?} for block {:?} ({}) failed with error {:?}",
segment_index,
block,
txn_range,
run.unwrap_err()
);
Err(SegmentError(s))
}
}
}

impl<F: RichField> Iterator for SegmentDataIterator<F> {
type Item = (TrimmedGenerationInputs, GenerationSegmentData);
type Item = AllData;

fn next(&mut self) -> Option<Self::Item> {
if let Some((data, next_data)) = self.generate_next_segment(self.partial_next_data.clone())
{
self.partial_next_data = next_data;
Some((self.interpreter.generation_state.inputs.clone(), data))
let run = self.generate_next_segment(self.partial_next_data.clone());

if let Ok(segment_run) = run {
match segment_run {
// The run was valid, but didn't not consume the payload fully.
Some(boxed) => {
let (data, next_data) = *boxed;
self.partial_next_data = next_data;
Some(Ok((self.interpreter.generation_state.inputs.clone(), data)))
}
// The payload was fully consumed.
None => None,
}
} else {
None
// The run encountered some error.
Some(Err(run.unwrap_err()))
}
}
}
Expand Down Expand Up @@ -654,11 +672,12 @@ pub mod testing {
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
let data_iterator = SegmentDataIterator::<F>::new(&inputs, Some(max_cpu_len_log));
let segment_data_iterator = SegmentDataIterator::<F>::new(&inputs, Some(max_cpu_len_log));
let inputs = inputs.trim();
let mut proofs = vec![];

for (_, mut next_data) in data_iterator {
for segment_run in segment_data_iterator {
let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?;
let proof = prove(
all_stark,
config,
Expand All @@ -676,11 +695,15 @@ pub mod testing {
pub fn simulate_execution_all_segments<F>(
inputs: GenerationInputs,
max_cpu_len_log: usize,
) -> anyhow::Result<()>
) -> Result<()>
where
F: RichField,
{
let _ = SegmentDataIterator::<F>::new(&inputs, Some(max_cpu_len_log)).collect::<Vec<_>>();
for segment in SegmentDataIterator::<F>::new(&inputs, Some(max_cpu_len_log)) {
if let Err(e) = segment {
return Err(anyhow::format_err!(e));
}
}

Ok(())
}
Expand Down
7 changes: 5 additions & 2 deletions zero_bin/common/src/prover_state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use evm_arithmetization::{
generation::TrimmedGenerationInputs,
proof::AllProof,
prover::{prove, GenerationSegmentData},
AllData, AllStark, StarkConfig,
AllStark, StarkConfig,
};
use plonky2::{
field::goldilocks_field::GoldilocksField, plonk::config::PoseidonGoldilocksConfig,
Expand Down Expand Up @@ -255,7 +255,10 @@ impl ProverStateManager {
/// - If the persistence strategy is [`CircuitPersistence::Disk`] with
/// [`TableLoadStrategy::OnDemand`], the table circuits are loaded as
/// needed.
pub fn generate_segment_proof(&self, input: AllData) -> anyhow::Result<GeneratedSegmentProof> {
pub fn generate_segment_proof(
&self,
input: (TrimmedGenerationInputs, GenerationSegmentData),
) -> anyhow::Result<GeneratedSegmentProof> {
let (generation_inputs, mut segment_data) = input;

match self.persistence {
Expand Down
53 changes: 29 additions & 24 deletions zero_bin/ops/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::time::Instant;

#[cfg(not(feature = "test_only"))]
use evm_arithmetization::generation::TrimmedGenerationInputs;
use evm_arithmetization::{proof::PublicValues, AllData};
use evm_arithmetization::proof::PublicValues;
#[cfg(feature = "test_only")]
use evm_arithmetization::{prover::testing::simulate_execution_all_segments, GenerationInputs};
use paladin::{
Expand All @@ -26,36 +26,20 @@ use zero_bin_common::{debug_utils::save_inputs_to_disk, prover_state::p_state};

registry!();

#[cfg(feature = "test_only")]
#[derive(Deserialize, Serialize, RemoteExecute)]
pub struct BatchTestOnly {
pub save_inputs_on_error: bool,
}

#[cfg(feature = "test_only")]
impl Operation for BatchTestOnly {
type Input = (GenerationInputs, usize);
type Output = ();

fn execute(&self, inputs: Self::Input) -> Result<Self::Output> {
simulate_execution_all_segments::<Field>(inputs.0, inputs.1)
.map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?;

Ok(())
}
}

#[derive(Deserialize, Serialize, RemoteExecute)]
pub struct SegmentProof {
pub save_inputs_on_error: bool,
}

#[cfg(not(feature = "test_only"))]
impl Operation for SegmentProof {
type Input = AllData;
type Input = evm_arithmetization::AllData;
type Output = proof_gen::proof_types::SegmentAggregatableProof;

fn execute(&self, all_data: Self::Input) -> Result<Self::Output> {
let all_data =
all_data.map_err(|err| FatalError::from_str(&err.0, FatalStrategy::Terminate))?;

let input = all_data.0.clone();
let segment_index = all_data.1.segment_index();
let _span = SegmentProofSpan::new(&input, all_data.1.segment_index());
Expand All @@ -65,7 +49,7 @@ impl Operation for SegmentProof {
.map_err(|err| {
if let Err(write_err) = save_inputs_to_disk(
format!(
"b{}_txns_{}-{}-({})_input.json",
"b{}_txns_{}..{}-({})_input.json",
input.block_metadata.block_number,
input.txn_number_before,
input.txn_number_before + input.txn_hashes.len(),
Expand All @@ -90,10 +74,31 @@ impl Operation for SegmentProof {

#[cfg(feature = "test_only")]
impl Operation for SegmentProof {
type Input = AllData;
type Input = (GenerationInputs, usize);
type Output = ();

fn execute(&self, _all_data: Self::Input) -> Result<Self::Output> {
fn execute(&self, inputs: Self::Input) -> Result<Self::Output> {
if self.save_inputs_on_error {
simulate_execution_all_segments::<Field>(inputs.0.clone(), inputs.1).map_err(|err| {
if let Err(write_err) = save_inputs_to_disk(
format!(
"b{}_txns_{}..{}_input.json",
inputs.0.block_metadata.block_number,
inputs.0.txn_number_before,
inputs.0.txn_number_before + inputs.0.signed_txns.len(),
),
inputs.0,
) {
error!("Failed to save txn proof input to disk: {:?}", write_err);
}

FatalError::from_anyhow(err, FatalStrategy::Terminate)
})?
} else {
simulate_execution_all_segments::<Field>(inputs.0, inputs.1)
.map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?;
}

Ok(())
}
}
Expand Down
12 changes: 7 additions & 5 deletions zero_bin/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl BlockProverInput {
) -> Result<GeneratedBlockProof> {
use std::iter::repeat;

use futures::StreamExt;
use futures::future;
use paladin::directive::{Directive, IndexedStream};

let ProverConfig {
Expand All @@ -147,7 +147,7 @@ impl BlockProverInput {
let block_generation_inputs =
trace_decoder::entrypoint(self.block_trace, self.other_data, batch_size)?;

let batch_ops = ops::BatchTestOnly {
let batch_ops = ops::SegmentProof {
save_inputs_on_error,
};

Expand All @@ -160,9 +160,11 @@ impl BlockProverInput {
&batch_ops,
);

let result = simulation.run(runtime).await?;

result.collect::<Vec<_>>().await;
simulation
.run(runtime)
.await?
.try_for_each(|_| future::ok(()))
.await?;

info!("Successfully generated witness for block {block_number}.");

Expand Down

0 comments on commit dcdaceb

Please sign in to comment.