From 1724620c517da818f87c12315029e48c140519fd Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Sat, 24 Aug 2024 13:57:13 -0400 Subject: [PATCH] Move segments related stuff into their own module (#529) --- .../src/fixed_recursive_verifier.rs | 6 +- evm_arithmetization/src/generation/mod.rs | 3 +- .../src/generation/segments.rs | 192 +++++++++++++++++ evm_arithmetization/src/generation/state.rs | 2 +- evm_arithmetization/src/lib.rs | 3 +- evm_arithmetization/src/prover.rs | 196 +----------------- proof_gen/src/proof_gen.rs | 4 +- zero_bin/common/src/prover_state/mod.rs | 7 +- zero_bin/prover/src/lib.rs | 2 +- 9 files changed, 215 insertions(+), 200 deletions(-) create mode 100644 evm_arithmetization/src/generation/segments.rs diff --git a/evm_arithmetization/src/fixed_recursive_verifier.rs b/evm_arithmetization/src/fixed_recursive_verifier.rs index 5a318b0fb..a892e924a 100644 --- a/evm_arithmetization/src/fixed_recursive_verifier.rs +++ b/evm_arithmetization/src/fixed_recursive_verifier.rs @@ -37,6 +37,7 @@ use starky::stark::Stark; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::cpu::kernel::aggregator::KERNEL; +use crate::generation::segments::{GenerationSegmentData, SegmentDataIterator, SegmentError}; use crate::generation::{GenerationInputs, TrimmedGenerationInputs}; use crate::get_challenges::observe_public_values_target; use crate::proof::{ @@ -44,7 +45,7 @@ use crate::proof::{ FinalPublicValues, MemCapTarget, PublicValues, PublicValuesTarget, RegistersDataTarget, TrieRoots, TrieRootsTarget, DEFAULT_CAP_LEN, TARGET_HASH_SIZE, }; -use crate::prover::{check_abort_signal, prove, GenerationSegmentData, SegmentDataIterator}; +use crate::prover::{check_abort_signal, prove}; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, get_memory_extra_looking_sum_circuit, recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs, @@ -1731,7 +1732,8 @@ where let mut proofs = vec![]; for segment_run in segment_iterator { - let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; + let (_, mut next_data) = + segment_run.map_err(|e: SegmentError| anyhow::format_err!(e))?; let proof = self.prove_segment( all_stark, config, diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 161ceda4c..9f1bd05c8 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -11,6 +11,7 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; +use segments::GenerationSegmentData; use serde::{Deserialize, Serialize}; use starky::config::StarkConfig; use GlobalMetadata::{ @@ -28,7 +29,6 @@ use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES}; use crate::proof::{ BlockHashes, BlockMetadata, ExtraBlockData, MemCap, PublicValues, RegistersData, TrieRoots, }; -use crate::prover::GenerationSegmentData; use crate::util::{h2u, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryState}; use crate::witness::state::RegistersState; @@ -37,6 +37,7 @@ pub(crate) mod linked_list; pub mod mpt; pub(crate) mod prover_input; pub(crate) mod rlp; +pub(crate) mod segments; pub(crate) mod state; pub(crate) mod trie_extractor; diff --git a/evm_arithmetization/src/generation/segments.rs b/evm_arithmetization/src/generation/segments.rs new file mode 100644 index 000000000..424ace9ad --- /dev/null +++ b/evm_arithmetization/src/generation/segments.rs @@ -0,0 +1,192 @@ +//! Module defining the logic around proof segmentation into chunks, +//! which allows what is commonly known as zk-continuations. + +use anyhow::Result; +use plonky2::hash::hash_types::RichField; +use serde::{Deserialize, Serialize}; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::{set_registers_and_run, ExtraSegmentData, Interpreter}; +use crate::generation::state::State; +use crate::generation::{debug_inputs, GenerationInputs}; +use crate::witness::memory::MemoryState; +use crate::witness::state::RegistersState; +use crate::AllData; + +/// Structure holding the data needed to initialize a segment. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct GenerationSegmentData { + /// Indicates the position of this segment in a sequence of + /// executions for a larger payload. + pub(crate) segment_index: usize, + /// Registers at the start of the segment execution. + pub(crate) registers_before: RegistersState, + /// Registers at the end of the segment execution. + pub(crate) registers_after: RegistersState, + /// Memory at the start of the segment execution. + pub(crate) memory: MemoryState, + /// Extra data required to initialize a segment. + pub(crate) extra_data: ExtraSegmentData, + /// Log of the maximal cpu length. + pub(crate) max_cpu_len_log: Option, +} + +impl GenerationSegmentData { + /// Retrieves the index of this segment. + pub fn segment_index(&self) -> usize { + self.segment_index + } +} + +/// Builds a new `GenerationSegmentData`. +#[allow(clippy::unwrap_or_default)] +fn build_segment_data( + segment_index: usize, + registers_before: Option, + registers_after: Option, + memory: Option, + interpreter: &Interpreter, +) -> GenerationSegmentData { + GenerationSegmentData { + segment_index, + registers_before: registers_before.unwrap_or(RegistersState::new()), + registers_after: registers_after.unwrap_or(RegistersState::new()), + memory: memory.unwrap_or(MemoryState { + preinitialized_segments: interpreter + .generation_state + .memory + .preinitialized_segments + .clone(), + ..Default::default() + }), + max_cpu_len_log: interpreter.get_max_cpu_len_log(), + extra_data: ExtraSegmentData { + bignum_modmul_result_limbs: interpreter + .generation_state + .bignum_modmul_result_limbs + .clone(), + rlp_prover_inputs: interpreter.generation_state.rlp_prover_inputs.clone(), + withdrawal_prover_inputs: interpreter + .generation_state + .withdrawal_prover_inputs + .clone(), + ger_prover_inputs: interpreter.generation_state.ger_prover_inputs.clone(), + trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), + jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), + next_txn_index: interpreter.generation_state.next_txn_index, + }, + } +} + +pub struct SegmentDataIterator { + interpreter: Interpreter, + partial_next_data: Option, +} + +pub type SegmentRunResult = Option)>>; + +#[derive(thiserror::Error, Debug, Serialize, Deserialize)] +#[error("{}", .0)] +pub struct SegmentError(pub String); + +impl SegmentDataIterator { + pub fn new(inputs: &GenerationInputs, max_cpu_len_log: Option) -> Self { + debug_inputs(inputs); + + let interpreter = Interpreter::::new_with_generation_inputs( + KERNEL.global_labels["init"], + vec![], + inputs, + max_cpu_len_log, + ); + + Self { + interpreter, + partial_next_data: None, + } + } + + /// Returns the data for the current segment, as well as the data -- except + /// registers_after -- for the next segment. + fn generate_next_segment( + &mut self, + partial_segment_data: Option, + ) -> Result { + // Get the (partial) current segment data, if it is provided. Otherwise, + // initialize it. + let mut segment_data = if let Some(partial) = partial_segment_data { + if partial.registers_after.program_counter == KERNEL.global_labels["halt"] { + return Ok(None); + } + self.interpreter + .get_mut_generation_state() + .set_segment_data(&partial); + self.interpreter.generation_state.memory = partial.memory.clone(); + partial + } else { + build_segment_data(0, None, None, None, &self.interpreter) + }; + + let segment_index = segment_data.segment_index; + + // Run the interpreter to get `registers_after` and the partial data for the + // next segment. + let run = set_registers_and_run(segment_data.registers_after, &mut self.interpreter); + if let Ok((updated_registers, mem_after)) = run { + let partial_segment_data = Some(build_segment_data( + segment_index + 1, + Some(updated_registers), + Some(updated_registers), + mem_after, + &self.interpreter, + )); + + segment_data.registers_after = updated_registers; + Ok(Some(Box::new((segment_data, partial_segment_data)))) + } else { + let inputs = &self.interpreter.get_generation_state().inputs; + let block = inputs.block_metadata.block_number; + let txn_range = match inputs.txn_hashes.len() { + 0 => "Dummy".to_string(), + 1 => format!("{:?}", inputs.txn_number_before), + _ => format!( + "{:?}_{:?}", + inputs.txn_number_before, + inputs.txn_number_before + inputs.txn_hashes.len() + ), + }; + let s = format!( + "Segment generation {:?} for block {:?} ({}) failed with error {:?}", + segment_index, + block, + txn_range, + run.unwrap_err() + ); + Err(SegmentError(s)) + } + } +} + +impl Iterator for SegmentDataIterator { + type Item = AllData; + + fn next(&mut self) -> Option { + let run = self.generate_next_segment(self.partial_next_data.clone()); + + if let Ok(segment_run) = run { + match segment_run { + // The run was valid, but didn't not consume the payload fully. + Some(boxed) => { + let (data, next_data) = *boxed; + self.partial_next_data = next_data; + Some(Ok((self.interpreter.generation_state.inputs.clone(), data))) + } + // The payload was fully consumed. + None => None, + } + } else { + // The run encountered some error. + Some(Err(run.unwrap_err())) + } + } +} diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index b5defe364..96865806a 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -9,6 +9,7 @@ use log::Level; use plonky2::field::types::Field; use super::mpt::TrieRootPtrs; +use super::segments::GenerationSegmentData; use super::{TrieInputs, TrimmedGenerationInputs, NUM_EXTRA_CYCLES_AFTER}; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::kernel::aggregator::KERNEL; @@ -21,7 +22,6 @@ use crate::generation::GenerationInputs; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::memory::segments::Segment; -use crate::prover::GenerationSegmentData; use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; use crate::witness::memory::MemoryChannel::GeneralPurpose; diff --git a/evm_arithmetization/src/lib.rs b/evm_arithmetization/src/lib.rs index 5af08fbfe..b76953311 100644 --- a/evm_arithmetization/src/lib.rs +++ b/evm_arithmetization/src/lib.rs @@ -212,6 +212,7 @@ pub mod extension_tower; pub mod testing_utils; pub mod util; +use generation::segments::SegmentError; use generation::TrimmedGenerationInputs; use mpt_trie::partial_trie::HashedPartialTrie; @@ -223,8 +224,8 @@ pub type BlockHeight = u64; pub use all_stark::AllStark; pub use fixed_recursive_verifier::AllRecursiveCircuits; +pub use generation::segments::{GenerationSegmentData, SegmentDataIterator}; pub use generation::GenerationInputs; -use prover::{GenerationSegmentData, SegmentError}; pub use starky::config::StarkConfig; /// Returned type from a `SegmentDataIterator`, needed to prove all segments in diff --git a/evm_arithmetization/src/prover.rs b/evm_arithmetization/src/prover.rs index 746e1926e..8cef8d647 100644 --- a/evm_arithmetization/src/prover.rs +++ b/evm_arithmetization/src/prover.rs @@ -13,7 +13,6 @@ use plonky2::iop::challenger::Challenger; use plonky2::plonk::config::{GenericConfig, GenericHashOut}; use plonky2::timed; use plonky2::util::timing::TimingTree; -use serde::{Deserialize, Serialize}; use starky::config::StarkConfig; use starky::cross_table_lookup::{get_ctl_data, CtlData}; use starky::lookup::GrandProductChallengeSet; @@ -23,39 +22,10 @@ use starky::stark::Stark; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::interpreter::{set_registers_and_run, ExtraSegmentData, Interpreter}; -use crate::generation::state::State; -use crate::generation::{debug_inputs, generate_traces, GenerationInputs, TrimmedGenerationInputs}; +use crate::generation::segments::GenerationSegmentData; +use crate::generation::{generate_traces, GenerationInputs, TrimmedGenerationInputs}; use crate::get_challenges::observe_public_values; use crate::proof::{AllProof, MemCap, PublicValues, DEFAULT_CAP_LEN}; -use crate::witness::memory::MemoryState; -use crate::witness::state::RegistersState; -use crate::AllData; - -/// Structure holding the data needed to initialize a segment. -#[derive(Clone, Default, Debug, Serialize, Deserialize)] -pub struct GenerationSegmentData { - /// Indicates the position of this segment in a sequence of - /// executions for a larger payload. - pub(crate) segment_index: usize, - /// Registers at the start of the segment execution. - pub(crate) registers_before: RegistersState, - /// Registers at the end of the segment execution. - pub(crate) registers_after: RegistersState, - /// Memory at the start of the segment execution. - pub(crate) memory: MemoryState, - /// Extra data required to initialize a segment. - pub(crate) extra_data: ExtraSegmentData, - /// Log of the maximal cpu length. - pub(crate) max_cpu_len_log: Option, -} - -impl GenerationSegmentData { - /// Retrieves the index of this segment. - pub fn segment_index(&self) -> usize { - self.segment_index - } -} /// Generate traces, then create all STARK proofs. pub fn prove( @@ -482,165 +452,16 @@ pub fn check_abort_signal(abort_signal: Option>) -> Result<()> { Ok(()) } -/// Builds a new `GenerationSegmentData`. -#[allow(clippy::unwrap_or_default)] -fn build_segment_data( - segment_index: usize, - registers_before: Option, - registers_after: Option, - memory: Option, - interpreter: &Interpreter, -) -> GenerationSegmentData { - GenerationSegmentData { - segment_index, - registers_before: registers_before.unwrap_or(RegistersState::new()), - registers_after: registers_after.unwrap_or(RegistersState::new()), - memory: memory.unwrap_or(MemoryState { - preinitialized_segments: interpreter - .generation_state - .memory - .preinitialized_segments - .clone(), - ..Default::default() - }), - max_cpu_len_log: interpreter.get_max_cpu_len_log(), - extra_data: ExtraSegmentData { - bignum_modmul_result_limbs: interpreter - .generation_state - .bignum_modmul_result_limbs - .clone(), - rlp_prover_inputs: interpreter.generation_state.rlp_prover_inputs.clone(), - withdrawal_prover_inputs: interpreter - .generation_state - .withdrawal_prover_inputs - .clone(), - ger_prover_inputs: interpreter.generation_state.ger_prover_inputs.clone(), - trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), - jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), - next_txn_index: interpreter.generation_state.next_txn_index, - }, - } -} - -pub struct SegmentDataIterator { - interpreter: Interpreter, - partial_next_data: Option, -} - -pub type SegmentRunResult = Option)>>; - -#[derive(thiserror::Error, Debug, Serialize, Deserialize)] -#[error("{}", .0)] -pub struct SegmentError(pub String); - -impl SegmentDataIterator { - pub fn new(inputs: &GenerationInputs, max_cpu_len_log: Option) -> Self { - debug_inputs(inputs); - - let interpreter = Interpreter::::new_with_generation_inputs( - KERNEL.global_labels["init"], - vec![], - inputs, - max_cpu_len_log, - ); - - Self { - interpreter, - partial_next_data: None, - } - } - - /// Returns the data for the current segment, as well as the data -- except - /// registers_after -- for the next segment. - fn generate_next_segment( - &mut self, - partial_segment_data: Option, - ) -> Result { - // Get the (partial) current segment data, if it is provided. Otherwise, - // initialize it. - let mut segment_data = if let Some(partial) = partial_segment_data { - if partial.registers_after.program_counter == KERNEL.global_labels["halt"] { - return Ok(None); - } - self.interpreter - .get_mut_generation_state() - .set_segment_data(&partial); - self.interpreter.generation_state.memory = partial.memory.clone(); - partial - } else { - build_segment_data(0, None, None, None, &self.interpreter) - }; - - let segment_index = segment_data.segment_index; - - // Run the interpreter to get `registers_after` and the partial data for the - // next segment. - let run = set_registers_and_run(segment_data.registers_after, &mut self.interpreter); - if let Ok((updated_registers, mem_after)) = run { - let partial_segment_data = Some(build_segment_data( - segment_index + 1, - Some(updated_registers), - Some(updated_registers), - mem_after, - &self.interpreter, - )); - - segment_data.registers_after = updated_registers; - Ok(Some(Box::new((segment_data, partial_segment_data)))) - } else { - let inputs = &self.interpreter.get_generation_state().inputs; - let block = inputs.block_metadata.block_number; - let txn_range = match inputs.txn_hashes.len() { - 0 => "Dummy".to_string(), - 1 => format!("{:?}", inputs.txn_number_before), - _ => format!( - "{:?}_{:?}", - inputs.txn_number_before, - inputs.txn_number_before + inputs.txn_hashes.len() - ), - }; - let s = format!( - "Segment generation {:?} for block {:?} ({}) failed with error {:?}", - segment_index, - block, - txn_range, - run.unwrap_err() - ); - Err(SegmentError(s)) - } - } -} - -impl Iterator for SegmentDataIterator { - type Item = AllData; - - fn next(&mut self) -> Option { - let run = self.generate_next_segment(self.partial_next_data.clone()); - - if let Ok(segment_run) = run { - match segment_run { - // The run was valid, but didn't not consume the payload fully. - Some(boxed) => { - let (data, next_data) = *boxed; - self.partial_next_data = next_data; - Some(Ok((self.interpreter.generation_state.inputs.clone(), data))) - } - // The payload was fully consumed. - None => None, - } - } else { - // The run encountered some error. - Some(Err(run.unwrap_err())) - } - } -} - /// A utility module designed to test witness generation externally. pub mod testing { use super::*; use crate::{ cpu::kernel::interpreter::Interpreter, - generation::{output_debug_tries, state::State}, + generation::{ + output_debug_tries, + segments::{SegmentDataIterator, SegmentError}, + state::State, + }, }; /// Simulates the zkEVM CPU execution. @@ -677,7 +498,8 @@ pub mod testing { let mut proofs = vec![]; for segment_run in segment_data_iterator { - let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; + let (_, mut next_data) = + segment_run.map_err(|e: SegmentError| anyhow::format_err!(e))?; let proof = prove( all_stark, config, diff --git a/proof_gen/src/proof_gen.rs b/proof_gen/src/proof_gen.rs index 916bc4e52..2e2e39ae2 100644 --- a/proof_gen/src/proof_gen.rs +++ b/proof_gen/src/proof_gen.rs @@ -4,8 +4,8 @@ use std::sync::{atomic::AtomicBool, Arc}; use evm_arithmetization::{ - fixed_recursive_verifier::ProverOutputData, generation::TrimmedGenerationInputs, - prover::GenerationSegmentData, AllStark, StarkConfig, + fixed_recursive_verifier::ProverOutputData, generation::TrimmedGenerationInputs, AllStark, + GenerationSegmentData, StarkConfig, }; use hashbrown::HashMap; use plonky2::{ diff --git a/zero_bin/common/src/prover_state/mod.rs b/zero_bin/common/src/prover_state/mod.rs index 638ca20bb..b1673e5f6 100644 --- a/zero_bin/common/src/prover_state/mod.rs +++ b/zero_bin/common/src/prover_state/mod.rs @@ -15,11 +15,8 @@ use std::{fmt::Display, sync::OnceLock}; use clap::ValueEnum; use evm_arithmetization::{ - fixed_recursive_verifier::ProverOutputData, - generation::TrimmedGenerationInputs, - proof::AllProof, - prover::{prove, GenerationSegmentData}, - AllStark, StarkConfig, + fixed_recursive_verifier::ProverOutputData, generation::TrimmedGenerationInputs, + proof::AllProof, prover::prove, AllStark, GenerationSegmentData, StarkConfig, }; use plonky2::{ field::goldilocks_field::GoldilocksField, plonk::config::PoseidonGoldilocksConfig, diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index 7207cbcaf..46472bc7f 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -55,7 +55,7 @@ impl BlockProverInput { prover_config: ProverConfig, ) -> Result { use anyhow::Context as _; - use evm_arithmetization::prover::SegmentDataIterator; + use evm_arithmetization::SegmentDataIterator; use futures::{stream::FuturesUnordered, FutureExt}; use paladin::directive::{Directive, IndexedStream};