From 4d7d9ffa3cc5a6ba39389a28c18657f641636ce6 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Mon, 11 Sep 2023 15:47:33 +0100 Subject: [PATCH 01/34] Constrain genesis block's state trie. --- evm/src/fixed_recursive_verifier.rs | 55 +++++++++++++++++++++++++---- evm/src/generation/mod.rs | 3 ++ evm/src/get_challenges.rs | 2 ++ evm/src/proof.rs | 35 +++++++++++++----- evm/src/recursive_verifier.rs | 6 ++++ evm/tests/add11_yml.rs | 1 + evm/tests/basic_smart_contract.rs | 1 + evm/tests/empty_txn_list.rs | 1 + evm/tests/log_opcode.rs | 5 +++ evm/tests/self_balance_gas_cost.rs | 1 + evm/tests/simple_transfer.rs | 1 + 11 files changed, 96 insertions(+), 15 deletions(-) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 72844db69c..4d3f4d3d60 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -640,6 +640,11 @@ where lhs: &ExtraBlockDataTarget, rhs: &ExtraBlockDataTarget, ) { + // Connect genesis state root values. + for (&limb0, &limb1) in pvs.genesis_state_root.iter().zip(&rhs.genesis_state_root) { + builder.connect(limb0, limb1); + } + // Connect the transaction number in public values to the lhs and rhs values correctly. builder.connect(pvs.txn_number_before, lhs.txn_number_before); builder.connect(pvs.txn_number_after, rhs.txn_number_after); @@ -776,6 +781,16 @@ where builder.connect(limb0, limb1); } + // Between blocks, the genesis state trie remains unchanged. + for (&limb0, limb1) in lhs + .extra_block_data + .genesis_state_root + .iter() + .zip(rhs.extra_block_data.genesis_state_root) + { + builder.connect(limb0, limb1); + } + // Connect block numbers. let one = builder.one(); let prev_block_nb = builder.sub(rhs.block_metadata.block_number, one); @@ -787,13 +802,23 @@ where // Connect intermediary values for gas_used and bloom filters to the block's final values. We only plug on the right, so there is no need to check the left-handside block. Self::connect_final_block_values_to_intermediary(builder, rhs); - // Chack that the genesis block number is 0. let zero = builder.zero(); let has_not_parent_block = builder.sub(one, has_parent_block.target); + // Chack that the genesis block number is 0. let gen_block_constr = builder.mul(has_not_parent_block, rhs.block_metadata.block_number); builder.connect(gen_block_constr, zero); - // TODO: Check that the genesis block has a predetermined state trie root. + // Check that the genesis block has a predetermined state trie root. + for (&limb0, limb1) in rhs + .trie_roots_before + .state_root + .iter() + .zip(rhs.extra_block_data.genesis_state_root) + { + let mut constr = builder.sub(limb0, limb1); + constr = builder.mul(has_not_parent_block, constr); + builder.connect(constr, zero); + } } fn connect_final_block_values_to_intermediary( @@ -981,16 +1006,34 @@ where block_inputs .set_proof_with_pis_target(&self.block.parent_block_proof, parent_block_proof); } else { - // Initialize state_root_after and the block number for correct connection between blocks. - let state_trie_root_keys = 24..32; - let block_number_key = TrieRootsTarget::SIZE * 2 + 6; + // Initialize genesis_state_trie, state_root_after and the block number for correct connection between blocks. + // Initialize `state_root_after`. + let state_trie_root_after_keys = 24..32; let mut nonzero_pis = HashMap::new(); - for (key, &value) in state_trie_root_keys + for (key, &value) in state_trie_root_after_keys .zip_eq(&h256_limbs::(public_values.trie_roots_before.state_root)) { nonzero_pis.insert(key, value); } + + // Initialize the genesis state trie digest. + let genesis_state_trie_keys = TrieRootsTarget::SIZE * 2 + + BlockMetadataTarget::SIZE + + BlockHashesTarget::BLOCK_HASHES_SIZE + ..TrieRootsTarget::SIZE * 2 + + BlockMetadataTarget::SIZE + + BlockHashesTarget::BLOCK_HASHES_SIZE + + 8; + for (key, &value) in genesis_state_trie_keys.zip_eq(&h256_limbs::( + public_values.extra_block_data.genesis_state_root, + )) { + nonzero_pis.insert(key, value); + } + + // Initialize the block number. + let block_number_key = TrieRootsTarget::SIZE * 2 + 6; nonzero_pis.insert(block_number_key, F::NEG_ONE); + block_inputs.set_proof_with_pis_target( &self.block.parent_block_proof, &cyclic_base_proof( diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 13c6670ba6..01e3209df9 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -48,6 +48,8 @@ pub struct GenerationInputs { pub tries: TrieInputs, /// Expected trie roots after the transactions are executed. pub trie_roots_after: TrieRoots, + /// State trie root of the genesis block. + pub genesis_state_trie_root: H256, /// Mapping between smart contract code hashes and the contract byte code. /// All account smart contracts that are invoked will have an entry present. @@ -251,6 +253,7 @@ pub fn generate_traces, const D: usize>( let txn_number_after = read_metadata(GlobalMetadata::TxnNumberAfter); let extra_block_data = ExtraBlockData { + genesis_state_root: inputs.genesis_state_trie_root, txn_number_before: inputs.txn_number_before, txn_number_after, gas_used_before: inputs.gas_used_before, diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 59be8439d8..0afa1d8022 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -116,6 +116,7 @@ fn observe_extra_block_data< challenger: &mut Challenger, extra_data: &ExtraBlockData, ) { + challenger.observe_elements(&h256_limbs(extra_data.genesis_state_root)); challenger.observe_element(F::from_canonical_u32(extra_data.txn_number_before.as_u32())); challenger.observe_element(F::from_canonical_u32(extra_data.txn_number_after.as_u32())); challenger.observe_element(F::from_canonical_u32(extra_data.gas_used_before.as_u32())); @@ -138,6 +139,7 @@ fn observe_extra_block_data_target< ) where C::Hasher: AlgebraicHasher, { + challenger.observe_elements(&extra_data.genesis_state_root); challenger.observe_element(extra_data.txn_number_before); challenger.observe_element(extra_data.txn_number_after); challenger.observe_element(extra_data.gas_used_before); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 14f22b6791..a5bb2b3d74 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -95,6 +95,7 @@ pub struct BlockMetadata { #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ExtraBlockData { + pub genesis_state_root: H256, pub txn_number_before: U256, pub txn_number_after: U256, pub gas_used_before: U256, @@ -166,6 +167,7 @@ impl PublicValuesTarget { buffer.write_target_array(&cur_hash)?; let ExtraBlockDataTarget { + genesis_state_root, txn_number_before, txn_number_after, gas_used_before, @@ -173,6 +175,7 @@ impl PublicValuesTarget { block_bloom_before, block_bloom_after, } = self.extra_block_data; + buffer.write_target_array(&genesis_state_root)?; buffer.write_target(txn_number_before)?; buffer.write_target(txn_number_after)?; buffer.write_target(gas_used_before)?; @@ -214,6 +217,7 @@ impl PublicValuesTarget { }; let extra_block_data = ExtraBlockDataTarget { + genesis_state_root: buffer.read_target_array()?, txn_number_before: buffer.read_target()?, txn_number_after: buffer.read_target()?, gas_used_before: buffer.read_target()?, @@ -381,7 +385,7 @@ pub struct BlockMetadataTarget { } impl BlockMetadataTarget { - const SIZE: usize = 77; + pub const SIZE: usize = 77; pub fn from_public_inputs(pis: &[Target]) -> Self { let block_beneficiary = pis[0..5].try_into().unwrap(); @@ -465,7 +469,7 @@ pub struct BlockHashesTarget { } impl BlockHashesTarget { - const BLOCK_HASHES_SIZE: usize = 2056; + pub const BLOCK_HASHES_SIZE: usize = 2056; pub fn from_public_inputs(pis: &[Target]) -> Self { Self { prev_hashes: pis[0..2048].try_into().unwrap(), @@ -505,6 +509,7 @@ impl BlockHashesTarget { #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct ExtraBlockDataTarget { + pub genesis_state_root: [Target; 8], pub txn_number_before: Target, pub txn_number_after: Target, pub gas_used_before: Target, @@ -514,17 +519,19 @@ pub struct ExtraBlockDataTarget { } impl ExtraBlockDataTarget { - const SIZE: usize = 132; + const SIZE: usize = 140; pub fn from_public_inputs(pis: &[Target]) -> Self { - let txn_number_before = pis[0]; - let txn_number_after = pis[1]; - let gas_used_before = pis[2]; - let gas_used_after = pis[3]; - let block_bloom_before = pis[4..68].try_into().unwrap(); - let block_bloom_after = pis[68..132].try_into().unwrap(); + let genesis_state_root = pis[0..8].try_into().unwrap(); + let txn_number_before = pis[8]; + let txn_number_after = pis[9]; + let gas_used_before = pis[10]; + let gas_used_after = pis[11]; + let block_bloom_before = pis[12..76].try_into().unwrap(); + let block_bloom_after = pis[76..140].try_into().unwrap(); Self { + genesis_state_root, txn_number_before, txn_number_after, gas_used_before, @@ -541,6 +548,13 @@ impl ExtraBlockDataTarget { ed1: Self, ) -> Self { Self { + genesis_state_root: core::array::from_fn(|i| { + builder.select( + condition, + ed0.genesis_state_root[i], + ed1.genesis_state_root[i], + ) + }), txn_number_before: builder.select( condition, ed0.txn_number_before, @@ -571,6 +585,9 @@ impl ExtraBlockDataTarget { ed0: Self, ed1: Self, ) { + for i in 0..8 { + builder.connect(ed0.genesis_state_root[i], ed1.genesis_state_root[i]); + } builder.connect(ed0.txn_number_before, ed1.txn_number_before); builder.connect(ed0.txn_number_after, ed1.txn_number_after); builder.connect(ed0.gas_used_before, ed1.gas_used_before); diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index e669f4ab35..3558dc9a78 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -808,6 +808,7 @@ pub(crate) fn add_virtual_block_hashes, const D: us pub(crate) fn add_virtual_extra_block_data, const D: usize>( builder: &mut CircuitBuilder, ) -> ExtraBlockDataTarget { + let genesis_state_root = builder.add_virtual_public_input_arr(); let txn_number_before = builder.add_virtual_public_input(); let txn_number_after = builder.add_virtual_public_input(); let gas_used_before = builder.add_virtual_public_input(); @@ -815,6 +816,7 @@ pub(crate) fn add_virtual_extra_block_data, const D let block_bloom_before: [Target; 64] = builder.add_virtual_public_input_arr(); let block_bloom_after: [Target; 64] = builder.add_virtual_public_input_arr(); ExtraBlockDataTarget { + genesis_state_root, txn_number_before, txn_number_after, gas_used_before, @@ -1070,6 +1072,10 @@ pub(crate) fn set_extra_public_values_target( F: RichField + Extendable, W: Witness, { + witness.set_target_arr( + &ed_target.genesis_state_root, + &h256_limbs::(ed.genesis_state_root), + ); witness.set_target( ed_target.txn_number_before, F::from_canonical_usize(ed.txn_number_before.as_usize()), diff --git a/evm/tests/add11_yml.rs b/evm/tests/add11_yml.rs index f628e94441..89ea13ee12 100644 --- a/evm/tests/add11_yml.rs +++ b/evm/tests/add11_yml.rs @@ -149,6 +149,7 @@ fn add11_yml() -> anyhow::Result<()> { trie_roots_after, contract_code, block_metadata, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), txn_number_before: 0.into(), gas_used_before: 0.into(), gas_used_after: 0xa868u64.into(), diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 4d0a2090b6..3118a34a49 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -165,6 +165,7 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 806726fc9c..977f3efd3e 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -57,6 +57,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { }, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index 271ab9456f..d86379ca50 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -225,6 +225,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), @@ -423,6 +424,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { tries: tries_before, trie_roots_after: tries_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata: block_metadata.clone(), txn_number_before: 0.into(), gas_used_before: 0.into(), @@ -564,6 +566,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 1.into(), gas_used_before: gas_used_second, @@ -589,6 +592,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { trie_roots_before: first_public_values.trie_roots_before, trie_roots_after: public_values.trie_roots_after, extra_block_data: ExtraBlockData { + genesis_state_root: first_public_values.extra_block_data.genesis_state_root, txn_number_before: first_public_values.extra_block_data.txn_number_before, txn_number_after: public_values.extra_block_data.txn_number_after, gas_used_before: first_public_values.extra_block_data.gas_used_before, @@ -864,6 +868,7 @@ fn test_two_txn() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index d346164725..d0e95e1115 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -154,6 +154,7 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index b8c47fe9a6..5dd3b1ac23 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -135,6 +135,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), From f65ad58a0854de3e5a65089815d771b4f9ef12d7 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 13 Feb 2023 15:58:26 +0100 Subject: [PATCH 02/34] Implement logUp --- evm/src/all_stark.rs | 28 +- evm/src/arithmetic/arithmetic_stark.rs | 45 +- evm/src/arithmetic/columns.rs | 5 +- evm/src/cpu/kernel/asm/curve/bn254/glv.asm | 97 ++++ evm/src/cpu/kernel/asm/curve/bn254/msm.asm | 73 +++ .../kernel/asm/curve/bn254/precomputation.asm | 35 ++ evm/src/cross_table_lookup.rs | 149 +++++- evm/src/fixed_recursive_verifier.rs | 6 +- evm/src/get_challenges.rs | 66 +-- evm/src/keccak/keccak_stark.rs | 3 +- evm/src/lib.rs | 1 - evm/src/lookup.rs | 327 +++++++++---- evm/src/memory/columns.rs | 7 +- evm/src/memory/memory_stark.rs | 38 +- evm/src/permutation.rs | 459 ------------------ evm/src/proof.rs | 64 +-- evm/src/prover.rs | 158 +++--- evm/src/recursive_verifier.rs | 59 ++- evm/src/stark.rs | 85 ++-- evm/src/vanishing_poly.rs | 35 +- evm/src/verifier.rs | 43 +- 21 files changed, 839 insertions(+), 944 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/curve/bn254/glv.asm create mode 100644 evm/src/cpu/kernel/asm/curve/bn254/msm.asm create mode 100644 evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm delete mode 100644 evm/src/permutation.rs diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 068b0bcbf9..b7168f8571 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -51,27 +51,15 @@ impl, const D: usize> Default for AllStark { } impl, const D: usize> AllStark { - pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { + pub(crate) fn num_lookups_helper_columns(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { [ - self.arithmetic_stark.num_permutation_batches(config), - self.byte_packing_stark.num_permutation_batches(config), - self.cpu_stark.num_permutation_batches(config), - self.keccak_stark.num_permutation_batches(config), - self.keccak_sponge_stark.num_permutation_batches(config), - self.logic_stark.num_permutation_batches(config), - self.memory_stark.num_permutation_batches(config), - ] - } - - pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { - [ - self.arithmetic_stark.permutation_batch_size(), - self.byte_packing_stark.permutation_batch_size(), - self.cpu_stark.permutation_batch_size(), - self.keccak_stark.permutation_batch_size(), - self.keccak_sponge_stark.permutation_batch_size(), - self.logic_stark.permutation_batch_size(), - self.memory_stark.permutation_batch_size(), + self.arithmetic_stark.num_lookup_helper_columns(config), + self.byte_packing_stark.num_lookup_helper_columns(config), + self.cpu_stark.num_lookup_helper_columns(config), + self.keccak_stark.num_lookup_helper_columns(config), + self.keccak_sponge_stark.num_lookup_helper_columns(config), + self.logic_stark.num_lookup_helper_columns(config), + self.memory_stark.num_lookup_helper_columns(config), ] } } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 5441cf2760..9584ab884a 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -12,11 +12,11 @@ use plonky2::util::transpose; use static_assertions::const_assert; use crate::all_stark::Table; +use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS}; use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::{Column, TableWithColumns}; -use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; -use crate::permutation::PermutationPair; +use crate::lookup::Lookup; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -122,13 +122,12 @@ impl ArithmeticStark { cols[columns::RANGE_COUNTER][i] = F::from_canonical_usize(RANGE_MAX - 1); } - // For each column c in cols, generate the range-check - // permutations and put them in the corresponding range-check - // columns rc_c and rc_c+1. - for (c, rc_c) in columns::SHARED_COLS.zip(columns::RC_COLS.step_by(2)) { - let (col_perm, table_perm) = permuted_cols(&cols[c], &cols[columns::RANGE_COUNTER]); - cols[rc_c].copy_from_slice(&col_perm); - cols[rc_c + 1].copy_from_slice(&table_perm); + // Generate the frequencies column. + for col in SHARED_COLS { + for i in 0..n_rows { + let x = cols[col][i].to_canonical_u64() as usize; + cols[RC_FREQUENCIES][x] += F::ONE; + } } } @@ -178,11 +177,6 @@ impl, const D: usize> Stark for ArithmeticSta FE: FieldExtension, P: PackedField, { - // Range check all the columns - for col in columns::RC_COLS.step_by(2) { - eval_lookups(vars, yield_constr, col, col + 1); - } - let lv = vars.local_values; let nv = vars.next_values; @@ -210,11 +204,6 @@ impl, const D: usize> Stark for ArithmeticSta vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { - // Range check all the columns - for col in columns::RC_COLS.step_by(2) { - eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); - } - let lv = vars.local_values; let nv = vars.next_values; @@ -240,18 +229,12 @@ impl, const D: usize> Stark for ArithmeticSta 3 } - fn permutation_pairs(&self) -> Vec { - const START: usize = columns::START_SHARED_COLS; - const END: usize = START + columns::NUM_SHARED_COLS; - let mut pairs = Vec::with_capacity(2 * columns::NUM_SHARED_COLS); - for (c, c_perm) in (START..END).zip_eq(columns::RC_COLS.step_by(2)) { - pairs.push(PermutationPair::singletons(c, c_perm)); - pairs.push(PermutationPair::singletons( - c_perm + 1, - columns::RANGE_COUNTER, - )); - } - pairs + fn lookups(&self) -> Vec { + vec![Lookup { + columns: SHARED_COLS.collect(), + table_column: RANGE_COUNTER, + frequencies_column: RC_FREQUENCIES, + }] } } diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 48e00f8e11..f2646fc565 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -109,8 +109,7 @@ pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; // of the column and the permutation of the range. The two // permutations associated to column i will be in columns RC_COLS[2i] // and RC_COLS[2i+1]. -pub(crate) const NUM_RANGE_CHECK_COLS: usize = 1 + 2 * NUM_SHARED_COLS; pub(crate) const RANGE_COUNTER: usize = START_SHARED_COLS + NUM_SHARED_COLS; -pub(crate) const RC_COLS: Range = RANGE_COUNTER + 1..RANGE_COUNTER + 1 + 2 * NUM_SHARED_COLS; +pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; -pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS + NUM_RANGE_CHECK_COLS; +pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS + 2; diff --git a/evm/src/cpu/kernel/asm/curve/bn254/glv.asm b/evm/src/cpu/kernel/asm/curve/bn254/glv.asm new file mode 100644 index 0000000000..c29d8f141d --- /dev/null +++ b/evm/src/cpu/kernel/asm/curve/bn254/glv.asm @@ -0,0 +1,97 @@ +// Inspired by https://github.com/AztecProtocol/weierstrudel/blob/master/huff_modules/endomorphism.huff +// See also Sage code in evm/src/cpu/kernel/tests/ecc/bn_glv_test_data +// Given scalar `k ∈ Bn254::ScalarField`, return `u, k1, k2` with `k1,k2 < 2^127` and such that +// `k = k1 - s*k2` if `u==0` otherwise `k = k1 + s*k2`, where `s` is the scalar value representing the endomorphism. +// In the comments below, N means @BN_SCALAR +// +// Z3 proof that the resulting `k1, k2` satisfy `k1>0`, `k1 < 2^127` and `|k2| < 2^127`. +// ```python +// from z3 import Solver, Int, Or, unsat +// q = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 +// glv_s = 0xB3C4D79D41A917585BFC41088D8DAAA78B17EA66B99C90DD +// +// b2 = 0x89D3256894D213E3 +// b1 = -0x6F4D8248EEB859FC8211BBEB7D4F1128 +// +// g1 = 0x24CCEF014A773D2CF7A7BD9D4391EB18D +// g2 = 0x2D91D232EC7E0B3D7 +// k = Int("k") +// c1 = Int("c1") +// c2 = Int("c2") +// s = Solver() +// +// c2p = -c2 +// s.add(k < q) +// s.add(0 < k) +// s.add(c1 * (2**256) <= g2 * k) +// s.add((c1 + 1) * (2**256) > g2 * k) +// s.add(c2p * (2**256) <= g1 * k) +// s.add((c2p + 1) * (2**256) > g1 * k) +// +// q1 = c1 * b1 +// q2 = c2 * b2 +// +// k2 = q2 - q1 +// k2L = (glv_s * k2) % q +// k1 = k - k2L +// k2 = -k2 +// +// s.add(Or((k2 >= 2**127), (-k2 >= 2**127), (k1 >= 2**127), (k1 < 0))) +// +// assert s.check() == unsat +// ``` +global bn_glv_decompose: + // stack: k, retdest + PUSH @BN_SCALAR DUP1 DUP1 + // Compute c2 which is the top 256 bits of k*g1. Use asm from https://medium.com/wicketh/mathemagic-full-multiply-27650fec525d. + PUSH @U256_MAX + // stack: -1, N, N, N, k, retdest + PUSH @BN_GLV_MINUS_G1 DUP6 + // stack: k, g1, -1, N, N, N, k, retdest + MULMOD + // stack: (k * g1 % -1), N, N, N, k, retdest + PUSH @BN_GLV_MINUS_G1 DUP6 + // stack: k, g1, (k * g1 % -1), N, N, N, k, retdest + MUL + // stack: bottom = (k * g1), (k * g1 % -1), N, N, N, k, retdest + DUP1 DUP3 + // stack: (k * g1 % -1), bottom, bottom, (k * g1 % -1), N, N, N, k, retdest + LT SWAP2 SUB SUB + // stack: c2, N, N, N, k, retdest + PUSH @BN_GLV_B2 MULMOD + // stack: q2=c2*b2, N, N, k, retdest + + // Use the same trick to compute c1 = top 256 bits of g2*k. + PUSH @BN_SCALAR PUSH @U256_MAX + PUSH @BN_GLV_G2 DUP7 MULMOD + PUSH @BN_GLV_G2 DUP7 MUL + DUP1 DUP3 LT + SWAP2 SUB SUB + // stack: c1, N, q2, N, N, k, retdest + PUSH @BN_GLV_B1 MULMOD + // stack: q1, q2, N, N, k, retdest + + // We compute k2 = q1 + q2 - N, but we check for underflow and return N-q1-q2 instead if there is one, + // along with a flag `underflow` set to 1 if there is an underflow, 0 otherwise. + ADD %sub_check_underflow + // stack: k2, underflow, N, k, retdest + SWAP3 PUSH @BN_SCALAR DUP5 PUSH @BN_GLV_S + // stack: s, k2, N, k, underflow, N, k2, retdest + MULMOD + // stack: s*k2, k, underflow, N, k2, retdest + // Need to return `k + s*k2` if no underflow occur, otherwise return `k - s*k2` which is done in the `underflowed` fn. + SWAP2 DUP1 %jumpi(underflowed) + %stack (underflow, k, x, N, k2) -> (k, x, N, k2, underflow) + ADDMOD + %stack (k1, k2, underflow, retdest) -> (retdest, underflow, k1, k2) + JUMP + +underflowed: + // stack: underflow, k, s*k2, N, k2 + // Compute (k-s*k2)%N. TODO: Use SUBMOD here when ready + %stack (u, k, x, N, k2) -> (N, x, k, N, k2, u) + SUB ADDMOD + %stack (k1, k2, underflow, retdest) -> (retdest, underflow, k1, k2) + JUMP + + diff --git a/evm/src/cpu/kernel/asm/curve/bn254/msm.asm b/evm/src/cpu/kernel/asm/curve/bn254/msm.asm new file mode 100644 index 0000000000..1036228737 --- /dev/null +++ b/evm/src/cpu/kernel/asm/curve/bn254/msm.asm @@ -0,0 +1,73 @@ +// Computes the multiplication `a*G` using a standard MSM with the GLV decomposition of `a`. +// see there for a detailed description. +global bn_msm: + // stack: retdest + PUSH 0 PUSH 0 PUSH 0 +global bn_msm_loop: + // stack: accx, accy, i, retdest + DUP3 %bn_mload_wnaf_a + // stack: w, accx, accy, i, retdest + DUP1 %jumpi(bn_msm_loop_add_a_nonzero) + POP +msm_loop_add_b: + //stack: accx, accy, i, retdest + DUP3 %bn_mload_wnaf_b + // stack: w, accx, accy, i, retdest + DUP1 %jumpi(bn_msm_loop_add_b_nonzero) + POP +msm_loop_contd: + %stack (accx, accy, i, retdest) -> (i, i, accx, accy, retdest) + // TODO: the GLV scalars for the BN curve are 127-bit, so could use 127 here. But this would require modifying `wnaf.asm`. Not sure it's worth it... + %eq_const(129) %jumpi(msm_end) + %increment + //stack: i+1, accx, accy, retdest + %stack (i, accx, accy, retdest) -> (accx, accy, bn_msm_loop, i, retdest) + %jump(bn_double) + +msm_end: + %stack (i, accx, accy, retdest) -> (retdest, accx, accy) + JUMP + +bn_msm_loop_add_a_nonzero: + %stack (w, accx, accy, i, retdest) -> (w, accx, accy, msm_loop_add_b, i, retdest) + %bn_mload_point_a + // stack: px, py, accx, accy, msm_loop_add_b, i, retdest + %jump(bn_add_valid_points) + +bn_msm_loop_add_b_nonzero: + %stack (w, accx, accy, i, retdest) -> (w, accx, accy, msm_loop_contd, i, retdest) + %bn_mload_point_b + // stack: px, py, accx, accy, msm_loop_contd, i, retdest + %jump(bn_add_valid_points) + +%macro bn_mload_wnaf_a + // stack: i + %mload_kernel(@SEGMENT_KERNEL_BN_WNAF_A) +%endmacro + +%macro bn_mload_wnaf_b + // stack: i + %mload_kernel(@SEGMENT_KERNEL_BN_WNAF_B) +%endmacro + +%macro bn_mload_point_a + // stack: w + DUP1 + %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + //stack: Gy, w + SWAP1 %decrement %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + //stack: Gx, Gy +%endmacro + +%macro bn_mload_point_b + // stack: w + DUP1 + %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + PUSH @BN_BNEG_LOC %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + %stack (bneg, Gy, w) -> (@BN_BASE, Gy, bneg, bneg, Gy, w) + SUB SWAP1 ISZERO MUL SWAP2 MUL ADD + SWAP1 %decrement %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + //stack: Gx, Gy + PUSH @BN_GLV_BETA + MULFP254 +%endmacro diff --git a/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm b/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm new file mode 100644 index 0000000000..a8c6ada926 --- /dev/null +++ b/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm @@ -0,0 +1,35 @@ +// Precompute a table of multiples of the BN254 point `Q = (Qx, Qy)`. +// Let `(Qxi, Qyi) = i * Q`, then store in the `SEGMENT_KERNEL_BN_TABLE_Q` segment of memory the values +// `i-1 => Qxi`, `i => Qyi if i < 16 else -Qy(32-i)` for `i in range(1, 32, 2)`. +global bn_precompute_table: + // stack: Qx, Qy, retdest + PUSH precompute_table_contd DUP3 DUP3 + %jump(bn_double) +precompute_table_contd: + // stack: Qx2, Qy2, Qx, Qy, retdest + PUSH 1 +bn_precompute_table_loop: + // stack i, Qx2, Qy2, Qx, Qy, retdest + PUSH 1 DUP2 SUB + %stack (im, i, Qx2, Qy2, Qx, Qy, retdest) -> (i, Qy, im, Qx, i, Qx2, Qy2, Qx, Qy, retdest) + %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + // stack: i, Qx2, Qy2, Qx, Qy, retdest + DUP1 PUSH 32 SUB PUSH 1 DUP2 SUB + // stack: 31-i, 32-i, i, Qx2, Qy2, Qx, Qy, retdest + DUP7 PUSH @BN_BASE SUB + // TODO: Could maybe avoid storing Qx a second time here, not sure if it would be more efficient. + %stack (Qyy, iii, ii, i, Qx2, Qy2, Qx, Qy, retdest) -> (iii, Qx, ii, Qyy, i, Qx2, Qy2, Qx, Qy, retdest) + %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) + // stack: i, Qx2, Qy2, Qx, Qy, retdest + PUSH 2 ADD + // stack: i+2, Qx2, Qy2, Qx, Qy, retdest + DUP1 PUSH 16 LT %jumpi(precompute_table_end) + %stack (i, Qx2, Qy2, Qx, Qy, retdest) -> (Qx, Qy, Qx2, Qy2, precompute_table_loop_contd, i, Qx2, Qy2, retdest) + %jump(bn_add_valid_points) +precompute_table_loop_contd: + %stack (Qx, Qy, i, Qx2, Qy2, retdest) -> (i, Qx2, Qy2, Qx, Qy, retdest) + %jump(bn_precompute_table_loop) + +precompute_table_end: + // stack: i, Qx2, Qy2, Qx, Qy, retdest + %pop5 JUMP diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index a9b90428ca..dfac49f7cf 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::fmt::Debug; use std::iter::repeat; use anyhow::{ensure, Result}; @@ -8,15 +9,19 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::plonk_common::{ + reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit, +}; +use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -217,6 +222,128 @@ impl CtlData { } } +/// Randomness for a single instance of a permutation check protocol. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub(crate) struct GrandProductChallenge { + /// Randomness used to combine multiple columns into one. + pub(crate) beta: T, + /// Random offset that's added to the beta-reduced column values. + pub(crate) gamma: T, +} + +impl GrandProductChallenge { + pub(crate) fn combine<'a, FE, P, T: IntoIterator, const D2: usize>( + &self, + terms: T, + ) -> P + where + FE: FieldExtension, + P: PackedField, + T::IntoIter: DoubleEndedIterator, + { + reduce_with_powers(terms, FE::from_basefield(self.beta)) + FE::from_basefield(self.gamma) + } +} + +impl GrandProductChallenge { + pub(crate) fn combine_circuit, const D: usize>( + &self, + builder: &mut CircuitBuilder, + terms: &[ExtensionTarget], + ) -> ExtensionTarget { + let reduced = reduce_with_powers_ext_circuit(builder, terms, self.beta); + let gamma = builder.convert_to_ext(self.gamma); + builder.add_extension(reduced, gamma) + } +} + +impl GrandProductChallenge { + pub(crate) fn combine_base_circuit, const D: usize>( + &self, + builder: &mut CircuitBuilder, + terms: &[Target], + ) -> Target { + let reduced = reduce_with_powers_circuit(builder, terms, self.beta); + builder.add(reduced, self.gamma) + } +} + +/// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. +#[derive(Clone, Eq, PartialEq, Debug)] +pub(crate) struct GrandProductChallengeSet { + pub(crate) challenges: Vec>, +} + +impl GrandProductChallengeSet { + pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { + buffer.write_usize(self.challenges.len())?; + for challenge in &self.challenges { + buffer.write_target(challenge.beta)?; + buffer.write_target(challenge.gamma)?; + } + Ok(()) + } + + pub fn from_buffer(buffer: &mut Buffer) -> IoResult { + let length = buffer.read_usize()?; + let mut challenges = Vec::with_capacity(length); + for _ in 0..length { + challenges.push(GrandProductChallenge { + beta: buffer.read_target()?, + gamma: buffer.read_target()?, + }); + } + + Ok(GrandProductChallengeSet { challenges }) + } +} + +fn get_grand_product_challenge>( + challenger: &mut Challenger, +) -> GrandProductChallenge { + let beta = challenger.get_challenge(); + let gamma = challenger.get_challenge(); + GrandProductChallenge { beta, gamma } +} + +pub(crate) fn get_grand_product_challenge_set>( + challenger: &mut Challenger, + num_challenges: usize, +) -> GrandProductChallengeSet { + let challenges = (0..num_challenges) + .map(|_| get_grand_product_challenge(challenger)) + .collect(); + GrandProductChallengeSet { challenges } +} + +fn get_grand_product_challenge_target< + F: RichField + Extendable, + H: AlgebraicHasher, + const D: usize, +>( + builder: &mut CircuitBuilder, + challenger: &mut RecursiveChallenger, +) -> GrandProductChallenge { + let beta = challenger.get_challenge(builder); + let gamma = challenger.get_challenge(builder); + GrandProductChallenge { beta, gamma } +} + +pub(crate) fn get_grand_product_challenge_set_target< + F: RichField + Extendable, + H: AlgebraicHasher, + const D: usize, +>( + builder: &mut CircuitBuilder, + challenger: &mut RecursiveChallenger, + num_challenges: usize, +) -> GrandProductChallengeSet { + let challenges = (0..num_challenges) + .map(|_| get_grand_product_challenge_target(builder, challenger)) + .collect(); + GrandProductChallengeSet { challenges } +} + pub(crate) fn cross_table_lookup_data( trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], @@ -317,15 +444,15 @@ impl<'a, F: RichField + Extendable, const D: usize> proofs: &[StarkProofWithMetadata; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize; NUM_TABLES], + num_lookup_columns: &[usize; NUM_TABLES], ) -> [Vec; NUM_TABLES] { let mut ctl_zs = proofs .iter() - .zip(num_permutation_zs) - .map(|(p, &num_perms)| { + .zip(num_lookup_columns) + .map(|(p, &num_lookup)| { let openings = &p.proof.openings; - let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); - let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); + let ctl_zs = openings.auxiliary_polys.iter().skip(num_lookup); + let ctl_zs_next = openings.auxiliary_polys_next.iter().skip(num_lookup); ctl_zs.zip(ctl_zs_next) }) .collect::>(); @@ -419,15 +546,15 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { proof: &StarkProofTarget, cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: usize, + num_lookup_columns: usize, ) -> Vec { let mut ctl_zs = { let openings = &proof.openings; - let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_permutation_zs); + let ctl_zs = openings.auxiliary_polys.iter().skip(num_lookup_columns); let ctl_zs_next = openings - .permutation_ctl_zs_next + .auxiliary_polys_next .iter() - .skip(num_permutation_zs); + .skip(num_lookup_columns); ctl_zs.zip(ctl_zs_next) }; diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 5c0f1f8084..33021c3952 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -32,14 +32,16 @@ use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::cpu::cpu_stark::CpuStark; -use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; +use crate::cross_table_lookup::{ + get_grand_product_challenge_set_target, verify_cross_table_lookups_circuit, CrossTableLookup, + GrandProductChallengeSet, +}; use crate::generation::GenerationInputs; use crate::get_challenges::observe_public_values_target; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; -use crate::permutation::{get_grand_product_challenge_set_target, GrandProductChallengeSet}; use crate::proof::{ BlockHashesTarget, BlockMetadataTarget, ExtraBlockDataTarget, PublicValues, PublicValuesTarget, StarkProofWithMetadata, TrieRootsTarget, diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 59be8439d8..a12aaa9548 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -8,10 +8,7 @@ use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; -use crate::permutation::{ - get_grand_product_challenge_set, get_n_grand_product_challenge_sets, - get_n_grand_product_challenge_sets_target, -}; +use crate::cross_table_lookup::get_grand_product_challenge_set; use crate::proof::*; use crate::util::{h256_limbs, u256_limbs}; @@ -224,18 +221,14 @@ impl, C: GenericConfig, const D: usize> A let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); - let num_permutation_zs = all_stark.nums_permutation_zs(config); - let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + let lookups = all_stark.num_lookups_helper_columns(config); AllProofChallenges { stark_challenges: core::array::from_fn(|i| { challenger.compact(); - self.stark_proofs[i].proof.get_challenges( - &mut challenger, - num_permutation_zs[i] > 0, - num_permutation_batch_sizes[i], - config, - ) + self.stark_proofs[i] + .proof + .get_challenges(&mut challenger, lookups[i] > 0, config) }), ctl_challenges, } @@ -258,17 +251,13 @@ impl, C: GenericConfig, const D: usize> A let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); - let num_permutation_zs = all_stark.nums_permutation_zs(config); - let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + let lookups = all_stark.num_lookups_helper_columns(config); let mut challenger_states = vec![challenger.compact()]; for i in 0..NUM_TABLES { - self.stark_proofs[i].proof.get_challenges( - &mut challenger, - num_permutation_zs[i] > 0, - num_permutation_batch_sizes[i], - config, - ); + self.stark_proofs[i] + .proof + .get_challenges(&mut challenger, lookups[i] > 0, config); challenger_states.push(challenger.compact()); } @@ -288,14 +277,13 @@ where pub(crate) fn get_challenges( &self, challenger: &mut Challenger, - stark_use_permutation: bool, - stark_permutation_batch_size: usize, + stark_use_lookup: bool, config: &StarkConfig, ) -> StarkProofChallenges { let degree_bits = self.recover_degree_bits(config); let StarkProof { - permutation_ctl_zs_cap, + auxiliary_polys_cap, quotient_polys_cap, openings, opening_proof: @@ -310,15 +298,10 @@ where let num_challenges = config.num_challenges; - let permutation_challenge_sets = stark_use_permutation.then(|| { - get_n_grand_product_challenge_sets( - challenger, - num_challenges, - stark_permutation_batch_size, - ) - }); + let lookup_challenges = + stark_use_lookup.then(|| challenger.get_n_challenges(config.num_challenges)); - challenger.observe_cap(permutation_ctl_zs_cap); + challenger.observe_cap(auxiliary_polys_cap); let stark_alphas = challenger.get_n_challenges(num_challenges); @@ -328,7 +311,7 @@ where challenger.observe_openings(&openings.to_fri_openings()); StarkProofChallenges { - permutation_challenge_sets, + lookup_challenges, stark_alphas, stark_zeta, fri_challenges: challenger.fri_challenges::( @@ -347,15 +330,14 @@ impl StarkProofTarget { &self, builder: &mut CircuitBuilder, challenger: &mut RecursiveChallenger, - stark_use_permutation: bool, - stark_permutation_batch_size: usize, + stark_use_lookup: bool, config: &StarkConfig, ) -> StarkProofChallengesTarget where C::Hasher: AlgebraicHasher, { let StarkProofTarget { - permutation_ctl_zs_cap, + auxiliary_polys_cap: auxiliary_polys, quotient_polys_cap, openings, opening_proof: @@ -370,16 +352,10 @@ impl StarkProofTarget { let num_challenges = config.num_challenges; - let permutation_challenge_sets = stark_use_permutation.then(|| { - get_n_grand_product_challenge_sets_target( - builder, - challenger, - num_challenges, - stark_permutation_batch_size, - ) - }); + let lookup_challenges = + stark_use_lookup.then(|| challenger.get_n_challenges(builder, num_challenges)); - challenger.observe_cap(permutation_ctl_zs_cap); + challenger.observe_cap(auxiliary_polys); let stark_alphas = challenger.get_n_challenges(builder, num_challenges); @@ -389,7 +365,7 @@ impl StarkProofTarget { challenger.observe_openings(&openings.to_fri_openings(builder.zero())); StarkProofChallengesTarget { - permutation_challenge_sets, + lookup_challenges, stark_alphas, stark_zeta, fri_challenges: challenger.fri_challenges( diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 74f92622fc..b4ff4b84ed 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -650,10 +650,9 @@ mod tests { use tiny_keccak::keccakf; use crate::config::StarkConfig; - use crate::cross_table_lookup::{CtlData, CtlZData}; + use crate::cross_table_lookup::{CtlData, CtlZData, GrandProductChallenge}; use crate::keccak::columns::reg_output_limb; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; - use crate::permutation::GrandProductChallenge; use crate::prover::prove_single_table; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; diff --git a/evm/src/lib.rs b/evm/src/lib.rs index ab48cda04f..93e29c7d31 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -23,7 +23,6 @@ pub mod keccak_sponge; pub mod logic; pub mod lookup; pub mod memory; -pub mod permutation; pub mod proof; pub mod prover; pub mod recursive_verifier; diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index d7e12bacf1..ad872a799e 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -1,127 +1,248 @@ -use std::cmp::Ordering; - use itertools::Itertools; -use plonky2::field::extension::Extendable; +use plonky2::field::batch_util::batch_add_inplace; +use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; -use plonky2::field::types::{Field, PrimeField64}; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; -pub(crate) fn eval_lookups, const COLS: usize>( - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - col_permuted_input: usize, - col_permuted_table: usize, -) { - let local_perm_input = vars.local_values[col_permuted_input]; - let next_perm_table = vars.next_values[col_permuted_table]; - let next_perm_input = vars.next_values[col_permuted_input]; - - // A "vertical" diff between the local and next permuted inputs. - let diff_input_prev = next_perm_input - local_perm_input; - // A "horizontal" diff between the next permuted input and permuted table value. - let diff_input_table = next_perm_input - next_perm_table; - - yield_constr.constraint(diff_input_prev * diff_input_table); - - // This is actually constraining the first row, as per the spec, since `diff_input_table` - // is a diff of the next row's values. In the context of `constraint_last_row`, the next - // row is the first row. - yield_constr.constraint_last_row(diff_input_table); +pub struct Lookup { + /// Columns whose values should be contained in the lookup table. + /// These are the f_i(x) polynomials in the logUp paper. + pub(crate) columns: Vec, + /// Column containing the lookup table. + /// This is the t(x) polynomial in the paper. + pub(crate) table_column: usize, + /// Column containing the frequencies of `columns` in `table_column`. + /// This is the m(x) polynomial in the paper. + pub(crate) frequencies_column: usize, } -pub(crate) fn eval_lookups_circuit< - F: RichField + Extendable, - const D: usize, - const COLS: usize, ->( - builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - col_permuted_input: usize, - col_permuted_table: usize, -) { - let local_perm_input = vars.local_values[col_permuted_input]; - let next_perm_table = vars.next_values[col_permuted_table]; - let next_perm_input = vars.next_values[col_permuted_input]; - - // A "vertical" diff between the local and next permuted inputs. - let diff_input_prev = builder.sub_extension(next_perm_input, local_perm_input); - // A "horizontal" diff between the next permuted input and permuted table value. - let diff_input_table = builder.sub_extension(next_perm_input, next_perm_table); - - let diff_product = builder.mul_extension(diff_input_prev, diff_input_table); - yield_constr.constraint(builder, diff_product); - - // This is actually constraining the first row, as per the spec, since `diff_input_table` - // is a diff of the next row's values. In the context of `constraint_last_row`, the next - // row is the first row. - yield_constr.constraint_last_row(builder, diff_input_table); +impl Lookup { + pub(crate) fn num_helper_columns(&self, constraint_degree: usize) -> usize { + // One helper column for each column batch of size `constraint_degree-1`, + // then one column for the inverse of `table + challenge` and one for the `Z` polynomial. + ceil_div_usize(self.columns.len(), constraint_degree - 1) + 2 + } } -/// Given an input column and a table column, generate the permuted input and permuted table columns -/// used in the Halo2 permutation argument. -pub fn permuted_cols(inputs: &[F], table: &[F]) -> (Vec, Vec) { - let n = inputs.len(); +/// logUp protocol from https://ia.cr/2022/1530 (TODO link to newer version?) +/// Compute the helper columns for the lookup argument. +/// Given columns `f0,...,fk` and a column `t`, such that `∪fi ⊆ t`, and challenges `x`, +/// this computes the helper columns `h_i = 1/(x+f_2i) + 1/(x+f_2i+1)`, `g = 1/(x+t)`, +/// and `Z(gx) = Z(x) + sum h_i(x) - m(x)g(x)` where `m` is the frequencies column. +pub(crate) fn lookup_helper_columns( + lookup: &Lookup, + trace_poly_values: &[PolynomialValues], + challenge: F, + constraint_degree: usize, +) -> Vec> { + assert_eq!( + constraint_degree, 3, + "TODO: Allow other constraint degrees." + ); + let num_helper_columns = lookup.num_helper_columns(constraint_degree); + let mut helper_columns: Vec> = Vec::with_capacity(num_helper_columns); + + // For each batch of `constraint_degree-1` columns `fi`, compute `sum 1/(f_i+challenge)` and + // add it to the helper columns. + // TODO: This does one batch inversion per column. It would also be possible to do one batch inversion + // for every column, but that would require building a big vector of all the columns concatenated. + // Not sure which approach is better. + // Note: these are the h_k(x) polynomials in the paper, with a few differences: + // * Here, the first ratio m_0(x)/phi_0(x) is not included with the columns batched up to create the + // h_k polynomials; instead there's a separate helper column for it (see below). + // * Here, we use 1 instead of -1 as the numerator (and subtract later). + // * Here, for now, the batch size (l) is always constraint_degree - 1 = 2. + for mut col_inds in &lookup.columns.iter().chunks(constraint_degree - 1) { + let first = *col_inds.next().unwrap(); + // TODO: The clone could probably be avoided by using a modified version of `batch_multiplicative_inverse` + // taking `challenge` as an additional argument. + let mut column = trace_poly_values[first].values.clone(); + for x in column.iter_mut() { + *x = challenge + *x; + } + let mut acc = F::batch_multiplicative_inverse(&column); + for &ind in col_inds { + let mut column = trace_poly_values[ind].values.clone(); + for x in column.iter_mut() { + *x = challenge + *x; + } + column = F::batch_multiplicative_inverse(&column); + batch_add_inplace(&mut acc, &column); + } + helper_columns.push(acc.into()); + } - // The permuted inputs do not have to be ordered, but we found that sorting was faster than - // hash-based grouping. We also sort the table, as this helps us identify "unused" table - // elements efficiently. + // Add `1/(table+challenge)` to the helper columns. + // This is 1/phi_0(x) = 1/(x + t(x)) from the paper. + // Here, we don't include m(x) in the numerator, instead multiplying it with this column later. + let mut table = trace_poly_values[lookup.table_column].values.clone(); + for x in table.iter_mut() { + *x = challenge + *x; + } + helper_columns.push(F::batch_multiplicative_inverse(&table).into()); + + // Compute the `Z` polynomial with `Z(1)=0` and `Z(gx) = Z(x) + sum h_i(x) - frequencies(x)g(x)`. + // This enforces the check from the paper, that the sum of the h_k(x) polynomials is 0 over H. + // In the paper, that sum includes m(x)/(x + t(x)) = frequencies(x)/g(x), because that was bundled + // into the h_k(x) polynomials. + let frequencies = &trace_poly_values[lookup.frequencies_column].values; + let mut z = Vec::with_capacity(frequencies.len()); + z.push(F::ZERO); + for i in 0..frequencies.len() - 1 { + let x = helper_columns[..num_helper_columns - 2] + .iter() + .map(|col| col.values[i]) + .sum::() + - frequencies[i] * helper_columns[num_helper_columns - 2].values[i]; + z.push(z[i] + x); + } + helper_columns.push(z.into()); - // To compare elements, e.g. for sorting, we first need them in canonical form. It would be - // wasteful to canonicalize in each comparison, as a single element may be involved in many - // comparisons. So we will canonicalize once upfront, then use `to_noncanonical_u64` when - // comparing elements. + helper_columns +} - let sorted_inputs = inputs - .iter() - .map(|x| x.to_canonical()) - .sorted_unstable_by_key(|x| x.to_noncanonical_u64()) - .collect_vec(); - let sorted_table = table - .iter() - .map(|x| x.to_canonical()) - .sorted_unstable_by_key(|x| x.to_noncanonical_u64()) - .collect_vec(); +pub struct LookupCheckVars +where + F: Field, + FE: FieldExtension, + P: PackedField, +{ + pub(crate) local_values: Vec

, + pub(crate) next_values: Vec

, + pub(crate) challenges: Vec, +} - let mut unused_table_inds = Vec::with_capacity(n); - let mut unused_table_vals = Vec::with_capacity(n); - let mut permuted_table = vec![F::ZERO; n]; - let mut i = 0; - let mut j = 0; - while (j < n) && (i < n) { - let input_val = sorted_inputs[i].to_noncanonical_u64(); - let table_val = sorted_table[j].to_noncanonical_u64(); - match input_val.cmp(&table_val) { - Ordering::Greater => { - unused_table_vals.push(sorted_table[j]); - j += 1; - } - Ordering::Less => { - if let Some(x) = unused_table_vals.pop() { - permuted_table[i] = x; - } else { - unused_table_inds.push(i); +/// Constraints for the logUp lookup argument. +pub(crate) fn eval_lookups_checks( + stark: &S, + lookups: &[Lookup], + vars: StarkEvaluationVars, + lookup_vars: LookupCheckVars, + yield_constr: &mut ConstraintConsumer

, +) where + F: RichField + Extendable, + FE: FieldExtension, + P: PackedField, + S: Stark, +{ + let degree = stark.constraint_degree(); + assert_eq!(degree, 3, "TODO: Allow other constraint degrees."); + let mut start = 0; + for lookup in lookups { + let num_helper_columns = lookup.num_helper_columns(degree); + for &challenge in &lookup_vars.challenges { + let challenge = FE::from_basefield(challenge); + // For each chunk, check that `h_i (x+f_2i) (x+f_2i+1) = (x+f_2i) + (x+f_2i+1)` if the chunk has length 2 + // or if it has length 1, check that `h_i * (x+f_2i) = 1`, where x is the challenge + for (j, chunk) in lookup.columns.chunks(degree - 1).enumerate() { + let mut x = lookup_vars.local_values[start + j]; + let mut y = P::ZEROS; + let fs = chunk.iter().map(|&k| vars.local_values[k]); + for f in fs { + x *= f + challenge; + y += f + challenge; + } + match chunk.len() { + 2 => yield_constr.constraint(x - y), + 1 => yield_constr.constraint(x - P::ONES), + _ => todo!("Allow other constraint degrees."), } - i += 1; - } - Ordering::Equal => { - permuted_table[i] = sorted_table[j]; - i += 1; - j += 1; } + // Check that the penultimate helper column contains `1/(table+challenge)`. + let x = lookup_vars.local_values[start + num_helper_columns - 2]; + let x = x * (vars.local_values[lookup.table_column] + challenge); + yield_constr.constraint(x - P::ONES); + + // Check the `Z` polynomial. + let z = lookup_vars.local_values[start + num_helper_columns - 1]; + let next_z = lookup_vars.next_values[start + num_helper_columns - 1]; + let y = lookup_vars.local_values[start..start + num_helper_columns - 2] + .iter() + .fold(P::ZEROS, |acc, x| acc + *x) + - vars.local_values[lookup.frequencies_column] + * lookup_vars.local_values[start + num_helper_columns - 2]; + yield_constr.constraint(next_z - z - y); + start += num_helper_columns; } } +} - unused_table_vals.extend_from_slice(&sorted_table[j..n]); - unused_table_inds.extend(i..n); +pub struct LookupCheckVarsTarget { + pub(crate) local_values: Vec>, + pub(crate) next_values: Vec>, + pub(crate) challenges: Vec, +} - for (ind, val) in unused_table_inds.into_iter().zip_eq(unused_table_vals) { - permuted_table[ind] = val; +pub(crate) fn eval_lookups_checks_circuit< + F: RichField + Extendable, + S: Stark, + const D: usize, +>( + builder: &mut CircuitBuilder, + stark: &S, + vars: StarkEvaluationTargets, + lookup_vars: LookupCheckVarsTarget, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one = builder.one_extension(); + let degree = stark.constraint_degree(); + let lookups = stark.lookups(); + assert_eq!(degree, 3, "TODO: Allow other constraint degrees."); + let mut start = 0; + for lookup in lookups { + let num_helper_columns = lookup.num_helper_columns(degree); + for &challenge in &lookup_vars.challenges { + let challenge = builder.convert_to_ext(challenge); + for (j, chunk) in lookup.columns.chunks(degree - 1).enumerate() { + let mut x = lookup_vars.local_values[start + j]; + let mut y = builder.zero_extension(); + let fs = chunk.iter().map(|&k| vars.local_values[k]); + for f in fs { + let tmp = builder.add_extension(f, challenge); + x = builder.mul_extension(x, tmp); + y = builder.add_extension(y, tmp); + } + match chunk.len() { + 2 => { + let tmp = builder.sub_extension(x, y); + yield_constr.constraint(builder, tmp) + } + 1 => { + let tmp = builder.sub_extension(x, one); + yield_constr.constraint(builder, tmp) + } + _ => todo!("Allow other constraint degrees."), + } + } + let x = lookup_vars.local_values[start + num_helper_columns - 2]; + let tmp = builder.add_extension(vars.local_values[lookup.table_column], challenge); + let x = builder.mul_sub_extension(x, tmp, one); + yield_constr.constraint(builder, x); + + let z = lookup_vars.local_values[start + num_helper_columns - 1]; + let next_z = lookup_vars.next_values[start + num_helper_columns - 1]; + let y = builder.add_many_extension( + &lookup_vars.local_values[start..start + num_helper_columns - 2], + ); + let tmp = builder.mul_extension( + vars.local_values[lookup.frequencies_column], + lookup_vars.local_values[start + num_helper_columns - 2], + ); + let y = builder.sub_extension(y, tmp); + let constraint = builder.sub_extension(next_z, z); + let constraint = builder.sub_extension(constraint, y); + yield_constr.constraint(builder, constraint); + start += num_helper_columns; + } } - - (sorted_inputs, permuted_table) } diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index c972dd0a07..56b121e1e2 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -31,8 +31,7 @@ pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1; pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + 1; // The counter column (used for the range check) starts from 0 and increments. pub(crate) const COUNTER: usize = RANGE_CHECK + 1; -// Helper columns for the permutation argument used to enforce the range check. -pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; -pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; +// The frequencies column used in logUp. +pub(crate) const FREQUENCIES: usize = COUNTER + 1; -pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; +pub(crate) const NUM_COLUMNS: usize = FREQUENCIES + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 36f7566543..1935af550d 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -14,14 +14,13 @@ use plonky2_maybe_rayon::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; -use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; +use crate::lookup::Lookup; use crate::memory::columns::{ - value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, - COUNTER_PERMUTED, FILTER, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, - SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, + value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, FILTER, + FREQUENCIES, IS_READ, NUM_COLUMNS, RANGE_CHECK, SEGMENT_FIRST_CHANGE, TIMESTAMP, + VIRTUAL_FIRST_CHANGE, }; use crate::memory::VALUE_LIMBS; -use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryOpKind::Read; @@ -144,10 +143,10 @@ impl, const D: usize> MemoryStark { let height = trace_col_vecs[0].len(); trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect(); - let (permuted_inputs, permuted_table) = - permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]); - trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs; - trace_col_vecs[COUNTER_PERMUTED] = permuted_table; + for i in 0..height { + let x = trace_col_vecs[RANGE_CHECK][i].to_canonical_u64() as usize; + trace_col_vecs[FREQUENCIES][x] += F::ONE; + } } /// This memory STARK orders rows by `(context, segment, virt, timestamp)`. To enforce the @@ -316,8 +315,6 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark usize { 3 } - fn permutation_pairs(&self) -> Vec { - vec![ - PermutationPair::singletons(RANGE_CHECK, RANGE_CHECK_PERMUTED), - PermutationPair::singletons(COUNTER, COUNTER_PERMUTED), - ] + fn lookups(&self) -> Vec { + vec![Lookup { + columns: vec![RANGE_CHECK], + table_column: COUNTER, + frequencies_column: FREQUENCIES, + }] } } diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs deleted file mode 100644 index 6ce9c9435b..0000000000 --- a/evm/src/permutation.rs +++ /dev/null @@ -1,459 +0,0 @@ -//! Permutation arguments. - -use std::fmt::Debug; - -use itertools::Itertools; -use plonky2::field::batch_util::batch_multiply_inplace; -use plonky2::field::extension::{Extendable, FieldExtension}; -use plonky2::field::packed::PackedField; -use plonky2::field::polynomial::PolynomialValues; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::target::Target; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::{AlgebraicHasher, Hasher}; -use plonky2::plonk::plonk_common::{ - reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit, -}; -use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget}; -use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; -use plonky2_maybe_rayon::*; - -use crate::config::StarkConfig; -use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; - -/// A pair of lists of columns, `lhs` and `rhs`, that should be permutations of one another. -/// In particular, there should exist some permutation `pi` such that for any `i`, -/// `trace[lhs[i]] = pi(trace[rhs[i]])`. Here `trace` denotes the trace in column-major form, so -/// `trace[col]` is a column vector. -pub struct PermutationPair { - /// Each entry contains two column indices, representing two columns which should be - /// permutations of one another. - pub column_pairs: Vec<(usize, usize)>, -} - -impl PermutationPair { - pub fn singletons(lhs: usize, rhs: usize) -> Self { - Self { - column_pairs: vec![(lhs, rhs)], - } - } -} - -/// A single instance of a permutation check protocol. -pub(crate) struct PermutationInstance<'a, T: Copy + Eq + PartialEq + Debug> { - pub(crate) pair: &'a PermutationPair, - pub(crate) challenge: GrandProductChallenge, -} - -/// Randomness for a single instance of a permutation check protocol. -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -pub(crate) struct GrandProductChallenge { - /// Randomness used to combine multiple columns into one. - pub(crate) beta: T, - /// Random offset that's added to the beta-reduced column values. - pub(crate) gamma: T, -} - -impl GrandProductChallenge { - pub(crate) fn combine<'a, FE, P, T: IntoIterator, const D2: usize>( - &self, - terms: T, - ) -> P - where - FE: FieldExtension, - P: PackedField, - T::IntoIter: DoubleEndedIterator, - { - reduce_with_powers(terms, FE::from_basefield(self.beta)) + FE::from_basefield(self.gamma) - } -} - -impl GrandProductChallenge { - pub(crate) fn combine_circuit, const D: usize>( - &self, - builder: &mut CircuitBuilder, - terms: &[ExtensionTarget], - ) -> ExtensionTarget { - let reduced = reduce_with_powers_ext_circuit(builder, terms, self.beta); - let gamma = builder.convert_to_ext(self.gamma); - builder.add_extension(reduced, gamma) - } -} - -impl GrandProductChallenge { - pub(crate) fn combine_base_circuit, const D: usize>( - &self, - builder: &mut CircuitBuilder, - terms: &[Target], - ) -> Target { - let reduced = reduce_with_powers_circuit(builder, terms, self.beta); - builder.add(reduced, self.gamma) - } -} - -/// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. -#[derive(Clone, Eq, PartialEq, Debug)] -pub(crate) struct GrandProductChallengeSet { - pub(crate) challenges: Vec>, -} - -impl GrandProductChallengeSet { - pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { - buffer.write_usize(self.challenges.len())?; - for challenge in &self.challenges { - buffer.write_target(challenge.beta)?; - buffer.write_target(challenge.gamma)?; - } - Ok(()) - } - - pub fn from_buffer(buffer: &mut Buffer) -> IoResult { - let length = buffer.read_usize()?; - let mut challenges = Vec::with_capacity(length); - for _ in 0..length { - challenges.push(GrandProductChallenge { - beta: buffer.read_target()?, - gamma: buffer.read_target()?, - }); - } - - Ok(GrandProductChallengeSet { challenges }) - } -} - -/// Compute all Z polynomials (for permutation arguments). -pub(crate) fn compute_permutation_z_polys( - stark: &S, - config: &StarkConfig, - trace_poly_values: &[PolynomialValues], - permutation_challenge_sets: &[GrandProductChallengeSet], -) -> Vec> -where - F: RichField + Extendable, - S: Stark, -{ - let permutation_pairs = stark.permutation_pairs(); - let permutation_batches = get_permutation_batches( - &permutation_pairs, - permutation_challenge_sets, - config.num_challenges, - stark.permutation_batch_size(), - ); - - permutation_batches - .into_par_iter() - .map(|instances| compute_permutation_z_poly(&instances, trace_poly_values)) - .collect() -} - -/// Compute a single Z polynomial. -fn compute_permutation_z_poly( - instances: &[PermutationInstance], - trace_poly_values: &[PolynomialValues], -) -> PolynomialValues { - let degree = trace_poly_values[0].len(); - let (reduced_lhs_polys, reduced_rhs_polys): (Vec<_>, Vec<_>) = instances - .iter() - .map(|instance| permutation_reduced_polys(instance, trace_poly_values, degree)) - .unzip(); - - let numerator = poly_product_elementwise(reduced_lhs_polys.into_iter()); - let denominator = poly_product_elementwise(reduced_rhs_polys.into_iter()); - - // Compute the quotients. - let denominator_inverses = F::batch_multiplicative_inverse(&denominator.values); - let mut quotients = numerator.values; - batch_multiply_inplace(&mut quotients, &denominator_inverses); - - // Compute Z, which contains partial products of the quotients. - let mut partial_products = Vec::with_capacity(degree); - let mut acc = F::ONE; - for q in quotients { - partial_products.push(acc); - acc *= q; - } - PolynomialValues::new(partial_products) -} - -/// Computes the reduced polynomial, `\sum beta^i f_i(x) + gamma`, for both the "left" and "right" -/// sides of a given `PermutationPair`. -fn permutation_reduced_polys( - instance: &PermutationInstance, - trace_poly_values: &[PolynomialValues], - degree: usize, -) -> (PolynomialValues, PolynomialValues) { - let PermutationInstance { - pair: PermutationPair { column_pairs }, - challenge: GrandProductChallenge { beta, gamma }, - } = instance; - - let mut reduced_lhs = PolynomialValues::constant(*gamma, degree); - let mut reduced_rhs = PolynomialValues::constant(*gamma, degree); - for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) { - reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight); - reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight); - } - (reduced_lhs, reduced_rhs) -} - -/// Computes the elementwise product of a set of polynomials. Assumes that the set is non-empty and -/// that each polynomial has the same length. -fn poly_product_elementwise( - mut polys: impl Iterator>, -) -> PolynomialValues { - let mut product = polys.next().expect("Expected at least one polynomial"); - for poly in polys { - batch_multiply_inplace(&mut product.values, &poly.values) - } - product -} - -fn get_grand_product_challenge>( - challenger: &mut Challenger, -) -> GrandProductChallenge { - let beta = challenger.get_challenge(); - let gamma = challenger.get_challenge(); - GrandProductChallenge { beta, gamma } -} - -pub(crate) fn get_grand_product_challenge_set>( - challenger: &mut Challenger, - num_challenges: usize, -) -> GrandProductChallengeSet { - let challenges = (0..num_challenges) - .map(|_| get_grand_product_challenge(challenger)) - .collect(); - GrandProductChallengeSet { challenges } -} - -pub(crate) fn get_n_grand_product_challenge_sets>( - challenger: &mut Challenger, - num_challenges: usize, - num_sets: usize, -) -> Vec> { - (0..num_sets) - .map(|_| get_grand_product_challenge_set(challenger, num_challenges)) - .collect() -} - -fn get_grand_product_challenge_target< - F: RichField + Extendable, - H: AlgebraicHasher, - const D: usize, ->( - builder: &mut CircuitBuilder, - challenger: &mut RecursiveChallenger, -) -> GrandProductChallenge { - let beta = challenger.get_challenge(builder); - let gamma = challenger.get_challenge(builder); - GrandProductChallenge { beta, gamma } -} - -pub(crate) fn get_grand_product_challenge_set_target< - F: RichField + Extendable, - H: AlgebraicHasher, - const D: usize, ->( - builder: &mut CircuitBuilder, - challenger: &mut RecursiveChallenger, - num_challenges: usize, -) -> GrandProductChallengeSet { - let challenges = (0..num_challenges) - .map(|_| get_grand_product_challenge_target(builder, challenger)) - .collect(); - GrandProductChallengeSet { challenges } -} - -pub(crate) fn get_n_grand_product_challenge_sets_target< - F: RichField + Extendable, - H: AlgebraicHasher, - const D: usize, ->( - builder: &mut CircuitBuilder, - challenger: &mut RecursiveChallenger, - num_challenges: usize, - num_sets: usize, -) -> Vec> { - (0..num_sets) - .map(|_| get_grand_product_challenge_set_target(builder, challenger, num_challenges)) - .collect() -} - -/// Get a list of instances of our batch-permutation argument. These are permutation arguments -/// where the same `Z(x)` polynomial is used to check more than one permutation. -/// Before batching, each permutation pair leads to `num_challenges` permutation arguments, so we -/// start with the cartesian product of `permutation_pairs` and `0..num_challenges`. Then we -/// chunk these arguments based on our batch size. -pub(crate) fn get_permutation_batches<'a, T: Copy + Eq + PartialEq + Debug>( - permutation_pairs: &'a [PermutationPair], - permutation_challenge_sets: &[GrandProductChallengeSet], - num_challenges: usize, - batch_size: usize, -) -> Vec>> { - permutation_pairs - .iter() - .cartesian_product(0..num_challenges) - .chunks(batch_size) - .into_iter() - .map(|batch| { - batch - .enumerate() - .map(|(i, (pair, chal))| { - let challenge = permutation_challenge_sets[i].challenges[chal]; - PermutationInstance { pair, challenge } - }) - .collect_vec() - }) - .collect() -} - -pub struct PermutationCheckVars -where - F: Field, - FE: FieldExtension, - P: PackedField, -{ - pub(crate) local_zs: Vec

, - pub(crate) next_zs: Vec

, - pub(crate) permutation_challenge_sets: Vec>, -} - -pub(crate) fn eval_permutation_checks( - stark: &S, - config: &StarkConfig, - vars: StarkEvaluationVars, - permutation_vars: PermutationCheckVars, - consumer: &mut ConstraintConsumer

, -) where - F: RichField + Extendable, - FE: FieldExtension, - P: PackedField, - S: Stark, -{ - let PermutationCheckVars { - local_zs, - next_zs, - permutation_challenge_sets, - } = permutation_vars; - - // Check that Z(1) = 1; - for &z in &local_zs { - consumer.constraint_first_row(z - FE::ONE); - } - - let permutation_pairs = stark.permutation_pairs(); - - let permutation_batches = get_permutation_batches( - &permutation_pairs, - &permutation_challenge_sets, - config.num_challenges, - stark.permutation_batch_size(), - ); - - // Each zs value corresponds to a permutation batch. - for (i, instances) in permutation_batches.iter().enumerate() { - // Z(gx) * down = Z x * up - let (reduced_lhs, reduced_rhs): (Vec

, Vec

) = instances - .iter() - .map(|instance| { - let PermutationInstance { - pair: PermutationPair { column_pairs }, - challenge: GrandProductChallenge { beta, gamma }, - } = instance; - let mut factor = ReducingFactor::new(*beta); - let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs - .iter() - .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) - .unzip(); - ( - factor.reduce_ext(lhs.into_iter()) + FE::from_basefield(*gamma), - factor.reduce_ext(rhs.into_iter()) + FE::from_basefield(*gamma), - ) - }) - .unzip(); - let constraint = next_zs[i] * reduced_rhs.into_iter().product::

() - - local_zs[i] * reduced_lhs.into_iter().product::

(); - consumer.constraint(constraint); - } -} - -pub struct PermutationCheckDataTarget { - pub(crate) local_zs: Vec>, - pub(crate) next_zs: Vec>, - pub(crate) permutation_challenge_sets: Vec>, -} - -pub(crate) fn eval_permutation_checks_circuit( - builder: &mut CircuitBuilder, - stark: &S, - config: &StarkConfig, - vars: StarkEvaluationTargets, - permutation_data: PermutationCheckDataTarget, - consumer: &mut RecursiveConstraintConsumer, -) where - F: RichField + Extendable, - S: Stark, - [(); S::COLUMNS]:, -{ - let PermutationCheckDataTarget { - local_zs, - next_zs, - permutation_challenge_sets, - } = permutation_data; - - let one = builder.one_extension(); - // Check that Z(1) = 1; - for &z in &local_zs { - let z_1 = builder.sub_extension(z, one); - consumer.constraint_first_row(builder, z_1); - } - - let permutation_pairs = stark.permutation_pairs(); - - let permutation_batches = get_permutation_batches( - &permutation_pairs, - &permutation_challenge_sets, - config.num_challenges, - stark.permutation_batch_size(), - ); - - // Each zs value corresponds to a permutation batch. - for (i, instances) in permutation_batches.iter().enumerate() { - let (reduced_lhs, reduced_rhs): (Vec>, Vec>) = - instances - .iter() - .map(|instance| { - let PermutationInstance { - pair: PermutationPair { column_pairs }, - challenge: GrandProductChallenge { beta, gamma }, - } = instance; - let beta_ext = builder.convert_to_ext(*beta); - let gamma_ext = builder.convert_to_ext(*gamma); - let mut factor = ReducingFactorTarget::new(beta_ext); - let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs - .iter() - .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) - .unzip(); - let reduced_lhs = factor.reduce(&lhs, builder); - let reduced_rhs = factor.reduce(&rhs, builder); - ( - builder.add_extension(reduced_lhs, gamma_ext), - builder.add_extension(reduced_rhs, gamma_ext), - ) - }) - .unzip(); - let reduced_lhs_product = builder.mul_many_extension(reduced_lhs); - let reduced_rhs_product = builder.mul_many_extension(reduced_rhs); - // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product - let constraint = { - let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); - builder.mul_sub_extension(next_zs[i], reduced_rhs_product, tmp) - }; - consumer.constraint(builder, constraint) - } -} diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 14f22b6791..23446ac4c8 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; -use crate::permutation::GrandProductChallengeSet; +use crate::cross_table_lookup::GrandProductChallengeSet; /// A STARK proof for each table, plus some metadata used to create recursive wrapper proofs. #[derive(Debug, Clone)] @@ -588,8 +588,8 @@ impl ExtraBlockDataTarget { pub struct StarkProof, C: GenericConfig, const D: usize> { /// Merkle cap of LDEs of trace values. pub trace_cap: MerkleCap, - /// Merkle cap of LDEs of permutation Z values. - pub permutation_ctl_zs_cap: MerkleCap, + /// Merkle cap of LDEs of lookup helper and CTL columns. + pub auxiliary_polys_cap: MerkleCap, /// Merkle cap of LDEs of quotient polynomial evaluations. pub quotient_polys_cap: MerkleCap, /// Purported values of each polynomial at the challenge point. @@ -630,7 +630,7 @@ impl, C: GenericConfig, const D: usize> S #[derive(Eq, PartialEq, Debug)] pub struct StarkProofTarget { pub trace_cap: MerkleCapTarget, - pub permutation_ctl_zs_cap: MerkleCapTarget, + pub auxiliary_polys_cap: MerkleCapTarget, pub quotient_polys_cap: MerkleCapTarget, pub openings: StarkOpeningSetTarget, pub opening_proof: FriProofTarget, @@ -639,7 +639,7 @@ pub struct StarkProofTarget { impl StarkProofTarget { pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { buffer.write_target_merkle_cap(&self.trace_cap)?; - buffer.write_target_merkle_cap(&self.permutation_ctl_zs_cap)?; + buffer.write_target_merkle_cap(&self.auxiliary_polys_cap)?; buffer.write_target_merkle_cap(&self.quotient_polys_cap)?; buffer.write_target_fri_proof(&self.opening_proof)?; self.openings.to_buffer(buffer)?; @@ -648,14 +648,14 @@ impl StarkProofTarget { pub fn from_buffer(buffer: &mut Buffer) -> IoResult { let trace_cap = buffer.read_target_merkle_cap()?; - let permutation_ctl_zs_cap = buffer.read_target_merkle_cap()?; + let auxiliary_polys_cap = buffer.read_target_merkle_cap()?; let quotient_polys_cap = buffer.read_target_merkle_cap()?; let opening_proof = buffer.read_target_fri_proof()?; let openings = StarkOpeningSetTarget::from_buffer(buffer)?; Ok(Self { trace_cap, - permutation_ctl_zs_cap, + auxiliary_polys_cap, quotient_polys_cap, openings, opening_proof, @@ -674,8 +674,8 @@ impl StarkProofTarget { } pub(crate) struct StarkProofChallenges, const D: usize> { - /// Randomness used in any permutation arguments. - pub permutation_challenge_sets: Option>>, + /// Randomness used in lookup arguments. + pub lookup_challenges: Option>, /// Random values used to combine STARK constraints. pub stark_alphas: Vec, @@ -687,7 +687,7 @@ pub(crate) struct StarkProofChallenges, const D: us } pub(crate) struct StarkProofChallengesTarget { - pub permutation_challenge_sets: Option>>, + pub lookup_challenges: Option>, pub stark_alphas: Vec, pub stark_zeta: ExtensionTarget, pub fri_challenges: FriChallengesTarget, @@ -700,10 +700,10 @@ pub struct StarkOpeningSet, const D: usize> { pub local_values: Vec, /// Openings of trace polynomials at `g * zeta`. pub next_values: Vec, - /// Openings of permutations and cross-table lookups `Z` polynomials at `zeta`. - pub permutation_ctl_zs: Vec, - /// Openings of permutations and cross-table lookups `Z` polynomials at `g * zeta`. - pub permutation_ctl_zs_next: Vec, + /// Openings of lookups and cross-table lookups `Z` polynomials at `zeta`. + pub auxiliary_polys: Vec, + /// Openings of lookups and cross-table lookups `Z` polynomials at `g * zeta`. + pub auxiliary_polys_next: Vec, /// Openings of cross-table lookups `Z` polynomials at `g^-1`. pub ctl_zs_last: Vec, /// Openings of quotient polynomials at `zeta`. @@ -715,10 +715,10 @@ impl, const D: usize> StarkOpeningSet { zeta: F::Extension, g: F, trace_commitment: &PolynomialBatch, - permutation_ctl_zs_commitment: &PolynomialBatch, + auxiliary_polys_commitment: &PolynomialBatch, quotient_commitment: &PolynomialBatch, degree_bits: usize, - num_permutation_zs: usize, + num_lookup_columns: usize, ) -> Self { let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { c.polynomials @@ -736,12 +736,12 @@ impl, const D: usize> StarkOpeningSet { Self { local_values: eval_commitment(zeta, trace_commitment), next_values: eval_commitment(zeta_next, trace_commitment), - permutation_ctl_zs: eval_commitment(zeta, permutation_ctl_zs_commitment), - permutation_ctl_zs_next: eval_commitment(zeta_next, permutation_ctl_zs_commitment), + auxiliary_polys: eval_commitment(zeta, auxiliary_polys_commitment), + auxiliary_polys_next: eval_commitment(zeta_next, auxiliary_polys_commitment), ctl_zs_last: eval_commitment_base( F::primitive_root_of_unity(degree_bits).inverse(), - permutation_ctl_zs_commitment, - )[num_permutation_zs..] + auxiliary_polys_commitment, + )[num_lookup_columns..] .to_vec(), quotient_polys: eval_commitment(zeta, quotient_commitment), } @@ -752,7 +752,7 @@ impl, const D: usize> StarkOpeningSet { values: self .local_values .iter() - .chain(&self.permutation_ctl_zs) + .chain(&self.auxiliary_polys) .chain(&self.quotient_polys) .copied() .collect_vec(), @@ -761,7 +761,7 @@ impl, const D: usize> StarkOpeningSet { values: self .next_values .iter() - .chain(&self.permutation_ctl_zs_next) + .chain(&self.auxiliary_polys_next) .copied() .collect_vec(), }; @@ -785,8 +785,8 @@ impl, const D: usize> StarkOpeningSet { pub struct StarkOpeningSetTarget { pub local_values: Vec>, pub next_values: Vec>, - pub permutation_ctl_zs: Vec>, - pub permutation_ctl_zs_next: Vec>, + pub auxiliary_polys: Vec>, + pub auxiliary_polys_next: Vec>, pub ctl_zs_last: Vec, pub quotient_polys: Vec>, } @@ -795,8 +795,8 @@ impl StarkOpeningSetTarget { pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { buffer.write_target_ext_vec(&self.local_values)?; buffer.write_target_ext_vec(&self.next_values)?; - buffer.write_target_ext_vec(&self.permutation_ctl_zs)?; - buffer.write_target_ext_vec(&self.permutation_ctl_zs_next)?; + buffer.write_target_ext_vec(&self.auxiliary_polys)?; + buffer.write_target_ext_vec(&self.auxiliary_polys_next)?; buffer.write_target_vec(&self.ctl_zs_last)?; buffer.write_target_ext_vec(&self.quotient_polys)?; Ok(()) @@ -805,16 +805,16 @@ impl StarkOpeningSetTarget { pub fn from_buffer(buffer: &mut Buffer) -> IoResult { let local_values = buffer.read_target_ext_vec::()?; let next_values = buffer.read_target_ext_vec::()?; - let permutation_ctl_zs = buffer.read_target_ext_vec::()?; - let permutation_ctl_zs_next = buffer.read_target_ext_vec::()?; + let auxiliary_polys = buffer.read_target_ext_vec::()?; + let auxiliary_polys_next = buffer.read_target_ext_vec::()?; let ctl_zs_last = buffer.read_target_vec()?; let quotient_polys = buffer.read_target_ext_vec::()?; Ok(Self { local_values, next_values, - permutation_ctl_zs, - permutation_ctl_zs_next, + auxiliary_polys, + auxiliary_polys_next, ctl_zs_last, quotient_polys, }) @@ -825,7 +825,7 @@ impl StarkOpeningSetTarget { values: self .local_values .iter() - .chain(&self.permutation_ctl_zs) + .chain(&self.auxiliary_polys) .chain(&self.quotient_polys) .copied() .collect_vec(), @@ -834,7 +834,7 @@ impl StarkOpeningSetTarget { values: self .next_values .iter() - .chain(&self.permutation_ctl_zs_next) + .chain(&self.auxiliary_polys_next) .copied() .collect_vec(), }; diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 8f5878232b..7a8439db45 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -26,18 +26,17 @@ use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; -use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData}; +use crate::cross_table_lookup::{ + cross_table_lookup_data, get_grand_product_challenge_set, CtlCheckVars, CtlData, +}; use crate::generation::outputs::GenerationOutputs; use crate::generation::{generate_traces, GenerationInputs}; use crate::get_challenges::observe_public_values; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; +use crate::lookup::{lookup_helper_columns, Lookup, LookupCheckVars}; use crate::memory::memory_stark::MemoryStark; -use crate::permutation::{ - compute_permutation_z_polys, get_grand_product_challenge_set, - get_n_grand_product_challenge_sets, GrandProductChallengeSet, PermutationCheckVars, -}; use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; @@ -335,37 +334,45 @@ where let init_challenger_state = challenger.compact(); - // Permutation arguments. - let permutation_challenges = stark.uses_permutation_args().then(|| { - get_n_grand_product_challenge_sets( - challenger, - config.num_challenges, - stark.permutation_batch_size(), - ) - }); - let permutation_zs = permutation_challenges.as_ref().map(|challenges| { - timed!( - timing, - "compute permutation Z(x) polys", - compute_permutation_z_polys::(stark, config, trace_poly_values, challenges) - ) - }); - let num_permutation_zs = permutation_zs.as_ref().map(|v| v.len()).unwrap_or(0); + let constraint_degree = stark.constraint_degree(); + let lookup_challenges = stark + .uses_lookups() + .then(|| challenger.get_n_challenges(config.num_challenges)); + let lookups = stark.lookups(); + let lookup_helper_columns = timed!( + timing, + "compute lookup helper columns", + lookup_challenges.as_ref().map(|challenges| { + let mut columns = Vec::new(); + for lookup in &lookups { + for &challenge in challenges { + columns.extend(lookup_helper_columns( + lookup, + trace_poly_values, + challenge, + constraint_degree, + )); + } + } + columns + }) + ); + let num_lookup_columns = lookup_helper_columns.as_ref().map(|v| v.len()).unwrap_or(0); - let z_polys = match permutation_zs { + let auxiliary_polys = match lookup_helper_columns { None => ctl_data.z_polys(), - Some(mut permutation_zs) => { - permutation_zs.extend(ctl_data.z_polys()); - permutation_zs + Some(mut lookup_columns) => { + lookup_columns.extend(ctl_data.z_polys()); + lookup_columns } }; - assert!(!z_polys.is_empty(), "No CTL?"); + assert!(!auxiliary_polys.is_empty(), "No CTL?"); - let permutation_ctl_zs_commitment = timed!( + let auxiliary_polys_commitment = timed!( timing, - "compute Zs commitment", + "compute auxiliary polynomials commitment", PolynomialBatch::from_values( - z_polys, + auxiliary_polys, rate_bits, false, config.fri_config.cap_height, @@ -374,21 +381,21 @@ where ) ); - let permutation_ctl_zs_cap = permutation_ctl_zs_commitment.merkle_tree.cap.clone(); - challenger.observe_cap(&permutation_ctl_zs_cap); + let auxiliary_polys_cap = auxiliary_polys_commitment.merkle_tree.cap.clone(); + challenger.observe_cap(&auxiliary_polys_cap); let alphas = challenger.get_n_challenges(config.num_challenges); if cfg!(test) { check_constraints( stark, trace_commitment, - &permutation_ctl_zs_commitment, - permutation_challenges.as_ref(), + &auxiliary_polys_commitment, + lookup_challenges.as_ref(), + &lookups, ctl_data, alphas.clone(), degree_bits, - num_permutation_zs, - config, + num_lookup_columns, ); } let quotient_polys = timed!( @@ -397,12 +404,13 @@ where compute_quotient_polys::::Packing, C, S, D>( stark, trace_commitment, - &permutation_ctl_zs_commitment, - permutation_challenges.as_ref(), + &auxiliary_polys_commitment, + lookup_challenges.as_ref(), + &lookups, ctl_data, alphas, degree_bits, - num_permutation_zs, + num_lookup_columns, config, ) ); @@ -451,16 +459,16 @@ where zeta, g, trace_commitment, - &permutation_ctl_zs_commitment, + &auxiliary_polys_commitment, "ient_commitment, degree_bits, - stark.num_permutation_batches(config), + stark.num_lookup_helper_columns(config), ); challenger.observe_openings(&openings.to_fri_openings()); let initial_merkle_trees = vec![ trace_commitment, - &permutation_ctl_zs_commitment, + &auxiliary_polys_commitment, "ient_commitment, ]; @@ -478,7 +486,7 @@ where let proof = StarkProof { trace_cap: trace_commitment.merkle_tree.cap.clone(), - permutation_ctl_zs_cap, + auxiliary_polys_cap, quotient_polys_cap, openings, opening_proof, @@ -494,12 +502,13 @@ where fn compute_quotient_polys<'a, F, P, C, S, const D: usize>( stark: &S, trace_commitment: &'a PolynomialBatch, - permutation_ctl_zs_commitment: &'a PolynomialBatch, - permutation_challenges: Option<&'a Vec>>, + auxiliary_polys_commitment: &'a PolynomialBatch, + lookup_challenges: Option<&'a Vec>, + lookups: &[Lookup], ctl_data: &CtlData, alphas: Vec, degree_bits: usize, - num_permutation_zs: usize, + num_lookup_columns: usize, config: &StarkConfig, ) -> Vec> where @@ -570,25 +579,22 @@ where local_values: &get_trace_values_packed(i_start), next_values: &get_trace_values_packed(i_next_start), }; - let permutation_check_vars = - permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { - local_zs: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) - [..num_permutation_zs] - .to_vec(), - next_zs: permutation_ctl_zs_commitment - .get_lde_values_packed(i_next_start, step)[..num_permutation_zs] - .to_vec(), - permutation_challenge_sets: permutation_challenge_sets.to_vec(), - }); + let lookup_vars = lookup_challenges.map(|challenges| LookupCheckVars { + local_values: auxiliary_polys_commitment.get_lde_values_packed(i_start, step) + [..num_lookup_columns] + .to_vec(), + next_values: auxiliary_polys_commitment.get_lde_values_packed(i_next_start, step), + challenges: challenges.to_vec(), + }); let ctl_vars = ctl_data .zs_columns .iter() .enumerate() .map(|(i, zs_columns)| CtlCheckVars:: { - local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) - [num_permutation_zs + i], - next_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_next_start, step) - [num_permutation_zs + i], + local_z: auxiliary_polys_commitment.get_lde_values_packed(i_start, step) + [num_lookup_columns + i], + next_z: auxiliary_polys_commitment.get_lde_values_packed(i_next_start, step) + [num_lookup_columns + i], challenges: zs_columns.challenge, columns: &zs_columns.columns, filter_column: &zs_columns.filter_column, @@ -596,9 +602,9 @@ where .collect::>(); eval_vanishing_poly::( stark, - config, vars, - permutation_check_vars, + lookups, + lookup_vars, &ctl_vars, &mut consumer, ); @@ -631,13 +637,13 @@ where fn check_constraints<'a, F, C, S, const D: usize>( stark: &S, trace_commitment: &'a PolynomialBatch, - permutation_ctl_zs_commitment: &'a PolynomialBatch, - permutation_challenges: Option<&'a Vec>>, + auxiliary_commitment: &'a PolynomialBatch, + lookup_challenges: Option<&'a Vec>, + lookups: &[Lookup], ctl_data: &CtlData, alphas: Vec, degree_bits: usize, - num_permutation_zs: usize, - config: &StarkConfig, + num_lookup_columns: usize, ) where F: RichField + Extendable, C: GenericConfig, @@ -668,7 +674,7 @@ fn check_constraints<'a, F, C, S, const D: usize>( }; let trace_subgroup_evals = get_subgroup_evals(trace_commitment); - let permutation_ctl_zs_subgroup_evals = get_subgroup_evals(permutation_ctl_zs_commitment); + let auxiliary_subgroup_evals = get_subgroup_evals(auxiliary_commitment); // Last element of the subgroup. let last = F::primitive_root_of_unity(degree_bits).inverse(); @@ -692,21 +698,19 @@ fn check_constraints<'a, F, C, S, const D: usize>( local_values: trace_subgroup_evals[i].as_slice().try_into().unwrap(), next_values: trace_subgroup_evals[i_next].as_slice().try_into().unwrap(), }; - let permutation_check_vars = - permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { - local_zs: permutation_ctl_zs_subgroup_evals[i][..num_permutation_zs].to_vec(), - next_zs: permutation_ctl_zs_subgroup_evals[i_next][..num_permutation_zs] - .to_vec(), - permutation_challenge_sets: permutation_challenge_sets.to_vec(), - }); + let lookup_vars = lookup_challenges.map(|challenges| LookupCheckVars { + local_values: auxiliary_subgroup_evals[i][..num_lookup_columns].to_vec(), + next_values: auxiliary_subgroup_evals[i_next][..num_lookup_columns].to_vec(), + challenges: challenges.to_vec(), + }); let ctl_vars = ctl_data .zs_columns .iter() .enumerate() .map(|(iii, zs_columns)| CtlCheckVars:: { - local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], - next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], + local_z: auxiliary_subgroup_evals[i][num_lookup_columns + iii], + next_z: auxiliary_subgroup_evals[i_next][num_lookup_columns + iii], challenges: zs_columns.challenge, columns: &zs_columns.columns, filter_column: &zs_columns.filter_column, @@ -714,9 +718,9 @@ fn check_constraints<'a, F, C, S, const D: usize>( .collect::>(); eval_vanishing_poly::( stark, - config, vars, - permutation_check_vars, + lookups, + lookup_vars, &ctl_vars, &mut consumer, ); diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index e669f4ab35..45cc0c485c 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -29,13 +29,13 @@ use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cross_table_lookup::{verify_cross_table_lookups, CrossTableLookup, CtlCheckVarsTarget}; +use crate::cross_table_lookup::{ + get_grand_product_challenge_set, verify_cross_table_lookups, CrossTableLookup, + CtlCheckVarsTarget, GrandProductChallenge, GrandProductChallengeSet, +}; +use crate::lookup::LookupCheckVarsTarget; use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; -use crate::permutation::{ - get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, - PermutationCheckDataTarget, -}; use crate::proof::{ BlockHashes, BlockHashesTarget, BlockMetadata, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, PublicValues, PublicValuesTarget, StarkOpeningSetTarget, StarkProof, @@ -302,8 +302,7 @@ where let mut builder = CircuitBuilder::::new(circuit_config.clone()); let zero_target = builder.zero(); - let num_permutation_zs = stark.num_permutation_batches(inner_config); - let num_permutation_batch_size = stark.permutation_batch_size(); + let num_lookup_columns = stark.num_lookup_helper_columns(inner_config); let num_ctl_zs = CrossTableLookup::num_ctl_zs(cross_table_lookups, table, inner_config.num_challenges); let proof_target = @@ -331,7 +330,7 @@ where &proof_target, cross_table_lookups, &ctl_challenges_target, - num_permutation_zs, + num_lookup_columns, ); let init_challenger_state_target = @@ -343,8 +342,7 @@ where let challenges = proof_target.get_challenges::( &mut builder, &mut challenger, - num_permutation_zs > 0, - num_permutation_batch_size, + num_lookup_columns > 0, inner_config, ); let challenger_state = challenger.compact(&mut builder); @@ -412,8 +410,8 @@ fn verify_stark_proof_with_challenges_circuit< let StarkOpeningSetTarget { local_values, next_values, - permutation_ctl_zs, - permutation_ctl_zs_next, + auxiliary_polys, + auxiliary_polys_next, ctl_zs_last, quotient_polys, } = &proof.openings; @@ -439,14 +437,12 @@ fn verify_stark_proof_with_challenges_circuit< l_last, ); - let num_permutation_zs = stark.num_permutation_batches(inner_config); - let permutation_data = stark - .uses_permutation_args() - .then(|| PermutationCheckDataTarget { - local_zs: permutation_ctl_zs[..num_permutation_zs].to_vec(), - next_zs: permutation_ctl_zs_next[..num_permutation_zs].to_vec(), - permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), - }); + let num_lookup_columns = stark.num_lookup_helper_columns(inner_config); + let lookup_vars = stark.uses_lookups().then(|| LookupCheckVarsTarget { + local_values: auxiliary_polys[..num_lookup_columns].to_vec(), + next_values: auxiliary_polys_next[..num_lookup_columns].to_vec(), + challenges: challenges.lookup_challenges.clone().unwrap(), + }); with_context!( builder, @@ -454,9 +450,8 @@ fn verify_stark_proof_with_challenges_circuit< eval_vanishing_poly_circuit::( builder, stark, - inner_config, vars, - permutation_data, + lookup_vars, ctl_vars, &mut consumer, ) @@ -476,7 +471,7 @@ fn verify_stark_proof_with_challenges_circuit< let merkle_caps = vec![ proof.trace_cap.clone(), - proof.permutation_ctl_zs_cap.clone(), + proof.auxiliary_polys_cap.clone(), proof.quotient_polys_cap.clone(), ]; @@ -840,15 +835,15 @@ pub(crate) fn add_virtual_stark_proof< let num_leaves_per_oracle = vec![ S::COLUMNS, - stark.num_permutation_batches(config) + num_ctl_zs, + stark.num_lookup_helper_columns(config) + num_ctl_zs, stark.quotient_degree_factor() * config.num_challenges, ]; - let permutation_zs_cap = builder.add_virtual_cap(cap_height); + let auxiliary_polys_cap = builder.add_virtual_cap(cap_height); StarkProofTarget { trace_cap: builder.add_virtual_cap(cap_height), - permutation_ctl_zs_cap: permutation_zs_cap, + auxiliary_polys_cap, quotient_polys_cap: builder.add_virtual_cap(cap_height), openings: add_virtual_stark_opening_set::(builder, stark, num_ctl_zs, config), opening_proof: builder.add_virtual_fri_proof(&num_leaves_per_oracle, &fri_params), @@ -865,10 +860,10 @@ fn add_virtual_stark_opening_set, S: Stark, c StarkOpeningSetTarget { local_values: builder.add_virtual_extension_targets(S::COLUMNS), next_values: builder.add_virtual_extension_targets(S::COLUMNS), - permutation_ctl_zs: builder - .add_virtual_extension_targets(stark.num_permutation_batches(config) + num_ctl_zs), - permutation_ctl_zs_next: builder - .add_virtual_extension_targets(stark.num_permutation_batches(config) + num_ctl_zs), + auxiliary_polys: builder + .add_virtual_extension_targets(stark.num_lookup_helper_columns(config) + num_ctl_zs), + auxiliary_polys_next: builder + .add_virtual_extension_targets(stark.num_lookup_helper_columns(config) + num_ctl_zs), ctl_zs_last: builder.add_virtual_targets(num_ctl_zs), quotient_polys: builder .add_virtual_extension_targets(stark.quotient_degree_factor() * num_challenges), @@ -894,8 +889,8 @@ pub(crate) fn set_stark_proof_target, W, const D: ); witness.set_cap_target( - &proof_target.permutation_ctl_zs_cap, - &proof.permutation_ctl_zs_cap, + &proof_target.auxiliary_polys_cap, + &proof.auxiliary_polys_cap, ); set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); diff --git a/evm/src/stark.rs b/evm/src/stark.rs index 72cee0ad60..4e76017db7 100644 --- a/evm/src/stark.rs +++ b/evm/src/stark.rs @@ -8,15 +8,14 @@ use plonky2::fri::structure::{ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2_util::ceil_div_usize; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::PermutationPair; +use crate::lookup::Lookup; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; const TRACE_ORACLE_INDEX: usize = 0; -const PERMUTATION_CTL_ORACLE_INDEX: usize = 1; +const AUXILIARY_ORACLE_INDEX: usize = 1; const QUOTIENT_ORACLE_INDEX: usize = 2; /// Represents a STARK system. @@ -94,20 +93,18 @@ pub trait Stark, const D: usize>: Sync { }; let trace_info = FriPolynomialInfo::from_range(TRACE_ORACLE_INDEX, 0..Self::COLUMNS); - let num_permutation_batches = self.num_permutation_batches(config); - let num_perutation_ctl_polys = num_permutation_batches + num_ctl_zs; - let permutation_ctl_oracle = FriOracleInfo { - num_polys: num_perutation_ctl_polys, + let num_lookup_columns = self.num_lookup_helper_columns(config); + let num_auxiliary_polys = num_lookup_columns + num_ctl_zs; + let auxiliary_oracle = FriOracleInfo { + num_polys: num_auxiliary_polys, blinding: false, }; - let permutation_ctl_zs_info = FriPolynomialInfo::from_range( - PERMUTATION_CTL_ORACLE_INDEX, - 0..num_perutation_ctl_polys, - ); + let auxiliary_polys_info = + FriPolynomialInfo::from_range(AUXILIARY_ORACLE_INDEX, 0..num_auxiliary_polys); let ctl_zs_info = FriPolynomialInfo::from_range( - PERMUTATION_CTL_ORACLE_INDEX, - num_permutation_batches..num_permutation_batches + num_ctl_zs, + AUXILIARY_ORACLE_INDEX, + num_lookup_columns..num_lookup_columns + num_ctl_zs, ); let num_quotient_polys = self.num_quotient_polys(config); @@ -122,21 +119,21 @@ pub trait Stark, const D: usize>: Sync { point: zeta, polynomials: [ trace_info.clone(), - permutation_ctl_zs_info.clone(), + auxiliary_polys_info.clone(), quotient_info, ] .concat(), }; let zeta_next_batch = FriBatchInfo { point: zeta.scalar_mul(g), - polynomials: [trace_info, permutation_ctl_zs_info].concat(), + polynomials: [trace_info, auxiliary_polys_info].concat(), }; let ctl_last_batch = FriBatchInfo { point: F::Extension::primitive_root_of_unity(degree_bits).inverse(), polynomials: ctl_zs_info, }; FriInstanceInfo { - oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], + oracles: vec![trace_oracle, auxiliary_oracle, quotient_oracle], batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], } } @@ -157,20 +154,18 @@ pub trait Stark, const D: usize>: Sync { }; let trace_info = FriPolynomialInfo::from_range(TRACE_ORACLE_INDEX, 0..Self::COLUMNS); - let num_permutation_batches = self.num_permutation_batches(inner_config); - let num_perutation_ctl_polys = num_permutation_batches + num_ctl_zs; - let permutation_ctl_oracle = FriOracleInfo { - num_polys: num_perutation_ctl_polys, + let num_lookup_columns = self.num_lookup_helper_columns(inner_config); + let num_auxiliary_polys = num_lookup_columns + num_ctl_zs; + let auxiliary_oracle = FriOracleInfo { + num_polys: num_auxiliary_polys, blinding: false, }; - let permutation_ctl_zs_info = FriPolynomialInfo::from_range( - PERMUTATION_CTL_ORACLE_INDEX, - 0..num_perutation_ctl_polys, - ); + let auxiliary_polys_info = + FriPolynomialInfo::from_range(AUXILIARY_ORACLE_INDEX, 0..num_auxiliary_polys); let ctl_zs_info = FriPolynomialInfo::from_range( - PERMUTATION_CTL_ORACLE_INDEX, - num_permutation_batches..num_permutation_batches + num_ctl_zs, + AUXILIARY_ORACLE_INDEX, + num_lookup_columns..num_lookup_columns + num_ctl_zs, ); let num_quotient_polys = self.num_quotient_polys(inner_config); @@ -185,7 +180,7 @@ pub trait Stark, const D: usize>: Sync { point: zeta, polynomials: [ trace_info.clone(), - permutation_ctl_zs_info.clone(), + auxiliary_polys_info.clone(), quotient_info, ] .concat(), @@ -193,7 +188,7 @@ pub trait Stark, const D: usize>: Sync { let zeta_next = builder.mul_const_extension(g, zeta); let zeta_next_batch = FriBatchInfoTarget { point: zeta_next, - polynomials: [trace_info, permutation_ctl_zs_info].concat(), + polynomials: [trace_info, auxiliary_polys_info].concat(), }; let ctl_last_batch = FriBatchInfoTarget { point: builder @@ -201,38 +196,24 @@ pub trait Stark, const D: usize>: Sync { polynomials: ctl_zs_info, }; FriInstanceInfoTarget { - oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], + oracles: vec![trace_oracle, auxiliary_oracle, quotient_oracle], batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], } } - /// Pairs of lists of columns that should be permutations of one another. A permutation argument - /// will be used for each such pair. Empty by default. - fn permutation_pairs(&self) -> Vec { + fn lookups(&self) -> Vec { vec![] } - fn uses_permutation_args(&self) -> bool { - !self.permutation_pairs().is_empty() - } - - /// The number of permutation argument instances that can be combined into a single constraint. - fn permutation_batch_size(&self) -> usize { - // The permutation argument constraints look like - // Z(x) \prod(...) = Z(g x) \prod(...) - // where each product has a number of terms equal to the batch size. So our batch size - // should be one less than our constraint degree, which happens to be our quotient degree. - self.quotient_degree_factor() - } - - fn num_permutation_instances(&self, config: &StarkConfig) -> usize { - self.permutation_pairs().len() * config.num_challenges + fn num_lookup_helper_columns(&self, config: &StarkConfig) -> usize { + self.lookups() + .iter() + .map(|lookup| lookup.num_helper_columns(self.constraint_degree())) + .sum::() + * config.num_challenges } - fn num_permutation_batches(&self, config: &StarkConfig) -> usize { - ceil_div_usize( - self.num_permutation_instances(config), - self.permutation_batch_size(), - ) + fn uses_lookups(&self) -> bool { + !self.lookups().is_empty() } } diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index 3a2da78c53..21d361674f 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -3,24 +3,23 @@ use plonky2::field::packed::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::{ eval_cross_table_lookup_checks, eval_cross_table_lookup_checks_circuit, CtlCheckVars, CtlCheckVarsTarget, }; -use crate::permutation::{ - eval_permutation_checks, eval_permutation_checks_circuit, PermutationCheckDataTarget, - PermutationCheckVars, +use crate::lookup::{ + eval_lookups_checks, eval_lookups_checks_circuit, Lookup, LookupCheckVars, + LookupCheckVarsTarget, }; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) fn eval_vanishing_poly( stark: &S, - config: &StarkConfig, vars: StarkEvaluationVars, - permutation_vars: Option>, + lookups: &[Lookup], + lookup_vars: Option>, ctl_vars: &[CtlCheckVars], consumer: &mut ConstraintConsumer

, ) where @@ -30,14 +29,8 @@ pub(crate) fn eval_vanishing_poly( S: Stark, { stark.eval_packed_generic(vars, consumer); - if let Some(permutation_vars) = permutation_vars { - eval_permutation_checks::( - stark, - config, - vars, - permutation_vars, - consumer, - ); + if let Some(lookup_vars) = lookup_vars { + eval_lookups_checks::(stark, lookups, vars, lookup_vars, consumer); } eval_cross_table_lookup_checks::(vars, ctl_vars, consumer); } @@ -45,9 +38,8 @@ pub(crate) fn eval_vanishing_poly( pub(crate) fn eval_vanishing_poly_circuit( builder: &mut CircuitBuilder, stark: &S, - config: &StarkConfig, vars: StarkEvaluationTargets, - permutation_data: Option>, + lookup_vars: Option>, ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) where @@ -56,15 +48,8 @@ pub(crate) fn eval_vanishing_poly_circuit( [(); S::COLUMNS]:, { stark.eval_ext_circuit(builder, vars, consumer); - if let Some(permutation_data) = permutation_data { - eval_permutation_checks_circuit::( - builder, - stark, - config, - vars, - permutation_data, - consumer, - ); + if let Some(lookup_vars) = lookup_vars { + eval_lookups_checks_circuit::(builder, stark, vars, lookup_vars, consumer); } eval_cross_table_lookup_checks_circuit::(builder, vars, ctl_vars, consumer); } diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 1aa9db9782..cf1c2e3659 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -17,14 +17,14 @@ use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars}; +use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars, GrandProductChallenge}; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; +use crate::lookup::LookupCheckVars; use crate::memory::memory_stark::MemoryStark; use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; -use crate::permutation::{GrandProductChallenge, PermutationCheckVars}; use crate::proof::{ AllProof, AllProofChallenges, PublicValues, StarkOpeningSet, StarkProof, StarkProofChallenges, }; @@ -52,7 +52,7 @@ where ctl_challenges, } = all_proof.get_challenges(all_stark, config); - let nums_permutation_zs = all_stark.nums_permutation_zs(config); + let num_lookup_columns = all_stark.num_lookups_helper_columns(config); let AllStark { arithmetic_stark, @@ -69,7 +69,7 @@ where &all_proof.stark_proofs, cross_table_lookups, &ctl_challenges, - &nums_permutation_zs, + &num_lookup_columns, ); verify_stark_proof_with_challenges( @@ -306,8 +306,8 @@ where let StarkOpeningSet { local_values, next_values, - permutation_ctl_zs, - permutation_ctl_zs_next, + auxiliary_polys, + auxiliary_polys_next, ctl_zs_last, quotient_polys, } = &proof.openings; @@ -330,17 +330,18 @@ where l_0, l_last, ); - let num_permutation_zs = stark.num_permutation_batches(config); - let permutation_data = stark.uses_permutation_args().then(|| PermutationCheckVars { - local_zs: permutation_ctl_zs[..num_permutation_zs].to_vec(), - next_zs: permutation_ctl_zs_next[..num_permutation_zs].to_vec(), - permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), + let num_lookup_columns = stark.num_lookup_helper_columns(config); + let lookup_vars = stark.uses_lookups().then(|| LookupCheckVars { + local_values: auxiliary_polys[..num_lookup_columns].to_vec(), + next_values: auxiliary_polys_next[..num_lookup_columns].to_vec(), + challenges: challenges.lookup_challenges.clone().unwrap(), }); + let lookups = stark.lookups(); eval_vanishing_poly::( stark, - config, vars, - permutation_data, + &lookups, + lookup_vars, ctl_vars, &mut consumer, ); @@ -366,7 +367,7 @@ where let merkle_caps = vec![ proof.trace_cap.clone(), - proof.permutation_ctl_zs_cap.clone(), + proof.auxiliary_polys_cap.clone(), proof.quotient_polys_cap.clone(), ]; @@ -402,7 +403,7 @@ where { let StarkProof { trace_cap, - permutation_ctl_zs_cap, + auxiliary_polys_cap, quotient_polys_cap, openings, // The shape of the opening proof will be checked in the FRI verifier (see @@ -413,8 +414,8 @@ where let StarkOpeningSet { local_values, next_values, - permutation_ctl_zs, - permutation_ctl_zs_next, + auxiliary_polys, + auxiliary_polys_next, ctl_zs_last, quotient_polys, } = openings; @@ -422,16 +423,16 @@ where let degree_bits = proof.recover_degree_bits(config); let fri_params = config.fri_params(degree_bits); let cap_height = fri_params.config.cap_height; - let num_zs = num_ctl_zs + stark.num_permutation_batches(config); + let num_auxiliary = num_ctl_zs + stark.num_lookup_helper_columns(config); ensure!(trace_cap.height() == cap_height); - ensure!(permutation_ctl_zs_cap.height() == cap_height); + ensure!(auxiliary_polys_cap.height() == cap_height); ensure!(quotient_polys_cap.height() == cap_height); ensure!(local_values.len() == S::COLUMNS); ensure!(next_values.len() == S::COLUMNS); - ensure!(permutation_ctl_zs.len() == num_zs); - ensure!(permutation_ctl_zs_next.len() == num_zs); + ensure!(auxiliary_polys.len() == num_auxiliary); + ensure!(auxiliary_polys_next.len() == num_auxiliary); ensure!(ctl_zs_last.len() == num_ctl_zs); ensure!(quotient_polys.len() == stark.num_quotient_polys(config)); From c9c0f8b7e5f37792243081bae305a6d928ed4a16 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Thu, 14 Sep 2023 10:57:33 +0100 Subject: [PATCH 03/34] Use CTL challenges for logUP + change comments + add assert --- evm/src/arithmetic/arithmetic_stark.rs | 6 ++ evm/src/cpu/kernel/asm/curve/bn254/glv.asm | 97 ------------------- evm/src/cpu/kernel/asm/curve/bn254/msm.asm | 73 -------------- .../kernel/asm/curve/bn254/precomputation.asm | 35 ------- evm/src/get_challenges.rs | 22 +---- evm/src/keccak/keccak_stark.rs | 9 +- evm/src/lookup.rs | 6 +- evm/src/memory/memory_stark.rs | 2 +- evm/src/proof.rs | 4 - evm/src/prover.rs | 20 +++- evm/src/recursive_verifier.rs | 17 ++-- evm/src/verifier.rs | 11 ++- 12 files changed, 56 insertions(+), 246 deletions(-) delete mode 100644 evm/src/cpu/kernel/asm/curve/bn254/glv.asm delete mode 100644 evm/src/cpu/kernel/asm/curve/bn254/msm.asm delete mode 100644 evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 9584ab884a..f7269d1e70 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -126,6 +126,12 @@ impl ArithmeticStark { for col in SHARED_COLS { for i in 0..n_rows { let x = cols[col][i].to_canonical_u64() as usize; + assert!( + x < RANGE_MAX, + "column value {} exceeds the max range value {}", + x, + RANGE_MAX + ); cols[RC_FREQUENCIES][x] += F::ONE; } } diff --git a/evm/src/cpu/kernel/asm/curve/bn254/glv.asm b/evm/src/cpu/kernel/asm/curve/bn254/glv.asm deleted file mode 100644 index c29d8f141d..0000000000 --- a/evm/src/cpu/kernel/asm/curve/bn254/glv.asm +++ /dev/null @@ -1,97 +0,0 @@ -// Inspired by https://github.com/AztecProtocol/weierstrudel/blob/master/huff_modules/endomorphism.huff -// See also Sage code in evm/src/cpu/kernel/tests/ecc/bn_glv_test_data -// Given scalar `k ∈ Bn254::ScalarField`, return `u, k1, k2` with `k1,k2 < 2^127` and such that -// `k = k1 - s*k2` if `u==0` otherwise `k = k1 + s*k2`, where `s` is the scalar value representing the endomorphism. -// In the comments below, N means @BN_SCALAR -// -// Z3 proof that the resulting `k1, k2` satisfy `k1>0`, `k1 < 2^127` and `|k2| < 2^127`. -// ```python -// from z3 import Solver, Int, Or, unsat -// q = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 -// glv_s = 0xB3C4D79D41A917585BFC41088D8DAAA78B17EA66B99C90DD -// -// b2 = 0x89D3256894D213E3 -// b1 = -0x6F4D8248EEB859FC8211BBEB7D4F1128 -// -// g1 = 0x24CCEF014A773D2CF7A7BD9D4391EB18D -// g2 = 0x2D91D232EC7E0B3D7 -// k = Int("k") -// c1 = Int("c1") -// c2 = Int("c2") -// s = Solver() -// -// c2p = -c2 -// s.add(k < q) -// s.add(0 < k) -// s.add(c1 * (2**256) <= g2 * k) -// s.add((c1 + 1) * (2**256) > g2 * k) -// s.add(c2p * (2**256) <= g1 * k) -// s.add((c2p + 1) * (2**256) > g1 * k) -// -// q1 = c1 * b1 -// q2 = c2 * b2 -// -// k2 = q2 - q1 -// k2L = (glv_s * k2) % q -// k1 = k - k2L -// k2 = -k2 -// -// s.add(Or((k2 >= 2**127), (-k2 >= 2**127), (k1 >= 2**127), (k1 < 0))) -// -// assert s.check() == unsat -// ``` -global bn_glv_decompose: - // stack: k, retdest - PUSH @BN_SCALAR DUP1 DUP1 - // Compute c2 which is the top 256 bits of k*g1. Use asm from https://medium.com/wicketh/mathemagic-full-multiply-27650fec525d. - PUSH @U256_MAX - // stack: -1, N, N, N, k, retdest - PUSH @BN_GLV_MINUS_G1 DUP6 - // stack: k, g1, -1, N, N, N, k, retdest - MULMOD - // stack: (k * g1 % -1), N, N, N, k, retdest - PUSH @BN_GLV_MINUS_G1 DUP6 - // stack: k, g1, (k * g1 % -1), N, N, N, k, retdest - MUL - // stack: bottom = (k * g1), (k * g1 % -1), N, N, N, k, retdest - DUP1 DUP3 - // stack: (k * g1 % -1), bottom, bottom, (k * g1 % -1), N, N, N, k, retdest - LT SWAP2 SUB SUB - // stack: c2, N, N, N, k, retdest - PUSH @BN_GLV_B2 MULMOD - // stack: q2=c2*b2, N, N, k, retdest - - // Use the same trick to compute c1 = top 256 bits of g2*k. - PUSH @BN_SCALAR PUSH @U256_MAX - PUSH @BN_GLV_G2 DUP7 MULMOD - PUSH @BN_GLV_G2 DUP7 MUL - DUP1 DUP3 LT - SWAP2 SUB SUB - // stack: c1, N, q2, N, N, k, retdest - PUSH @BN_GLV_B1 MULMOD - // stack: q1, q2, N, N, k, retdest - - // We compute k2 = q1 + q2 - N, but we check for underflow and return N-q1-q2 instead if there is one, - // along with a flag `underflow` set to 1 if there is an underflow, 0 otherwise. - ADD %sub_check_underflow - // stack: k2, underflow, N, k, retdest - SWAP3 PUSH @BN_SCALAR DUP5 PUSH @BN_GLV_S - // stack: s, k2, N, k, underflow, N, k2, retdest - MULMOD - // stack: s*k2, k, underflow, N, k2, retdest - // Need to return `k + s*k2` if no underflow occur, otherwise return `k - s*k2` which is done in the `underflowed` fn. - SWAP2 DUP1 %jumpi(underflowed) - %stack (underflow, k, x, N, k2) -> (k, x, N, k2, underflow) - ADDMOD - %stack (k1, k2, underflow, retdest) -> (retdest, underflow, k1, k2) - JUMP - -underflowed: - // stack: underflow, k, s*k2, N, k2 - // Compute (k-s*k2)%N. TODO: Use SUBMOD here when ready - %stack (u, k, x, N, k2) -> (N, x, k, N, k2, u) - SUB ADDMOD - %stack (k1, k2, underflow, retdest) -> (retdest, underflow, k1, k2) - JUMP - - diff --git a/evm/src/cpu/kernel/asm/curve/bn254/msm.asm b/evm/src/cpu/kernel/asm/curve/bn254/msm.asm deleted file mode 100644 index 1036228737..0000000000 --- a/evm/src/cpu/kernel/asm/curve/bn254/msm.asm +++ /dev/null @@ -1,73 +0,0 @@ -// Computes the multiplication `a*G` using a standard MSM with the GLV decomposition of `a`. -// see there for a detailed description. -global bn_msm: - // stack: retdest - PUSH 0 PUSH 0 PUSH 0 -global bn_msm_loop: - // stack: accx, accy, i, retdest - DUP3 %bn_mload_wnaf_a - // stack: w, accx, accy, i, retdest - DUP1 %jumpi(bn_msm_loop_add_a_nonzero) - POP -msm_loop_add_b: - //stack: accx, accy, i, retdest - DUP3 %bn_mload_wnaf_b - // stack: w, accx, accy, i, retdest - DUP1 %jumpi(bn_msm_loop_add_b_nonzero) - POP -msm_loop_contd: - %stack (accx, accy, i, retdest) -> (i, i, accx, accy, retdest) - // TODO: the GLV scalars for the BN curve are 127-bit, so could use 127 here. But this would require modifying `wnaf.asm`. Not sure it's worth it... - %eq_const(129) %jumpi(msm_end) - %increment - //stack: i+1, accx, accy, retdest - %stack (i, accx, accy, retdest) -> (accx, accy, bn_msm_loop, i, retdest) - %jump(bn_double) - -msm_end: - %stack (i, accx, accy, retdest) -> (retdest, accx, accy) - JUMP - -bn_msm_loop_add_a_nonzero: - %stack (w, accx, accy, i, retdest) -> (w, accx, accy, msm_loop_add_b, i, retdest) - %bn_mload_point_a - // stack: px, py, accx, accy, msm_loop_add_b, i, retdest - %jump(bn_add_valid_points) - -bn_msm_loop_add_b_nonzero: - %stack (w, accx, accy, i, retdest) -> (w, accx, accy, msm_loop_contd, i, retdest) - %bn_mload_point_b - // stack: px, py, accx, accy, msm_loop_contd, i, retdest - %jump(bn_add_valid_points) - -%macro bn_mload_wnaf_a - // stack: i - %mload_kernel(@SEGMENT_KERNEL_BN_WNAF_A) -%endmacro - -%macro bn_mload_wnaf_b - // stack: i - %mload_kernel(@SEGMENT_KERNEL_BN_WNAF_B) -%endmacro - -%macro bn_mload_point_a - // stack: w - DUP1 - %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - //stack: Gy, w - SWAP1 %decrement %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - //stack: Gx, Gy -%endmacro - -%macro bn_mload_point_b - // stack: w - DUP1 - %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - PUSH @BN_BNEG_LOC %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - %stack (bneg, Gy, w) -> (@BN_BASE, Gy, bneg, bneg, Gy, w) - SUB SWAP1 ISZERO MUL SWAP2 MUL ADD - SWAP1 %decrement %mload_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - //stack: Gx, Gy - PUSH @BN_GLV_BETA - MULFP254 -%endmacro diff --git a/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm b/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm deleted file mode 100644 index a8c6ada926..0000000000 --- a/evm/src/cpu/kernel/asm/curve/bn254/precomputation.asm +++ /dev/null @@ -1,35 +0,0 @@ -// Precompute a table of multiples of the BN254 point `Q = (Qx, Qy)`. -// Let `(Qxi, Qyi) = i * Q`, then store in the `SEGMENT_KERNEL_BN_TABLE_Q` segment of memory the values -// `i-1 => Qxi`, `i => Qyi if i < 16 else -Qy(32-i)` for `i in range(1, 32, 2)`. -global bn_precompute_table: - // stack: Qx, Qy, retdest - PUSH precompute_table_contd DUP3 DUP3 - %jump(bn_double) -precompute_table_contd: - // stack: Qx2, Qy2, Qx, Qy, retdest - PUSH 1 -bn_precompute_table_loop: - // stack i, Qx2, Qy2, Qx, Qy, retdest - PUSH 1 DUP2 SUB - %stack (im, i, Qx2, Qy2, Qx, Qy, retdest) -> (i, Qy, im, Qx, i, Qx2, Qy2, Qx, Qy, retdest) - %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - // stack: i, Qx2, Qy2, Qx, Qy, retdest - DUP1 PUSH 32 SUB PUSH 1 DUP2 SUB - // stack: 31-i, 32-i, i, Qx2, Qy2, Qx, Qy, retdest - DUP7 PUSH @BN_BASE SUB - // TODO: Could maybe avoid storing Qx a second time here, not sure if it would be more efficient. - %stack (Qyy, iii, ii, i, Qx2, Qy2, Qx, Qy, retdest) -> (iii, Qx, ii, Qyy, i, Qx2, Qy2, Qx, Qy, retdest) - %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) %mstore_kernel(@SEGMENT_KERNEL_BN_TABLE_Q) - // stack: i, Qx2, Qy2, Qx, Qy, retdest - PUSH 2 ADD - // stack: i+2, Qx2, Qy2, Qx, Qy, retdest - DUP1 PUSH 16 LT %jumpi(precompute_table_end) - %stack (i, Qx2, Qy2, Qx, Qy, retdest) -> (Qx, Qy, Qx2, Qy2, precompute_table_loop_contd, i, Qx2, Qy2, retdest) - %jump(bn_add_valid_points) -precompute_table_loop_contd: - %stack (Qx, Qy, i, Qx2, Qy2, retdest) -> (i, Qx2, Qy2, Qx, Qy, retdest) - %jump(bn_precompute_table_loop) - -precompute_table_end: - // stack: i, Qx2, Qy2, Qx, Qy, retdest - %pop5 JUMP diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index a12aaa9548..e1f2eddb8a 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -205,11 +205,7 @@ pub(crate) fn observe_public_values_target< impl, C: GenericConfig, const D: usize> AllProof { /// Computes all Fiat-Shamir challenges used in the STARK proof. - pub(crate) fn get_challenges( - &self, - all_stark: &AllStark, - config: &StarkConfig, - ) -> AllProofChallenges { + pub(crate) fn get_challenges(&self, config: &StarkConfig) -> AllProofChallenges { let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { @@ -221,14 +217,12 @@ impl, C: GenericConfig, const D: usize> A let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); - let lookups = all_stark.num_lookups_helper_columns(config); - AllProofChallenges { stark_challenges: core::array::from_fn(|i| { challenger.compact(); self.stark_proofs[i] .proof - .get_challenges(&mut challenger, lookups[i] > 0, config) + .get_challenges(&mut challenger, config) }), ctl_challenges, } @@ -257,7 +251,7 @@ impl, C: GenericConfig, const D: usize> A for i in 0..NUM_TABLES { self.stark_proofs[i] .proof - .get_challenges(&mut challenger, lookups[i] > 0, config); + .get_challenges(&mut challenger, config); challenger_states.push(challenger.compact()); } @@ -277,7 +271,6 @@ where pub(crate) fn get_challenges( &self, challenger: &mut Challenger, - stark_use_lookup: bool, config: &StarkConfig, ) -> StarkProofChallenges { let degree_bits = self.recover_degree_bits(config); @@ -298,9 +291,6 @@ where let num_challenges = config.num_challenges; - let lookup_challenges = - stark_use_lookup.then(|| challenger.get_n_challenges(config.num_challenges)); - challenger.observe_cap(auxiliary_polys_cap); let stark_alphas = challenger.get_n_challenges(num_challenges); @@ -311,7 +301,6 @@ where challenger.observe_openings(&openings.to_fri_openings()); StarkProofChallenges { - lookup_challenges, stark_alphas, stark_zeta, fri_challenges: challenger.fri_challenges::( @@ -330,7 +319,6 @@ impl StarkProofTarget { &self, builder: &mut CircuitBuilder, challenger: &mut RecursiveChallenger, - stark_use_lookup: bool, config: &StarkConfig, ) -> StarkProofChallengesTarget where @@ -352,9 +340,6 @@ impl StarkProofTarget { let num_challenges = config.num_challenges; - let lookup_challenges = - stark_use_lookup.then(|| challenger.get_n_challenges(builder, num_challenges)); - challenger.observe_cap(auxiliary_polys); let stark_alphas = challenger.get_n_challenges(builder, num_challenges); @@ -365,7 +350,6 @@ impl StarkProofTarget { challenger.observe_openings(&openings.to_fri_openings(builder.zero())); StarkProofChallengesTarget { - lookup_challenges, stark_alphas, stark_zeta, fri_challenges: challenger.fri_challenges( diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index b4ff4b84ed..e1b4bbb899 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -650,7 +650,9 @@ mod tests { use tiny_keccak::keccakf; use crate::config::StarkConfig; - use crate::cross_table_lookup::{CtlData, CtlZData, GrandProductChallenge}; + use crate::cross_table_lookup::{ + CtlData, CtlZData, GrandProductChallenge, GrandProductChallengeSet, + }; use crate::keccak::columns::reg_output_limb; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::prover::prove_single_table; @@ -766,7 +768,7 @@ mod tests { filter_column: None, }; let ctl_data = CtlData { - zs_columns: vec![ctl_z_data; config.num_challenges], + zs_columns: vec![ctl_z_data.clone(); config.num_challenges], }; prove_single_table( @@ -775,6 +777,9 @@ mod tests { &trace_poly_values, &trace_commitments, &ctl_data, + GrandProductChallengeSet { + challenges: vec![ctl_z_data.challenge; config.num_challenges], + }, &mut Challenger::new(), &mut timing, )?; diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index ad872a799e..2aa45a9d04 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -34,7 +34,7 @@ impl Lookup { } } -/// logUp protocol from https://ia.cr/2022/1530 (TODO link to newer version?) +/// logUp protocol from https://ia.cr/2022/1530 /// Compute the helper columns for the lookup argument. /// Given columns `f0,...,fk` and a column `t`, such that `∪fi ⊆ t`, and challenges `x`, /// this computes the helper columns `h_i = 1/(x+f_2i) + 1/(x+f_2i+1)`, `g = 1/(x+t)`, @@ -55,10 +55,10 @@ pub(crate) fn lookup_helper_columns( // For each batch of `constraint_degree-1` columns `fi`, compute `sum 1/(f_i+challenge)` and // add it to the helper columns. // TODO: This does one batch inversion per column. It would also be possible to do one batch inversion - // for every column, but that would require building a big vector of all the columns concatenated. + // for every group of columns, but that would require building a big vector of all the columns concatenated. // Not sure which approach is better. // Note: these are the h_k(x) polynomials in the paper, with a few differences: - // * Here, the first ratio m_0(x)/phi_0(x) is not included with the columns batched up to create the + // * Here, the first ratio m_0(x)/phi_0(x) is not included with the columns batched up to create the // h_k polynomials; instead there's a separate helper column for it (see below). // * Here, we use 1 instead of -1 as the numerator (and subtract later). // * Here, for now, the batch size (l) is always constraint_degree - 1 = 2. diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 1935af550d..08c092a766 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -137,7 +137,7 @@ impl, const D: usize> MemoryStark { trace_rows } - /// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a + /// Generates the `COUNTER`, `RANGE_CHECK` and `FREQUENCIES` columns, given a /// trace in column-major form. fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { let height = trace_col_vecs[0].len(); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 23446ac4c8..72515f627d 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -674,9 +674,6 @@ impl StarkProofTarget { } pub(crate) struct StarkProofChallenges, const D: usize> { - /// Randomness used in lookup arguments. - pub lookup_challenges: Option>, - /// Random values used to combine STARK constraints. pub stark_alphas: Vec, @@ -687,7 +684,6 @@ pub(crate) struct StarkProofChallenges, const D: us } pub(crate) struct StarkProofChallengesTarget { - pub lookup_challenges: Option>, pub stark_alphas: Vec, pub stark_zeta: ExtensionTarget, pub fri_challenges: FriChallengesTarget, diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 7a8439db45..77ba45224f 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -28,6 +28,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::{ cross_table_lookup_data, get_grand_product_challenge_set, CtlCheckVars, CtlData, + GrandProductChallengeSet, }; use crate::generation::outputs::GenerationOutputs; use crate::generation::{generate_traces, GenerationInputs}; @@ -172,6 +173,7 @@ where trace_commitments, ctl_data_per_table, &mut challenger, + ctl_challenges.clone(), timing )? ); @@ -190,6 +192,7 @@ fn prove_with_commitments( trace_commitments: Vec>, ctl_data_per_table: [CtlData; NUM_TABLES], challenger: &mut Challenger, + ctl_challenges: GrandProductChallengeSet, timing: &mut TimingTree, ) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where @@ -212,6 +215,7 @@ where &trace_poly_values[Table::Arithmetic as usize], &trace_commitments[Table::Arithmetic as usize], &ctl_data_per_table[Table::Arithmetic as usize], + ctl_challenges.clone(), challenger, timing, )? @@ -238,6 +242,7 @@ where &trace_poly_values[Table::Cpu as usize], &trace_commitments[Table::Cpu as usize], &ctl_data_per_table[Table::Cpu as usize], + ctl_challenges.clone(), challenger, timing, )? @@ -251,6 +256,7 @@ where &trace_poly_values[Table::Keccak as usize], &trace_commitments[Table::Keccak as usize], &ctl_data_per_table[Table::Keccak as usize], + ctl_challenges.clone(), challenger, timing, )? @@ -264,6 +270,7 @@ where &trace_poly_values[Table::KeccakSponge as usize], &trace_commitments[Table::KeccakSponge as usize], &ctl_data_per_table[Table::KeccakSponge as usize], + ctl_challenges.clone(), challenger, timing, )? @@ -277,6 +284,7 @@ where &trace_poly_values[Table::Logic as usize], &trace_commitments[Table::Logic as usize], &ctl_data_per_table[Table::Logic as usize], + ctl_challenges.clone(), challenger, timing, )? @@ -290,6 +298,7 @@ where &trace_poly_values[Table::Memory as usize], &trace_commitments[Table::Memory as usize], &ctl_data_per_table[Table::Memory as usize], + ctl_challenges, challenger, timing, )? @@ -313,6 +322,7 @@ pub(crate) fn prove_single_table( trace_poly_values: &[PolynomialValues], trace_commitment: &PolynomialBatch, ctl_data: &CtlData, + ctl_challenges: GrandProductChallengeSet, challenger: &mut Challenger, timing: &mut TimingTree, ) -> Result> @@ -335,9 +345,13 @@ where let init_challenger_state = challenger.compact(); let constraint_degree = stark.constraint_degree(); - let lookup_challenges = stark - .uses_lookups() - .then(|| challenger.get_n_challenges(config.num_challenges)); + let lookup_challenges = stark.uses_lookups().then(|| { + ctl_challenges + .challenges + .iter() + .map(|ch| ch.beta) + .collect::>() + }); let lookups = stark.lookups(); let lookup_helper_columns = timed!( timing, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 45cc0c485c..6abd61e770 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -339,12 +339,8 @@ where })); let mut challenger = RecursiveChallenger::::from_state(init_challenger_state_target); - let challenges = proof_target.get_challenges::( - &mut builder, - &mut challenger, - num_lookup_columns > 0, - inner_config, - ); + let challenges = + proof_target.get_challenges::(&mut builder, &mut challenger, inner_config); let challenger_state = challenger.compact(&mut builder); builder.register_public_inputs(challenger_state.as_ref()); @@ -438,10 +434,17 @@ fn verify_stark_proof_with_challenges_circuit< ); let num_lookup_columns = stark.num_lookup_helper_columns(inner_config); + let lookup_challenges = (num_lookup_columns > 0).then(|| { + ctl_vars + .iter() + .map(|ch| ch.challenges.beta) + .collect::>() + }); + let lookup_vars = stark.uses_lookups().then(|| LookupCheckVarsTarget { local_values: auxiliary_polys[..num_lookup_columns].to_vec(), next_values: auxiliary_polys_next[..num_lookup_columns].to_vec(), - challenges: challenges.lookup_challenges.clone().unwrap(), + challenges: lookup_challenges.unwrap(), }); with_context!( diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index cf1c2e3659..7195486ac2 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -50,7 +50,7 @@ where let AllProofChallenges { stark_challenges, ctl_challenges, - } = all_proof.get_challenges(all_stark, config); + } = all_proof.get_challenges(config); let num_lookup_columns = all_stark.num_lookups_helper_columns(config); @@ -331,10 +331,17 @@ where l_last, ); let num_lookup_columns = stark.num_lookup_helper_columns(config); + let lookup_challenges = (num_lookup_columns > 0).then(|| { + ctl_vars + .iter() + .map(|ch| ch.challenges.beta) + .collect::>() + }); + let lookup_vars = stark.uses_lookups().then(|| LookupCheckVars { local_values: auxiliary_polys[..num_lookup_columns].to_vec(), next_values: auxiliary_polys_next[..num_lookup_columns].to_vec(), - challenges: challenges.lookup_challenges.clone().unwrap(), + challenges: lookup_challenges.unwrap(), }); let lookups = stark.lookups(); eval_vanishing_poly::( From 17f661f90fa8520a76f318a8f264a889c036c4eb Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Thu, 14 Sep 2023 17:35:29 +0100 Subject: [PATCH 04/34] Fix BytePacking range-check. Fix lookup challenges --- evm/src/byte_packing/byte_packing_stark.rs | 39 ++++++++++++---------- evm/src/byte_packing/columns.rs | 5 ++- evm/src/keccak/keccak_stark.rs | 2 +- evm/src/prover.rs | 17 +++++----- evm/src/verifier.rs | 17 ++++++++-- 5 files changed, 48 insertions(+), 32 deletions(-) diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs index f97a2b28ab..9fba11bcdd 100644 --- a/evm/src/byte_packing/byte_packing_stark.rs +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -47,11 +47,11 @@ use plonky2::util::transpose; use super::NUM_BYTES; use crate::byte_packing::columns::{ index_bytes, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, BYTE_INDICES_COLS, IS_READ, - NUM_COLUMNS, RANGE_COUNTER, RC_COLS, SEQUENCE_END, SEQUENCE_LEN, TIMESTAMP, + NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, SEQUENCE_END, SEQUENCE_LEN, TIMESTAMP, }; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; -use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; +use crate::lookup::Lookup; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryAddress; @@ -240,11 +240,18 @@ impl, const D: usize> BytePackingStark { // For each column c in cols, generate the range-check // permutations and put them in the corresponding range-check // columns rc_c and rc_c+1. - for (i, rc_c) in (0..NUM_BYTES).zip(RC_COLS.step_by(2)) { - let c = value_bytes(i); - let (col_perm, table_perm) = permuted_cols(&cols[c], &cols[RANGE_COUNTER]); - cols[rc_c].copy_from_slice(&col_perm); - cols[rc_c + 1].copy_from_slice(&table_perm); + for col in 0..NUM_BYTES { + for i in 0..n_rows { + let c = value_bytes(col); + let x = cols[c][i].to_canonical_u64() as usize; + assert!( + x < BYTE_RANGE_MAX, + "column value {} exceeds the max range value {}", + x, + BYTE_RANGE_MAX + ); + cols[RC_FREQUENCIES][x] += F::ONE; + } } } @@ -291,11 +298,6 @@ impl, const D: usize> Stark for BytePackingSt FE: FieldExtension, P: PackedField, { - // Range check all the columns - for col in RC_COLS.step_by(2) { - eval_lookups(vars, yield_constr, col, col + 1); - } - let one = P::ONES; // We filter active columns by summing all the byte indices. @@ -417,11 +419,6 @@ impl, const D: usize> Stark for BytePackingSt vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { - // Range check all the columns - for col in RC_COLS.step_by(2) { - eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); - } - // We filter active columns by summing all the byte indices. // Constraining each of them to be boolean is done later on below. let current_filter = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); @@ -569,6 +566,14 @@ impl, const D: usize> Stark for BytePackingSt fn constraint_degree(&self) -> usize { 3 } + + fn lookups(&self) -> Vec { + vec![Lookup { + columns: (value_bytes(0)..value_bytes(NUM_BYTES)).collect(), + table_column: RANGE_COUNTER, + frequencies_column: RC_FREQUENCIES, + }] + } } #[cfg(test)] diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs index f04f450c51..1e62b7e8a1 100644 --- a/evm/src/byte_packing/columns.rs +++ b/evm/src/byte_packing/columns.rs @@ -42,7 +42,6 @@ pub(crate) const fn value_bytes(i: usize) -> usize { // The two permutations associated to the byte in column i will be in // columns RC_COLS[2i] and RC_COLS[2i+1]. pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; -pub(crate) const NUM_RANGE_CHECK_COLS: usize = 1 + 2 * NUM_BYTES; -pub(crate) const RC_COLS: Range = RANGE_COUNTER + 1..RANGE_COUNTER + NUM_RANGE_CHECK_COLS; +pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; -pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + NUM_RANGE_CHECK_COLS; +pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + 2; diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index e1b4bbb899..329917dc4e 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -777,7 +777,7 @@ mod tests { &trace_poly_values, &trace_commitments, &ctl_data, - GrandProductChallengeSet { + &GrandProductChallengeSet { challenges: vec![ctl_z_data.challenge; config.num_challenges], }, &mut Challenger::new(), diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 77ba45224f..10172005ac 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -173,7 +173,7 @@ where trace_commitments, ctl_data_per_table, &mut challenger, - ctl_challenges.clone(), + &ctl_challenges, timing )? ); @@ -192,7 +192,7 @@ fn prove_with_commitments( trace_commitments: Vec>, ctl_data_per_table: [CtlData; NUM_TABLES], challenger: &mut Challenger, - ctl_challenges: GrandProductChallengeSet, + ctl_challenges: &GrandProductChallengeSet, timing: &mut TimingTree, ) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where @@ -215,7 +215,7 @@ where &trace_poly_values[Table::Arithmetic as usize], &trace_commitments[Table::Arithmetic as usize], &ctl_data_per_table[Table::Arithmetic as usize], - ctl_challenges.clone(), + ctl_challenges, challenger, timing, )? @@ -229,6 +229,7 @@ where &trace_poly_values[Table::BytePacking as usize], &trace_commitments[Table::BytePacking as usize], &ctl_data_per_table[Table::BytePacking as usize], + ctl_challenges, challenger, timing, )? @@ -242,7 +243,7 @@ where &trace_poly_values[Table::Cpu as usize], &trace_commitments[Table::Cpu as usize], &ctl_data_per_table[Table::Cpu as usize], - ctl_challenges.clone(), + ctl_challenges, challenger, timing, )? @@ -256,7 +257,7 @@ where &trace_poly_values[Table::Keccak as usize], &trace_commitments[Table::Keccak as usize], &ctl_data_per_table[Table::Keccak as usize], - ctl_challenges.clone(), + ctl_challenges, challenger, timing, )? @@ -270,7 +271,7 @@ where &trace_poly_values[Table::KeccakSponge as usize], &trace_commitments[Table::KeccakSponge as usize], &ctl_data_per_table[Table::KeccakSponge as usize], - ctl_challenges.clone(), + ctl_challenges, challenger, timing, )? @@ -284,7 +285,7 @@ where &trace_poly_values[Table::Logic as usize], &trace_commitments[Table::Logic as usize], &ctl_data_per_table[Table::Logic as usize], - ctl_challenges.clone(), + ctl_challenges, challenger, timing, )? @@ -322,7 +323,7 @@ pub(crate) fn prove_single_table( trace_poly_values: &[PolynomialValues], trace_commitment: &PolynomialBatch, ctl_data: &CtlData, - ctl_challenges: GrandProductChallengeSet, + ctl_challenges: &GrandProductChallengeSet, challenger: &mut Challenger, timing: &mut TimingTree, ) -> Result> diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 7195486ac2..e4c277fd40 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -17,7 +17,9 @@ use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars, GrandProductChallenge}; +use crate::cross_table_lookup::{ + verify_cross_table_lookups, CtlCheckVars, GrandProductChallenge, GrandProductChallengeSet, +}; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; @@ -77,6 +79,7 @@ where &all_proof.stark_proofs[Table::Arithmetic as usize].proof, &stark_challenges[Table::Arithmetic as usize], &ctl_vars_per_table[Table::Arithmetic as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -84,6 +87,7 @@ where &all_proof.stark_proofs[Table::BytePacking as usize].proof, &stark_challenges[Table::BytePacking as usize], &ctl_vars_per_table[Table::BytePacking as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -91,6 +95,7 @@ where &all_proof.stark_proofs[Table::Cpu as usize].proof, &stark_challenges[Table::Cpu as usize], &ctl_vars_per_table[Table::Cpu as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -98,6 +103,7 @@ where &all_proof.stark_proofs[Table::Keccak as usize].proof, &stark_challenges[Table::Keccak as usize], &ctl_vars_per_table[Table::Keccak as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -105,6 +111,7 @@ where &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &stark_challenges[Table::KeccakSponge as usize], &ctl_vars_per_table[Table::KeccakSponge as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -112,6 +119,7 @@ where &all_proof.stark_proofs[Table::Logic as usize].proof, &stark_challenges[Table::Logic as usize], &ctl_vars_per_table[Table::Logic as usize], + &ctl_challenges, config, )?; verify_stark_proof_with_challenges( @@ -119,6 +127,7 @@ where &all_proof.stark_proofs[Table::Memory as usize].proof, &stark_challenges[Table::Memory as usize], &ctl_vars_per_table[Table::Memory as usize], + &ctl_challenges, config, )?; @@ -296,6 +305,7 @@ pub(crate) fn verify_stark_proof_with_challenges< proof: &StarkProof, challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], + ctl_challenges: &GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> where @@ -332,9 +342,10 @@ where ); let num_lookup_columns = stark.num_lookup_helper_columns(config); let lookup_challenges = (num_lookup_columns > 0).then(|| { - ctl_vars + ctl_challenges + .challenges .iter() - .map(|ch| ch.challenges.beta) + .map(|ch| ch.beta) .collect::>() }); From 9ab8a11887312160d35d59ef1f29523fd7a27be0 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Thu, 14 Sep 2023 19:15:43 +0100 Subject: [PATCH 05/34] Remove one helper function --- evm/src/lookup.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index 2aa45a9d04..1c25700b1f 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -30,7 +30,7 @@ impl Lookup { pub(crate) fn num_helper_columns(&self, constraint_degree: usize) -> usize { // One helper column for each column batch of size `constraint_degree-1`, // then one column for the inverse of `table + challenge` and one for the `Z` polynomial. - ceil_div_usize(self.columns.len(), constraint_degree - 1) + 2 + ceil_div_usize(self.columns.len(), constraint_degree - 1) + 1 } } @@ -89,7 +89,7 @@ pub(crate) fn lookup_helper_columns( for x in table.iter_mut() { *x = challenge + *x; } - helper_columns.push(F::batch_multiplicative_inverse(&table).into()); + let table_inverse: Vec = F::batch_multiplicative_inverse(&table).into(); // Compute the `Z` polynomial with `Z(1)=0` and `Z(gx) = Z(x) + sum h_i(x) - frequencies(x)g(x)`. // This enforces the check from the paper, that the sum of the h_k(x) polynomials is 0 over H. @@ -99,11 +99,11 @@ pub(crate) fn lookup_helper_columns( let mut z = Vec::with_capacity(frequencies.len()); z.push(F::ZERO); for i in 0..frequencies.len() - 1 { - let x = helper_columns[..num_helper_columns - 2] + let x = helper_columns[..num_helper_columns - 1] .iter() .map(|col| col.values[i]) .sum::() - - frequencies[i] * helper_columns[num_helper_columns - 2].values[i]; + - frequencies[i] * table_inverse[i]; z.push(z[i] + x); } helper_columns.push(z.into()); @@ -158,20 +158,17 @@ pub(crate) fn eval_lookups_checks( _ => todo!("Allow other constraint degrees."), } } - // Check that the penultimate helper column contains `1/(table+challenge)`. - let x = lookup_vars.local_values[start + num_helper_columns - 2]; - let x = x * (vars.local_values[lookup.table_column] + challenge); - yield_constr.constraint(x - P::ONES); // Check the `Z` polynomial. let z = lookup_vars.local_values[start + num_helper_columns - 1]; let next_z = lookup_vars.next_values[start + num_helper_columns - 1]; - let y = lookup_vars.local_values[start..start + num_helper_columns - 2] + let table_with_challenge = vars.local_values[lookup.table_column] + challenge; + let y = lookup_vars.local_values[start..start + num_helper_columns - 1] .iter() .fold(P::ZEROS, |acc, x| acc + *x) - - vars.local_values[lookup.frequencies_column] - * lookup_vars.local_values[start + num_helper_columns - 2]; - yield_constr.constraint(next_z - z - y); + * table_with_challenge + - vars.local_values[lookup.frequencies_column]; + yield_constr.constraint((next_z - z) * table_with_challenge - y); start += num_helper_columns; } } @@ -224,23 +221,21 @@ pub(crate) fn eval_lookups_checks_circuit< _ => todo!("Allow other constraint degrees."), } } - let x = lookup_vars.local_values[start + num_helper_columns - 2]; - let tmp = builder.add_extension(vars.local_values[lookup.table_column], challenge); - let x = builder.mul_sub_extension(x, tmp, one); - yield_constr.constraint(builder, x); let z = lookup_vars.local_values[start + num_helper_columns - 1]; let next_z = lookup_vars.next_values[start + num_helper_columns - 1]; - let y = builder.add_many_extension( - &lookup_vars.local_values[start..start + num_helper_columns - 2], + let table_with_challenge = + builder.add_extension(vars.local_values[lookup.table_column], challenge); + let mut y = builder.add_many_extension( + &lookup_vars.local_values[start..start + num_helper_columns - 1], ); - let tmp = builder.mul_extension( - vars.local_values[lookup.frequencies_column], - lookup_vars.local_values[start + num_helper_columns - 2], - ); - let y = builder.sub_extension(y, tmp); - let constraint = builder.sub_extension(next_z, z); - let constraint = builder.sub_extension(constraint, y); + + y = builder.mul_extension(y, table_with_challenge); + y = builder.sub_extension(y, vars.local_values[lookup.frequencies_column]); + + let mut constraint = builder.sub_extension(next_z, z); + constraint = builder.mul_extension(constraint, table_with_challenge); + constraint = builder.sub_extension(constraint, y); yield_constr.constraint(builder, constraint); start += num_helper_columns; } From c5af894e3fbf3c6d87a931fbdf3c82c872b64bd8 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Thu, 14 Sep 2023 22:57:52 +0100 Subject: [PATCH 06/34] Add assert with char(F). Cleanup. Fix recursive challenges. --- evm/src/lookup.rs | 9 +++++++-- evm/src/recursive_verifier.rs | 7 +++++-- evm/src/vanishing_poly.rs | 12 +++++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index 1c25700b1f..b886e73657 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -1,4 +1,5 @@ use itertools::Itertools; +use num_bigint::BigUint; use plonky2::field::batch_util::batch_add_inplace; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; @@ -49,6 +50,10 @@ pub(crate) fn lookup_helper_columns( constraint_degree, 3, "TODO: Allow other constraint degrees." ); + + let num_total_logup_entries = trace_poly_values[0].values.len() * lookup.columns.len(); + assert!(BigUint::from(num_total_logup_entries) < F::characteristic()); + let num_helper_columns = lookup.num_helper_columns(constraint_degree); let mut helper_columns: Vec> = Vec::with_capacity(num_helper_columns); @@ -123,7 +128,7 @@ where } /// Constraints for the logUp lookup argument. -pub(crate) fn eval_lookups_checks( +pub(crate) fn eval_packed_lookups_generic( stark: &S, lookups: &[Lookup], vars: StarkEvaluationVars, @@ -180,7 +185,7 @@ pub struct LookupCheckVarsTarget { pub(crate) challenges: Vec, } -pub(crate) fn eval_lookups_checks_circuit< +pub(crate) fn eval_ext_lookups_circuit< F: RichField + Extendable, S: Stark, const D: usize, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 6abd61e770..3539f35be4 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -352,6 +352,7 @@ where &proof_target, &challenges, &ctl_vars, + &ctl_challenges_target, inner_config, ); @@ -395,6 +396,7 @@ fn verify_stark_proof_with_challenges_circuit< proof: &StarkProofTarget, challenges: &StarkProofChallengesTarget, ctl_vars: &[CtlCheckVarsTarget], + ctl_challenges: &GrandProductChallengeSet, inner_config: &StarkConfig, ) where C::Hasher: AlgebraicHasher, @@ -435,9 +437,10 @@ fn verify_stark_proof_with_challenges_circuit< let num_lookup_columns = stark.num_lookup_helper_columns(inner_config); let lookup_challenges = (num_lookup_columns > 0).then(|| { - ctl_vars + ctl_challenges + .challenges .iter() - .map(|ch| ch.challenges.beta) + .map(|ch| ch.beta) .collect::>() }); diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index 21d361674f..a395112037 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -9,7 +9,7 @@ use crate::cross_table_lookup::{ CtlCheckVarsTarget, }; use crate::lookup::{ - eval_lookups_checks, eval_lookups_checks_circuit, Lookup, LookupCheckVars, + eval_ext_lookups_circuit, eval_packed_lookups_generic, Lookup, LookupCheckVars, LookupCheckVarsTarget, }; use crate::stark::Stark; @@ -30,7 +30,13 @@ pub(crate) fn eval_vanishing_poly( { stark.eval_packed_generic(vars, consumer); if let Some(lookup_vars) = lookup_vars { - eval_lookups_checks::(stark, lookups, vars, lookup_vars, consumer); + eval_packed_lookups_generic::( + stark, + lookups, + vars, + lookup_vars, + consumer, + ); } eval_cross_table_lookup_checks::(vars, ctl_vars, consumer); } @@ -49,7 +55,7 @@ pub(crate) fn eval_vanishing_poly_circuit( { stark.eval_ext_circuit(builder, vars, consumer); if let Some(lookup_vars) = lookup_vars { - eval_lookups_checks_circuit::(builder, stark, vars, lookup_vars, consumer); + eval_ext_lookups_circuit::(builder, stark, vars, lookup_vars, consumer); } eval_cross_table_lookup_checks_circuit::(builder, vars, ctl_vars, consumer); } From 7dc2a7744d0762ab2833d9e6eaa86e718d92dbcc Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Fri, 15 Sep 2023 09:00:11 +0100 Subject: [PATCH 07/34] Cleanup --- evm/src/lookup.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index b886e73657..31ba4ae281 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -94,7 +94,7 @@ pub(crate) fn lookup_helper_columns( for x in table.iter_mut() { *x = challenge + *x; } - let table_inverse: Vec = F::batch_multiplicative_inverse(&table).into(); + let table_inverse: Vec = F::batch_multiplicative_inverse(&table); // Compute the `Z` polynomial with `Z(1)=0` and `Z(gx) = Z(x) + sum h_i(x) - frequencies(x)g(x)`. // This enforces the check from the paper, that the sum of the h_k(x) polynomials is 0 over H. From 9697c906f23c3ede53b2b495ab0db590305a0a58 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Fri, 15 Sep 2023 09:16:06 +0100 Subject: [PATCH 08/34] Clippy --- evm/src/arithmetic/arithmetic_stark.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index f7269d1e70..58f7c400ad 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -1,7 +1,6 @@ use std::marker::PhantomData; use std::ops::Range; -use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; From 66f935a7488e0bc862725102aac7a55bf6d83d64 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Fri, 15 Sep 2023 15:34:06 +0100 Subject: [PATCH 09/34] Remove where clauses: [(); CpuStark::::COLUMNS] --- evm/src/fixed_recursive_verifier.rs | 2 -- evm/src/prover.rs | 5 ----- evm/src/verifier.rs | 2 -- 3 files changed, 9 deletions(-) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 33021c3952..b16fde3aa5 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -31,7 +31,6 @@ use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; -use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{ get_grand_product_challenge_set_target, verify_cross_table_lookups_circuit, CrossTableLookup, GrandProductChallengeSet, @@ -302,7 +301,6 @@ where C::Hasher: AlgebraicHasher, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 10172005ac..3486727d3d 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -24,7 +24,6 @@ use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; -use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::{ cross_table_lookup_data, get_grand_product_challenge_set, CtlCheckVars, CtlData, @@ -55,7 +54,6 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -78,7 +76,6 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -107,7 +104,6 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -200,7 +196,6 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index e4c277fd40..fa21e56655 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -15,7 +15,6 @@ use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; -use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cross_table_lookup::{ verify_cross_table_lookups, CtlCheckVars, GrandProductChallenge, GrandProductChallengeSet, @@ -43,7 +42,6 @@ pub fn verify_proof, C: GenericConfig, co where [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, From ec9e6196781bd71ca4d46df5922d0a7a7179920d Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Fri, 15 Sep 2023 10:57:32 -0400 Subject: [PATCH 10/34] Fix range --- evm/src/byte_packing/byte_packing_stark.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs index 9fba11bcdd..fb941dc7b1 100644 --- a/evm/src/byte_packing/byte_packing_stark.rs +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -569,7 +569,7 @@ impl, const D: usize> Stark for BytePackingSt fn lookups(&self) -> Vec { vec![Lookup { - columns: (value_bytes(0)..value_bytes(NUM_BYTES)).collect(), + columns: (value_bytes(0)..value_bytes(0) + NUM_BYTES).collect(), table_column: RANGE_COUNTER, frequencies_column: RC_FREQUENCIES, }] From 4f0330adea5b68bfc9600adc9ae2f49689d99528 Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Fri, 15 Sep 2023 11:12:14 -0400 Subject: [PATCH 11/34] Update clippy in CI --- .github/workflows/continuous-integration-workflow.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index ba2ad1bd9b..a0ac3ec727 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -124,5 +124,5 @@ jobs: command: clippy args: --all-features --all-targets -- -D warnings -A incomplete-features env: - CARGO_INCREMENTAL: 1 - + # Seems necessary until https://github.com/rust-lang/rust/pull/115819 is merged. + CARGO_INCREMENTAL: 0 From a9b7b5a62f2ab03018d25bec7fdf99aa27c5caad Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Fri, 15 Sep 2023 17:53:44 -0400 Subject: [PATCH 12/34] Revert "Remove where clauses: [(); CpuStark::::COLUMNS]" This reverts commit 66f935a7488e0bc862725102aac7a55bf6d83d64. --- evm/src/fixed_recursive_verifier.rs | 2 ++ evm/src/prover.rs | 5 +++++ evm/src/verifier.rs | 2 ++ 3 files changed, 9 insertions(+) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index b16fde3aa5..33021c3952 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -31,6 +31,7 @@ use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; +use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{ get_grand_product_challenge_set_target, verify_cross_table_lookups_circuit, CrossTableLookup, GrandProductChallengeSet, @@ -301,6 +302,7 @@ where C::Hasher: AlgebraicHasher, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 3486727d3d..10172005ac 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -24,6 +24,7 @@ use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; +use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::{ cross_table_lookup_data, get_grand_product_challenge_set, CtlCheckVars, CtlData, @@ -54,6 +55,7 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -76,6 +78,7 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -104,6 +107,7 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, @@ -196,6 +200,7 @@ where C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index fa21e56655..e4c277fd40 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -15,6 +15,7 @@ use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; +use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cross_table_lookup::{ verify_cross_table_lookups, CtlCheckVars, GrandProductChallenge, GrandProductChallengeSet, @@ -42,6 +43,7 @@ pub fn verify_proof, C: GenericConfig, co where [(); ArithmeticStark::::COLUMNS]:, [(); BytePackingStark::::COLUMNS]:, + [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, From f438d45f069d615e5b00266199996e7eeeb0b193 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Wed, 20 Sep 2023 12:45:14 -0400 Subject: [PATCH 13/34] Merge branch 'main' into 'new-logup'. --- evm/src/byte_packing/byte_packing_stark.rs | 81 +++++++++------------- evm/src/byte_packing/columns.rs | 9 +-- evm/src/cpu/kernel/asm/memory/syscalls.asm | 80 ++++----------------- evm/src/cross_table_lookup.rs | 4 +- evm/src/fixed_recursive_verifier.rs | 14 ++-- evm/src/proof.rs | 34 +++++++++ plonky2/src/plonk/prover.rs | 29 ++++++-- 7 files changed, 113 insertions(+), 138 deletions(-) diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs index fb941dc7b1..3e8d35bb63 100644 --- a/evm/src/byte_packing/byte_packing_stark.rs +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -47,7 +47,7 @@ use plonky2::util::transpose; use super::NUM_BYTES; use crate::byte_packing::columns::{ index_bytes, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, BYTE_INDICES_COLS, IS_READ, - NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, SEQUENCE_END, SEQUENCE_LEN, TIMESTAMP, + NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, SEQUENCE_END, TIMESTAMP, }; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; @@ -76,15 +76,16 @@ pub(crate) fn ctl_looked_data() -> Vec> { }) .collect(); - Column::singles([ - ADDR_CONTEXT, - ADDR_SEGMENT, - ADDR_VIRTUAL, - SEQUENCE_LEN, - TIMESTAMP, - ]) - .chain(outputs) - .collect() + // This will correspond to the actual sequence length when the `SEQUENCE_END` flag is on. + let sequence_len: Column = Column::linear_combination( + (0..NUM_BYTES).map(|i| (index_bytes(i), F::from_canonical_usize(i + 1))), + ); + + Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]) + .chain([sequence_len]) + .chain(Column::singles(&[TIMESTAMP])) + .chain(outputs) + .collect() } pub fn ctl_looked_filter() -> Column { @@ -202,7 +203,6 @@ impl, const D: usize> BytePackingStark { row[ADDR_VIRTUAL] = F::from_canonical_usize(virt + bytes.len() - 1); row[TIMESTAMP] = F::from_canonical_usize(timestamp); - row[SEQUENCE_LEN] = F::from_canonical_usize(bytes.len()); for (i, &byte) in bytes.iter().rev().enumerate() { if i == bytes.len() - 1 { @@ -356,27 +356,20 @@ impl, const D: usize> Stark for BytePackingSt current_sequence_end * next_filter * (next_sequence_start - one), ); - // The remaining length of a byte sequence must decrease by one or be zero. - let current_sequence_length = vars.local_values[SEQUENCE_LEN]; + // The active position in a byte sequence must increase by one on every row + // or be one on the next row (i.e. at the start of a new sequence). let current_position = self.get_active_position(vars.local_values); let next_position = self.get_active_position(vars.next_values); - let current_remaining_length = current_sequence_length - current_position; - let next_sequence_length = vars.next_values[SEQUENCE_LEN]; - let next_remaining_length = next_sequence_length - next_position; yield_constr.constraint_transition( - current_remaining_length * (current_remaining_length - next_remaining_length - one), - ); - - // At the start of a sequence, the remaining length must be equal to the starting length minus one - yield_constr.constraint( - current_sequence_start * (current_sequence_length - current_remaining_length - one), + next_filter * (next_position - one) * (next_position - current_position - one), ); - // The remaining length on the last row must be zero. - yield_constr.constraint_last_row(current_remaining_length); + // The last row must be the end of a sequence or a padding row. + yield_constr.constraint_last_row(current_filter * (current_sequence_end - one)); - // If the current remaining length is zero, the end flag must be one. - yield_constr.constraint(current_remaining_length * current_sequence_end); + // If the next position is one in an active row, the current end flag must be one. + yield_constr + .constraint_transition(next_filter * current_sequence_end * (next_position - one)); // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. // The virtual address must decrement by one at each step of a sequence. @@ -486,36 +479,26 @@ impl, const D: usize> Stark for BytePackingSt let constraint = builder.mul_extension(next_filter, constraint); yield_constr.constraint_transition(builder, constraint); - // The remaining length of a byte sequence must decrease by one or be zero. - let current_sequence_length = vars.local_values[SEQUENCE_LEN]; - let next_sequence_length = vars.next_values[SEQUENCE_LEN]; + // The active position in a byte sequence must increase by one on every row + // or be one on the next row (i.e. at the start of a new sequence). let current_position = self.get_active_position_circuit(builder, vars.local_values); let next_position = self.get_active_position_circuit(builder, vars.next_values); - let current_remaining_length = - builder.sub_extension(current_sequence_length, current_position); - let next_remaining_length = builder.sub_extension(next_sequence_length, next_position); - let length_diff = builder.sub_extension(current_remaining_length, next_remaining_length); - let constraint = builder.mul_sub_extension( - current_remaining_length, - length_diff, - current_remaining_length, - ); + let position_diff = builder.sub_extension(next_position, current_position); + let is_new_or_inactive = builder.mul_sub_extension(next_filter, next_position, next_filter); + let constraint = + builder.mul_sub_extension(is_new_or_inactive, position_diff, is_new_or_inactive); yield_constr.constraint_transition(builder, constraint); - // At the start of a sequence, the remaining length must be equal to the starting length minus one - let current_sequence_length = vars.local_values[SEQUENCE_LEN]; - let length_diff = builder.sub_extension(current_sequence_length, current_remaining_length); + // The last row must be the end of a sequence or a padding row. let constraint = - builder.mul_sub_extension(current_sequence_start, length_diff, current_sequence_start); - yield_constr.constraint(builder, constraint); + builder.mul_sub_extension(current_filter, current_sequence_end, current_filter); + yield_constr.constraint_last_row(builder, constraint); - // The remaining length on the last row must be zero. - yield_constr.constraint_last_row(builder, current_remaining_length); - - // If the current remaining length is zero, the end flag must be one. - let constraint = builder.mul_extension(current_remaining_length, current_sequence_end); - yield_constr.constraint(builder, constraint); + // If the next position is one in an active row, the current end flag must be one. + let constraint = builder.mul_extension(next_filter, current_sequence_end); + let constraint = builder.mul_sub_extension(constraint, next_position, constraint); + yield_constr.constraint_transition(builder, constraint); // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. // The virtual address must decrement by one at each step of a sequence. diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs index 1e62b7e8a1..fdaa46211f 100644 --- a/evm/src/byte_packing/columns.rs +++ b/evm/src/byte_packing/columns.rs @@ -16,7 +16,8 @@ pub(crate) const fn index_bytes(i: usize) -> usize { BYTES_INDICES_START + i } -// Note: Those are used as filter for distinguishing active vs padding rows. +// Note: Those are used as filter for distinguishing active vs padding rows, +// and also to obtain the length of a sequence of bytes being processed. pub(crate) const BYTE_INDICES_COLS: Range = BYTES_INDICES_START..BYTES_INDICES_START + NUM_BYTES; @@ -25,12 +26,8 @@ pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; pub(crate) const TIMESTAMP: usize = ADDR_VIRTUAL + 1; -/// The total length of a sequence of bytes. -/// Cannot be greater than 32. -pub(crate) const SEQUENCE_LEN: usize = TIMESTAMP + 1; - // 32 byte limbs hold a total of 256 bits. -const BYTES_VALUES_START: usize = SEQUENCE_LEN + 1; +const BYTES_VALUES_START: usize = TIMESTAMP + 1; pub(crate) const fn value_bytes(i: usize) -> usize { debug_assert!(i < NUM_BYTES); BYTES_VALUES_START + i diff --git a/evm/src/cpu/kernel/asm/memory/syscalls.asm b/evm/src/cpu/kernel/asm/memory/syscalls.asm index 5f02382f41..3548930c36 100644 --- a/evm/src/cpu/kernel/asm/memory/syscalls.asm +++ b/evm/src/cpu/kernel/asm/memory/syscalls.asm @@ -8,41 +8,12 @@ global sys_mload: // stack: expanded_num_bytes, kexit_info, offset %update_mem_bytes // stack: kexit_info, offset - PUSH 0 // acc = 0 - // stack: acc, kexit_info, offset - DUP3 %add_const( 0) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xf8) ADD - DUP3 %add_const( 1) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xf0) ADD - DUP3 %add_const( 2) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xe8) ADD - DUP3 %add_const( 3) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xe0) ADD - DUP3 %add_const( 4) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xd8) ADD - DUP3 %add_const( 5) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xd0) ADD - DUP3 %add_const( 6) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xc8) ADD - DUP3 %add_const( 7) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xc0) ADD - DUP3 %add_const( 8) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xb8) ADD - DUP3 %add_const( 9) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xb0) ADD - DUP3 %add_const(10) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xa8) ADD - DUP3 %add_const(11) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xa0) ADD - DUP3 %add_const(12) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x98) ADD - DUP3 %add_const(13) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x90) ADD - DUP3 %add_const(14) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x88) ADD - DUP3 %add_const(15) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x80) ADD - DUP3 %add_const(16) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x78) ADD - DUP3 %add_const(17) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x70) ADD - DUP3 %add_const(18) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x68) ADD - DUP3 %add_const(19) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x60) ADD - DUP3 %add_const(20) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x58) ADD - DUP3 %add_const(21) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x50) ADD - DUP3 %add_const(22) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x48) ADD - DUP3 %add_const(23) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x40) ADD - DUP3 %add_const(24) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x38) ADD - DUP3 %add_const(25) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x30) ADD - DUP3 %add_const(26) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x28) ADD - DUP3 %add_const(27) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x20) ADD - DUP3 %add_const(28) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x18) ADD - DUP3 %add_const(29) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x10) ADD - DUP3 %add_const(30) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x08) ADD - DUP3 %add_const(31) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x00) ADD - %stack (acc, kexit_info, offset) -> (kexit_info, acc) + %stack(kexit_info, offset) -> (offset, 32, kexit_info) + PUSH @SEGMENT_MAIN_MEMORY + GET_CONTEXT + // stack: addr: 3, len, kexit_info + MLOAD_32BYTES + %stack (value, kexit_info) -> (kexit_info, value) EXIT_KERNEL global sys_mstore: @@ -55,39 +26,12 @@ global sys_mstore: // stack: expanded_num_bytes, kexit_info, offset, value %update_mem_bytes // stack: kexit_info, offset, value - DUP3 PUSH 0 BYTE DUP3 %add_const( 0) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 1 BYTE DUP3 %add_const( 1) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 2 BYTE DUP3 %add_const( 2) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 3 BYTE DUP3 %add_const( 3) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 4 BYTE DUP3 %add_const( 4) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 5 BYTE DUP3 %add_const( 5) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 6 BYTE DUP3 %add_const( 6) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 7 BYTE DUP3 %add_const( 7) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 8 BYTE DUP3 %add_const( 8) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 9 BYTE DUP3 %add_const( 9) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 10 BYTE DUP3 %add_const(10) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 11 BYTE DUP3 %add_const(11) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 12 BYTE DUP3 %add_const(12) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 13 BYTE DUP3 %add_const(13) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 14 BYTE DUP3 %add_const(14) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 15 BYTE DUP3 %add_const(15) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 16 BYTE DUP3 %add_const(16) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 17 BYTE DUP3 %add_const(17) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 18 BYTE DUP3 %add_const(18) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 19 BYTE DUP3 %add_const(19) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 20 BYTE DUP3 %add_const(20) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 21 BYTE DUP3 %add_const(21) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 22 BYTE DUP3 %add_const(22) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 23 BYTE DUP3 %add_const(23) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 24 BYTE DUP3 %add_const(24) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 25 BYTE DUP3 %add_const(25) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 26 BYTE DUP3 %add_const(26) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 27 BYTE DUP3 %add_const(27) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 28 BYTE DUP3 %add_const(28) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 29 BYTE DUP3 %add_const(29) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 30 BYTE DUP3 %add_const(30) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 31 BYTE DUP3 %add_const(31) %mstore_current(@SEGMENT_MAIN_MEMORY) - %stack (kexit_info, offset, value) -> (kexit_info) + %stack(kexit_info, offset, value) -> (offset, value, 32, kexit_info) + PUSH @SEGMENT_MAIN_MEMORY + GET_CONTEXT + // stack: addr: 3, value, len, kexit_info + MSTORE_32BYTES + // stack: kexit_info EXIT_KERNEL global sys_mstore8: diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 213858a846..d6773feeb7 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -184,7 +184,7 @@ impl Column { // If we access the next row at the last row, for sanity, we consider the next row's values to be 0. // If CTLs are correctly written, the filter should be 0 in that case anyway. - if !self.next_row_linear_combination.is_empty() && row < table.len() - 1 { + if !self.next_row_linear_combination.is_empty() && row < table[0].values.len() - 1 { res += self .next_row_linear_combination .iter() @@ -624,7 +624,7 @@ pub(crate) fn eval_cross_table_lookup_checks>(); let combined = challenges.combine(evals.iter()); let local_filter = if let Some(column) = filter_column { - column.eval(vars.local_values) + column.eval_with_next(vars.local_values, vars.next_values) } else { P::ONES }; diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 8c389c2a75..65c7687cdb 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -377,49 +377,49 @@ where let arithmetic = RecursiveCircuitsForTable::new( Table::Arithmetic, &all_stark.arithmetic_stark, - degree_bits_ranges[0].clone(), + degree_bits_ranges[Table::Arithmetic as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let byte_packing = RecursiveCircuitsForTable::new( Table::BytePacking, &all_stark.byte_packing_stark, - degree_bits_ranges[1].clone(), + degree_bits_ranges[Table::BytePacking as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let cpu = RecursiveCircuitsForTable::new( Table::Cpu, &all_stark.cpu_stark, - degree_bits_ranges[2].clone(), + degree_bits_ranges[Table::Cpu as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak = RecursiveCircuitsForTable::new( Table::Keccak, &all_stark.keccak_stark, - degree_bits_ranges[3].clone(), + degree_bits_ranges[Table::Keccak as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak_sponge = RecursiveCircuitsForTable::new( Table::KeccakSponge, &all_stark.keccak_sponge_stark, - degree_bits_ranges[4].clone(), + degree_bits_ranges[Table::KeccakSponge as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, - degree_bits_ranges[5].clone(), + degree_bits_ranges[Table::Logic as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let memory = RecursiveCircuitsForTable::new( Table::Memory, &all_stark.memory_stark, - degree_bits_ranges[6].clone(), + degree_bits_ranges[Table::Memory as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 0eccc4a67d..3254359da4 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -74,32 +74,66 @@ impl Default for BlockHashes { } } +/// User-provided helper values to compute the `BLOCKHASH` opcode. +/// The proofs across consecutive blocks ensure that these values +/// are consistent (i.e. shifted by one to the left). +/// +/// When the block number is less than 256, dummy values, i.e. `H256::default()`, +/// should be used for the additional block hashes. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BlockHashes { + /// The previous 256 hashes to the current block. The leftmost hash, i.e. `prev_hashes[0]`, + /// is the oldest, and the rightmost, i.e. `prev_hashes[255]` is the hash of the parent block. pub prev_hashes: Vec, + // The hash of the current block. pub cur_hash: H256, } +/// Metadata contained in a block header. Those are identical between +/// all state transition proofs within the same block. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct BlockMetadata { + /// The address of this block's producer. pub block_beneficiary: Address, + /// The timestamp of this block. pub block_timestamp: U256, + /// The index of this block. pub block_number: U256, + /// The difficulty (before PoS transition) of this block. pub block_difficulty: U256, + /// The gas limit of this block. It must fit in a `u32`. pub block_gaslimit: U256, + /// The chain id of this block. pub block_chain_id: U256, + /// The base fee of this block. pub block_base_fee: U256, + /// The total gas used in this block. It must fit in a `u32`. pub block_gas_used: U256, + /// The block bloom of this block, represented as the consecutive + /// 32-byte chunks of a block's final bloom filter string. pub block_bloom: [U256; 8], } +/// Additional block data that are specific to the local transaction being proven, +/// unlike `BlockMetadata`. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ExtraBlockData { + /// The transaction count prior execution of the local state transition, starting + /// at 0 for the initial transaction of a block. pub txn_number_before: U256, + /// The transaction count after execution of the local state transition. pub txn_number_after: U256, + /// The accumulated gas used prior execution of the local state transition, starting + /// at 0 for the initial transaction of a block. pub gas_used_before: U256, + /// The accumulated gas used after execution of the local state transition. It should + /// match the `block_gas_used` value after execution of the last transaction in a block. pub gas_used_after: U256, + /// The accumulated bloom filter of this block prior execution of the local state transition, + /// starting with all zeros for the initial transaction of a block. pub block_bloom_before: [U256; 8], + /// The accumulated bloom filter after execution of the local state transition. It should + /// match the `block_bloom` value after execution of the last transaction in a block. pub block_bloom_after: [U256; 8], } diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index b77f7aa5ff..41aebdb1e9 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -113,6 +113,29 @@ pub fn prove, C: GenericConfig, const D: inputs: PartialWitness, timing: &mut TimingTree, ) -> Result> +where + C::Hasher: Hasher, + C::InnerHasher: Hasher, +{ + let partition_witness = timed!( + timing, + &format!("run {} generators", prover_data.generators.len()), + generate_partial_witness(inputs, prover_data, common_data) + ); + + prove_with_partition_witness(prover_data, common_data, partition_witness, timing) +} + +pub fn prove_with_partition_witness< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, + mut partition_witness: PartitionWitness, + timing: &mut TimingTree, +) -> Result> where C::Hasher: Hasher, C::InnerHasher: Hasher, @@ -123,12 +146,6 @@ where let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - let mut partition_witness = timed!( - timing, - &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(inputs, prover_data, common_data) - ); - set_lookup_wires(prover_data, common_data, &mut partition_witness); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs); From 3983969ce9724da1b2f077eb7f475317dfe6b1f2 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Mon, 25 Sep 2023 10:35:38 -0400 Subject: [PATCH 14/34] Use function for genesis block connection. --- evm/src/fixed_recursive_verifier.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 55577fb258..0e77e57185 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -819,16 +819,28 @@ where let zero = builder.zero(); let has_not_parent_block = builder.sub(one, has_parent_block.target); + // Chack that the genesis block number is 0. let gen_block_constr = builder.mul(has_not_parent_block, rhs.block_metadata.block_number); builder.connect(gen_block_constr, zero); - // Check that the genesis block has a predetermined state trie root. - for (&limb0, limb1) in rhs + // Check that the genesis block has the predetermined state trie root in `ExtraBlockData`. + Self::connect_genesis_block(builder, rhs, has_not_parent_block); + } + + fn connect_genesis_block( + builder: &mut CircuitBuilder, + x: &PublicValuesTarget, + has_not_parent_block: Target, + ) where + F: RichField + Extendable, + { + let zero = builder.zero(); + for (&limb0, limb1) in x .trie_roots_before .state_root .iter() - .zip(rhs.extra_block_data.genesis_state_root) + .zip(x.extra_block_data.genesis_state_root) { let mut constr = builder.sub(limb0, limb1); constr = builder.mul(has_not_parent_block, constr); From 75c0e47a3007d33681620c845bf8fc15bcd04f6c Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 26 Sep 2023 12:21:29 -0400 Subject: [PATCH 15/34] Apply comments. --- evm/src/fixed_recursive_verifier.rs | 5 ++++- evm/src/proof.rs | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 0e77e57185..dffbd24573 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -659,6 +659,9 @@ where for (&limb0, &limb1) in pvs.genesis_state_root.iter().zip(&rhs.genesis_state_root) { builder.connect(limb0, limb1); } + for (&limb0, &limb1) in pvs.genesis_state_root.iter().zip(&lhs.genesis_state_root) { + builder.connect(limb0, limb1); + } // Connect the transaction number in public values to the lhs and rhs values correctly. builder.connect(pvs.txn_number_before, lhs.txn_number_before); @@ -820,7 +823,7 @@ where let zero = builder.zero(); let has_not_parent_block = builder.sub(one, has_parent_block.target); - // Chack that the genesis block number is 0. + // Check that the genesis block number is 0. let gen_block_constr = builder.mul(has_not_parent_block, rhs.block_metadata.block_number); builder.connect(gen_block_constr, zero); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 71b2feb44c..b3020fe67a 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -118,7 +118,7 @@ pub struct BlockMetadata { /// unlike `BlockMetadata`. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ExtraBlockData { - /// The state trie digest of the gensis block. + /// The state trie digest of the genesis block. pub genesis_state_root: H256, /// The transaction count prior execution of the local state transition, starting /// at 0 for the initial transaction of a block. From acc659da07edb31237358b5794bb362c067e2cb4 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:56:18 -0400 Subject: [PATCH 16/34] Add type 1 and 2 txn for RLP encoding support (#1255) --- evm/src/generation/mpt.rs | 69 +++++++++++++++++++++++++++++++++++++-- evm/tests/log_opcode.rs | 8 +++-- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index dbc36cacbb..be99418812 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -6,7 +6,7 @@ use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256, U512}; use keccak_hash::keccak; -use rlp::PayloadInfo; +use rlp::{Decodable, DecoderError, Encodable, PayloadInfo, Rlp, RlpStream}; use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; @@ -33,12 +33,46 @@ impl Default for AccountRlp { } } +#[derive(RlpEncodable, RlpDecodable, Debug)] +pub struct AccessListItemRlp { + pub address: Address, + pub storage_keys: Vec, +} + +#[derive(Debug)] +pub struct AddressOption(pub Option

); + +impl Encodable for AddressOption { + fn rlp_append(&self, s: &mut RlpStream) { + match self.0 { + None => { + s.append_empty_data(); + } + Some(value) => { + s.encoder().encode_value(&value.to_fixed_bytes()); + } + } + } +} + +impl Decodable for AddressOption { + fn decode(rlp: &Rlp) -> Result { + if rlp.is_int() && rlp.is_empty() { + return Ok(AddressOption(None)); + } + if rlp.is_data() && rlp.size() == 20 { + return Ok(AddressOption(Some(Address::decode(rlp)?))); + } + Err(DecoderError::RlpExpectedToBeData) + } +} + #[derive(RlpEncodable, RlpDecodable, Debug)] pub struct LegacyTransactionRlp { pub nonce: U256, pub gas_price: U256, pub gas: U256, - pub to: Address, + pub to: AddressOption, pub value: U256, pub data: Bytes, pub v: U256, @@ -46,6 +80,37 @@ pub struct LegacyTransactionRlp { pub s: U256, } +#[derive(RlpEncodable, RlpDecodable, Debug)] +pub struct AccessListTransactionRlp { + pub chain_id: u64, + pub nonce: U256, + pub gas_price: U256, + pub gas: U256, + pub to: AddressOption, + pub value: U256, + pub data: Bytes, + pub access_list: Vec, + pub y_parity: U256, + pub r: U256, + pub s: U256, +} + +#[derive(RlpEncodable, RlpDecodable, Debug)] +pub struct FeeMarketTransactionRlp { + pub chain_id: u64, + pub nonce: U256, + pub max_priority_fee_per_gas: U256, + pub max_fee_per_gas: U256, + pub gas: U256, + pub to: AddressOption, + pub value: U256, + pub data: Bytes, + pub access_list: Vec, + pub y_parity: U256, + pub r: U256, + pub s: U256, +} + #[derive(RlpEncodable, RlpDecodable, Debug)] pub struct LogRlp { pub address: Address, diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index 2742c4254f..e6821c53c4 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -17,7 +17,9 @@ use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp, LegacyTransactionRlp, LogRlp}; +use plonky2_evm::generation::mpt::{ + AccountRlp, AddressOption, LegacyReceiptRlp, LegacyTransactionRlp, LogRlp, +}; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use plonky2_evm::prover::prove; @@ -631,7 +633,7 @@ fn test_txn_and_receipt_trie_hash() -> anyhow::Result<()> { nonce: 157823u64.into(), gas_price: 1000000000u64.into(), gas: 250000u64.into(), - to: hex!("7ef66b77759e12Caf3dDB3E4AFF524E577C59D8D").into(), + to: AddressOption(Some(hex!("7ef66b77759e12Caf3dDB3E4AFF524E577C59D8D").into())), value: 0u64.into(), data: hex!("e9c6c176000000000000000000000000000000000000000000000000000000000000002a0000000000000000000000000000000000000000000000000000000000bd9fe6f7af1cc94b1aef2e0fa15f1b4baefa86eb60e78fa4bd082372a0a446d197fb58") .to_vec() @@ -651,7 +653,7 @@ fn test_txn_and_receipt_trie_hash() -> anyhow::Result<()> { nonce: 157824u64.into(), gas_price: 1000000000u64.into(), gas: 250000u64.into(), - to: hex!("7ef66b77759e12Caf3dDB3E4AFF524E577C59D8D").into(), + to: AddressOption(Some(hex!("7ef66b77759e12Caf3dDB3E4AFF524E577C59D8D").into())), value: 0u64.into(), data: hex!("e9c6c176000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000004920eaa814f7df6a2203dc0e472e8828be95957c6b329fee8e2b1bb6f044c1eb4fc243") .to_vec() From f49fbc8e9b84ecf46090e3ec2469dd31dc5da4a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Wed, 27 Sep 2023 16:00:16 +0200 Subject: [PATCH 17/34] Transactions trie support (#1232) * Implement transactions. * Fix receipts and transactions * Add some fixes * Update tests * Remove changes added for debugging purposes only * Clippy * Remove additional debug changes * Remove unused * Apply comments --------- Co-authored-by: Linda Guiga Co-authored-by: Robin Salen --- evm/src/cpu/kernel/aggregator.rs | 1 + .../cpu/kernel/asm/core/create_receipt.asm | 131 +++++----- evm/src/cpu/kernel/asm/main.asm | 24 +- .../asm/mpt/hash/hash_trie_specific.asm | 25 +- .../asm/mpt/insert/insert_trie_specific.asm | 57 ++-- .../asm/mpt/load/load_trie_specific.asm | 18 +- evm/src/cpu/kernel/asm/rlp/encode.asm | 2 +- .../kernel/asm/rlp/increment_bounded_rlp.asm | 38 +++ evm/src/cpu/kernel/asm/rlp/read_to_memory.asm | 6 +- .../cpu/kernel/asm/transactions/router.asm | 30 ++- evm/src/cpu/kernel/tests/mpt/load.rs | 45 ++++ evm/src/cpu/kernel/tests/receipt.rs | 34 ++- evm/src/generation/mpt.rs | 4 +- evm/src/witness/transition.rs | 2 +- evm/tests/add11_yml.rs | 8 +- evm/tests/basic_smart_contract.rs | 7 +- evm/tests/log_opcode.rs | 32 ++- evm/tests/many_transactions.rs | 246 ++++++++++++++++++ evm/tests/self_balance_gas_cost.rs | 8 +- evm/tests/simple_transfer.rs | 8 +- 20 files changed, 612 insertions(+), 114 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm create mode 100644 evm/tests/many_transactions.rs diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 0c7c657999..20081bb935 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -131,6 +131,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/rlp/encode.asm"), include_str!("asm/rlp/encode_rlp_scalar.asm"), include_str!("asm/rlp/encode_rlp_string.asm"), + include_str!("asm/rlp/increment_bounded_rlp.asm"), include_str!("asm/rlp/num_bytes.asm"), include_str!("asm/rlp/read_to_memory.asm"), include_str!("asm/shift.asm"), diff --git a/evm/src/cpu/kernel/asm/core/create_receipt.asm b/evm/src/cpu/kernel/asm/core/create_receipt.asm index 586ea5bb01..fccabe0885 100644 --- a/evm/src/cpu/kernel/asm/core/create_receipt.asm +++ b/evm/src/cpu/kernel/asm/core/create_receipt.asm @@ -1,4 +1,4 @@ -// Pre-stack: status, leftover_gas, prev_cum_gas, txn_nb, retdest +// Pre-stack: status, leftover_gas, prev_cum_gas, txn_nb, num_nibbles, retdest // Post stack: new_cum_gas, txn_nb // A receipt is stored in MPT_TRIE_DATA as: // [payload_len, status, cum_gas_used, bloom, logs_payload_len, num_logs, [logs]] @@ -11,210 +11,211 @@ // - insert a new node in receipt_trie, // - set the bloom filter back to 0 global process_receipt: - // stack: status, leftover_gas, prev_cum_gas, txn_nb, retdest + // stack: status, leftover_gas, prev_cum_gas, txn_nb, num_nibbles, retdest DUP2 DUP4 - // stack: prev_cum_gas, leftover_gas, status, leftover_gas, prev_cum_gas, txn_nb, retdest + // stack: prev_cum_gas, leftover_gas, status, leftover_gas, prev_cum_gas, txn_nb, num_nibbles, retdest %compute_cumulative_gas - // stack: new_cum_gas, status, leftover_gas, prev_cum_gas, txn_nb, retdest + // stack: new_cum_gas, status, leftover_gas, prev_cum_gas, txn_nb, num_nibbles, retdest SWAP3 POP - // stack: status, leftover_gas, new_cum_gas, txn_nb, retdest + // stack: status, leftover_gas, new_cum_gas, txn_nb, num_nibbles, retdest SWAP1 POP - // stack: status, new_cum_gas, txn_nb, retdest + // stack: status, new_cum_gas, txn_nb, num_nibbles, retdest // Now, we need to check whether the transaction has failed. DUP1 ISZERO %jumpi(failed_receipt) process_receipt_after_status: - // stack: status, new_cum_gas, txn_nb, retdest + // stack: status, new_cum_gas, txn_nb, num_nibbles, retdest PUSH process_receipt_after_bloom %jump(logs_bloom) process_receipt_after_bloom: - // stack: status, new_cum_gas, txn_nb, retdest + // stack: status, new_cum_gas, txn_nb, num_nibbles, retdest DUP2 DUP4 - // stack: txn_nb, new_cum_gas, status, new_cum_gas, txn_nb, retdest + // stack: txn_nb, new_cum_gas, status, new_cum_gas, txn_nb, num_nibbles, retdest SWAP2 - // stack: status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Compute the total RLP payload length of the receipt. PUSH 1 // status is always 1 byte. - // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP3 %rlp_scalar_len // cum_gas is a simple scalar. ADD - // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Next is the bloom_filter, which is a 256-byte array. Its RLP encoding is // 1 + 2 + 256 bytes. %add_const(259) - // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Last is the logs. %mload_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) %rlp_list_len ADD - // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Now we can write the receipt in MPT_TRIE_DATA. %get_trie_data_size - // stack: receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write transaction type if necessary. RLP_RAW contains, at index 0, the current transaction type. PUSH 0 %mload_kernel(@SEGMENT_RLP_RAW) - // stack: first_txn_byte, receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: first_txn_byte, receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %eq_const(1) %jumpi(receipt_nonzero_type) DUP1 %eq_const(2) %jumpi(receipt_nonzero_type) // If we are here, we are dealing with a legacy transaction, and we do not need to write the type. POP process_receipt_after_type: - // stack: receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, payload_len, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write payload_len. SWAP1 %append_to_trie_data - // stack: receipt_ptr, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, status, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write status. SWAP1 %append_to_trie_data - // stack: receipt_ptr, new_cum_gas, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, new_cum_gas, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write cum_gas_used. SWAP1 %append_to_trie_data - // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write Bloom filter. PUSH 256 // Bloom length. PUSH 0 PUSH @SEGMENT_TXN_BLOOM PUSH 0 // Bloom memory address. %get_trie_data_size PUSH @SEGMENT_TRIE_DATA PUSH 0 // MPT dest address. - // stack: DST, SRC, 256, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: DST, SRC, 256, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %memcpy - // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Update trie data size. %get_trie_data_size %add_const(256) %set_trie_data_size // Now we write logs. - // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // We start with the logs payload length. %mload_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) %append_to_trie_data - // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %mload_global_metadata(@GLOBAL_METADATA_LOGS_LEN) // Then the number of logs. - // stack: num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %append_to_trie_data PUSH 0 // Each log is written in MPT_TRIE_DATA as: // [payload_len, address, num_topics, [topics], data_len, [data]]. process_receipt_logs_loop: - // stack: i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP2 DUP2 EQ - // stack: i == num_logs, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: i == num_logs, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %jumpi(process_receipt_after_write) - // stack: i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %mload_kernel(@SEGMENT_LOGS) - // stack: log_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: log_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write payload_len. DUP1 %mload_kernel(@SEGMENT_LOGS_DATA) %append_to_trie_data - // stack: log_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: log_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write address. %increment - // stack: addr_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: addr_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %mload_kernel(@SEGMENT_LOGS_DATA) %append_to_trie_data - // stack: addr_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: addr_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest //Write num_topics. %increment - // stack: num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %mload_kernel(@SEGMENT_LOGS_DATA) - // stack: num_topics, num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_topics, num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %append_to_trie_data - // stack: num_topics, num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_topics, num_topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest SWAP1 %increment SWAP1 - // stack: num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest PUSH 0 process_receipt_topics_loop: - // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP2 DUP2 EQ - // stack: j == num_topics, j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j == num_topics, j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %jumpi(process_receipt_topics_end) - // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write j-th topic. DUP3 DUP2 ADD - // stack: cur_topic_ptr, j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: cur_topic_ptr, j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %mload_kernel(@SEGMENT_LOGS_DATA) %append_to_trie_data - // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %increment %jump(process_receipt_topics_loop) process_receipt_topics_end: - // stack: num_topics, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_topics, num_topics, topics_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest POP ADD - // stack: data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write data_len DUP1 %mload_kernel(@SEGMENT_LOGS_DATA) - // stack: data_len, data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: data_len, data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP1 %append_to_trie_data - // stack: data_len, data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: data_len, data_len_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest SWAP1 %increment SWAP1 - // stack: data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest PUSH 0 process_receipt_data_loop: - // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest DUP2 DUP2 EQ - // stack: j == data_len, j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j == data_len, j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %jumpi(process_receipt_data_end) - // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Write j-th data byte. DUP3 DUP2 ADD - // stack: cur_data_ptr, j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: cur_data_ptr, j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %mload_kernel(@SEGMENT_LOGS_DATA) %append_to_trie_data - // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: j, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %increment %jump(process_receipt_data_loop) process_receipt_data_end: - // stack: data_len, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: data_len, data_len, data_ptr, i, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %pop3 %increment %jump(process_receipt_logs_loop) process_receipt_after_write: - // stack: num_logs, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: num_logs, num_logs, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest %pop2 - // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, retdest + // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest SWAP1 - // stack: txn_nb, receipt_ptr, new_cum_gas, txn_nb, retdest + // stack: txn_nb, receipt_ptr, new_cum_gas, txn_nb, num_nibbles, retdest + DUP5 %mpt_insert_receipt_trie - // stack: new_cum_gas, txn_nb, retdest + // stack: new_cum_gas, txn_nb, num_nibbles, retdest // Now, we set the Bloom filter back to 0. PUSH 0 %rep 256 - // stack: counter, new_cum_gas, txn_nb, retdest + // stack: counter, new_cum_gas, txn_nb, num_nibbles, retdest PUSH 0 DUP2 - // stack: counter, 0, counter, new_cum_gas, txn_nb, retdest + // stack: counter, 0, counter, new_cum_gas, txn_nb, num_nibbles, retdest %mstore_kernel(@SEGMENT_TXN_BLOOM) - // stack: counter, new_cum_gas, txn_nb, retdest + // stack: counter, new_cum_gas, txn_nb, num_nibbles, retdest %increment %endrep POP - // stack: new_cum_gas, txn_nb, retdest - %stack (new_cum_gas, txn_nb, retdest) -> (retdest, new_cum_gas, txn_nb) + // stack: new_cum_gas, txn_nb, num_nibbles, retdest + %stack (new_cum_gas, txn_nb, num_nibbles, retdest) -> (retdest, new_cum_gas) JUMP receipt_nonzero_type: @@ -223,16 +224,16 @@ receipt_nonzero_type: %jump(process_receipt_after_type) failed_receipt: - // stack: status, new_cum_gas, txn_nb + // stack: status, new_cum_gas, num_nibbles, txn_nb // It is the receipt of a failed transaction, so set num_logs to 0. This will also lead to Bloom filter = 0. PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_LEN) - // stack: status, new_cum_gas, txn_nb + // stack: status, new_cum_gas, num_nibbles, txn_nb %jump(process_receipt_after_status) %macro process_receipt - // stack: success, leftover_gas, cur_cum_gas, txn_nb - %stack (success, leftover_gas, cur_cum_gas, txn_nb) -> (success, leftover_gas, cur_cum_gas, txn_nb, %%after) + // stack: success, leftover_gas, cur_cum_gas, txn_nb, num_nibbles + %stack (success, leftover_gas, cur_cum_gas, txn_nb, num_nibbles) -> (success, leftover_gas, cur_cum_gas, txn_nb, num_nibbles, %%after) %jump(process_receipt) %%after: %endmacro diff --git a/evm/src/cpu/kernel/asm/main.asm b/evm/src/cpu/kernel/asm/main.asm index 612b4bcecd..bd555218be 100644 --- a/evm/src/cpu/kernel/asm/main.asm +++ b/evm/src/cpu/kernel/asm/main.asm @@ -16,9 +16,18 @@ global hash_initial_tries: global start_txns: // stack: (empty) + // The special case of an empty trie (i.e. for the first transaction) + // is handled outside of the kernel. %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_BEFORE) + // stack: txn_nb %mload_global_metadata(@GLOBAL_METADATA_BLOCK_GAS_USED_BEFORE) // stack: init_used_gas, txn_nb + DUP2 %scalar_to_rlp + // stack: txn_counter, init_gas_used, txn_nb + DUP1 %num_bytes %mul_const(2) + // stack: num_nibbles, txn_counter, init_gas_used, txn_nb + SWAP2 + // stack: init_gas_used, txn_counter, num_nibbles, txn_nb txn_loop: // If the prover has no more txns for us to process, halt. @@ -27,21 +36,24 @@ txn_loop: // Call route_txn. When we return, continue the txn loop. PUSH txn_loop_after - // stack: retdest, prev_used_gas, txn_nb + // stack: retdest, prev_gas_used, txn_counter, num_nibbles, txn_nb + DUP4 DUP4 %increment_bounded_rlp + %stack (next_txn_counter, next_num_nibbles, retdest, prev_gas_used, txn_counter, num_nibbles) -> (txn_counter, num_nibbles, retdest, prev_gas_used, txn_counter, num_nibbles, next_txn_counter, next_num_nibbles) %jump(route_txn) global txn_loop_after: - // stack: success, leftover_gas, cur_cum_gas, txn_nb + // stack: success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles, txn_counter, num_nibbles, txn_nb %process_receipt - // stack: new_cum_gas, txn_nb - SWAP1 %increment SWAP1 + // stack: new_cum_gas, txn_counter, num_nibbles, txn_nb + SWAP3 %increment SWAP3 %jump(txn_loop) global hash_final_tries: - // stack: cum_gas, txn_nb + // stack: cum_gas, txn_counter, num_nibbles, txn_nb // Check that we end up with the correct `cum_gas`, `txn_nb` and bloom filter. %mload_global_metadata(@GLOBAL_METADATA_BLOCK_GAS_USED_AFTER) %assert_eq - %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_AFTER) %assert_eq + DUP3 %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_AFTER) %assert_eq + %pop3 %check_metadata_block_bloom %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) %assert_eq %mpt_hash_txn_trie %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_AFTER) %assert_eq diff --git a/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm index 6b057a49a4..8935b7f3f9 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm @@ -97,7 +97,30 @@ global encode_account: JUMP global encode_txn: - PANIC // TODO + // stack: rlp_pos, value_ptr, retdest + + // Load the txn_rlp_len which is at the beginnig of value_ptr + DUP2 %mload_trie_data + // stack: txn_rlp_len, rlp_pos, value_ptr, retdest + SWAP2 %increment + // stack: txn_rlp_ptr=value_ptr+1, rlp_pos, txn_rlp_len, retdest + + %stack (txn_rlp_ptr, rlp_pos, txn_rlp_len) -> (rlp_pos, txn_rlp_len, txn_rlp_len, txn_rlp_ptr) + // Encode the txn rlp prefix + // stack: rlp_pos, txn_rlp_len, txn_rlp_len, txn_rlp_ptr, retdest + %encode_rlp_multi_byte_string_prefix + // copy txn_rlp to the new block + // stack: rlp_pos, txn_rlp_len, txn_rlp_ptr, retdest + %stack (rlp_pos, txn_rlp_len, txn_rlp_ptr) -> ( + 0, @SEGMENT_RLP_RAW, rlp_pos, // dest addr + 0, @SEGMENT_TRIE_DATA, txn_rlp_ptr, // src addr. Kernel has context 0 + txn_rlp_len, // mcpy len + txn_rlp_len, rlp_pos) + %memcpy + ADD + // stack new_rlp_pos, retdest + SWAP1 + JUMP // We assume a receipt in memory is stored as: // [payload_len, status, cum_gas_used, bloom, logs_payload_len, num_logs, [logs]]. diff --git a/evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm index 457d604f2f..1bf9f6f8fb 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm @@ -23,15 +23,34 @@ mpt_insert_state_trie_save: %%after: %endmacro +// Insert a node in the transaction trie. The payload +// must be pointing to the rlp encoded txn +// Pre stack: key, txn_rlp_ptr, redest +// Post stack: (empty) +global mpt_insert_txn_trie: + // stack: key=rlp(key), num_nibbles, txn_rlp_ptr, retdest + %stack (key, num_nibbles, txn_rlp_ptr) + -> (num_nibbles, key, txn_rlp_ptr, mpt_insert_txn_trie_save) + %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + // stack: txn_trie_root_ptr, num_nibbles, key, txn_rlp_ptr, mpt_insert_state_trie_save, retdest + %jump(mpt_insert) + +mpt_insert_txn_trie_save: + // stack: updated_node_ptr, retdest + %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + JUMP + +%macro mpt_insert_txn_trie + %stack (key, txn_rpl_ptr) -> (key, txn_rlp_ptr, %%after) + %jump(mpt_insert_txn_trie) +%%after: +%endmacro + global mpt_insert_receipt_trie: - // stack: scalar, value_ptr, retdest - %stack (scalar, value_ptr) - -> (scalar, value_ptr, mpt_insert_receipt_trie_save) - // The key is the RLP encoding of scalar. - %scalar_to_rlp - // stack: key, value_ptr, mpt_insert_receipt_trie_save, retdest - DUP1 - %num_bytes %mul_const(2) + // stack: num_nibbles, scalar, value_ptr, retdest + %stack (num_nibbles, scalar, value_ptr) + -> (num_nibbles, scalar, value_ptr, mpt_insert_receipt_trie_save) + // The key is the scalar, which is an RLP encoding of the transaction number // stack: num_nibbles, key, value_ptr, mpt_insert_receipt_trie_save, retdest %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) // stack: receipt_root_ptr, num_nibbles, key, value_ptr, mpt_insert_receipt_trie_save, retdest @@ -42,27 +61,29 @@ mpt_insert_receipt_trie_save: JUMP %macro mpt_insert_receipt_trie - %stack (key, value_ptr) -> (key, value_ptr, %%after) + %stack (num_nibbles, key, value_ptr) -> (num_nibbles, key, value_ptr, %%after) %jump(mpt_insert_receipt_trie) %%after: %endmacro // Pre stack: scalar, retdest // Post stack: rlp_scalar -// We will make use of %encode_rlp_scalar, which clobbers RlpRaw. -// We're not hashing tries yet, so it's not an issue. global scalar_to_rlp: // stack: scalar, retdest - PUSH 0 + %mload_global_metadata(@GLOBAL_METADATA_RLP_DATA_SIZE) // stack: pos, scalar, retdest + SWAP1 DUP2 %encode_rlp_scalar - // stack: pos', retdest - // Now our rlp_encoding is in RlpRaw in the first pos' cells. - DUP1 // len of the key - PUSH 0 PUSH @SEGMENT_RLP_RAW PUSH 0 // address where we get the key from + // stack: pos', init_pos, retdest + // Now our rlp_encoding is in RlpRaw. + // Set new RlpRaw data size + DUP1 %mstore_global_metadata(@GLOBAL_METADATA_RLP_DATA_SIZE) + DUP2 DUP2 SUB // len of the key + // stack: len, pos', init_pos, retdest + DUP3 PUSH @SEGMENT_RLP_RAW PUSH 0 // address where we get the key from %mload_packing - // stack: packed_key, pos', retdest - SWAP1 POP + // stack: packed_key, pos', init_pos, retdest + SWAP2 %pop2 // stack: key, retdest SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm index 9ac5177332..92471fd801 100644 --- a/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm @@ -29,7 +29,23 @@ global mpt_load_state_trie_value: global mpt_load_txn_trie_value: // stack: retdest - PANIC // TODO + PROVER_INPUT(mpt) + // stack: rlp_len, retdest + // The first element is the rlp length + DUP1 %append_to_trie_data + PUSH 0 + +mpt_load_loop: + // stack: i, rlp_len, retdest + DUP2 DUP2 EQ %jumpi(mpt_load_end) + PROVER_INPUT(mpt) %append_to_trie_data + %increment + %jump(mpt_load_loop) + +mpt_load_end: + // stack: i, rlp_len, retdest + %pop2 + JUMP global mpt_load_receipt_trie_value: // stack: retdest diff --git a/evm/src/cpu/kernel/asm/rlp/encode.asm b/evm/src/cpu/kernel/asm/rlp/encode.asm index fd42fe5221..71eeaa8a96 100644 --- a/evm/src/cpu/kernel/asm/rlp/encode.asm +++ b/evm/src/cpu/kernel/asm/rlp/encode.asm @@ -40,7 +40,7 @@ global encode_rlp_fixed: %increment // increment pos // stack: pos, len, string, retdest %stack (pos, len, string) -> (pos, string, len, encode_rlp_fixed_finish) - // stack: context, segment, pos, string, len, encode_rlp_fixed_finish, retdest + // stack: pos, string, len, encode_rlp_fixed_finish, retdest %jump(mstore_unpacking_rlp) encode_rlp_fixed_finish: // stack: pos', retdest diff --git a/evm/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm b/evm/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm new file mode 100644 index 0000000000..2e76c20f8f --- /dev/null +++ b/evm/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm @@ -0,0 +1,38 @@ +// Increment by 1 the rlp encoded index and increment +// its number of nibbles when required. Shouldn't be +// called with rlp_index > 0x82 ff ff +global increment_bounded_rlp: + // stack: rlp_index, num_nibbles, retdest + DUP1 + %eq_const(0x80) + %jumpi(case_0x80) + DUP1 + %eq_const(0x7f) + %jumpi(case_0x7f) + DUP1 + %eq_const(0x81ff) + %jumpi(case_0x81ff) + // If rlp_index != 0x80 and rlp_index != 0x7f and rlp_index != 0x81ff + // we only need to add one and keep the number of nibbles + %increment + %stack (rlp_index, num_nibbles, retdest) -> (retdest, rlp_index, num_nibbles) + JUMP + +case_0x80: + %stack (rlp_index, num_nibbles, retdest) -> (retdest, 0x01, 2) + JUMP +case_0x7f: + %stack (rlp_index, num_nibbles, retdest) -> (retdest, 0x8180, 4) + JUMP + +case_0x81ff: + %stack (rlp_index, num_nibbles, retdest) -> (retdest, 0x820100, 6) + JUMP + + + +%macro increment_bounded_rlp + %stack (rlp_index, num_nibbles) -> (rlp_index, num_nibbles, %%after) + %jump(increment_bounded_rlp) +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index 2d71e65a8b..85a7817522 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -31,6 +31,6 @@ read_rlp_to_memory_loop: read_rlp_to_memory_finish: // stack: pos, len, retdest - %pop2 - // stack: retdest - JUMP + POP + // stack: len, retdest + SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/transactions/router.asm b/evm/src/cpu/kernel/asm/transactions/router.asm index 3f4ebe37db..c01216fb11 100644 --- a/evm/src/cpu/kernel/asm/transactions/router.asm +++ b/evm/src/cpu/kernel/asm/transactions/router.asm @@ -3,9 +3,12 @@ // jump to the appropriate transaction parsing method. global route_txn: - // stack: retdest + // stack: txn_counter, num_nibbles, retdest // First load transaction data into memory, where it will be parsed. PUSH read_txn_from_memory + SWAP2 SWAP1 + PUSH update_txn_trie + // stack: update_txn_trie, tx_counter, num_nibbles, read_txn_from_memory, retdest %jump(read_rlp_to_memory) // At this point, the raw txn data is in memory. @@ -34,3 +37,28 @@ read_txn_from_memory: // At this point, since it's not a type 1 or 2 transaction, // it must be a legacy (aka type 0) transaction. %jump(process_type_0_txn) + +global update_txn_trie: + // stack: txn_rlp_len, txn_counter, num_nibbles, retdest + // Copy the transaction rlp to the trie data segment. + %get_trie_data_size + // stack: value_ptr, txn_rlp_len, txn_counter, num_nibbles, retdest + SWAP1 + // First we write txn rlp length + DUP1 %append_to_trie_data + // stack: txn_rlp_len, value_ptr, txn_counter, num_nibbles, ret_dest + DUP2 %increment + // stack: rlp_start=value_ptr+1, txn_rlp_len, value_ptr, txn_counter, num_nibbles, retdest + + + // and now copy txn_rlp to the new block + %stack (rlp_start, txn_rlp_len, value_ptr, txn_counter, num_nibbles) -> ( + 0, @SEGMENT_TRIE_DATA, rlp_start, // dest addr + 0, @SEGMENT_RLP_RAW, 0, // src addr. Kernel has context 0 + txn_rlp_len, // mcpy len + txn_rlp_len, rlp_start, txn_counter, num_nibbles, value_ptr) + %memcpy + ADD + %set_trie_data_size + // stack: txn_counter, num_nibbles, value_ptr, retdest + %jump(mpt_insert_txn_trie) diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index 50a8a0ef1b..ae0bfa3bc8 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,5 +1,10 @@ +use std::str::FromStr; + use anyhow::{anyhow, Result}; +use eth_trie_utils::nibbles::Nibbles; +use eth_trie_utils::partial_trie::HashedPartialTrie; use ethereum_types::{BigEndianHash, H256, U256}; +use hex_literal::hex; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; @@ -247,3 +252,43 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { Ok(()) } + +#[test] +fn load_mpt_txn_trie() -> Result<()> { + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let txn = hex!("f860010a830186a094095e7baea6a6c7c4c2dfeb977efac326af552e89808025a04a223955b0bd3827e3740a9a427d0ea43beb5bafa44a0204bf0a3306c8219f7ba0502c32d78f233e9e7ce9f5df3b576556d5d49731e0678fd5a068cdf359557b5b").to_vec(); + + let trie_inputs = TrieInputs { + state_trie: Default::default(), + transactions_trie: HashedPartialTrie::from(Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.clone(), + }), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let mut expected_trie_data = vec![ + 0.into(), + U256::from(PartialTrieType::Leaf as u32), + 2.into(), + 128.into(), // Nibble + 5.into(), // value_ptr + txn.len().into(), + ]; + expected_trie_data.extend(txn.into_iter().map(U256::from)); + let trie_data = interpreter.get_trie_data(); + + assert_eq!(trie_data, expected_trie_data); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/receipt.rs b/evm/src/cpu/kernel/tests/receipt.rs index b558365429..f82bbcda43 100644 --- a/evm/src/cpu/kernel/tests/receipt.rs +++ b/evm/src/cpu/kernel/tests/receipt.rs @@ -37,7 +37,15 @@ fn test_process_receipt() -> Result<()> { let expected_bloom = logs_bloom_bytes_fn(test_logs_list).to_vec(); // Set memory. - let initial_stack = vec![retdest, 0.into(), prev_cum_gas, leftover_gas, success]; + let num_nibbles = 2.into(); + let initial_stack: Vec = vec![ + retdest, + num_nibbles, + 0.into(), + prev_cum_gas, + leftover_gas, + success, + ]; let mut interpreter = Interpreter::new_with_kernel(process_receipt, initial_stack); interpreter.set_memory_segment( Segment::LogsData, @@ -119,7 +127,7 @@ fn test_receipt_encoding() -> Result<()> { // Get the expected RLP encoding. let expected_rlp = rlp::encode(&rlp::encode(&receipt_1)); - let initial_stack = vec![retdest, 0.into(), 0.into()]; + let initial_stack: Vec = vec![retdest, 0.into(), 0.into()]; let mut interpreter = Interpreter::new_with_kernel(encode_receipt, initial_stack); // Write data to memory. @@ -238,7 +246,7 @@ fn test_receipt_bloom_filter() -> Result<()> { let topic03 = 0xbd9fe6.into(); // Set logs memory and initialize TxnBloom and BlockBloom segments. - let initial_stack = vec![retdest]; + let initial_stack: Vec = vec![retdest]; let mut interpreter = Interpreter::new_with_kernel(logs_bloom, initial_stack); let mut logs = vec![ @@ -410,7 +418,7 @@ fn test_mpt_insert_receipt() -> Result<()> { receipt.extend(logs_0.clone()); // First, we load all mpts. - let initial_stack = vec![retdest]; + let initial_stack: Vec = vec![retdest]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); interpreter.generation_state.mpt_prover_inputs = @@ -425,7 +433,13 @@ fn test_mpt_insert_receipt() -> Result<()> { } // stack: transaction_nb, value_ptr, retdest - let initial_stack = [retdest, cur_trie_data.len().into(), 0.into()]; + let num_nibbles = 2; + let initial_stack: Vec = vec![ + retdest, + cur_trie_data.len().into(), + 0x80.into(), + num_nibbles.into(), + ]; for i in 0..initial_stack.len() { interpreter.push(initial_stack[i]); } @@ -489,7 +503,13 @@ fn test_mpt_insert_receipt() -> Result<()> { // Get updated TrieData segment. cur_trie_data = interpreter.get_memory_segment(Segment::TrieData); - let initial_stack2 = [retdest, cur_trie_data.len().into(), 1.into()]; + let num_nibbles = 2; + let initial_stack2: Vec = vec![ + retdest, + cur_trie_data.len().into(), + 0x01.into(), + num_nibbles.into(), + ]; for i in 0..initial_stack2.len() { interpreter.push(initial_stack2[i]); } @@ -528,7 +548,7 @@ fn test_bloom_two_logs() -> Result<()> { let retdest = 0xDEADBEEFu32.into(); let logs_bloom = KERNEL.global_labels["logs_bloom"]; - let initial_stack = vec![retdest]; + let initial_stack: Vec = vec![retdest]; // Set memory. let logs = vec![ diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index be99418812..1f1be2b9a0 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -185,7 +185,9 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Result(state: &GenerationState, op: Operation) { state.registers.context, KERNEL.offset_name(pc), op, - state.stack() + state.stack(), ); assert!(pc < KERNEL.code.len(), "Kernel PC is out of range: {}", pc); diff --git a/evm/tests/add11_yml.rs b/evm/tests/add11_yml.rs index d059620991..cb0212a388 100644 --- a/evm/tests/add11_yml.rs +++ b/evm/tests/add11_yml.rs @@ -139,9 +139,15 @@ fn add11_yml() -> anyhow::Result<()> { Nibbles::from_str("0x80").unwrap(), rlp::encode(&receipt_0).to_vec(), ); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); + let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: tries_before.transactions_trie.hash(), // TODO: Fix this when we have transactions trie. + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 162c5e2052..687328dcb9 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -171,10 +171,15 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { Nibbles::from_str("0x80").unwrap(), rlp::encode(&receipt_0).to_vec(), ); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: tries_before.transactions_trie.hash(), // TODO: Fix this when we have transactions trie. + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index e6821c53c4..b9d587b359 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -204,9 +204,15 @@ fn test_log_opcodes() -> anyhow::Result<()> { expected_state_trie_after.insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); + let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: HashedPartialTrie::from(Node::Empty).hash(), + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let block_bloom_after = [ @@ -417,9 +423,15 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { rlp::encode(&receipt_0).to_vec(), ); + let mut transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); + let tries_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: HashedPartialTrie::from(Node::Empty).hash(), + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.clone().hash(), }; @@ -466,7 +478,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { let tries_before = TrieInputs { state_trie: state_trie_before, - transactions_trie: Node::Empty.into(), + transactions_trie: transactions_trie.clone(), receipts_trie: receipts_trie.clone(), storage_tries: vec![], }; @@ -543,9 +555,11 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { rlp::encode(&to_account_second_after).to_vec(), ); + transactions_trie.insert(Nibbles::from_str("0x01").unwrap(), txn_2.to_vec()); + let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: HashedPartialTrie::from(Node::Empty).hash(), + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; @@ -863,9 +877,17 @@ fn test_two_txn() -> anyhow::Result<()> { rlp::encode(&receipt_1).to_vec(), ); + let mut transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn_0.to_vec(), + } + .into(); + + transactions_trie.insert(Nibbles::from_str("0x01").unwrap(), txn_1.to_vec()); + let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: HashedPartialTrie::from(Node::Empty).hash(), + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { diff --git a/evm/tests/many_transactions.rs b/evm/tests/many_transactions.rs new file mode 100644 index 0000000000..134eb968f7 --- /dev/null +++ b/evm/tests/many_transactions.rs @@ -0,0 +1,246 @@ +#![allow(clippy::upper_case_acronyms)] + +use std::collections::HashMap; +use std::str::FromStr; +use std::time::Duration; + +use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; +use eth_trie_utils::nibbles::Nibbles; +use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; +use ethereum_types::{Address, H256, U256}; +use hex_literal::hex; +use keccak_hash::keccak; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::KeccakGoldilocksConfig; +use plonky2::util::timing::TimingTree; +use plonky2_evm::all_stark::AllStark; +use plonky2_evm::config::StarkConfig; +use plonky2_evm::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; +use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::{GenerationInputs, TrieInputs}; +use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; +use plonky2_evm::prover::prove; +use plonky2_evm::verifier::verify_proof; +use plonky2_evm::Node; + +type F = GoldilocksField; +const D: usize = 2; +type C = KeccakGoldilocksConfig; + +/// Test the validity of four transactions, where only the first one is valid and the other three abort. +#[test] +fn test_four_transactions() -> anyhow::Result<()> { + init_logger(); + + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + + let beneficiary = hex!("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef"); + let sender = hex!("2c7536e3605d9c16a7a3d7b1898e529396a65c23"); + let to = hex!("a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0"); + + let beneficiary_state_key = keccak(beneficiary); + let sender_state_key = keccak(sender); + let to_state_key = keccak(to); + + let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); + let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); + let to_nibbles = Nibbles::from_bytes_be(to_state_key.as_bytes()).unwrap(); + + let push1 = get_push_opcode(1); + let add = get_opcode("ADD"); + let stop = get_opcode("STOP"); + let code = [push1, 3, push1, 4, add, stop]; + let code_gas = 3 + 3 + 3; + let code_hash = keccak(code); + + let beneficiary_account_before = AccountRlp::default(); + let sender_account_before = AccountRlp { + nonce: 5.into(), + + balance: eth_to_wei(100_000.into()), + + ..AccountRlp::default() + }; + let to_account_before = AccountRlp { + code_hash, + ..AccountRlp::default() + }; + + let state_trie_before = { + let mut children = core::array::from_fn(|_| Node::Empty.into()); + children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: sender_nibbles.truncate_n_nibbles_front(1), + + value: rlp::encode(&sender_account_before).to_vec(), + } + .into(); + children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: to_nibbles.truncate_n_nibbles_front(1), + + value: rlp::encode(&to_account_before).to_vec(), + } + .into(); + Node::Branch { + children, + value: vec![], + } + } + .into(); + + let tries_before = TrieInputs { + state_trie: state_trie_before, + transactions_trie: Node::Empty.into(), + receipts_trie: Node::Empty.into(), + storage_tries: vec![], + }; + + // Generated using a little py-evm script. + let txn1 = hex!("f861050a8255f094a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0648242421ba02c89eb757d9deeb1f5b3859a9d4d679951ef610ac47ad4608dc142beb1b7e313a05af7e9fbab825455d36c36c7f4cfcafbeafa9a77bdff936b52afb36d4fe4bcdd"); + let txn2 = hex!("f863800a83061a8094095e7baea6a6c7c4c2dfeb977efac326af552d87830186a0801ba0ffb600e63115a7362e7811894a91d8ba4330e526f22121c994c4692035dfdfd5a06198379fcac8de3dbfac48b165df4bf88e2088f294b61efb9a65fe2281c76e16"); + let txn3 = hex!("f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509b"); + let txn4 = hex!("f866800a82520894095e7baea6a6c7c4c2dfeb977efac326af552d878711c37937e080008026a01fcd0ce88ac7600698a771f206df24b70e67981b6f107bd7c1c24ea94f113bcba00d87cc5c7afc2988e4ff200b5a0c7016b0d5498bbc692065ca983fcbbfe02555"); + + let txdata_gas = 2 * 16; + let gas_used = 21_000 + code_gas + txdata_gas; + + let value = U256::from(100u32); + + let block_metadata = BlockMetadata { + block_beneficiary: Address::from(beneficiary), + block_timestamp: 0x03e8.into(), + block_number: 1.into(), + block_difficulty: 0x020000.into(), + block_gaslimit: 0x445566u64.into(), + block_chain_id: 1.into(), + block_gas_used: gas_used.into(), + ..BlockMetadata::default() + }; + + let mut contract_code = HashMap::new(); + contract_code.insert(keccak(vec![]), vec![]); + contract_code.insert(code_hash, code.to_vec()); + + // Update trie roots after the 4 transactions. + // State trie. + let expected_state_trie_after: HashedPartialTrie = { + let beneficiary_account_after = AccountRlp { + balance: beneficiary_account_before.balance + gas_used * 10, + ..beneficiary_account_before + }; + let sender_account_after = AccountRlp { + balance: sender_account_before.balance - value - gas_used * 10, + nonce: sender_account_before.nonce + 1, + ..sender_account_before + }; + let to_account_after = AccountRlp { + balance: to_account_before.balance + value, + ..to_account_before + }; + + let mut children = core::array::from_fn(|_| Node::Empty.into()); + children[beneficiary_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: beneficiary_nibbles.truncate_n_nibbles_front(1), + + value: rlp::encode(&beneficiary_account_after).to_vec(), + } + .into(); + children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: sender_nibbles.truncate_n_nibbles_front(1), + + value: rlp::encode(&sender_account_after).to_vec(), + } + .into(); + children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: to_nibbles.truncate_n_nibbles_front(1), + + value: rlp::encode(&to_account_after).to_vec(), + } + .into(); + Node::Branch { + children, + value: vec![], + } + } + .into(); + + // Transactions trie. + let mut transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn1.to_vec(), + } + .into(); + transactions_trie.insert(Nibbles::from_str("0x01").unwrap(), txn2.to_vec()); + transactions_trie.insert(Nibbles::from_str("0x02").unwrap(), txn3.to_vec()); + transactions_trie.insert(Nibbles::from_str("0x03").unwrap(), txn4.to_vec()); + + // Receipts trie. + let mut receipts_trie = HashedPartialTrie::from(Node::Empty); + let receipt_0 = LegacyReceiptRlp { + status: true, + cum_gas_used: gas_used.into(), + bloom: [0x00; 256].to_vec().into(), + logs: vec![], + }; + let receipt_1 = LegacyReceiptRlp { + status: false, + cum_gas_used: gas_used.into(), + bloom: [0x00; 256].to_vec().into(), + logs: vec![], + }; + receipts_trie.insert( + Nibbles::from_str("0x80").unwrap(), + rlp::encode(&receipt_0).to_vec(), + ); + receipts_trie.insert( + Nibbles::from_str("0x01").unwrap(), + rlp::encode(&receipt_1).to_vec(), + ); + receipts_trie.insert( + Nibbles::from_str("0x02").unwrap(), + rlp::encode(&receipt_1).to_vec(), + ); + receipts_trie.insert( + Nibbles::from_str("0x03").unwrap(), + rlp::encode(&receipt_1).to_vec(), + ); + + let trie_roots_after = TrieRoots { + state_root: expected_state_trie_after.hash(), + transactions_root: transactions_trie.hash(), + receipts_root: receipts_trie.hash(), + }; + let inputs = GenerationInputs { + signed_txns: vec![txn1.to_vec(), txn2.to_vec(), txn3.to_vec(), txn4.to_vec()], + tries: tries_before, + trie_roots_after, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + contract_code, + block_metadata: block_metadata.clone(), + addresses: vec![], + block_bloom_before: [0.into(); 8], + gas_used_before: 0.into(), + gas_used_after: gas_used.into(), + txn_number_before: 0.into(), + block_bloom_after: [0.into(); 8], + block_hashes: BlockHashes { + prev_hashes: vec![H256::default(); 256], + cur_hash: H256::default(), + }, + }; + + let mut timing = TimingTree::new("prove", log::Level::Debug); + let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + verify_proof(&all_stark, proof, &config) +} + +fn eth_to_wei(eth: U256) -> U256 { + // 1 ether = 10^18 wei. + eth * U256::from(10).pow(18.into()) +} + +fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info")); +} diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index ea434c5542..4492ba9af4 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -158,9 +158,15 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { Nibbles::from_str("0x80").unwrap(), rlp::encode(&receipt_0).to_vec(), ); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); + let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: tries_before.transactions_trie.hash(), // TODO: Fix this when we have transactions trie. + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index 9ec28aa487..80bee8afeb 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -55,6 +55,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { value: rlp::encode(&sender_account_before).to_vec(), } .into(); + let tries_before = TrieInputs { state_trie: state_trie_before, transactions_trie: HashedPartialTrie::from(Node::Empty), @@ -125,10 +126,15 @@ fn test_simple_transfer() -> anyhow::Result<()> { Nibbles::from_str("0x80").unwrap(), rlp::encode(&receipt_0).to_vec(), ); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); let trie_roots_after = TrieRoots { state_root: expected_state_trie_after.hash(), - transactions_root: tries_before.transactions_trie.hash(), // TODO: Fix this when we have transactions trie. + transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { From 300059572b0dd5bb53c0060bf4dc32f8a5a9bf90 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 29 Sep 2023 15:57:56 +0200 Subject: [PATCH 18/34] Optimize lookup builder (#1258) * Add tests with big LUTs * Optimize lookup builder * Fix comment describing optimization * Cargo fmt * Clone LookupTableGate instead of instantiating * Remove needless enumerate + improving comments --- plonky2/src/gadgets/lookup.rs | 47 +++++++++++--- plonky2/src/lookup_test.rs | 114 ++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 10 deletions(-) diff --git a/plonky2/src/gadgets/lookup.rs b/plonky2/src/gadgets/lookup.rs index b71b780879..826f3e2902 100644 --- a/plonky2/src/gadgets/lookup.rs +++ b/plonky2/src/gadgets/lookup.rs @@ -89,25 +89,52 @@ impl, const D: usize> CircuitBuilder { let lookups = self.get_lut_lookups(lut_index).to_owned(); - for (looking_in, looking_out) in lookups { - let gate = LookupGate::new_from_table(&self.config, lut.clone()); + let gate = LookupGate::new_from_table(&self.config, lut.clone()); + let num_slots = LookupGate::num_slots(&self.config); + + // Given the number of lookups and the number of slots for each gate, it is possible + // to compute the number of gates that will employ all their slots; such gates can + // can be instantiated with `add_gate` rather than being instantiated slot by slot + + // lookup_iter will iterate over the lookups that can be placed in fully utilized + // gates, splitting them in chunks that can be placed in the same `LookupGate` + let lookup_iter = lookups.chunks_exact(num_slots); + // `last_chunk` will contain the remainder of lookups, which cannot fill all the + // slots of a `LookupGate`; this last chunk will be processed by incrementally + // filling slots, to avoid that the `LookupGenerator` is run on unused slots + let last_chunk = lookup_iter.remainder(); + // handle chunks that can fill all the slots of a `LookupGate` + lookup_iter.for_each(|chunk| { + let row = self.add_gate(gate.clone(), vec![]); + for (i, (looking_in, looking_out)) in chunk.iter().enumerate() { + let gate_in = Target::wire(row, LookupGate::wire_ith_looking_inp(i)); + let gate_out = Target::wire(row, LookupGate::wire_ith_looking_out(i)); + self.connect(gate_in, *looking_in); + self.connect(gate_out, *looking_out); + } + }); + // deal with the last chunk + for (looking_in, looking_out) in last_chunk.iter() { let (gate, i) = - self.find_slot(gate, &[F::from_canonical_usize(lut_index)], &[]); + self.find_slot(gate.clone(), &[F::from_canonical_usize(lut_index)], &[]); let gate_in = Target::wire(gate, LookupGate::wire_ith_looking_inp(i)); let gate_out = Target::wire(gate, LookupGate::wire_ith_looking_out(i)); - self.connect(gate_in, looking_in); - self.connect(gate_out, looking_out); + self.connect(gate_in, *looking_in); + self.connect(gate_out, *looking_out); } // Create LUT gates. Nothing is connected to them. let last_lut_gate = self.num_gates(); let num_lut_entries = LookupTableGate::num_slots(&self.config); let num_lut_rows = (self.get_luts_idx_length(lut_index) - 1) / num_lut_entries + 1; - let num_lut_cells = num_lut_entries * num_lut_rows; - for _ in 0..num_lut_cells { - let gate = - LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate); - self.find_slot(gate, &[], &[]); + let gate = + LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate); + // Also instances of `LookupTableGate` can be placed with the `add_gate` function + // rather than being instantiated slot by slot; note that in this case there is no + // need to separately handle the last chunk of LUT entries that cannot fill all the + // slots of a `LookupTableGate`, as the generator already handles empty slots + for _ in 0..num_lut_rows { + self.add_gate(gate.clone(), vec![]); } let first_lut_gate = self.num_gates() - 1; diff --git a/plonky2/src/lookup_test.rs b/plonky2/src/lookup_test.rs index bca90d59e3..af85decaeb 100644 --- a/plonky2/src/lookup_test.rs +++ b/plonky2/src/lookup_test.rs @@ -467,6 +467,120 @@ pub fn test_same_luts() -> anyhow::Result<()> { Ok(()) } +#[test] +fn test_big_lut() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + const LUT_SIZE: usize = u16::MAX as usize + 1; + let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16); + let lut_fn = |inp: u16| inp / 10; + let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 51; + let look_val_b = 2; + + let output_a = builder.add_lookup_from_index(initial_a, lut_index); + let output_b = builder.add_lookup_from_index(initial_b, lut_index); + + builder.register_public_input(output_a); + builder.register_public_input(output_b); + + let data = builder.build::(); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let proof = data.prove(pw)?; + assert_eq!( + proof.public_inputs[0], + F::from_canonical_u16(lut_fn(look_val_a)) + ); + assert_eq!( + proof.public_inputs[1], + F::from_canonical_u16(lut_fn(look_val_b)) + ); + + data.verify(proof) +} + +#[test] +fn test_many_lookups_on_big_lut() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + const LUT_SIZE: usize = u16::MAX as usize + 1; + let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16); + let lut_fn = |inp: u16| inp / 10; + let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs); + + let inputs = (0..LUT_SIZE) + .map(|_| { + let input_target = builder.add_virtual_target(); + _ = builder.add_lookup_from_index(input_target, lut_index); + input_target + }) + .collect::>(); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 51; + let look_val_b = 2; + + let output_a = builder.add_lookup_from_index(initial_a, lut_index); + let output_b = builder.add_lookup_from_index(initial_b, lut_index); + let sum = builder.add(output_a, output_b); + + builder.register_public_input(sum); + + let data = builder.build::(); + + let mut pw = PartialWitness::new(); + + inputs + .into_iter() + .enumerate() + .for_each(|(i, t)| pw.set_target(t, F::from_canonical_usize(i))); + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let proof = data.prove(pw)?; + assert_eq!( + proof.public_inputs[0], + F::from_canonical_u16(lut_fn(look_val_a) + lut_fn(look_val_b)) + ); + + data.verify(proof) +} + fn init_logger() -> anyhow::Result<()> { let mut builder = env_logger::Builder::from_default_env(); builder.format_timestamp(None); From 8afd06cfdd4db635c9523a93a0ac3be8a46875fe Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Fri, 29 Sep 2023 11:24:36 -0400 Subject: [PATCH 19/34] Fix description of Range-Check columns in STARK modules --- evm/src/arithmetic/columns.rs | 7 ++----- evm/src/byte_packing/columns.rs | 7 ++----- evm/src/memory/columns.rs | 4 ++-- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index f2646fc565..36eb983e0b 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -104,12 +104,9 @@ pub(crate) const MODULAR_AUX_INPUT_HI: Range = AUX_REGISTER_2; // Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; -// Need one column for the table, then two columns for every value -// that needs to be range checked in the trace, namely the permutation -// of the column and the permutation of the range. The two -// permutations associated to column i will be in columns RC_COLS[2i] -// and RC_COLS[2i+1]. +/// The counter column (used for the range check) starts from 0 and increments. pub(crate) const RANGE_COUNTER: usize = START_SHARED_COLS + NUM_SHARED_COLS; +/// The frequencies column used in logUp. pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS + 2; diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs index fdaa46211f..4eff0df8f5 100644 --- a/evm/src/byte_packing/columns.rs +++ b/evm/src/byte_packing/columns.rs @@ -33,12 +33,9 @@ pub(crate) const fn value_bytes(i: usize) -> usize { BYTES_VALUES_START + i } -// We need one column for the table, then two columns for every value -// that needs to be range checked in the trace (all written bytes), -// namely the permutation of the column and the permutation of the range. -// The two permutations associated to the byte in column i will be in -// columns RC_COLS[2i] and RC_COLS[2i+1]. +/// The counter column (used for the range check) starts from 0 and increments. pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; +/// The frequencies column used in logUp. pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + 2; diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 56b121e1e2..9a41323200 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -29,9 +29,9 @@ pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1; // We use a range check to enforce the ordering. pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + 1; -// The counter column (used for the range check) starts from 0 and increments. +/// The counter column (used for the range check) starts from 0 and increments. pub(crate) const COUNTER: usize = RANGE_CHECK + 1; -// The frequencies column used in logUp. +/// The frequencies column used in logUp. pub(crate) const FREQUENCIES: usize = COUNTER + 1; pub(crate) const NUM_COLUMNS: usize = FREQUENCIES + 1; From 0f19cd0dbc25f9f1aa8fc325ae4dd1b95ca933b3 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:47:23 -0400 Subject: [PATCH 20/34] Make gas fit in 2 limbs (#1261) * Make gas fit in 2 limbs * Fix recursive challenger * Fix indices * Add clarifying comments on ranges supported * Add mention to revert before production --- evm/src/cpu/columns/mod.rs | 4 +- evm/src/cpu/gas.rs | 42 ++++++++----- evm/src/cpu/jumps.rs | 13 ++-- evm/src/cpu/syscalls_exceptions.rs | 21 ++++--- evm/src/fixed_recursive_verifier.rs | 28 ++++++--- evm/src/generation/mod.rs | 5 +- evm/src/get_challenges.rs | 24 +++++--- evm/src/proof.rs | 95 ++++++++++++++++++----------- evm/src/recursive_verifier.rs | 72 ++++++++++++---------- evm/src/witness/operation.rs | 6 +- evm/src/witness/transition.rs | 5 +- 11 files changed, 190 insertions(+), 125 deletions(-) diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index fecc8df986..cc98fceb3f 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -58,8 +58,8 @@ pub struct CpuColumnsView { /// If CPU cycle: We're in kernel (privileged) mode. pub is_kernel_mode: T, - /// If CPU cycle: Gas counter. - pub gas: T, + /// If CPU cycle: Gas counter, split in two 32-bit limbs in little-endian order. + pub gas: [T; 2], /// If CPU cycle: flags for EVM instructions (a few cannot be shared; see the comments in /// `OpsColumnsView`). diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 51f375c056..694fb0f47e 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -70,20 +70,26 @@ fn eval_packed_accumulate( }) .sum(); - let constr = nv.gas - (lv.gas + gas_used); + // TODO: This may cause soundness issue if the recomputed gas (as u64) overflows the field size. + // This is fine as we are only using two-limbs for testing purposes (to support all cases from + // the Ethereum test suite). + // This should be changed back to a single 32-bit limb before going into production! + let gas_diff = nv.gas[1] * P::Scalar::from_canonical_u64(1 << 32) + nv.gas[0] + - (lv.gas[1] * P::Scalar::from_canonical_u64(1 << 32) + lv.gas[0]); + let constr = gas_diff - gas_used; yield_constr.constraint_transition(filter * constr); for (maybe_cost, op_flag) in izip!(SIMPLE_OPCODES.into_iter(), lv.op.into_iter()) { if let Some(cost) = maybe_cost { let cost = P::Scalar::from_canonical_u32(cost); - yield_constr.constraint_transition(op_flag * (nv.gas - lv.gas - cost)); + yield_constr.constraint_transition(op_flag * (gas_diff - cost)); } } // For jumps. let jump_gas_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) + lv.opcode_bits[0] * P::Scalar::from_canonical_u32(G_HIGH.unwrap() - G_MID.unwrap()); - yield_constr.constraint_transition(lv.op.jumps * (nv.gas - lv.gas - jump_gas_cost)); + yield_constr.constraint_transition(lv.op.jumps * (gas_diff - jump_gas_cost)); // For binary_ops. // MUL, DIV and MOD are differentiated from ADD, SUB, LT, GT and BYTE by their first and fifth bits set to 0. @@ -92,13 +98,13 @@ fn eval_packed_accumulate( + cost_filter * (P::Scalar::from_canonical_u32(G_VERYLOW.unwrap()) - P::Scalar::from_canonical_u32(G_LOW.unwrap())); - yield_constr.constraint_transition(lv.op.binary_op * (nv.gas - lv.gas - binary_op_cost)); + yield_constr.constraint_transition(lv.op.binary_op * (gas_diff - binary_op_cost)); // For ternary_ops. // SUBMOD is differentiated by its second bit set to 1. let ternary_op_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) - lv.opcode_bits[1] * P::Scalar::from_canonical_u32(G_MID.unwrap()); - yield_constr.constraint_transition(lv.op.ternary_op * (nv.gas - lv.gas - ternary_op_cost)); + yield_constr.constraint_transition(lv.op.ternary_op * (gas_diff - ternary_op_cost)); } fn eval_packed_init( @@ -111,7 +117,8 @@ fn eval_packed_init( // `nv` is the first row that executes an instruction. let filter = (is_cpu_cycle - P::ONES) * is_cpu_cycle_next; // Set initial gas to zero. - yield_constr.constraint_transition(filter * nv.gas); + yield_constr.constraint_transition(filter * nv.gas[0]); + yield_constr.constraint_transition(filter * nv.gas[1]); } pub fn eval_packed( @@ -154,16 +161,22 @@ fn eval_ext_circuit_accumulate, const D: usize>( }, ); - let constr = { - let t = builder.add_extension(lv.gas, gas_used); - builder.sub_extension(nv.gas, t) - }; + // TODO: This may cause soundness issue if the recomputed gas (as u64) overflows the field size. + // This is fine as we are only using two-limbs for testing purposes (to support all cases from + // the Ethereum test suite). + // This should be changed back to a single 32-bit limb before going into production! + let nv_gas = + builder.mul_const_add_extension(F::from_canonical_u64(1 << 32), nv.gas[1], nv.gas[0]); + let lv_gas = + builder.mul_const_add_extension(F::from_canonical_u64(1 << 32), lv.gas[1], lv.gas[0]); + let nv_lv_diff = builder.sub_extension(nv_gas, lv_gas); + + let constr = builder.sub_extension(nv_lv_diff, gas_used); let filtered_constr = builder.mul_extension(filter, constr); yield_constr.constraint_transition(builder, filtered_constr); for (maybe_cost, op_flag) in izip!(SIMPLE_OPCODES.into_iter(), lv.op.into_iter()) { if let Some(cost) = maybe_cost { - let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_u32(cost), @@ -184,7 +197,6 @@ fn eval_ext_circuit_accumulate, const D: usize>( let jump_gas_cost = builder.add_const_extension(jump_gas_cost, F::from_canonical_u32(G_MID.unwrap())); - let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); let gas_diff = builder.sub_extension(nv_lv_diff, jump_gas_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); @@ -204,7 +216,6 @@ fn eval_ext_circuit_accumulate, const D: usize>( let binary_op_cost = builder.add_const_extension(binary_op_cost, F::from_canonical_u32(G_LOW.unwrap())); - let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); let gas_diff = builder.sub_extension(nv_lv_diff, binary_op_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); @@ -219,7 +230,6 @@ fn eval_ext_circuit_accumulate, const D: usize>( let ternary_op_cost = builder.add_const_extension(ternary_op_cost, F::from_canonical_u32(G_MID.unwrap())); - let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); let gas_diff = builder.sub_extension(nv_lv_diff, ternary_op_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); @@ -236,7 +246,9 @@ fn eval_ext_circuit_init, const D: usize>( let is_cpu_cycle_next = builder.add_many_extension(COL_MAP.op.iter().map(|&col_i| nv[col_i])); let filter = builder.mul_sub_extension(is_cpu_cycle, is_cpu_cycle_next, is_cpu_cycle_next); // Set initial gas to zero. - let constr = builder.mul_extension(filter, nv.gas); + let constr = builder.mul_extension(filter, nv.gas[0]); + yield_constr.constraint_transition(builder, constr); + let constr = builder.mul_extension(filter, nv.gas[1]); yield_constr.constraint_transition(builder, constr); } diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 62d9bdfd25..1829177384 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -23,9 +23,8 @@ pub fn eval_packed_exit_kernel( // but we trust the kernel to set them to zero). yield_constr.constraint_transition(filter * (input[0] - nv.program_counter)); yield_constr.constraint_transition(filter * (input[1] - nv.is_kernel_mode)); - yield_constr.constraint_transition(filter * (input[6] - nv.gas)); - // High limb of gas must be 0 for convenient detection of overflow. - yield_constr.constraint(filter * input[7]); + yield_constr.constraint_transition(filter * (input[6] - nv.gas[0])); + yield_constr.constraint_transition(filter * (input[7] - nv.gas[1])); } pub fn eval_ext_circuit_exit_kernel, const D: usize>( @@ -50,14 +49,14 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize yield_constr.constraint_transition(builder, kernel_constr); { - let diff = builder.sub_extension(input[6], nv.gas); + let diff = builder.sub_extension(input[6], nv.gas[0]); let constr = builder.mul_extension(filter, diff); yield_constr.constraint_transition(builder, constr); } { - // High limb of gas must be 0 for convenient detection of overflow. - let constr = builder.mul_extension(filter, input[7]); - yield_constr.constraint(builder, constr); + let diff = builder.sub_extension(input[7], nv.gas[1]); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint_transition(builder, constr); } } diff --git a/evm/src/cpu/syscalls_exceptions.rs b/evm/src/cpu/syscalls_exceptions.rs index abc47baf73..f9ea9a0a9f 100644 --- a/evm/src/cpu/syscalls_exceptions.rs +++ b/evm/src/cpu/syscalls_exceptions.rs @@ -99,7 +99,8 @@ pub fn eval_packed( // Maintain current context yield_constr.constraint_transition(total_filter * (nv.context - lv.context)); // Reset gas counter to zero. - yield_constr.constraint_transition(total_filter * nv.gas); + yield_constr.constraint_transition(total_filter * nv.gas[0]); + yield_constr.constraint_transition(total_filter * nv.gas[1]); // This memory channel is constrained in `stack.rs`. let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; @@ -108,9 +109,9 @@ pub fn eval_packed( yield_constr.constraint(filter_exception * (output[0] - lv.program_counter)); // Check the kernel mode, for syscalls only yield_constr.constraint(filter_syscall * (output[1] - lv.is_kernel_mode)); - yield_constr.constraint(total_filter * (output[6] - lv.gas)); - // TODO: Range check `output[6]`. - yield_constr.constraint(total_filter * output[7]); // High limb of gas is zero. + // TODO: Range check `output[6] and output[7]`. + yield_constr.constraint(total_filter * (output[6] - lv.gas[0])); + yield_constr.constraint(total_filter * (output[7] - lv.gas[1])); // Zero the rest of that register // output[1] is 0 for exceptions, but not for syscalls @@ -265,7 +266,9 @@ pub fn eval_ext_circuit, const D: usize>( } // Reset gas counter to zero. { - let constr = builder.mul_extension(total_filter, nv.gas); + let constr = builder.mul_extension(total_filter, nv.gas[0]); + yield_constr.constraint_transition(builder, constr); + let constr = builder.mul_extension(total_filter, nv.gas[1]); yield_constr.constraint_transition(builder, constr); } @@ -290,15 +293,15 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(filter_syscall, diff); yield_constr.constraint(builder, constr); } + // TODO: Range check `output[6]` and `output[7]. { - let diff = builder.sub_extension(output[6], lv.gas); + let diff = builder.sub_extension(output[6], lv.gas[0]); let constr = builder.mul_extension(total_filter, diff); yield_constr.constraint(builder, constr); } - // TODO: Range check `output[6]`. { - // High limb of gas is zero. - let constr = builder.mul_extension(total_filter, output[7]); + let diff = builder.sub_extension(output[7], lv.gas[1]); + let constr = builder.mul_extension(total_filter, diff); yield_constr.constraint(builder, constr); } diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 1c0928e4a8..766b9102bb 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -655,15 +655,18 @@ where builder.connect(pvs.txn_number_before, lhs.txn_number_before); builder.connect(pvs.txn_number_after, rhs.txn_number_after); - // Connect lhs `txn_number_after`with rhs `txn_number_before`. + // Connect lhs `txn_number_after` with rhs `txn_number_before`. builder.connect(lhs.txn_number_after, rhs.txn_number_before); // Connect the gas used in public values to the lhs and rhs values correctly. - builder.connect(pvs.gas_used_before, lhs.gas_used_before); - builder.connect(pvs.gas_used_after, rhs.gas_used_after); + builder.connect(pvs.gas_used_before[0], lhs.gas_used_before[0]); + builder.connect(pvs.gas_used_before[1], lhs.gas_used_before[1]); + builder.connect(pvs.gas_used_after[0], rhs.gas_used_after[0]); + builder.connect(pvs.gas_used_after[1], rhs.gas_used_after[1]); - // Connect lhs `gas_used_after`with rhs `gas_used_before`. - builder.connect(lhs.gas_used_after, rhs.gas_used_before); + // Connect lhs `gas_used_after` with rhs `gas_used_before`. + builder.connect(lhs.gas_used_after[0], rhs.gas_used_before[0]); + builder.connect(lhs.gas_used_after[1], rhs.gas_used_before[1]); // Connect the `block_bloom` in public values to the lhs and rhs values correctly. for (&limb0, &limb1) in pvs.block_bloom_after.iter().zip(&rhs.block_bloom_after) { @@ -672,7 +675,7 @@ where for (&limb0, &limb1) in pvs.block_bloom_before.iter().zip(&lhs.block_bloom_before) { builder.connect(limb0, limb1); } - // Connect lhs `block_bloom_after`with rhs `block_bloom_before`. + // Connect lhs `block_bloom_after` with rhs `block_bloom_before`. for (&limb0, &limb1) in lhs.block_bloom_after.iter().zip(&rhs.block_bloom_before) { builder.connect(limb0, limb1); } @@ -846,8 +849,12 @@ where F: RichField + Extendable, { builder.connect( - x.block_metadata.block_gas_used, - x.extra_block_data.gas_used_after, + x.block_metadata.block_gas_used[0], + x.extra_block_data.gas_used_after[0], + ); + builder.connect( + x.block_metadata.block_gas_used[1], + x.extra_block_data.gas_used_after[1], ); for (&limb0, &limb1) in x @@ -867,8 +874,9 @@ where let zero = builder.constant(F::ZERO); // The initial number of transactions is 0. builder.connect(x.extra_block_data.txn_number_before, zero); - // The initial gas used is 0 - builder.connect(x.extra_block_data.gas_used_before, zero); + // The initial gas used is 0. + builder.connect(x.extra_block_data.gas_used_before[0], zero); + builder.connect(x.extra_block_data.gas_used_before[1], zero); // The initial bloom filter is all zeroes. for t in x.extra_block_data.block_bloom_before { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 628a0600bb..317326e8b6 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -304,7 +304,10 @@ fn simulate_cpu, const D: usize>( row.context = F::from_canonical_usize(state.registers.context); row.program_counter = F::from_canonical_usize(pc); row.is_kernel_mode = F::ONE; - row.gas = F::from_canonical_u64(state.registers.gas_used); + row.gas = [ + F::from_canonical_u32(state.registers.gas_used as u32), + F::from_canonical_u32((state.registers.gas_used >> 32) as u32), + ]; row.stack_len = F::from_canonical_usize(state.registers.stack_len); loop { diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 1d0aeac9da..715c1097be 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -62,12 +62,16 @@ fn observe_block_metadata< challenger.observe_element(u256_to_u32(block_metadata.block_number)?); challenger.observe_element(u256_to_u32(block_metadata.block_difficulty)?); challenger.observe_elements(&h256_limbs::(block_metadata.block_random)); - challenger.observe_element(u256_to_u32(block_metadata.block_gaslimit)?); + let gaslimit = u256_to_u64(block_metadata.block_gaslimit)?; + challenger.observe_element(gaslimit.0); + challenger.observe_element(gaslimit.1); challenger.observe_element(u256_to_u32(block_metadata.block_chain_id)?); let basefee = u256_to_u64(block_metadata.block_base_fee)?; challenger.observe_element(basefee.0); challenger.observe_element(basefee.1); - challenger.observe_element(u256_to_u32(block_metadata.block_gas_used)?); + let gas_used = u256_to_u64(block_metadata.block_gas_used)?; + challenger.observe_element(gas_used.0); + challenger.observe_element(gas_used.1); for i in 0..8 { challenger.observe_elements(&u256_limbs(block_metadata.block_bloom[i])); } @@ -90,10 +94,10 @@ fn observe_block_metadata_target< challenger.observe_element(block_metadata.block_number); challenger.observe_element(block_metadata.block_difficulty); challenger.observe_elements(&block_metadata.block_random); - challenger.observe_element(block_metadata.block_gaslimit); + challenger.observe_elements(&block_metadata.block_gaslimit); challenger.observe_element(block_metadata.block_chain_id); challenger.observe_elements(&block_metadata.block_base_fee); - challenger.observe_element(block_metadata.block_gas_used); + challenger.observe_elements(&block_metadata.block_gas_used); challenger.observe_elements(&block_metadata.block_bloom); } @@ -108,8 +112,12 @@ fn observe_extra_block_data< challenger.observe_elements(&h256_limbs(extra_data.genesis_state_root)); challenger.observe_element(u256_to_u32(extra_data.txn_number_before)?); challenger.observe_element(u256_to_u32(extra_data.txn_number_after)?); - challenger.observe_element(u256_to_u32(extra_data.gas_used_before)?); - challenger.observe_element(u256_to_u32(extra_data.gas_used_after)?); + let gas_used_before = u256_to_u64(extra_data.gas_used_before)?; + challenger.observe_element(gas_used_before.0); + challenger.observe_element(gas_used_before.1); + let gas_used_after = u256_to_u64(extra_data.gas_used_after)?; + challenger.observe_element(gas_used_after.0); + challenger.observe_element(gas_used_after.1); for i in 0..8 { challenger.observe_elements(&u256_limbs(extra_data.block_bloom_before[i])); } @@ -133,8 +141,8 @@ fn observe_extra_block_data_target< challenger.observe_elements(&extra_data.genesis_state_root); challenger.observe_element(extra_data.txn_number_before); challenger.observe_element(extra_data.txn_number_after); - challenger.observe_element(extra_data.gas_used_before); - challenger.observe_element(extra_data.gas_used_after); + challenger.observe_elements(&extra_data.gas_used_before); + challenger.observe_elements(&extra_data.gas_used_after); challenger.observe_elements(&extra_data.block_bloom_before); challenger.observe_elements(&extra_data.block_bloom_after); } diff --git a/evm/src/proof.rs b/evm/src/proof.rs index c6d15dd1e3..fd6c4f3e9a 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -89,26 +89,31 @@ pub struct BlockHashes { pub cur_hash: H256, } +// TODO: Before going into production, `block_gas_used` and `block_gaslimit` here +// as well as `gas_used_before` / `gas_used_after` in `ExtraBlockData` should be +// updated to fit in a single 32-bit limb, as supporting 64-bit values for those +// fields is only necessary for testing purposes. /// Metadata contained in a block header. Those are identical between /// all state transition proofs within the same block. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct BlockMetadata { /// The address of this block's producer. pub block_beneficiary: Address, - /// The timestamp of this block. + /// The timestamp of this block. It must fit in a `u32`. pub block_timestamp: U256, - /// The index of this block. + /// The index of this block. It must fit in a `u32`. pub block_number: U256, /// The difficulty (before PoS transition) of this block. pub block_difficulty: U256, + /// The `mix_hash` value of this block. pub block_random: H256, - /// The gas limit of this block. It must fit in a `u32`. + /// The gas limit of this block. It must fit in a `u64`. pub block_gaslimit: U256, - /// The chain id of this block. + /// The chain id of this block. It must fit in a `u32`. pub block_chain_id: U256, - /// The base fee of this block. + /// The base fee of this block. It must fit in a `u64`. pub block_base_fee: U256, - /// The total gas used in this block. It must fit in a `u32`. + /// The total gas used in this block. It must fit in a `u64`. pub block_gas_used: U256, /// The block bloom of this block, represented as the consecutive /// 32-byte chunks of a block's final bloom filter string. @@ -191,10 +196,10 @@ impl PublicValuesTarget { buffer.write_target(block_number)?; buffer.write_target(block_difficulty)?; buffer.write_target_array(&block_random)?; - buffer.write_target(block_gaslimit)?; + buffer.write_target_array(&block_gaslimit)?; buffer.write_target(block_chain_id)?; buffer.write_target_array(&block_base_fee)?; - buffer.write_target(block_gas_used)?; + buffer.write_target_array(&block_gas_used)?; buffer.write_target_array(&block_bloom)?; let BlockHashesTarget { @@ -216,8 +221,8 @@ impl PublicValuesTarget { buffer.write_target_array(&genesis_state_root)?; buffer.write_target(txn_number_before)?; buffer.write_target(txn_number_after)?; - buffer.write_target(gas_used_before)?; - buffer.write_target(gas_used_after)?; + buffer.write_target_array(&gas_used_before)?; + buffer.write_target_array(&gas_used_after)?; buffer.write_target_array(&block_bloom_before)?; buffer.write_target_array(&block_bloom_after)?; @@ -243,10 +248,10 @@ impl PublicValuesTarget { block_number: buffer.read_target()?, block_difficulty: buffer.read_target()?, block_random: buffer.read_target_array()?, - block_gaslimit: buffer.read_target()?, + block_gaslimit: buffer.read_target_array()?, block_chain_id: buffer.read_target()?, block_base_fee: buffer.read_target_array()?, - block_gas_used: buffer.read_target()?, + block_gas_used: buffer.read_target_array()?, block_bloom: buffer.read_target_array()?, }; @@ -259,8 +264,8 @@ impl PublicValuesTarget { genesis_state_root: buffer.read_target_array()?, txn_number_before: buffer.read_target()?, txn_number_after: buffer.read_target()?, - gas_used_before: buffer.read_target()?, - gas_used_after: buffer.read_target()?, + gas_used_before: buffer.read_target_array()?, + gas_used_after: buffer.read_target_array()?, block_bloom_before: buffer.read_target_array()?, block_bloom_after: buffer.read_target_array()?, }; @@ -417,15 +422,15 @@ pub struct BlockMetadataTarget { pub block_number: Target, pub block_difficulty: Target, pub block_random: [Target; 8], - pub block_gaslimit: Target, + pub block_gaslimit: [Target; 2], pub block_chain_id: Target, pub block_base_fee: [Target; 2], - pub block_gas_used: Target, + pub block_gas_used: [Target; 2], pub block_bloom: [Target; 64], } impl BlockMetadataTarget { - pub const SIZE: usize = 85; + pub const SIZE: usize = 87; pub fn from_public_inputs(pis: &[Target]) -> Self { let block_beneficiary = pis[0..5].try_into().unwrap(); @@ -433,11 +438,11 @@ impl BlockMetadataTarget { let block_number = pis[6]; let block_difficulty = pis[7]; let block_random = pis[8..16].try_into().unwrap(); - let block_gaslimit = pis[16]; - let block_chain_id = pis[17]; - let block_base_fee = pis[18..20].try_into().unwrap(); - let block_gas_used = pis[20]; - let block_bloom = pis[21..85].try_into().unwrap(); + let block_gaslimit = pis[16..18].try_into().unwrap(); + let block_chain_id = pis[18]; + let block_base_fee = pis[19..21].try_into().unwrap(); + let block_gas_used = pis[21..23].try_into().unwrap(); + let block_bloom = pis[23..87].try_into().unwrap(); Self { block_beneficiary, @@ -473,12 +478,16 @@ impl BlockMetadataTarget { block_random: core::array::from_fn(|i| { builder.select(condition, bm0.block_random[i], bm1.block_random[i]) }), - block_gaslimit: builder.select(condition, bm0.block_gaslimit, bm1.block_gaslimit), + block_gaslimit: core::array::from_fn(|i| { + builder.select(condition, bm0.block_gaslimit[i], bm1.block_gaslimit[i]) + }), block_chain_id: builder.select(condition, bm0.block_chain_id, bm1.block_chain_id), block_base_fee: core::array::from_fn(|i| { builder.select(condition, bm0.block_base_fee[i], bm1.block_base_fee[i]) }), - block_gas_used: builder.select(condition, bm0.block_gas_used, bm1.block_gas_used), + block_gas_used: core::array::from_fn(|i| { + builder.select(condition, bm0.block_gas_used[i], bm1.block_gas_used[i]) + }), block_bloom: core::array::from_fn(|i| { builder.select(condition, bm0.block_bloom[i], bm1.block_bloom[i]) }), @@ -499,12 +508,16 @@ impl BlockMetadataTarget { for i in 0..8 { builder.connect(bm0.block_random[i], bm1.block_random[i]); } - builder.connect(bm0.block_gaslimit, bm1.block_gaslimit); + for i in 0..2 { + builder.connect(bm0.block_gaslimit[i], bm1.block_gaslimit[i]) + } builder.connect(bm0.block_chain_id, bm1.block_chain_id); for i in 0..2 { builder.connect(bm0.block_base_fee[i], bm1.block_base_fee[i]) } - builder.connect(bm0.block_gas_used, bm1.block_gas_used); + for i in 0..2 { + builder.connect(bm0.block_gas_used[i], bm1.block_gas_used[i]) + } for i in 0..64 { builder.connect(bm0.block_bloom[i], bm1.block_bloom[i]) } @@ -561,23 +574,23 @@ pub struct ExtraBlockDataTarget { pub genesis_state_root: [Target; 8], pub txn_number_before: Target, pub txn_number_after: Target, - pub gas_used_before: Target, - pub gas_used_after: Target, + pub gas_used_before: [Target; 2], + pub gas_used_after: [Target; 2], pub block_bloom_before: [Target; 64], pub block_bloom_after: [Target; 64], } impl ExtraBlockDataTarget { - const SIZE: usize = 140; + const SIZE: usize = 142; pub fn from_public_inputs(pis: &[Target]) -> Self { let genesis_state_root = pis[0..8].try_into().unwrap(); let txn_number_before = pis[8]; let txn_number_after = pis[9]; - let gas_used_before = pis[10]; - let gas_used_after = pis[11]; - let block_bloom_before = pis[12..76].try_into().unwrap(); - let block_bloom_after = pis[76..140].try_into().unwrap(); + let gas_used_before = pis[10..12].try_into().unwrap(); + let gas_used_after = pis[12..14].try_into().unwrap(); + let block_bloom_before = pis[14..78].try_into().unwrap(); + let block_bloom_after = pis[78..142].try_into().unwrap(); Self { genesis_state_root, @@ -610,8 +623,12 @@ impl ExtraBlockDataTarget { ed1.txn_number_before, ), txn_number_after: builder.select(condition, ed0.txn_number_after, ed1.txn_number_after), - gas_used_before: builder.select(condition, ed0.gas_used_before, ed1.gas_used_before), - gas_used_after: builder.select(condition, ed0.gas_used_after, ed1.gas_used_after), + gas_used_before: core::array::from_fn(|i| { + builder.select(condition, ed0.gas_used_before[i], ed1.gas_used_before[i]) + }), + gas_used_after: core::array::from_fn(|i| { + builder.select(condition, ed0.gas_used_after[i], ed1.gas_used_after[i]) + }), block_bloom_before: core::array::from_fn(|i| { builder.select( condition, @@ -639,8 +656,12 @@ impl ExtraBlockDataTarget { } builder.connect(ed0.txn_number_before, ed1.txn_number_before); builder.connect(ed0.txn_number_after, ed1.txn_number_after); - builder.connect(ed0.gas_used_before, ed1.gas_used_before); - builder.connect(ed1.gas_used_after, ed1.gas_used_after); + for i in 0..2 { + builder.connect(ed0.gas_used_before[i], ed1.gas_used_before[i]); + } + for i in 0..2 { + builder.connect(ed1.gas_used_after[i], ed1.gas_used_after[i]); + } for i in 0..64 { builder.connect(ed0.block_bloom_before[i], ed1.block_bloom_before[i]); } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 601b16687f..04259208b8 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -519,26 +519,10 @@ pub(crate) fn get_memory_extra_looking_products_circuit< GlobalMetadata::BlockDifficulty as usize, public_values.block_metadata.block_difficulty, ), - ( - GlobalMetadata::BlockGasLimit as usize, - public_values.block_metadata.block_gaslimit, - ), ( GlobalMetadata::BlockChainId as usize, public_values.block_metadata.block_chain_id, ), - ( - GlobalMetadata::BlockGasUsed as usize, - public_values.block_metadata.block_gas_used, - ), - ( - GlobalMetadata::BlockGasUsedBefore as usize, - public_values.extra_block_data.gas_used_before, - ), - ( - GlobalMetadata::BlockGasUsedAfter as usize, - public_values.extra_block_data.gas_used_after, - ), ( GlobalMetadata::TxnNumberBefore as usize, public_values.extra_block_data.txn_number_before, @@ -549,7 +533,10 @@ pub(crate) fn get_memory_extra_looking_products_circuit< ), ]; - let beneficiary_random_base_fee_cur_hash_fields: [(usize, &[Target]); 4] = [ + // This contains the `block_beneficiary`, `block_random`, `block_base_fee`, + // `block_gaslimit`, `block_gas_used` as well as `cur_hash`, `gas_used_before` + // and `gas_used_after`. + let block_fields_arrays: [(usize, &[Target]); 8] = [ ( GlobalMetadata::BlockBeneficiary as usize, &public_values.block_metadata.block_beneficiary, @@ -562,10 +549,26 @@ pub(crate) fn get_memory_extra_looking_products_circuit< GlobalMetadata::BlockBaseFee as usize, &public_values.block_metadata.block_base_fee, ), + ( + GlobalMetadata::BlockGasLimit as usize, + &public_values.block_metadata.block_gaslimit, + ), + ( + GlobalMetadata::BlockGasUsed as usize, + &public_values.block_metadata.block_gas_used, + ), ( GlobalMetadata::BlockCurrentHash as usize, &public_values.block_hashes.cur_hash, ), + ( + GlobalMetadata::BlockGasUsedBefore as usize, + &public_values.extra_block_data.gas_used_before, + ), + ( + GlobalMetadata::BlockGasUsedAfter as usize, + &public_values.extra_block_data.gas_used_after, + ), ]; let metadata_segment = builder.constant(F::from_canonical_u32(Segment::GlobalMetadata as u32)); @@ -581,7 +584,7 @@ pub(crate) fn get_memory_extra_looking_products_circuit< ); }); - beneficiary_random_base_fee_cur_hash_fields.map(|(field, targets)| { + block_fields_arrays.map(|(field, targets)| { product = add_data_write( builder, challenge, @@ -778,10 +781,10 @@ pub(crate) fn add_virtual_block_metadata, const D: let block_number = builder.add_virtual_public_input(); let block_difficulty = builder.add_virtual_public_input(); let block_random = builder.add_virtual_public_input_arr(); - let block_gaslimit = builder.add_virtual_public_input(); + let block_gaslimit = builder.add_virtual_public_input_arr(); let block_chain_id = builder.add_virtual_public_input(); let block_base_fee = builder.add_virtual_public_input_arr(); - let block_gas_used = builder.add_virtual_public_input(); + let block_gas_used = builder.add_virtual_public_input_arr(); let block_bloom = builder.add_virtual_public_input_arr(); BlockMetadataTarget { block_beneficiary, @@ -813,8 +816,8 @@ pub(crate) fn add_virtual_extra_block_data, const D let genesis_state_root = builder.add_virtual_public_input_arr(); let txn_number_before = builder.add_virtual_public_input(); let txn_number_after = builder.add_virtual_public_input(); - let gas_used_before = builder.add_virtual_public_input(); - let gas_used_after = builder.add_virtual_public_input(); + let gas_used_before = builder.add_virtual_public_input_arr(); + let gas_used_after = builder.add_virtual_public_input_arr(); let block_bloom_before: [Target; 64] = builder.add_virtual_public_input_arr(); let block_bloom_after: [Target; 64] = builder.add_virtual_public_input_arr(); ExtraBlockDataTarget { @@ -1027,10 +1030,10 @@ where &block_metadata_target.block_random, &h256_limbs(block_metadata.block_random), ); - witness.set_target( - block_metadata_target.block_gaslimit, - u256_to_u32(block_metadata.block_gaslimit)?, - ); + // Gaslimit fits in 2 limbs + let gaslimit = u256_to_u64(block_metadata.block_gaslimit)?; + witness.set_target(block_metadata_target.block_gaslimit[0], gaslimit.0); + witness.set_target(block_metadata_target.block_gaslimit[1], gaslimit.1); witness.set_target( block_metadata_target.block_chain_id, u256_to_u32(block_metadata.block_chain_id)?, @@ -1039,10 +1042,10 @@ where let basefee = u256_to_u64(block_metadata.block_base_fee)?; witness.set_target(block_metadata_target.block_base_fee[0], basefee.0); witness.set_target(block_metadata_target.block_base_fee[1], basefee.1); - witness.set_target( - block_metadata_target.block_gas_used, - u256_to_u32(block_metadata.block_gas_used)?, - ); + // Gas used fits in 2 limbs + let gas_used = u256_to_u64(block_metadata.block_gas_used)?; + witness.set_target(block_metadata_target.block_gas_used[0], gas_used.0); + witness.set_target(block_metadata_target.block_gas_used[1], gas_used.1); let mut block_bloom_limbs = [F::ZERO; 64]; for (i, limbs) in block_bloom_limbs.chunks_exact_mut(8).enumerate() { limbs.copy_from_slice(&u256_limbs(block_metadata.block_bloom[i])); @@ -1092,8 +1095,13 @@ where ed_target.txn_number_after, u256_to_u32(ed.txn_number_after)?, ); - witness.set_target(ed_target.gas_used_before, u256_to_u32(ed.gas_used_before)?); - witness.set_target(ed_target.gas_used_after, u256_to_u32(ed.gas_used_after)?); + // Gas used before/after fit in 2 limbs + let gas_used_before = u256_to_u64(ed.gas_used_before)?; + witness.set_target(ed_target.gas_used_before[0], gas_used_before.0); + witness.set_target(ed_target.gas_used_before[1], gas_used_before.1); + let gas_used_after = u256_to_u64(ed.gas_used_after)?; + witness.set_target(ed_target.gas_used_after[0], gas_used_after.0); + witness.set_target(ed_target.gas_used_after[1], gas_used_after.1); let block_bloom_before = ed.block_bloom_before; let mut block_bloom_limbs = [F::ZERO; 64]; diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 2abeaea4c3..0620069f00 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -557,7 +557,7 @@ pub(crate) fn generate_syscall( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - if TryInto::::try_into(state.registers.gas_used).is_err() { + if TryInto::::try_into(state.registers.gas_used).is_err() { return Err(ProgramError::GasLimitError); } @@ -650,7 +650,7 @@ pub(crate) fn generate_exit_kernel( assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1); let is_kernel_mode = is_kernel_mode_val != 0; let gas_used_val = kexit_info.0[3]; - if TryInto::::try_into(gas_used_val).is_err() { + if TryInto::::try_into(gas_used_val).is_err() { return Err(ProgramError::GasLimitError); } @@ -792,7 +792,7 @@ pub(crate) fn generate_exception( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - if TryInto::::try_into(state.registers.gas_used).is_err() { + if TryInto::::try_into(state.registers.gas_used).is_err() { return Err(ProgramError::GasLimitError); } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 5ce2a0ce43..2a710f4b94 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -242,7 +242,10 @@ fn base_row(state: &mut GenerationState) -> (CpuColumnsView, u8) row.context = F::from_canonical_usize(state.registers.context); row.program_counter = F::from_canonical_usize(state.registers.program_counter); row.is_kernel_mode = F::from_bool(state.registers.is_kernel); - row.gas = F::from_canonical_u64(state.registers.gas_used); + row.gas = [ + F::from_canonical_u32(state.registers.gas_used as u32), + F::from_canonical_u32((state.registers.gas_used >> 32) as u32), + ]; row.stack_len = F::from_canonical_usize(state.registers.stack_len); let opcode = read_code_memory(state, &mut row); From cd36e96cb844f3042aa7acf9c03ea8e66735c904 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Mon, 2 Oct 2023 09:59:45 -0400 Subject: [PATCH 21/34] Derive clone for txn RLP structs (#1264) * Derive Clone for txn rlp structs * Put txn rlp related structs behind testing module * Move module to end of file --- evm/src/generation/mpt.rs | 164 +++++++++++++++++++------------------- evm/tests/log_opcode.rs | 5 +- 2 files changed, 86 insertions(+), 83 deletions(-) diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index 1f1be2b9a0..f829c4e239 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -33,92 +33,14 @@ impl Default for AccountRlp { } } -#[derive(RlpEncodable, RlpDecodable, Debug)] -pub struct AccessListItemRlp { - pub address: Address, - pub storage_keys: Vec, -} - -#[derive(Debug)] -pub struct AddressOption(pub Option
); - -impl Encodable for AddressOption { - fn rlp_append(&self, s: &mut RlpStream) { - match self.0 { - None => { - s.append_empty_data(); - } - Some(value) => { - s.encoder().encode_value(&value.to_fixed_bytes()); - } - } - } -} - -impl Decodable for AddressOption { - fn decode(rlp: &Rlp) -> Result { - if rlp.is_int() && rlp.is_empty() { - return Ok(AddressOption(None)); - } - if rlp.is_data() && rlp.size() == 20 { - return Ok(AddressOption(Some(Address::decode(rlp)?))); - } - Err(DecoderError::RlpExpectedToBeData) - } -} - -#[derive(RlpEncodable, RlpDecodable, Debug)] -pub struct LegacyTransactionRlp { - pub nonce: U256, - pub gas_price: U256, - pub gas: U256, - pub to: AddressOption, - pub value: U256, - pub data: Bytes, - pub v: U256, - pub r: U256, - pub s: U256, -} - -#[derive(RlpEncodable, RlpDecodable, Debug)] -pub struct AccessListTransactionRlp { - pub chain_id: u64, - pub nonce: U256, - pub gas_price: U256, - pub gas: U256, - pub to: AddressOption, - pub value: U256, - pub data: Bytes, - pub access_list: Vec, - pub y_parity: U256, - pub r: U256, - pub s: U256, -} - -#[derive(RlpEncodable, RlpDecodable, Debug)] -pub struct FeeMarketTransactionRlp { - pub chain_id: u64, - pub nonce: U256, - pub max_priority_fee_per_gas: U256, - pub max_fee_per_gas: U256, - pub gas: U256, - pub to: AddressOption, - pub value: U256, - pub data: Bytes, - pub access_list: Vec, - pub y_parity: U256, - pub r: U256, - pub s: U256, -} - -#[derive(RlpEncodable, RlpDecodable, Debug)] +#[derive(RlpEncodable, RlpDecodable, Debug, Clone)] pub struct LogRlp { pub address: Address, pub topics: Vec, pub data: Bytes, } -#[derive(RlpEncodable, RlpDecodable, Debug)] +#[derive(RlpEncodable, RlpDecodable, Debug, Clone)] pub struct LegacyReceiptRlp { pub status: bool, pub cum_gas_used: U256, @@ -356,3 +278,85 @@ fn empty_nibbles() -> Nibbles { packed: U512::zero(), } } + +pub mod transaction_testing { + use super::*; + + #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + pub struct AccessListItemRlp { + pub address: Address, + pub storage_keys: Vec, + } + + #[derive(Debug, Clone)] + pub struct AddressOption(pub Option
); + + impl Encodable for AddressOption { + fn rlp_append(&self, s: &mut RlpStream) { + match self.0 { + None => { + s.append_empty_data(); + } + Some(value) => { + s.encoder().encode_value(&value.to_fixed_bytes()); + } + } + } + } + + impl Decodable for AddressOption { + fn decode(rlp: &Rlp) -> Result { + if rlp.is_int() && rlp.is_empty() { + return Ok(AddressOption(None)); + } + if rlp.is_data() && rlp.size() == 20 { + return Ok(AddressOption(Some(Address::decode(rlp)?))); + } + Err(DecoderError::RlpExpectedToBeData) + } + } + + #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + pub struct LegacyTransactionRlp { + pub nonce: U256, + pub gas_price: U256, + pub gas: U256, + pub to: AddressOption, + pub value: U256, + pub data: Bytes, + pub v: U256, + pub r: U256, + pub s: U256, + } + + #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + pub struct AccessListTransactionRlp { + pub chain_id: u64, + pub nonce: U256, + pub gas_price: U256, + pub gas: U256, + pub to: AddressOption, + pub value: U256, + pub data: Bytes, + pub access_list: Vec, + pub y_parity: U256, + pub r: U256, + pub s: U256, + } + + #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + pub struct FeeMarketTransactionRlp { + pub chain_id: u64, + pub nonce: U256, + pub max_priority_fee_per_gas: U256, + pub max_fee_per_gas: U256, + pub gas: U256, + pub to: AddressOption, + pub value: U256, + pub data: Bytes, + pub access_list: Vec, + pub y_parity: U256, + pub r: U256, + pub s: U256, + } +} diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index b9d587b359..3a0e5abf50 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -17,9 +17,8 @@ use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; -use plonky2_evm::generation::mpt::{ - AccountRlp, AddressOption, LegacyReceiptRlp, LegacyTransactionRlp, LogRlp, -}; +use plonky2_evm::generation::mpt::transaction_testing::{AddressOption, LegacyTransactionRlp}; +use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp, LogRlp}; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use plonky2_evm::prover::prove; From 3ac0c4ae18b9b550ce8e70949feeff94ab132f77 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 3 Oct 2023 17:47:10 +0200 Subject: [PATCH 22/34] Fix genesis state trie root when calling `prove_root` (#1271) * Fix genesis state trie root in some tests * Just do it in tests calling prove_block --- evm/src/fixed_recursive_verifier.rs | 20 ++++++++++++++------ evm/src/generation/mod.rs | 2 +- evm/src/get_challenges.rs | 4 ++-- evm/src/proof.rs | 23 +++++++++++++---------- evm/src/recursive_verifier.rs | 8 ++++---- evm/tests/log_opcode.rs | 7 ++++--- evm/tests/many_transactions.rs | 2 +- 7 files changed, 39 insertions(+), 27 deletions(-) diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 766b9102bb..1e76a30e4c 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -644,10 +644,18 @@ where rhs: &ExtraBlockDataTarget, ) { // Connect genesis state root values. - for (&limb0, &limb1) in pvs.genesis_state_root.iter().zip(&rhs.genesis_state_root) { + for (&limb0, &limb1) in pvs + .genesis_state_trie_root + .iter() + .zip(&rhs.genesis_state_trie_root) + { builder.connect(limb0, limb1); } - for (&limb0, &limb1) in pvs.genesis_state_root.iter().zip(&lhs.genesis_state_root) { + for (&limb0, &limb1) in pvs + .genesis_state_trie_root + .iter() + .zip(&lhs.genesis_state_trie_root) + { builder.connect(limb0, limb1); } @@ -793,9 +801,9 @@ where // Between blocks, the genesis state trie remains unchanged. for (&limb0, limb1) in lhs .extra_block_data - .genesis_state_root + .genesis_state_trie_root .iter() - .zip(rhs.extra_block_data.genesis_state_root) + .zip(rhs.extra_block_data.genesis_state_trie_root) { builder.connect(limb0, limb1); } @@ -834,7 +842,7 @@ where .trie_roots_before .state_root .iter() - .zip(x.extra_block_data.genesis_state_root) + .zip(x.extra_block_data.genesis_state_trie_root) { let mut constr = builder.sub(limb0, limb1); constr = builder.mul(has_not_parent_block, constr); @@ -1037,7 +1045,7 @@ where + BlockHashesTarget::BLOCK_HASHES_SIZE + 8; for (key, &value) in genesis_state_trie_keys.zip_eq(&h256_limbs::( - public_values.extra_block_data.genesis_state_root, + public_values.extra_block_data.genesis_state_trie_root, )) { nonzero_pis.insert(key, value); } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 317326e8b6..62182cd254 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -261,7 +261,7 @@ pub fn generate_traces, const D: usize>( let txn_number_after = read_metadata(GlobalMetadata::TxnNumberAfter); let extra_block_data = ExtraBlockData { - genesis_state_root: inputs.genesis_state_trie_root, + genesis_state_trie_root: inputs.genesis_state_trie_root, txn_number_before: inputs.txn_number_before, txn_number_after, gas_used_before: inputs.gas_used_before, diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 715c1097be..ed8ff91510 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -109,7 +109,7 @@ fn observe_extra_block_data< challenger: &mut Challenger, extra_data: &ExtraBlockData, ) -> Result<(), ProgramError> { - challenger.observe_elements(&h256_limbs(extra_data.genesis_state_root)); + challenger.observe_elements(&h256_limbs(extra_data.genesis_state_trie_root)); challenger.observe_element(u256_to_u32(extra_data.txn_number_before)?); challenger.observe_element(u256_to_u32(extra_data.txn_number_after)?); let gas_used_before = u256_to_u64(extra_data.gas_used_before)?; @@ -138,7 +138,7 @@ fn observe_extra_block_data_target< ) where C::Hasher: AlgebraicHasher, { - challenger.observe_elements(&extra_data.genesis_state_root); + challenger.observe_elements(&extra_data.genesis_state_trie_root); challenger.observe_element(extra_data.txn_number_before); challenger.observe_element(extra_data.txn_number_after); challenger.observe_elements(&extra_data.gas_used_before); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index fd6c4f3e9a..3f744a6134 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -125,7 +125,7 @@ pub struct BlockMetadata { #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ExtraBlockData { /// The state trie digest of the genesis block. - pub genesis_state_root: H256, + pub genesis_state_trie_root: H256, /// The transaction count prior execution of the local state transition, starting /// at 0 for the initial transaction of a block. pub txn_number_before: U256, @@ -210,7 +210,7 @@ impl PublicValuesTarget { buffer.write_target_array(&cur_hash)?; let ExtraBlockDataTarget { - genesis_state_root, + genesis_state_trie_root: genesis_state_root, txn_number_before, txn_number_after, gas_used_before, @@ -261,7 +261,7 @@ impl PublicValuesTarget { }; let extra_block_data = ExtraBlockDataTarget { - genesis_state_root: buffer.read_target_array()?, + genesis_state_trie_root: buffer.read_target_array()?, txn_number_before: buffer.read_target()?, txn_number_after: buffer.read_target()?, gas_used_before: buffer.read_target_array()?, @@ -571,7 +571,7 @@ impl BlockHashesTarget { #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct ExtraBlockDataTarget { - pub genesis_state_root: [Target; 8], + pub genesis_state_trie_root: [Target; 8], pub txn_number_before: Target, pub txn_number_after: Target, pub gas_used_before: [Target; 2], @@ -584,7 +584,7 @@ impl ExtraBlockDataTarget { const SIZE: usize = 142; pub fn from_public_inputs(pis: &[Target]) -> Self { - let genesis_state_root = pis[0..8].try_into().unwrap(); + let genesis_state_trie_root = pis[0..8].try_into().unwrap(); let txn_number_before = pis[8]; let txn_number_after = pis[9]; let gas_used_before = pis[10..12].try_into().unwrap(); @@ -593,7 +593,7 @@ impl ExtraBlockDataTarget { let block_bloom_after = pis[78..142].try_into().unwrap(); Self { - genesis_state_root, + genesis_state_trie_root, txn_number_before, txn_number_after, gas_used_before, @@ -610,11 +610,11 @@ impl ExtraBlockDataTarget { ed1: Self, ) -> Self { Self { - genesis_state_root: core::array::from_fn(|i| { + genesis_state_trie_root: core::array::from_fn(|i| { builder.select( condition, - ed0.genesis_state_root[i], - ed1.genesis_state_root[i], + ed0.genesis_state_trie_root[i], + ed1.genesis_state_trie_root[i], ) }), txn_number_before: builder.select( @@ -652,7 +652,10 @@ impl ExtraBlockDataTarget { ed1: Self, ) { for i in 0..8 { - builder.connect(ed0.genesis_state_root[i], ed1.genesis_state_root[i]); + builder.connect( + ed0.genesis_state_trie_root[i], + ed1.genesis_state_trie_root[i], + ); } builder.connect(ed0.txn_number_before, ed1.txn_number_before); builder.connect(ed0.txn_number_after, ed1.txn_number_after); diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 04259208b8..8c99ca5d1d 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -813,7 +813,7 @@ pub(crate) fn add_virtual_block_hashes, const D: us pub(crate) fn add_virtual_extra_block_data, const D: usize>( builder: &mut CircuitBuilder, ) -> ExtraBlockDataTarget { - let genesis_state_root = builder.add_virtual_public_input_arr(); + let genesis_state_trie_root = builder.add_virtual_public_input_arr(); let txn_number_before = builder.add_virtual_public_input(); let txn_number_after = builder.add_virtual_public_input(); let gas_used_before = builder.add_virtual_public_input_arr(); @@ -821,7 +821,7 @@ pub(crate) fn add_virtual_extra_block_data, const D let block_bloom_before: [Target; 64] = builder.add_virtual_public_input_arr(); let block_bloom_after: [Target; 64] = builder.add_virtual_public_input_arr(); ExtraBlockDataTarget { - genesis_state_root, + genesis_state_trie_root, txn_number_before, txn_number_after, gas_used_before, @@ -1084,8 +1084,8 @@ where W: Witness, { witness.set_target_arr( - &ed_target.genesis_state_root, - &h256_limbs::(ed.genesis_state_root), + &ed_target.genesis_state_trie_root, + &h256_limbs::(ed.genesis_state_trie_root), ); witness.set_target( ed_target.txn_number_before, diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index 3a0e5abf50..dd7ea223e4 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -339,6 +339,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { to_second_nibbles, rlp::encode(&to_account_second_before).to_vec(), ); + let genesis_state_trie_root = state_trie_before.hash(); let tries_before = TrieInputs { state_trie: state_trie_before, @@ -439,7 +440,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { tries: tries_before, trie_roots_after: tries_after, contract_code, - genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + genesis_state_trie_root, block_metadata: block_metadata.clone(), txn_number_before: 0.into(), gas_used_before: 0.into(), @@ -583,7 +584,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { tries: tries_before, trie_roots_after, contract_code, - genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + genesis_state_trie_root, block_metadata, txn_number_before: 1.into(), gas_used_before: gas_used_second, @@ -609,7 +610,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { trie_roots_before: first_public_values.trie_roots_before, trie_roots_after: public_values.trie_roots_after, extra_block_data: ExtraBlockData { - genesis_state_root: first_public_values.extra_block_data.genesis_state_root, + genesis_state_trie_root, txn_number_before: first_public_values.extra_block_data.txn_number_before, txn_number_after: public_values.extra_block_data.txn_number_after, gas_used_before: first_public_values.extra_block_data.gas_used_before, diff --git a/evm/tests/many_transactions.rs b/evm/tests/many_transactions.rs index 134eb968f7..9678d652d3 100644 --- a/evm/tests/many_transactions.rs +++ b/evm/tests/many_transactions.rs @@ -216,7 +216,7 @@ fn test_four_transactions() -> anyhow::Result<()> { trie_roots_after, genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), contract_code, - block_metadata: block_metadata.clone(), + block_metadata, addresses: vec![], block_bloom_before: [0.into(); 8], gas_used_before: 0.into(), From 571dc14f4c7b6c9f18ca62b412379c1628dff873 Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Tue, 3 Oct 2023 20:14:23 -0400 Subject: [PATCH 23/34] Fix encoding for empty recipient --- evm/src/generation/mpt.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index f829c4e239..20e8b30b60 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -282,21 +282,19 @@ fn empty_nibbles() -> Nibbles { pub mod transaction_testing { use super::*; - #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + #[derive(RlpEncodable, RlpDecodable, Debug, Clone, PartialEq, Eq)] pub struct AccessListItemRlp { pub address: Address, pub storage_keys: Vec, } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AddressOption(pub Option
); impl Encodable for AddressOption { fn rlp_append(&self, s: &mut RlpStream) { match self.0 { - None => { - s.append_empty_data(); - } + None => s.encoder().encode_value(&[]), Some(value) => { s.encoder().encode_value(&value.to_fixed_bytes()); } @@ -316,7 +314,7 @@ pub mod transaction_testing { } } - #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + #[derive(RlpEncodable, RlpDecodable, Debug, Clone, PartialEq, Eq)] pub struct LegacyTransactionRlp { pub nonce: U256, pub gas_price: U256, @@ -329,7 +327,7 @@ pub mod transaction_testing { pub s: U256, } - #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + #[derive(RlpEncodable, RlpDecodable, Debug, Clone, PartialEq, Eq)] pub struct AccessListTransactionRlp { pub chain_id: u64, pub nonce: U256, @@ -344,7 +342,7 @@ pub mod transaction_testing { pub s: U256, } - #[derive(RlpEncodable, RlpDecodable, Debug, Clone)] + #[derive(RlpEncodable, RlpDecodable, Debug, Clone, PartialEq, Eq)] pub struct FeeMarketTransactionRlp { pub chain_id: u64, pub nonce: U256, From 0de6f949622a6410835bb776c40cc3cad04316a9 Mon Sep 17 00:00:00 2001 From: Linda Guiga <101227802+LindaGuiga@users.noreply.github.com> Date: Thu, 5 Oct 2023 09:56:56 -0400 Subject: [PATCH 24/34] Remove extra SHL/SHR CTL. (#1270) * Remove extra shift CTL. * Change order of inputs for the arithmetic shift operations. Add SHR test. Fix max number of bit shifts. Cleanup. * Fix SHR in the case shift >= 256 * Limit visibility of helper functions --- evm/src/all_stark.rs | 5 +- evm/src/arithmetic/arithmetic_stark.rs | 3 + evm/src/arithmetic/columns.rs | 2 +- evm/src/arithmetic/divmod.rs | 77 ++++-- evm/src/arithmetic/mod.rs | 33 ++- evm/src/arithmetic/modular.rs | 24 +- evm/src/arithmetic/mul.rs | 87 +++++-- evm/src/arithmetic/shift.rs | 338 +++++++++++++++++++++++++ evm/src/cpu/cpu_stark.rs | 32 +-- evm/src/witness/operation.rs | 8 +- 10 files changed, 503 insertions(+), 106 deletions(-) create mode 100644 evm/src/arithmetic/shift.rs diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index b7168f8571..e5f631e81d 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -104,10 +104,7 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { fn ctl_arithmetic() -> CrossTableLookup { CrossTableLookup::new( - vec![ - cpu_stark::ctl_arithmetic_base_rows(), - cpu_stark::ctl_arithmetic_shift_rows(), - ], + vec![cpu_stark::ctl_arithmetic_base_rows()], arithmetic_stark::ctl_arithmetic_rows(), ) } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index f38aab9ddb..3d281c868c 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -12,6 +12,7 @@ use plonky2::util::transpose; use static_assertions::const_assert; use super::columns::NUM_ARITH_COLUMNS; +use super::shift; use crate::all_stark::Table; use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS}; use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation}; @@ -208,6 +209,7 @@ impl, const D: usize> Stark for ArithmeticSta divmod::eval_packed(lv, nv, yield_constr); modular::eval_packed(lv, nv, yield_constr); byte::eval_packed(lv, yield_constr); + shift::eval_packed_generic(lv, nv, yield_constr); } fn eval_ext_circuit( @@ -237,6 +239,7 @@ impl, const D: usize> Stark for ArithmeticSta divmod::eval_ext_circuit(builder, lv, nv, yield_constr); modular::eval_ext_circuit(builder, lv, nv, yield_constr); byte::eval_ext_circuit(builder, lv, yield_constr); + shift::eval_ext_circuit(builder, lv, nv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 36eb983e0b..df2d12476b 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -101,7 +101,7 @@ pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_REGISTER_0; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start; pub(crate) const MODULAR_AUX_INPUT_LO: Range = AUX_REGISTER_1.start + 1..AUX_REGISTER_1.end; pub(crate) const MODULAR_AUX_INPUT_HI: Range = AUX_REGISTER_2; -// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] +// Must be set to MOD_IS_ZERO for DIV and SHR operations i.e. MOD_IS_ZERO * (lv[IS_DIV] + lv[IS_SHR]). pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; /// The counter column (used for the range check) starts from 0 and increments. diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs index 258c131f32..e143ded6dd 100644 --- a/evm/src/arithmetic/divmod.rs +++ b/evm/src/arithmetic/divmod.rs @@ -15,24 +15,19 @@ use crate::arithmetic::modular::{ use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -/// Generate the output and auxiliary values for modular operations. -pub(crate) fn generate( +/// Generates the output and auxiliary values for modular operations, +/// assuming the input, modular and output limbs are already set. +pub(crate) fn generate_divmod( lv: &mut [F], nv: &mut [F], filter: usize, - input0: U256, - input1: U256, - result: U256, + input_limbs_range: Range, + modulus_range: Range, ) { - debug_assert!(lv.len() == NUM_ARITH_COLUMNS); - - u256_to_array(&mut lv[INPUT_REGISTER_0], input0); - u256_to_array(&mut lv[INPUT_REGISTER_1], input1); - u256_to_array(&mut lv[OUTPUT_REGISTER], result); - - let input_limbs = read_value_i64_limbs::(lv, INPUT_REGISTER_0); + let input_limbs = read_value_i64_limbs::(lv, input_limbs_range); let pol_input = pol_extend(input_limbs); - let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1); + let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, modulus_range); + debug_assert!( &quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO), "expected top half of quo_input to be zero" @@ -62,16 +57,35 @@ pub(crate) fn generate( ); lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]); } - _ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"), + _ => panic!("expected filter to be IS_DIV, IS_SHR or IS_MOD but it was {filter}"), }; } +/// Generate the output and auxiliary values for modular operations. +pub(crate) fn generate( + lv: &mut [F], + nv: &mut [F], + filter: usize, + input0: U256, + input1: U256, + result: U256, +) { + debug_assert!(lv.len() == NUM_ARITH_COLUMNS); + + u256_to_array(&mut lv[INPUT_REGISTER_0], input0); + u256_to_array(&mut lv[INPUT_REGISTER_1], input1); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); + + generate_divmod(lv, nv, filter, INPUT_REGISTER_0, INPUT_REGISTER_1); +} /// Verify that num = quo * den + rem and 0 <= rem < den. -fn eval_packed_divmod_helper( +pub(crate) fn eval_packed_divmod_helper( lv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, filter: P, + num_range: Range, + den_range: Range, quo_range: Range, rem_range: Range, ) { @@ -80,8 +94,8 @@ fn eval_packed_divmod_helper( yield_constr.constraint_last_row(filter); - let num = &lv[INPUT_REGISTER_0]; - let den = read_value(lv, INPUT_REGISTER_1); + let num = &lv[num_range]; + let den = read_value(lv, den_range); let quo = { let mut quo = [P::ZEROS; 2 * N_LIMBS]; quo[..N_LIMBS].copy_from_slice(&lv[quo_range]); @@ -104,14 +118,13 @@ pub(crate) fn eval_packed( nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { - // Constrain IS_SHR independently, so that it doesn't impact the - // constraints when combining the flag with IS_DIV. - yield_constr.constraint_last_row(lv[IS_SHR]); eval_packed_divmod_helper( lv, nv, yield_constr, - lv[IS_DIV] + lv[IS_SHR], + lv[IS_DIV], + INPUT_REGISTER_0, + INPUT_REGISTER_1, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -120,24 +133,28 @@ pub(crate) fn eval_packed( nv, yield_constr, lv[IS_MOD], + INPUT_REGISTER_0, + INPUT_REGISTER_1, AUX_INPUT_REGISTER_0, OUTPUT_REGISTER, ); } -fn eval_ext_circuit_divmod_helper, const D: usize>( +pub(crate) fn eval_ext_circuit_divmod_helper, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, + num_range: Range, + den_range: Range, quo_range: Range, rem_range: Range, ) { yield_constr.constraint_last_row(builder, filter); - let num = &lv[INPUT_REGISTER_0]; - let den = read_value(lv, INPUT_REGISTER_1); + let num = &lv[num_range]; + let den = read_value(lv, den_range); let quo = { let zero = builder.zero_extension(); let mut quo = [zero; 2 * N_LIMBS]; @@ -164,14 +181,14 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { - yield_constr.constraint_last_row(builder, lv[IS_SHR]); - let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); eval_ext_circuit_divmod_helper( builder, lv, nv, yield_constr, - div_shr_flag, + lv[IS_DIV], + INPUT_REGISTER_0, + INPUT_REGISTER_1, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -181,6 +198,8 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv, yield_constr, lv[IS_MOD], + INPUT_REGISTER_0, + INPUT_REGISTER_1, AUX_INPUT_REGISTER_0, OUTPUT_REGISTER, ); @@ -214,7 +233,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } - // Deactivate the SHR flag so that a DIV operation is not triggered. + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( @@ -247,6 +266,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; @@ -308,6 +328,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index bd6d56e8cb..7763e98a06 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -9,6 +9,7 @@ mod byte; mod divmod; mod modular; mod mul; +mod shift; mod utils; pub mod arithmetic_stark; @@ -35,15 +36,29 @@ impl BinaryOperator { pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { BinaryOperator::Add => input0.overflowing_add(input1).0, - BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0, + BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Shl => { + if input0 < U256::from(256usize) { + input1 << input0 + } else { + U256::zero() + } + } BinaryOperator::Sub => input0.overflowing_sub(input1).0, - BinaryOperator::Div | BinaryOperator::Shr => { + BinaryOperator::Div => { if input1.is_zero() { U256::zero() } else { input0 / input1 } } + BinaryOperator::Shr => { + if input0 < U256::from(256usize) { + input1 >> input0 + } else { + U256::zero() + } + } BinaryOperator::Mod => { if input1.is_zero() { U256::zero() @@ -238,15 +253,25 @@ fn binary_op_to_rows( addcy::generate(&mut row, op.row_filter(), input0, input1); (row, None) } - BinaryOperator::Mul | BinaryOperator::Shl => { + BinaryOperator::Mul => { mul::generate(&mut row, input0, input1); (row, None) } - BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => { + BinaryOperator::Shl => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, true, input0, input1, result); + (row, None) + } + BinaryOperator::Div | BinaryOperator::Mod => { let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); (row, Some(nv)) } + BinaryOperator::Shr => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, false, input0, input1, result); + (row, Some(nv)) + } BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => { ternary_op_to_rows::(op.row_filter(), input0, input1, BN_BASE, result) } diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 4e540cb6b4..4e6e21a632 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -239,7 +239,7 @@ pub(crate) fn generate_modular_op( let mut mod_is_zero = F::ZERO; if modulus.is_zero() { - if filter == columns::IS_DIV { + if filter == columns::IS_DIV || filter == columns::IS_SHR { // set modulus = 2^256; the condition above means we know // it's zero at this point, so we can just set bit 256. modulus.set_bit(256, true); @@ -330,7 +330,7 @@ pub(crate) fn generate_modular_op( nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64)); - nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV]; + nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]); ( output_limbs.map(F::from_canonical_i64), @@ -392,14 +392,14 @@ pub(crate) fn check_reduced( // Verify that the output is reduced, i.e. output < modulus. let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; // This sets is_less_than to 1 unless we get mod_is_zero when - // doing a DIV; in that case, we need is_less_than=0, since + // doing a DIV or SHR; in that case, we need is_less_than=0, since // eval_packed_generic_addcy checks // // modulus + out_aux_red == output + is_less_than*2^256 // // and we are given output = out_aux_red when modulus is zero. let mut is_less_than = [P::ZEROS; N_LIMBS]; - is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV]; + is_less_than[0] = P::ONES - mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]); // NB: output and modulus in lv while out_aux_red and // is_less_than (via mod_is_zero) depend on nv, hence the // 'is_two_row_op' argument is set to 'true'. @@ -448,13 +448,15 @@ pub(crate) fn modular_constr_poly( // modulus = 0. modulus[0] += mod_is_zero; - // Is 1 iff the operation is DIV and the denominator is zero. + // Is 1 iff the operation is DIV or SHR and the denominator is zero. let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; - yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero)); + yield_constr.constraint_transition( + filter * (mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]) - div_denom_is_zero), + ); // Needed to compensate for adding mod_is_zero to modulus above, // since the call eval_packed_generic_addcy() below subtracts modulus - // to verify in the case of a DIV. + // to verify in the case of a DIV or SHR. output[0] += div_denom_is_zero; check_reduced(lv, nv, yield_constr, filter, output, modulus, mod_is_zero); @@ -635,7 +637,8 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons modulus[0] = builder.add_extension(modulus[0], mod_is_zero); let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; - let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero); + let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); + let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero); let t = builder.mul_extension(filter, t); yield_constr.constraint_transition(builder, t); output[0] = builder.add_extension(output[0], div_denom_is_zero); @@ -645,7 +648,7 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons let zero = builder.zero_extension(); let mut is_less_than = [zero; N_LIMBS]; is_less_than[0] = - builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); + builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, div_shr_filter, one); eval_ext_circuit_addcy( builder, @@ -834,6 +837,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; @@ -867,6 +871,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; @@ -926,6 +931,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index efb4d82247..c09c39d8dc 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -67,16 +67,8 @@ use crate::arithmetic::columns::*; use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { - // TODO: It would probably be clearer/cleaner to read the U256 - // into an [i64;N] and then copy that to the lv table. - u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); - u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); - u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); - - let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0); - let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1); - +/// Given the two limbs of `left_in` and `right_in`, computes `left_in * right_in`. +pub(crate) fn generate_mul(lv: &mut [F], left_in: [i64; 16], right_in: [i64; 16]) { const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; // Input and output have 16-bit limbs @@ -86,7 +78,7 @@ pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), // then do carry propagation to obtain C = c(β) = a(β)*b(β). let mut cy = 0i64; - let mut unreduced_prod = pol_mul_lo(input0, input1); + let mut unreduced_prod = pol_mul_lo(left_in, right_in); for col in 0..N_LIMBS { let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; @@ -115,17 +107,30 @@ pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { .copy_from_slice(&aux_limbs.map(|c| F::from_canonical_u16((c >> 16) as u16))); } -pub fn eval_packed_generic( +pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { + // TODO: It would probably be clearer/cleaner to read the U256 + // into an [i64;N] and then copy that to the lv table. + u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); + u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); + u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); + + let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0); + let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1); + + generate_mul(lv, input0, input1); +} + +pub(crate) fn eval_packed_generic_mul( lv: &[P; NUM_ARITH_COLUMNS], + filter: P, + left_in_limbs: [P; 16], + right_in_limbs: [P; 16], yield_constr: &mut ConstraintConsumer

, ) { - let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - - let is_mul = lv[IS_MUL] + lv[IS_SHL]; - let input0_limbs = read_value::(lv, INPUT_REGISTER_0); - let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); + let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); + let aux_limbs = { // MUL_AUX_INPUT was offset by 2^20 in generation, so we undo // that here @@ -153,7 +158,7 @@ pub fn eval_packed_generic( // // s(x) = \sum_i aux_limbs[i] * x^i // - let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); + let mut constr_poly = pol_mul_lo(left_in_limbs, right_in_limbs); pol_sub_assign(&mut constr_poly, &output_limbs); // This subtracts (x - β) * s(x) from constr_poly. @@ -164,18 +169,29 @@ pub fn eval_packed_generic( // multiplication is valid if and only if all of those // coefficients are zero. for &c in &constr_poly { - yield_constr.constraint(is_mul * c); + yield_constr.constraint(filter * c); } } -pub fn eval_ext_circuit, const D: usize>( +pub fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_mul = lv[IS_MUL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_0); + let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + + eval_packed_generic_mul(lv, is_mul, input0_limbs, input1_limbs, yield_constr); +} + +pub(crate) fn eval_ext_mul_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + filter: ExtensionTarget, + left_in_limbs: [ExtensionTarget; 16], + right_in_limbs: [ExtensionTarget; 16], yield_constr: &mut RecursiveConstraintConsumer, ) { - let is_mul = builder.add_extension(lv[IS_MUL], lv[IS_SHL]); - let input0_limbs = read_value::(lv, INPUT_REGISTER_0); - let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); let aux_limbs = { @@ -192,7 +208,7 @@ pub fn eval_ext_circuit, const D: usize>( aux_limbs }; - let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs); + let mut constr_poly = pol_mul_lo_ext_circuit(builder, left_in_limbs, right_in_limbs); pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); @@ -200,11 +216,30 @@ pub fn eval_ext_circuit, const D: usize>( pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs); for &c in &constr_poly { - let filter = builder.mul_extension(is_mul, c); + let filter = builder.mul_extension(filter, c); yield_constr.constraint(builder, filter); } } +pub fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_mul = lv[IS_MUL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_0); + let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + + eval_ext_mul_circuit( + builder, + lv, + is_mul, + input0_limbs, + input1_limbs, + yield_constr, + ); +} + #[cfg(test)] mod tests { use plonky2::field::goldilocks_field::GoldilocksField; @@ -229,8 +264,6 @@ mod tests { // if `IS_MUL == 0`, then the constraints should be met even // if all values are garbage. lv[IS_MUL] = F::ZERO; - // Deactivate the SHL flag so that a MUL operation is not triggered. - lv[IS_SHL] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/arithmetic/shift.rs b/evm/src/arithmetic/shift.rs new file mode 100644 index 0000000000..6600c01e54 --- /dev/null +++ b/evm/src/arithmetic/shift.rs @@ -0,0 +1,338 @@ +//! Support for the EVM SHL and SHR instructions. +//! +//! This crate verifies an EVM shift instruction, which takes two +//! 256-bit inputs S and A, and produces a 256-bit output C satisfying +//! +//! C = A << S (mod 2^256) for SHL or +//! C = A >> S (mod 2^256) for SHR. +//! +//! The way this computation is carried is by providing a third input +//! B = 1 << S (mod 2^256) +//! and then computing: +//! C = A * B (mod 2^256) for SHL or +//! C = A / B (mod 2^256) for SHR +//! +//! Inputs A, S, and B, and output C, are given as arrays of 16-bit +//! limbs. For example, if the limbs of A are a[0]...a[15], then +//! +//! A = \sum_{i=0}^15 a[i] β^i, +//! +//! where β = 2^16 = 2^LIMB_BITS. To verify that A, S, B and C satisfy +//! the equations, we proceed similarly to MUL for SHL and to DIV for SHR. + +use ethereum_types::U256; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::PrimeField64; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use super::{divmod, mul}; +use crate::arithmetic::columns::*; +use crate::arithmetic::utils::*; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +/// Generates a shift operation (either SHL or SHR). +/// The inputs are stored in the form `(shift, input, 1 << shift)`. +/// NB: if `shift >= 256`, then the third register holds 0. +/// We leverage the functions in mul.rs and divmod.rs to carry out +/// the computation. +pub fn generate( + lv: &mut [F], + nv: &mut [F], + is_shl: bool, + shift: U256, + input: U256, + result: U256, +) { + // We use the multiplication logic to generate SHL + // TODO: It would probably be clearer/cleaner to read the U256 + // into an [i64;N] and then copy that to the lv table. + // The first input is the shift we need to apply. + u256_to_array(&mut lv[INPUT_REGISTER_0], shift); + // The second register holds the input which needs shifting. + u256_to_array(&mut lv[INPUT_REGISTER_1], input); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); + // If `shift >= 256`, the shifted displacement is set to 0. + // Compute 1 << shift and store it in the third input register. + let shifted_displacement = if shift > U256::from(255u64) { + U256::zero() + } else { + U256::one() << shift + }; + + u256_to_array(&mut lv[INPUT_REGISTER_2], shifted_displacement); + + let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_1); // input + let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_2); // 1 << shift + + if is_shl { + // We generate the multiplication input0 * input1 using mul.rs. + mul::generate_mul(lv, input0, input1); + } else { + // If the operation is SHR, we compute: `input / shifted_displacement` if `shifted_displacement == 0` + // otherwise, the output is 0. We use the logic in divmod.rs to achieve that. + divmod::generate_divmod(lv, nv, IS_SHR, INPUT_REGISTER_1, INPUT_REGISTER_2); + } +} + +/// Evaluates the constraints for an SHL opcode. +/// The logic is the same as the one for MUL. The only difference is that +/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of +/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`. +fn eval_packed_shl( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_shl = lv[IS_SHL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_1); + let shifted_limbs = read_value::(lv, INPUT_REGISTER_2); + + mul::eval_packed_generic_mul(lv, is_shl, input0_limbs, shifted_limbs, yield_constr); +} + +/// Evaluates the constraints for an SHR opcode. +/// The logic is tha same as the one for DIV. The only difference is that +/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of +/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`. +fn eval_packed_shr( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let quo_range = OUTPUT_REGISTER; + let rem_range = AUX_INPUT_REGISTER_0; + let filter = lv[IS_SHR]; + + divmod::eval_packed_divmod_helper( + lv, + nv, + yield_constr, + filter, + INPUT_REGISTER_1, + INPUT_REGISTER_2, + quo_range, + rem_range, + ); +} + +pub fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + eval_packed_shl(lv, yield_constr); + eval_packed_shr(lv, nv, yield_constr); +} + +fn eval_ext_circuit_shl, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_shl = lv[IS_SHL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_1); + let shifted_limbs = read_value::(lv, INPUT_REGISTER_2); + + mul::eval_ext_mul_circuit( + builder, + lv, + is_shl, + input0_limbs, + shifted_limbs, + yield_constr, + ); +} + +fn eval_ext_circuit_shr, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = lv[IS_SHR]; + let quo_range = OUTPUT_REGISTER; + let rem_range = AUX_INPUT_REGISTER_0; + + divmod::eval_ext_circuit_divmod_helper( + builder, + lv, + nv, + yield_constr, + filter, + INPUT_REGISTER_1, + INPUT_REGISTER_2, + quo_range, + rem_range, + ); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + eval_ext_circuit_shl(builder, lv, yield_constr); + eval_ext_circuit_shr(builder, lv, nv, yield_constr); +} + +#[cfg(test)] +mod tests { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{Field, Sample}; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::arithmetic::columns::NUM_ARITH_COLUMNS; + use crate::constraint_consumer::ConstraintConsumer; + + const N_RND_TESTS: usize = 1000; + + // TODO: Should be able to refactor this test to apply to all operations. + #[test] + fn generate_eval_consistency_not_shift() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // if `IS_SHL == 0` and `IS_SHR == 0`, then the constraints should be met even + // if all values are garbage. + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ZERO; + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + fn generate_eval_consistency_shift(is_shl: bool) { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied. + if is_shl { + lv[IS_SHL] = F::ONE; + lv[IS_SHR] = F::ZERO; + } else { + // Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR. + lv[IS_DIV] = F::ZERO; + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ONE; + } + + for _i in 0..N_RND_TESTS { + let shift = U256::from(rng.gen::()); + + let mut full_input = U256::from(0); + // set inputs to random values + for ai in INPUT_REGISTER_1 { + lv[ai] = F::from_canonical_u16(rng.gen()); + full_input = + U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16); + } + + let output = if is_shl { + full_input << shift + } else { + full_input >> shift + }; + + generate(&mut lv, &mut nv, is_shl, shift, full_input, output); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ZERO, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + + #[test] + fn generate_eval_consistency_shl() { + generate_eval_consistency_shift(true); + } + + #[test] + fn generate_eval_consistency_shr() { + generate_eval_consistency_shift(false); + } + + fn generate_eval_consistency_shift_over_256(is_shl: bool) { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied. + if is_shl { + lv[IS_SHL] = F::ONE; + lv[IS_SHR] = F::ZERO; + } else { + // Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR. + lv[IS_DIV] = F::ZERO; + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ONE; + } + + for _i in 0..N_RND_TESTS { + let mut shift = U256::from(rng.gen::()); + while shift > U256::MAX - 256 { + shift = U256::from(rng.gen::()); + } + shift += U256::from(256); + + let mut full_input = U256::from(0); + // set inputs to random values + for ai in INPUT_REGISTER_1 { + lv[ai] = F::from_canonical_u16(rng.gen()); + full_input = + U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16); + } + + let output = 0.into(); + generate(&mut lv, &mut nv, is_shl, shift, full_input, output); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ZERO, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + + #[test] + fn generate_eval_consistency_shl_over_256() { + generate_eval_consistency_shift_over_256(true); + } + + #[test] + fn generate_eval_consistency_shr_over_256() { + generate_eval_consistency_shift_over_256(false); + } +} diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index f23ff308b6..82ca5452b7 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -63,19 +63,10 @@ fn ctl_data_binops() -> Vec> { /// one output of a ternary operation. By default, ternary operations use /// the first three memory channels, and the last one for the result (binary /// operations do not use the third inputs). -/// -/// Shift operations are different, as they are simulated with `MUL` or `DIV` -/// on the arithmetic side. We first convert the shift into the multiplicand -/// (in case of `SHL`) or the divisor (in case of `SHR`), making the first memory -/// channel not directly usable. We overcome this by adding an offset of 1 in -/// case of shift operations, which will skip the first memory channel and use the -/// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations, -/// the last memory channel used for the inputs will be safely ignored. -fn ctl_data_ternops(is_shift: bool) -> Vec> { - let offset = is_shift as usize; - let mut res = Column::singles(COL_MAP.mem_channels[offset].value).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value)); - res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value)); +fn ctl_data_ternops() -> Vec> { + let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res.extend(Column::singles( COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, )); @@ -96,7 +87,7 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_arithmetic_base_rows() -> TableWithColumns { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; - columns.extend(ctl_data_ternops(false)); + columns.extend(ctl_data_ternops()); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of @@ -109,22 +100,11 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { COL_MAP.op.binary_op, COL_MAP.op.fp254_op, COL_MAP.op.ternary_op, + COL_MAP.op.shift, ])), ) } -pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { - // Instead of taking single columns, we reconstruct the entire opcode value directly. - let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; - columns.extend(ctl_data_ternops(true)); - // Create the CPU Table whose columns are those with the three - // inputs and one output of the ternary operations listed in `ops` - // (also `ops` is used as the operation filter). The list of - // operations includes binary operations which will simply ignore - // the third input. - TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift))) -} - pub fn ctl_data_byte_packing() -> Vec> { ctl_data_keccak_sponge() } diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 0620069f00..568fe4b181 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -499,18 +499,12 @@ fn append_shift( channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); } - // Convert the shift, and log the corresponding arithmetic operation. - let input0 = if input0 > U256::from(255u64) { - U256::zero() - } else { - U256::one() << input0 - }; let operator = if is_shl { BinaryOperator::Shl } else { BinaryOperator::Shr }; - let operation = arithmetic::Operation::binary(operator, input1, input0); + let operation = arithmetic::Operation::binary(operator, input0, input1); state.traces.push_arithmetic(operation); state.traces.push_memory(log_in0); From e58d7795f87a0299aeee0eff7ab7e43eb7b76a31 Mon Sep 17 00:00:00 2001 From: Linda Guiga <101227802+LindaGuiga@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:49:57 -0400 Subject: [PATCH 25/34] Remove reg_preimage columns in KeccakStark (#1279) * Remove reg_preimage columns in KeccakStark * Apply comments * Minor cleanup --- evm/src/all_stark.rs | 28 ++++- evm/src/keccak/columns.rs | 14 +-- evm/src/keccak/keccak_stark.rs | 121 ++++++++----------- evm/src/keccak_sponge/keccak_sponge_stark.rs | 12 +- evm/src/witness/traces.rs | 10 +- evm/src/witness/util.rs | 8 +- 6 files changed, 98 insertions(+), 95 deletions(-) diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index e5f631e81d..079ff114c4 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -96,7 +96,8 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ctl_arithmetic(), ctl_byte_packing(), ctl_keccak_sponge(), - ctl_keccak(), + ctl_keccak_inputs(), + ctl_keccak_outputs(), ctl_logic(), ctl_memory(), ] @@ -131,16 +132,33 @@ fn ctl_byte_packing() -> CrossTableLookup { ) } -fn ctl_keccak() -> CrossTableLookup { +// We now need two different looked tables for `KeccakStark`: +// one for the inputs and one for the outputs. +// They are linked with the timestamp. +fn ctl_keccak_inputs() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, - keccak_sponge_stark::ctl_looking_keccak(), + keccak_sponge_stark::ctl_looking_keccak_inputs(), Some(keccak_sponge_stark::ctl_looking_keccak_filter()), ); let keccak_looked = TableWithColumns::new( Table::Keccak, - keccak_stark::ctl_data(), - Some(keccak_stark::ctl_filter()), + keccak_stark::ctl_data_inputs(), + Some(keccak_stark::ctl_filter_inputs()), + ); + CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) +} + +fn ctl_keccak_outputs() -> CrossTableLookup { + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_keccak_outputs(), + Some(keccak_sponge_stark::ctl_looking_keccak_filter()), + ); + let keccak_looked = TableWithColumns::new( + Table::Keccak, + keccak_stark::ctl_data_outputs(), + Some(keccak_stark::ctl_filter_outputs()), ); CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) } diff --git a/evm/src/keccak/columns.rs b/evm/src/keccak/columns.rs index afd92ad9ec..d9a71af4ff 100644 --- a/evm/src/keccak/columns.rs +++ b/evm/src/keccak/columns.rs @@ -20,7 +20,7 @@ pub fn reg_input_limb(i: usize) -> Column { let y = i_u64 / 5; let x = i_u64 % 5; - let reg_low_limb = reg_preimage(x, y); + let reg_low_limb = reg_a(x, y); let is_high_limb = i % 2; Column::single(reg_low_limb + is_high_limb) } @@ -48,15 +48,11 @@ const R: [[u8; 5]; 5] = [ [27, 20, 39, 8, 14], ]; -const START_PREIMAGE: usize = NUM_ROUNDS; -/// Registers to hold the original input to a permutation, i.e. the input to the first round. -pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize { - debug_assert!(x < 5); - debug_assert!(y < 5); - START_PREIMAGE + (x * 5 + y) * 2 -} +/// Column holding the timestamp, used to link inputs and outputs +/// in the `KeccakSpongeStark`. +pub(crate) const TIMESTAMP: usize = NUM_ROUNDS; -const START_A: usize = START_PREIMAGE + 5 * 5 * 2; +const START_A: usize = TIMESTAMP + 1; pub(crate) const fn reg_a(x: usize, y: usize) -> usize { debug_assert!(x < 5); debug_assert!(y < 5); diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 189380f541..2745d03302 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -11,13 +11,13 @@ use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit; use plonky2::timed; use plonky2::util::timing::TimingTree; +use super::columns::reg_input_limb; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak::columns::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, - reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step, - NUM_COLUMNS, + reg_b, reg_c, reg_c_prime, reg_output_limb, reg_step, NUM_COLUMNS, TIMESTAMP, }; use crate::keccak::constants::{rc_value, rc_value_bit}; use crate::keccak::logic::{ @@ -33,13 +33,23 @@ pub(crate) const NUM_ROUNDS: usize = 24; /// Number of 64-bit elements in the Keccak permutation input. pub(crate) const NUM_INPUTS: usize = 25; -pub fn ctl_data() -> Vec> { +pub fn ctl_data_inputs() -> Vec> { let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect(); - res.extend(Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb))); + res.push(Column::single(TIMESTAMP)); res } -pub fn ctl_filter() -> Column { +pub fn ctl_data_outputs() -> Vec> { + let mut res: Vec<_> = Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)).collect(); + res.push(Column::single(TIMESTAMP)); + res +} + +pub fn ctl_filter_inputs() -> Column { + Column::single(reg_step(0)) +} + +pub fn ctl_filter_outputs() -> Column { Column::single(reg_step(NUM_ROUNDS - 1)) } @@ -53,16 +63,16 @@ impl, const D: usize> KeccakStark { /// in our lookup arguments, as those are computed after transposing to column-wise form. fn generate_trace_rows( &self, - inputs: Vec<[u64; NUM_INPUTS]>, + inputs_and_timestamps: Vec<([u64; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec<[F; NUM_COLUMNS]> { - let num_rows = (inputs.len() * NUM_ROUNDS) + let num_rows = (inputs_and_timestamps.len() * NUM_ROUNDS) .max(min_rows) .next_power_of_two(); let mut rows = Vec::with_capacity(num_rows); - for input in inputs.iter() { - let rows_for_perm = self.generate_trace_rows_for_perm(*input); + for input_and_timestamp in inputs_and_timestamps.iter() { + let rows_for_perm = self.generate_trace_rows_for_perm(*input_and_timestamp); rows.extend(rows_for_perm); } @@ -72,20 +82,19 @@ impl, const D: usize> KeccakStark { rows } - fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> { + fn generate_trace_rows_for_perm( + &self, + input_and_timestamp: ([u64; NUM_INPUTS], usize), + ) -> Vec<[F; NUM_COLUMNS]> { let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; - - // Populate the preimage for each row. + let input = input_and_timestamp.0; + let timestamp = input_and_timestamp.1; + // Set the timestamp of the current input. + // It will be checked against the value in `KeccakSponge`. + // The timestamp is used to link the input and output of + // the same permutation together. for round in 0..24 { - for x in 0..5 { - for y in 0..5 { - let input_xy = input[y * 5 + x]; - let reg_preimage_lo = reg_preimage(x, y); - let reg_preimage_hi = reg_preimage_lo + 1; - rows[round][reg_preimage_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF); - rows[round][reg_preimage_hi] = F::from_canonical_u64(input_xy >> 32); - } - } + rows[round][TIMESTAMP] = F::from_canonical_usize(timestamp); } // Populate the round input for the first round. @@ -220,7 +229,7 @@ impl, const D: usize> KeccakStark { pub fn generate_trace( &self, - inputs: Vec<[u64; NUM_INPUTS]>, + inputs: Vec<([u64; NUM_INPUTS], usize)>, min_rows: usize, timing: &mut TimingTree, ) -> Vec> { @@ -269,26 +278,14 @@ impl, const D: usize> Stark for KeccakStark(); + yield_constr.constraint( + sum_round_flags * not_final_step * (next_values[TIMESTAMP] - local_values[TIMESTAMP]), + ); // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). for x in 0..5 { @@ -454,34 +451,13 @@ impl, const D: usize> Stark for KeccakStark = (0..NUM_PERMS).map(|_| rand::random()).collect(); + let input: Vec<([u64; NUM_INPUTS], usize)> = + (0..NUM_PERMS).map(|_| (rand::random(), 0)).collect(); let mut timing = TimingTree::new("prove", log::Level::Debug); let trace_poly_values = timed!( diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 2ed31c1fee..e491252ba8 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -47,7 +47,7 @@ pub(crate) fn ctl_looked_data() -> Vec> { .collect() } -pub(crate) fn ctl_looking_keccak() -> Vec> { +pub(crate) fn ctl_looking_keccak_inputs() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; let mut res: Vec<_> = Column::singles( [ @@ -57,6 +57,13 @@ pub(crate) fn ctl_looking_keccak() -> Vec> { .concat(), ) .collect(); + res.push(Column::single(cols.timestamp)); + + res +} + +pub(crate) fn ctl_looking_keccak_outputs() -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; // We recover the 32-bit digest limbs from their corresponding bytes, // and then append them to the rest of the updated state limbs. @@ -68,9 +75,10 @@ pub(crate) fn ctl_looking_keccak() -> Vec> { ) }); - res.extend(digest_u32s); + let mut res: Vec<_> = digest_u32s.collect(); res.extend(Column::singles(&cols.partial_updated_state_u32s)); + res.push(Column::single(cols.timestamp)); res } diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index c4cf832dd5..91035fc403 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -36,7 +36,7 @@ pub(crate) struct Traces { pub(crate) cpu: Vec>, pub(crate) logic_ops: Vec, pub(crate) memory_ops: Vec, - pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, + pub(crate) keccak_inputs: Vec<([u64; keccak::keccak_stark::NUM_INPUTS], usize)>, pub(crate) keccak_sponge_ops: Vec, } @@ -131,18 +131,18 @@ impl Traces { self.byte_packing_ops.push(op); } - pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) { - self.keccak_inputs.push(input); + pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) { + self.keccak_inputs.push((input, clock)); } - pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) { + pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) { let chunks = input .chunks(size_of::()) .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) .collect_vec() .try_into() .unwrap(); - self.push_keccak(chunks); + self.push_keccak(chunks, clock); } pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) { diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 068a8e1113..dbe4c0ede5 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -229,7 +229,9 @@ pub(crate) fn keccak_sponge_log( address.increment(); } xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap()); - state.traces.push_keccak_bytes(sponge_state); + state + .traces + .push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); keccakf_u8s(&mut sponge_state); } @@ -254,7 +256,9 @@ pub(crate) fn keccak_sponge_log( final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; } xor_into_sponge(state, &mut sponge_state, &final_block); - state.traces.push_keccak_bytes(sponge_state); + state + .traces + .push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); state.traces.push_keccak_sponge(KeccakSpongeOp { base_address, From 8a5eed9d1cc0c48f17c68209ade5855e938d9fac Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Mon, 9 Oct 2023 07:41:30 -0400 Subject: [PATCH 26/34] Fix shift constraint (#1280) --- evm/src/cpu/shift.rs | 16 +++++++++++++--- evm/src/witness/operation.rs | 4 ++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index a424929798..0f92cbd20d 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -19,8 +19,8 @@ pub(crate) fn eval_packed( // Not needed here; val is the input and we're verifying that output is // val * 2^d (mod 2^256) - //let val = lv.mem_channels[0]; - //let output = lv.mem_channels[NUM_GP_CHANNELS - 1]; + // let val = lv.mem_channels[0]; + // let output = lv.mem_channels[NUM_GP_CHANNELS - 1]; let shift_table_segment = P::Scalar::from_canonical_u64(Segment::ShiftTable as u64); @@ -28,7 +28,7 @@ pub(crate) fn eval_packed( // two_exp.used is true (1) if the high limbs of the displacement are // zero and false (0) otherwise. let high_limbs_are_zero = two_exp.used; - yield_constr.constraint(is_shift * (two_exp.is_read - P::ONES)); + yield_constr.constraint(is_shift * high_limbs_are_zero * (two_exp.is_read - P::ONES)); let high_limbs_sum: P = displacement.value[1..].iter().copied().sum(); let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; @@ -70,14 +70,20 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64); + // Only lookup the shifting factor when displacement is < 2^32. + // two_exp.used is true (1) if the high limbs of the displacement are + // zero and false (0) otherwise. let high_limbs_are_zero = two_exp.used; let one = builder.one_extension(); let t = builder.sub_extension(two_exp.is_read, one); + let t = builder.mul_extension(high_limbs_are_zero, t); let t = builder.mul_extension(is_shift, t); yield_constr.constraint(builder, t); let high_limbs_sum = builder.add_many_extension(&displacement.value[1..]); let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; + // Verify that high_limbs_are_zero = 0 implies high_limbs_sum != 0 and + // high_limbs_are_zero = 1 implies high_limbs_sum = 0. let t = builder.one_extension(); let t = builder.sub_extension(t, high_limbs_are_zero); let t = builder.mul_sub_extension(high_limbs_sum, high_limbs_sum_inv, t); @@ -87,6 +93,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let t = builder.mul_many_extension([is_shift, high_limbs_sum, high_limbs_are_zero]); yield_constr.constraint(builder, t); + // When the shift displacement is < 2^32, constrain the two_exp + // mem_channel to be the entry corresponding to `displacement` in + // the shift table lookup (will be zero if displacement >= 256). let t = builder.mul_extension(is_shift, two_exp.addr_context); yield_constr.constraint(builder, t); let t = builder.arithmetic_extension( @@ -101,6 +110,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let t = builder.mul_extension(is_shift, t); yield_constr.constraint(builder, t); + // Other channels must be unused for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] { let t = builder.mul_extension(is_shift, chan.used); yield_constr.constraint(builder, t); diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 568fe4b181..f4dc03e806 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -497,6 +497,10 @@ fn append_shift( channel.addr_context = F::from_canonical_usize(lookup_addr.context); channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); + + // Extra field required by the constraints for large shifts. + let high_limb_sum = row.mem_channels[0].value[1..].iter().copied().sum::(); + row.general.shift_mut().high_limb_sum_inv = high_limb_sum.inverse(); } let operator = if is_shl { From 41a29f069b6731c4af2644337959fbad8c771c77 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Mon, 9 Oct 2023 09:07:01 -0400 Subject: [PATCH 27/34] Remove some dead_code in EVM crate (#1281) * Remove unnecessary CpuArithmeticView. * Remove AllChallengerState * Remove RecursiveAllProof * Remove unused generate methods * Remove dead_code from cpu/columns * Remove todo --------- Co-authored-by: Linda Guiga --- evm/src/cpu/columns/general.rs | 17 ------- evm/src/cpu/columns/mod.rs | 3 -- evm/src/cpu/columns/ops.rs | 9 +--- evm/src/cpu/cpu_stark.rs | 14 ++---- evm/src/cpu/decode.rs | 45 ----------------- evm/src/cpu/membus.rs | 14 ------ evm/src/get_challenges.rs | 34 ------------- evm/src/proof.rs | 11 +---- evm/src/recursive_verifier.rs | 88 +++------------------------------- 9 files changed, 11 insertions(+), 224 deletions(-) diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 91a35c218f..57eb16fcf8 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -6,7 +6,6 @@ use std::mem::{size_of, transmute}; /// operation is occurring at this row. #[derive(Clone, Copy)] pub(crate) union CpuGeneralColumnsView { - arithmetic: CpuArithmeticView, exception: CpuExceptionView, logic: CpuLogicView, jumps: CpuJumpsView, @@ -14,16 +13,6 @@ pub(crate) union CpuGeneralColumnsView { } impl CpuGeneralColumnsView { - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn arithmetic(&self) -> &CpuArithmeticView { - unsafe { &self.arithmetic } - } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn arithmetic_mut(&mut self) -> &mut CpuArithmeticView { - unsafe { &mut self.arithmetic } - } - // SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn exception(&self) -> &CpuExceptionView { unsafe { &self.exception } @@ -94,12 +83,6 @@ impl BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView { } } -#[derive(Copy, Clone)] -pub(crate) struct CpuArithmeticView { - // TODO: Add "looking" columns for the arithmetic CTL. - tmp: T, // temporary, to suppress errors -} - #[derive(Copy, Clone)] pub(crate) struct CpuExceptionView { // Exception code as little-endian bits. diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index cc98fceb3f..b7b4f780e0 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -1,6 +1,3 @@ -// TODO: remove when possible. -#![allow(dead_code)] - use std::borrow::{Borrow, BorrowMut}; use std::fmt::Debug; use std::mem::{size_of, transmute}; diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index d4d753f7cf..64474c9874 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -2,7 +2,7 @@ use std::borrow::{Borrow, BorrowMut}; use std::mem::{size_of, transmute}; use std::ops::{Deref, DerefMut}; -use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +use crate::util::transmute_no_compile_time_size_checks; #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] @@ -73,10 +73,3 @@ impl DerefMut for OpsColumnsView { unsafe { transmute(self) } } } - -const fn make_col_map() -> OpsColumnsView { - let indices_arr = indices_arr::(); - unsafe { transmute::<[usize; NUM_OPS_COLUMNS], OpsColumnsView>(indices_arr) } -} - -pub const COL_MAP: OpsColumnsView = make_col_map(); diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 82ca5452b7..a77adbcbeb 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -1,4 +1,4 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::Borrow; use std::iter::repeat; use std::marker::PhantomData; @@ -9,10 +9,11 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; +use super::columns::CpuColumnsView; use super::halt; use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; +use crate::cpu::columns::{COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::{ bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio, @@ -198,15 +199,6 @@ pub struct CpuStark { pub f: PhantomData, } -impl CpuStark { - // TODO: Remove? - pub fn generate(&self, local_values: &mut [F; NUM_CPU_COLUMNS]) { - let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut(); - decode::generate(local_values); - membus::generate(local_values); - } -} - impl, const D: usize> Stark for CpuStark { type EvaluationFrame = StarkFrame where diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index c1c43a0bb1..ba4aa0c62d 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -61,51 +61,6 @@ const COMBINED_OPCODES: [usize; 6] = [ COL_MAP.op.m_op_general, ]; -pub fn generate(lv: &mut CpuColumnsView) { - let cycle_filter: F = COL_MAP.op.iter().map(|&col_i| lv[col_i]).sum(); - - // This assert is not _strictly_ necessary, but I include it as a sanity check. - assert_eq!(cycle_filter, F::ONE, "cycle_filter should be 0 or 1"); - - // Validate all opcode bits. - for bit in lv.opcode_bits.into_iter() { - assert!(bit.to_canonical_u64() <= 1); - } - let opcode = lv - .opcode_bits - .into_iter() - .enumerate() - .map(|(i, bit)| bit.to_canonical_u64() << i) - .sum::() as u8; - - let top_bits: [u8; 9] = [ - 0, - opcode & 0x80, - opcode & 0xc0, - opcode & 0xe0, - opcode & 0xf0, - opcode & 0xf8, - opcode & 0xfc, - opcode & 0xfe, - opcode, - ]; - - let kernel = lv.is_kernel_mode.to_canonical_u64(); - assert!(kernel <= 1); - let kernel = kernel != 0; - - for (oc, block_length, kernel_only, col) in OPCODES { - let available = !kernel_only || kernel; - let opcode_match = top_bits[8 - block_length] == oc; - let flag = available && opcode_match; - lv[col] = F::from_bool(flag); - } - - if opcode == 0xfb || opcode == 0xfc { - lv.op.m_op_general = F::from_bool(kernel); - } -} - /// Break up an opcode (which is 8 bits long) into its eight bits. const fn bits_from_opcode(opcode: u8) -> [bool; 8] { [ diff --git a/evm/src/cpu/membus.rs b/evm/src/cpu/membus.rs index bf7a03aeb5..10dc25a4ca 100644 --- a/evm/src/cpu/membus.rs +++ b/evm/src/cpu/membus.rs @@ -1,10 +1,8 @@ use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::PrimeField64; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; -use super::columns::COL_MAP; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; @@ -33,18 +31,6 @@ pub mod channel_indices { /// These limitations save us numerous columns in the CPU table. pub const NUM_CHANNELS: usize = channel_indices::GP.end; -/// Calculates `lv.stack_len_bounds_aux`. Note that this must be run after decode. -pub fn generate(lv: &CpuColumnsView) { - let cycle_filter: F = COL_MAP.op.iter().map(|&col_i| lv[col_i]).sum(); - if cycle_filter != F::ZERO { - assert!(lv.is_kernel_mode.to_canonical_u64() <= 1); - } - - for channel in lv.mem_channels { - assert!(channel.used.to_canonical_u64() <= 1); - } -} - pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ed8ff91510..e9e5de9360 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -6,7 +6,6 @@ use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; -use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cross_table_lookup::get_grand_product_challenge_set; use crate::proof::*; @@ -234,39 +233,6 @@ impl, C: GenericConfig, const D: usize> A ctl_challenges, }) } - - #[allow(unused)] // TODO: should be used soon - pub(crate) fn get_challenger_states( - &self, - all_stark: &AllStark, - config: &StarkConfig, - ) -> AllChallengerState { - let mut challenger = Challenger::::new(); - - for proof in &self.stark_proofs { - challenger.observe_cap(&proof.proof.trace_cap); - } - - observe_public_values::(&mut challenger, &self.public_values); - - let ctl_challenges = - get_grand_product_challenge_set(&mut challenger, config.num_challenges); - - let lookups = all_stark.num_lookups_helper_columns(config); - - let mut challenger_states = vec![challenger.compact()]; - for i in 0..NUM_TABLES { - self.stark_proofs[i] - .proof - .get_challenges(&mut challenger, config); - challenger_states.push(challenger.compact()); - } - - AllChallengerState { - states: challenger_states.try_into().unwrap(), - ctl_challenges, - } - } } impl StarkProof diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 3f744a6134..43561e8c8b 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -39,14 +39,6 @@ pub(crate) struct AllProofChallenges, const D: usiz pub ctl_challenges: GrandProductChallengeSet, } -#[allow(unused)] // TODO: should be used soon -pub(crate) struct AllChallengerState, H: Hasher, const D: usize> { - /// Sponge state of the challenger before starting each proof, - /// along with the final state after all proofs are done. This final state isn't strictly needed. - pub states: [H::Permutation; NUM_TABLES + 1], - pub ctl_challenges: GrandProductChallengeSet, -} - /// Memory values which are public. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct PublicValues { @@ -697,8 +689,7 @@ where C: GenericConfig, { pub(crate) init_challenger_state: >::Permutation, - // TODO: set it back to pub(crate) when cpu trace len is a public input - pub proof: StarkProof, + pub(crate) proof: StarkProof, } impl, C: GenericConfig, const D: usize> StarkProof { diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 8c99ca5d1d..d88d847184 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use anyhow::{ensure, Result}; +use anyhow::Result; use ethereum_types::{BigEndianHash, U256}; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; @@ -10,13 +10,13 @@ use plonky2::gates::gate::GateRef; use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::RichField; use plonky2::hash::hashing::PlonkyPermutation; -use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; +use plonky2::iop::challenger::RecursiveChallenger; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; -use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::util::serialization::{ @@ -25,13 +25,12 @@ use plonky2::util::serialization::{ use plonky2::with_context; use plonky2_util::log2_ceil; -use crate::all_stark::{Table, NUM_TABLES}; +use crate::all_stark::Table; use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cross_table_lookup::{ - get_grand_product_challenge_set, verify_cross_table_lookups, CrossTableLookup, - CtlCheckVarsTarget, GrandProductChallenge, GrandProductChallengeSet, + CrossTableLookup, CtlCheckVarsTarget, GrandProductChallenge, GrandProductChallengeSet, }; use crate::evaluation_frame::StarkEvaluationFrame; use crate::lookup::LookupCheckVarsTarget; @@ -48,15 +47,6 @@ use crate::util::{h256_limbs, u256_limbs, u256_to_u32, u256_to_u64}; use crate::vanishing_poly::eval_vanishing_poly_circuit; use crate::witness::errors::ProgramError; -/// Table-wise recursive proofs of an `AllProof`. -pub struct RecursiveAllProof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, -> { - pub recursive_proofs: [ProofWithPublicInputs; NUM_TABLES], -} - pub(crate) struct PublicInputs> { pub(crate) trace_cap: Vec>, @@ -98,72 +88,6 @@ impl> Public } } -impl, C: GenericConfig, const D: usize> - RecursiveAllProof -{ - /// Verify every recursive proof. - pub fn verify( - self, - verifier_data: &[VerifierCircuitData; NUM_TABLES], - cross_table_lookups: Vec>, - inner_config: &StarkConfig, - ) -> Result<()> { - let pis: [_; NUM_TABLES] = core::array::from_fn(|i| { - PublicInputs::>::Permutation>::from_vec( - &self.recursive_proofs[i].public_inputs, - inner_config, - ) - }); - - let mut challenger = Challenger::::new(); - for pi in &pis { - for h in &pi.trace_cap { - challenger.observe_elements(h); - } - } - - // TODO: Observe public values if the code isn't deprecated. - - let ctl_challenges = - get_grand_product_challenge_set(&mut challenger, inner_config.num_challenges); - // Check that the correct CTL challenges are used in every proof. - for pi in &pis { - ensure!(ctl_challenges == pi.ctl_challenges); - } - - let state = challenger.compact(); - ensure!(state == pis[0].challenger_state_before); - // Check that the challenger state is consistent between proofs. - for i in 1..NUM_TABLES { - ensure!(pis[i].challenger_state_before == pis[i - 1].challenger_state_after); - } - - // Dummy values which will make the check fail. - // TODO: Fix this if the code isn't deprecated. - let mut extra_looking_products = Vec::new(); - for i in 0..NUM_TABLES { - extra_looking_products.push(Vec::new()); - for _ in 0..inner_config.num_challenges { - extra_looking_products[i].push(F::ONE); - } - } - - // Verify the CTL checks. - verify_cross_table_lookups::( - &cross_table_lookups, - pis.map(|p| p.ctl_zs_first), - extra_looking_products, - inner_config, - )?; - - // Verify the proofs. - for (proof, verifier_data) in self.recursive_proofs.into_iter().zip(verifier_data) { - verifier_data.verify(proof)?; - } - Ok(()) - } -} - /// Represents a circuit which recursively verifies a STARK proof. #[derive(Eq, PartialEq, Debug)] pub(crate) struct StarkWrapperCircuit From b4203c3d47c9164cecb589f1ef91f21e100dcb02 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Tue, 10 Oct 2023 06:23:20 -0400 Subject: [PATCH 28/34] Make sure success is 0 in contract failure (#1283) --- evm/src/cpu/kernel/asm/core/process_txn.asm | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/process_txn.asm b/evm/src/cpu/kernel/asm/core/process_txn.asm index 6a1b1c6c2d..779acab1da 100644 --- a/evm/src/cpu/kernel/asm/core/process_txn.asm +++ b/evm/src/cpu/kernel/asm/core/process_txn.asm @@ -420,7 +420,7 @@ contract_creation_fault_3: // stack: leftover_gas', retdest, success %delete_all_touched_addresses %delete_all_selfdestructed_addresses - %stack (leftover_gas, retdest, success) -> (retdest, success, leftover_gas) + %stack (leftover_gas, retdest, success) -> (retdest, 0, leftover_gas) JUMP contract_creation_fault_3_zero_leftover: @@ -432,7 +432,7 @@ contract_creation_fault_3_zero_leftover: %pay_coinbase_and_refund_sender %delete_all_touched_addresses %delete_all_selfdestructed_addresses - %stack (leftover_gas, retdest, success) -> (retdest, success, leftover_gas) + %stack (leftover_gas, retdest, success) -> (retdest, 0, leftover_gas) JUMP contract_creation_fault_4: @@ -444,7 +444,7 @@ contract_creation_fault_4: %pay_coinbase_and_refund_sender %delete_all_touched_addresses %delete_all_selfdestructed_addresses - %stack (leftover_gas, retdest, success) -> (retdest, success, leftover_gas) + %stack (leftover_gas, retdest, success) -> (retdest, 0, leftover_gas) JUMP From 2aeecc3dd8cd2f01cb4a84b1404ef2fe76efcab2 Mon Sep 17 00:00:00 2001 From: Linda Guiga <101227802+LindaGuiga@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:02:24 -0400 Subject: [PATCH 29/34] Fix failed receipt. (#1284) --- evm/src/cpu/kernel/asm/core/create_receipt.asm | 1 + 1 file changed, 1 insertion(+) diff --git a/evm/src/cpu/kernel/asm/core/create_receipt.asm b/evm/src/cpu/kernel/asm/core/create_receipt.asm index fccabe0885..66f48d010b 100644 --- a/evm/src/cpu/kernel/asm/core/create_receipt.asm +++ b/evm/src/cpu/kernel/asm/core/create_receipt.asm @@ -228,6 +228,7 @@ failed_receipt: // It is the receipt of a failed transaction, so set num_logs to 0. This will also lead to Bloom filter = 0. PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) // stack: status, new_cum_gas, num_nibbles, txn_nb %jump(process_receipt_after_status) From 49ca63ee0fe8a827e78d36c350e27d83e0fc6ea4 Mon Sep 17 00:00:00 2001 From: Hamy Ratoanina Date: Wed, 11 Oct 2023 16:23:09 +0200 Subject: [PATCH 30/34] Fix sys_blockhash (#1285) --- evm/src/cpu/kernel/asm/memory/metadata.asm | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/evm/src/cpu/kernel/asm/memory/metadata.asm b/evm/src/cpu/kernel/asm/memory/metadata.asm index c26d3d5fb9..d5b4033d56 100644 --- a/evm/src/cpu/kernel/asm/memory/metadata.asm +++ b/evm/src/cpu/kernel/asm/memory/metadata.asm @@ -241,6 +241,8 @@ global sys_blockhash: SWAP1 // stack: block_number, kexit_info %blockhash + // stack: blockhash, kexit_info + SWAP1 EXIT_KERNEL global blockhash: @@ -262,7 +264,6 @@ global blockhash: // stack: block_hash_number, retdest %mload_kernel(@SEGMENT_BLOCK_HASHES) SWAP1 JUMP - JUMP %macro blockhash // stack: block_number From 9fd0425f6781129cb09e6df94c7a584bda1c97ff Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 11 Oct 2023 16:24:00 +0200 Subject: [PATCH 31/34] Fix journal order in `sys_selfdestruct` (#1287) * Fix * Minor --- evm/src/cpu/kernel/asm/core/terminate.asm | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/terminate.asm b/evm/src/cpu/kernel/asm/core/terminate.asm index 8910f62468..fb01f7aa23 100644 --- a/evm/src/cpu/kernel/asm/core/terminate.asm +++ b/evm/src/cpu/kernel/asm/core/terminate.asm @@ -92,25 +92,20 @@ global sys_selfdestruct: %mstore_trie_data %stack (balance, address, recipient, kexit_info) -> - (address, recipient, balance, address, recipient, recipient, balance, kexit_info) - %journal_add_account_destroyed + (address, recipient, address, recipient, balance, kexit_info) // If the recipient is the same as the address, then we're done. // Otherwise, send the balance to the recipient. - // stack: address, recipient, recipient, balance, kexit_info - EQ %jumpi(sys_selfdestruct_same_addr) - // stack: recipient, balance, kexit_info + // stack: address, recipient, address, recipient, balance, kexit_info + EQ %jumpi(sys_selfdestruct_journal_add) + %stack (address, recipient, balance, kexit_info) -> (recipient, balance, address, recipient, balance, kexit_info) %add_eth - // stack: kexit_info - %leftover_gas - // stack: leftover_gas - PUSH 1 // success - %jump(terminate_common) +sys_selfdestruct_journal_add: + // stack: address, recipient, balance, kexit_info + %journal_add_account_destroyed -sys_selfdestruct_same_addr: - // stack: recipient, balance, kexit_info - %pop2 + // stack: kexit_info %leftover_gas // stack: leftover_gas PUSH 1 // success From d7990ee137cfdce7fc2f2422d4ddf7541ec2779c Mon Sep 17 00:00:00 2001 From: Linda Guiga <101227802+LindaGuiga@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:36:23 -0400 Subject: [PATCH 32/34] Add journal entry for logs (#1286) * Add journal entry for logs * Move journal labels to another file. * Minor cleanup --- evm/src/cpu/kernel/aggregator.rs | 1 + evm/src/cpu/kernel/asm/core/log.asm | 2 ++ evm/src/cpu/kernel/asm/journal/log.asm | 20 +++++++++++++++++++ evm/src/cpu/kernel/asm/journal/revert.asm | 1 + evm/src/cpu/kernel/constants/journal_entry.rs | 5 ++++- 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 evm/src/cpu/kernel/asm/journal/log.asm diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 20081bb935..bda2ab610e 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -148,6 +148,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/journal/refund.asm"), include_str!("asm/journal/account_created.asm"), include_str!("asm/journal/revert.asm"), + include_str!("asm/journal/log.asm"), include_str!("asm/transactions/common_decoding.asm"), include_str!("asm/transactions/router.asm"), include_str!("asm/transactions/type_0.asm"), diff --git a/evm/src/cpu/kernel/asm/core/log.asm b/evm/src/cpu/kernel/asm/core/log.asm index 727e3fa4e0..0689d49211 100644 --- a/evm/src/cpu/kernel/asm/core/log.asm +++ b/evm/src/cpu/kernel/asm/core/log.asm @@ -188,6 +188,8 @@ log_after_topics: %rlp_list_len // stack: rlp_log_len, data_len_ptr, num_topics, data_len, data_offset, retdest %mload_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) + // Add payload length and logs_data_len to journal. + DUP1 %mload_global_metadata(@GLOBAL_METADATA_LOGS_DATA_LEN) %journal_add_log ADD %mstore_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) // stack: data_len_ptr, num_topics, data_len, data_offset, retdest diff --git a/evm/src/cpu/kernel/asm/journal/log.asm b/evm/src/cpu/kernel/asm/journal/log.asm new file mode 100644 index 0000000000..0b815faef6 --- /dev/null +++ b/evm/src/cpu/kernel/asm/journal/log.asm @@ -0,0 +1,20 @@ +// struct Log { logs_data_len, logs_payload_len } + +%macro journal_add_log + %journal_add_2(@JOURNAL_ENTRY_LOG) +%endmacro + +global revert_log: + // stack: entry_type, ptr, retdest + POP + // First, reduce the number of logs. + %mload_global_metadata(@GLOBAL_METADATA_LOGS_LEN) + %decrement + %mstore_global_metadata(@GLOBAL_METADATA_LOGS_LEN) + // stack: ptr, retdest + // Second, restore payload length. + %journal_load_2 + // stack: prev_logs_data_len, prev_payload_len, retdest + %mstore_global_metadata(@GLOBAL_METADATA_LOGS_DATA_LEN) + %mstore_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) + JUMP diff --git a/evm/src/cpu/kernel/asm/journal/revert.asm b/evm/src/cpu/kernel/asm/journal/revert.asm index 1967239ae0..857bf612b2 100644 --- a/evm/src/cpu/kernel/asm/journal/revert.asm +++ b/evm/src/cpu/kernel/asm/journal/revert.asm @@ -16,6 +16,7 @@ DUP1 %eq_const(@JOURNAL_ENTRY_CODE_CHANGE) %jumpi(revert_code_change) DUP1 %eq_const(@JOURNAL_ENTRY_REFUND) %jumpi(revert_refund) DUP1 %eq_const(@JOURNAL_ENTRY_ACCOUNT_CREATED) %jumpi(revert_account_created) + DUP1 %eq_const(@JOURNAL_ENTRY_LOG) %jumpi(revert_log) PANIC // This should never happen. %%after: // stack: journal_size-1 diff --git a/evm/src/cpu/kernel/constants/journal_entry.rs b/evm/src/cpu/kernel/constants/journal_entry.rs index be5db120ca..8015ce2162 100644 --- a/evm/src/cpu/kernel/constants/journal_entry.rs +++ b/evm/src/cpu/kernel/constants/journal_entry.rs @@ -11,10 +11,11 @@ pub(crate) enum JournalEntry { CodeChange = 7, Refund = 8, AccountCreated = 9, + Log = 10, } impl JournalEntry { - pub(crate) const COUNT: usize = 10; + pub(crate) const COUNT: usize = 11; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -28,6 +29,7 @@ impl JournalEntry { Self::CodeChange, Self::Refund, Self::AccountCreated, + Self::Log, ] } @@ -44,6 +46,7 @@ impl JournalEntry { Self::CodeChange => "JOURNAL_ENTRY_CODE_CHANGE", Self::Refund => "JOURNAL_ENTRY_REFUND", Self::AccountCreated => "JOURNAL_ENTRY_ACCOUNT_CREATED", + Self::Log => "JOURNAL_ENTRY_LOG", } } } From 762e6f07b834df04be8cd290f07465a28c392c6d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 11 Oct 2023 18:57:17 +0200 Subject: [PATCH 33/34] Fix hash node case in `mpt_delete_branch` (#1278) * Fix * Add test * Fix test * Clippy --- evm/src/cpu/kernel/asm/mpt/delete/delete_branch.asm | 8 +++++--- evm/src/cpu/kernel/tests/mpt/delete.rs | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/delete/delete_branch.asm b/evm/src/cpu/kernel/asm/mpt/delete/delete_branch.asm index 69c97943cc..775e4e11ed 100644 --- a/evm/src/cpu/kernel/asm/mpt/delete/delete_branch.asm +++ b/evm/src/cpu/kernel/asm/mpt/delete/delete_branch.asm @@ -77,14 +77,16 @@ loop_end: // stack: i, updated_child_ptr, first_nibble, node_payload_ptr, retdest DUP4 ADD %mload_trie_data // stack: only_child_ptr, updated_child_ptr, first_nibble, node_payload_ptr, retdest - DUP1 %mload_trie_data %eq_const(@MPT_NODE_BRANCH) %jumpi(maybe_normalize_branch_branch) + DUP1 %mload_trie_data %eq_const(@MPT_NODE_BRANCH) %jumpi(maybe_normalize_branch_branchhash) + DUP1 %mload_trie_data %eq_const(@MPT_NODE_HASH) %jumpi(maybe_normalize_branch_branchhash) DUP1 %mload_trie_data %eq_const(@MPT_NODE_EXTENSION) %jumpi(maybe_normalize_branch_leafext) DUP1 %mload_trie_data %eq_const(@MPT_NODE_LEAF) %jumpi(maybe_normalize_branch_leafext) PANIC // This should never happen. -// The only child of the branch node is a branch node. +// The only child of the branch node is a branch node or a hash node. // Transform the branch node into an extension node of length 1. -maybe_normalize_branch_branch: +// This assumes that the hash node does not contain a leaf or an extension node (in which case this implementation is incorrect). +maybe_normalize_branch_branchhash: // stack: only_child_ptr, updated_child_ptr, first_nibble, node_payload_ptr, retdest %get_trie_data_size // pointer to the extension node we're about to create // stack: extension_ptr, only_child_ptr, updated_child_ptr, first_nibble, node_payload_ptr, retdest diff --git a/evm/src/cpu/kernel/tests/mpt/delete.rs b/evm/src/cpu/kernel/tests/mpt/delete.rs index 42e8caf991..074eea26ef 100644 --- a/evm/src/cpu/kernel/tests/mpt/delete.rs +++ b/evm/src/cpu/kernel/tests/mpt/delete.rs @@ -36,6 +36,17 @@ fn mpt_delete_leaf_overlapping_keys() -> Result<()> { test_state_trie(state_trie, nibbles_64(0xADE), test_account_2()) } +#[test] +fn mpt_delete_branch_into_hash() -> Result<()> { + let hash = Node::Hash(H256::random()); + let state_trie = Node::Extension { + nibbles: nibbles_64(0xADF), + child: hash.into(), + } + .into(); + test_state_trie(state_trie, nibbles_64(0xADE), test_account_2()) +} + /// Note: The account's storage_root is ignored, as we can't insert a new storage_root without the /// accompanying trie data. An empty trie's storage_root is used instead. fn test_state_trie( From 1d60431992ab3cc90addfa12b43f851e35ad97cb Mon Sep 17 00:00:00 2001 From: Hamy Ratoanina Date: Wed, 11 Oct 2023 22:28:49 +0200 Subject: [PATCH 34/34] Store top of the stack in memory channel 0 (#1215) * Store top of the stack in memory channel 0 * Fix interpreter * Apply comments * Remove debugging code * Merge commit * Remove debugging comments * Apply comments * Fix witness generation for exceptions * Fix witness generation for exceptions (again) * Fix modfp254 constraint --- evm/src/cpu/columns/general.rs | 20 + evm/src/cpu/columns/ops.rs | 3 +- evm/src/cpu/contextops.rs | 226 ++++++------ evm/src/cpu/control_flow.rs | 8 +- evm/src/cpu/cpu_stark.rs | 12 +- evm/src/cpu/decode.rs | 5 +- evm/src/cpu/dup_swap.rs | 123 ++++--- evm/src/cpu/gas.rs | 3 +- evm/src/cpu/jumps.rs | 107 +++++- evm/src/cpu/kernel/interpreter.rs | 54 ++- evm/src/cpu/kernel/tests/signed_syscalls.rs | 4 +- evm/src/cpu/memio.rs | 178 +++++++-- evm/src/cpu/mod.rs | 2 +- evm/src/cpu/modfp254.rs | 4 +- evm/src/cpu/pc.rs | 15 +- evm/src/cpu/push0.rs | 9 +- evm/src/cpu/stack.rs | 388 +++++++++++++++----- evm/src/cpu/syscalls_exceptions.rs | 21 +- evm/src/witness/memory.rs | 12 + evm/src/witness/operation.rs | 316 +++++++++++----- evm/src/witness/state.rs | 7 + evm/src/witness/transition.rs | 96 ++++- evm/src/witness/util.rs | 125 +++++-- 23 files changed, 1249 insertions(+), 489 deletions(-) diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 57eb16fcf8..d4f3447380 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -10,6 +10,7 @@ pub(crate) union CpuGeneralColumnsView { logic: CpuLogicView, jumps: CpuJumpsView, shift: CpuShiftView, + stack: CpuStackView, } impl CpuGeneralColumnsView { @@ -52,6 +53,16 @@ impl CpuGeneralColumnsView { pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView { unsafe { &mut self.shift } } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn stack(&self) -> &CpuStackView { + unsafe { &self.stack } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn stack_mut(&mut self) -> &mut CpuStackView { + unsafe { &mut self.stack } + } } impl PartialEq for CpuGeneralColumnsView { @@ -110,5 +121,14 @@ pub(crate) struct CpuShiftView { pub(crate) high_limb_sum_inv: T, } +#[derive(Copy, Clone)] +pub(crate) struct CpuStackView { + // Used for conditionally enabling and disabling channels when reading the next `stack_top`. + _unused: [T; 5], + pub(crate) stack_inv: T, + pub(crate) stack_inv_aux: T, + pub(crate) stack_inv_aux_2: T, +} + // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 64474c9874..feeb3f5f75 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -24,7 +24,8 @@ pub struct OpsColumnsView { pub push: T, pub dup: T, pub swap: T, - pub context_op: T, + pub get_context: T, + pub set_context: T, pub mstore_32bytes: T, pub mload_32bytes: T, pub exit_kernel: T, diff --git a/evm/src/cpu/contextops.rs b/evm/src/cpu/contextops.rs index 55f1482041..1683c30e56 100644 --- a/evm/src/cpu/contextops.rs +++ b/evm/src/cpu/contextops.rs @@ -8,99 +8,38 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; fn eval_packed_get( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0 - let filter = lv.op.context_op * (P::ONES - lv.opcode_bits[0]); - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - yield_constr.constraint(filter * (push_channel.value[0] - lv.context)); - for &limb in &push_channel.value[1..] { + let filter = lv.op.get_context; + let new_stack_top = nv.mem_channels[0].value; + yield_constr.constraint(filter * (new_stack_top[0] - lv.context)); + for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } - - // Stack constraints - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * channel.is_read); - - yield_constr.constraint(filter * (channel.addr_context - lv.context)); - yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - let addr_virtual = lv.stack_len; - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); - - // Unused channels - for i in 0..NUM_GP_CHANNELS - 1 { - let channel = lv.mem_channels[i]; - yield_constr.constraint(filter * channel.used); - } } fn eval_ext_circuit_get, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.context_op; - let one = builder.one_extension(); - let minus = builder.sub_extension(one, lv.opcode_bits[0]); - filter = builder.mul_extension(filter, minus); - - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; + let filter = lv.op.get_context; + let new_stack_top = nv.mem_channels[0].value; { - let diff = builder.sub_extension(push_channel.value[0], lv.context); + let diff = builder.sub_extension(new_stack_top[0], lv.context); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for &limb in &push_channel.value[1..] { + for &limb in &new_stack_top[1..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } - - // Stack constraints - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - - { - let constr = builder.mul_sub_extension(filter, channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_extension(filter, channel.is_read); - yield_constr.constraint(builder, constr); - } - - { - let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.arithmetic_extension( - F::ONE, - -F::from_canonical_u64(Segment::Stack as u64), - filter, - channel.addr_segment, - filter, - ); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension(F::ONE, F::ZERO, filter, diff, filter); - yield_constr.constraint(builder, constr); - } - - for i in 0..NUM_GP_CHANNELS - 1 { - let channel = lv.mem_channels[i]; - let constr = builder.mul_extension(filter, channel.used); - yield_constr.constraint(builder, constr); - } } fn eval_packed_set( @@ -108,22 +47,16 @@ fn eval_packed_set( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.context_op * lv.opcode_bits[0]; - let pop_channel = lv.mem_channels[0]; + let filter = lv.op.set_context; + let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; - let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); let ctx_metadata_segment = P::Scalar::from_canonical_u64(Segment::ContextMetadata as u64); let stack_size_field = P::Scalar::from_canonical_u64(ContextMetadata::StackSize as u64); let local_sp_dec = lv.stack_len - P::ONES; - // The next row's context is read from memory channel 0. - yield_constr.constraint(filter * (pop_channel.value[0] - nv.context)); - yield_constr.constraint(filter * (pop_channel.used - P::ONES)); - yield_constr.constraint(filter * (pop_channel.is_read - P::ONES)); - yield_constr.constraint(filter * (pop_channel.addr_context - lv.context)); - yield_constr.constraint(filter * (pop_channel.addr_segment - stack_segment)); - yield_constr.constraint(filter * (pop_channel.addr_virtual - local_sp_dec)); + // The next row's context is read from stack_top. + yield_constr.constraint(filter * (stack_top[0] - nv.context)); // The old SP is decremented (since the new context was popped) and written to memory. yield_constr.constraint(filter * (write_old_sp_channel.value[0] - local_sp_dec)); @@ -144,10 +77,34 @@ fn eval_packed_set( yield_constr.constraint(filter * (read_new_sp_channel.addr_segment - ctx_metadata_segment)); yield_constr.constraint(filter * (read_new_sp_channel.addr_virtual - stack_size_field)); - // Disable unused memory channels - for &channel in &lv.mem_channels[3..] { - yield_constr.constraint(filter * channel.used); + // The next row's stack top is loaded from memory (if the stack isn't empty). + yield_constr.constraint(filter * nv.mem_channels[0].used); + + let read_new_stack_top_channel = lv.mem_channels[3]; + let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); + let new_filter = filter * nv.stack_len; + + for (limb_channel, limb_top) in read_new_stack_top_channel + .value + .iter() + .zip(nv.mem_channels[0].value) + { + yield_constr.constraint(new_filter * (*limb_channel - limb_top)); } + yield_constr.constraint(new_filter * (read_new_stack_top_channel.used - P::ONES)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.is_read - P::ONES)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_context - nv.context)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_segment - stack_segment)); + yield_constr.constraint( + new_filter * (read_new_stack_top_channel.addr_virtual - (nv.stack_len - P::ONES)), + ); + + // If the new stack is empty, disable the channel read. + yield_constr.constraint( + filter * (nv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint(empty_stack_filter * read_new_stack_top_channel.used); } fn eval_ext_circuit_set, const D: usize>( @@ -156,13 +113,10 @@ fn eval_ext_circuit_set, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.context_op; - filter = builder.mul_extension(filter, lv.opcode_bits[0]); - let pop_channel = lv.mem_channels[0]; + let filter = lv.op.set_context; + let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; - let stack_segment = - builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); let ctx_metadata_segment = builder.constant_extension(F::Extension::from_canonical_u32( Segment::ContextMetadata as u32, )); @@ -172,32 +126,9 @@ fn eval_ext_circuit_set, const D: usize>( let one = builder.one_extension(); let local_sp_dec = builder.sub_extension(lv.stack_len, one); - // The next row's context is read from memory channel 0. + // The next row's context is read from stack_top. { - let diff = builder.sub_extension(pop_channel.value[0], nv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, pop_channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, pop_channel.is_read, filter); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_segment, stack_segment); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_virtual, local_sp_dec); + let diff = builder.sub_extension(stack_top[0], nv.context); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } @@ -266,9 +197,66 @@ fn eval_ext_circuit_set, const D: usize>( yield_constr.constraint(builder, constr); } - // Disable unused memory channels - for &channel in &lv.mem_channels[3..] { - let constr = builder.mul_extension(filter, channel.used); + // The next row's stack top is loaded from memory (if the stack isn't empty). + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } + + let read_new_stack_top_channel = lv.mem_channels[3]; + let stack_segment = + builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); + + let new_filter = builder.mul_extension(filter, nv.stack_len); + + for (limb_channel, limb_top) in read_new_stack_top_channel + .value + .iter() + .zip(nv.mem_channels[0].value) + { + let diff = builder.sub_extension(*limb_channel, limb_top); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = + builder.mul_sub_extension(new_filter, read_new_stack_top_channel.used, new_filter); + yield_constr.constraint(builder, constr); + } + { + let constr = + builder.mul_sub_extension(new_filter, read_new_stack_top_channel.is_read, new_filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(read_new_stack_top_channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(read_new_stack_top_channel.addr_segment, stack_segment); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(nv.stack_len, one); + let diff = builder.sub_extension(read_new_stack_top_channel.addr_virtual, diff); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + + // If the new stack is empty, disable the channel read. + { + let diff = builder.mul_extension(nv.stack_len, lv.general.stack().stack_inv); + let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, read_new_stack_top_channel.used); yield_constr.constraint(builder, constr); } } @@ -278,7 +266,7 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - eval_packed_get(lv, yield_constr); + eval_packed_get(lv, nv, yield_constr); eval_packed_set(lv, nv, yield_constr); } @@ -288,6 +276,6 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - eval_ext_circuit_get(builder, lv, yield_constr); + eval_ext_circuit_get(builder, lv, nv, yield_constr); eval_ext_circuit_set(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 9c17367aa2..a192ffb13f 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 17] = [ +const NATIVE_INSTRUCTIONS: [usize; 18] = [ COL_MAP.op.binary_op, COL_MAP.op.ternary_op, COL_MAP.op.fp254_op, @@ -19,15 +19,15 @@ const NATIVE_INSTRUCTIONS: [usize; 17] = [ COL_MAP.op.keccak_general, COL_MAP.op.prover_input, COL_MAP.op.pop, - // not JUMP (need to jump) - // not JUMPI (possible need to jump) + // not JUMPS (possible need to jump) COL_MAP.op.pc, COL_MAP.op.jumpdest, COL_MAP.op.push0, // not PUSH (need to increment by more than 1) COL_MAP.op.dup, COL_MAP.op.swap, - COL_MAP.op.context_op, + COL_MAP.op.get_context, + COL_MAP.op.set_context, // not EXIT_KERNEL (performs a jump) COL_MAP.op.m_op_general, // not SYSCALL (performs a jump) diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index a77adbcbeb..64a2db9c36 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -224,15 +224,15 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark, const D: usize>( fn eval_packed_dup( n: P, lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.dup; - let in_channel = &lv.mem_channels[0]; - let out_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let write_channel = &lv.mem_channels[1]; + let read_channel = &lv.mem_channels[2]; - channels_equal_packed(filter, in_channel, out_channel, yield_constr); + channels_equal_packed(filter, write_channel, &lv.mem_channels[0], yield_constr); + constrain_channel_packed(false, filter, P::ZEROS, write_channel, lv, yield_constr); - constrain_channel_packed(true, filter, n, in_channel, lv, yield_constr); - constrain_channel_packed( - false, - filter, - P::Scalar::NEG_ONE.into(), - out_channel, - lv, - yield_constr, - ); + channels_equal_packed(filter, read_channel, &nv.mem_channels[0], yield_constr); + constrain_channel_packed(true, filter, n, read_channel, lv, yield_constr); + + // Constrain nv.stack_len. + yield_constr.constraint_transition(filter * (nv.stack_len - lv.stack_len - P::ONES)); + + // TODO: Constrain unused channels? } fn eval_ext_circuit_dup, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let neg_one = builder.constant_extension(F::NEG_ONE.into()); + let zero = builder.zero_extension(); let filter = lv.op.dup; - let in_channel = &lv.mem_channels[0]; - let out_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let write_channel = &lv.mem_channels[1]; + let read_channel = &lv.mem_channels[2]; - channels_equal_ext_circuit(builder, filter, in_channel, out_channel, yield_constr); - - constrain_channel_ext_circuit(builder, true, filter, n, in_channel, lv, yield_constr); + channels_equal_ext_circuit( + builder, + filter, + write_channel, + &lv.mem_channels[0], + yield_constr, + ); constrain_channel_ext_circuit( builder, false, filter, - neg_one, - out_channel, + zero, + write_channel, lv, yield_constr, ); + + channels_equal_ext_circuit( + builder, + filter, + read_channel, + &nv.mem_channels[0], + yield_constr, + ); + constrain_channel_ext_circuit(builder, true, filter, n, read_channel, lv, yield_constr); + + // Constrain nv.stack_len. + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_sub_extension(filter, diff, filter); + yield_constr.constraint_transition(builder, constr); + + // TODO: Constrain unused channels? } fn eval_packed_swap( n: P, lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let n_plus_one = n + P::ONES; @@ -170,25 +191,27 @@ fn eval_packed_swap( let in1_channel = &lv.mem_channels[0]; let in2_channel = &lv.mem_channels[1]; - let out1_channel = &lv.mem_channels[NUM_GP_CHANNELS - 2]; - let out2_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let out_channel = &lv.mem_channels[2]; - channels_equal_packed(filter, in1_channel, out1_channel, yield_constr); - channels_equal_packed(filter, in2_channel, out2_channel, yield_constr); + channels_equal_packed(filter, in1_channel, out_channel, yield_constr); + constrain_channel_packed(false, filter, n_plus_one, out_channel, lv, yield_constr); - constrain_channel_packed(true, filter, P::ZEROS, in1_channel, lv, yield_constr); + channels_equal_packed(filter, in2_channel, &nv.mem_channels[0], yield_constr); constrain_channel_packed(true, filter, n_plus_one, in2_channel, lv, yield_constr); - constrain_channel_packed(false, filter, n_plus_one, out1_channel, lv, yield_constr); - constrain_channel_packed(false, filter, P::ZEROS, out2_channel, lv, yield_constr); + + // Constrain nv.stack_len; + yield_constr.constraint(filter * (nv.stack_len - lv.stack_len)); + + // TODO: Constrain unused channels? } fn eval_ext_circuit_swap, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let zero = builder.zero_extension(); let one = builder.one_extension(); let n_plus_one = builder.add_extension(n, one); @@ -196,36 +219,47 @@ fn eval_ext_circuit_swap, const D: usize>( let in1_channel = &lv.mem_channels[0]; let in2_channel = &lv.mem_channels[1]; - let out1_channel = &lv.mem_channels[NUM_GP_CHANNELS - 2]; - let out2_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; - - channels_equal_ext_circuit(builder, filter, in1_channel, out1_channel, yield_constr); - channels_equal_ext_circuit(builder, filter, in2_channel, out2_channel, yield_constr); + let out_channel = &lv.mem_channels[2]; - constrain_channel_ext_circuit(builder, true, filter, zero, in1_channel, lv, yield_constr); + channels_equal_ext_circuit(builder, filter, in1_channel, out_channel, yield_constr); constrain_channel_ext_circuit( builder, - true, + false, filter, n_plus_one, - in2_channel, + out_channel, lv, yield_constr, ); + + channels_equal_ext_circuit( + builder, + filter, + in2_channel, + &nv.mem_channels[0], + yield_constr, + ); constrain_channel_ext_circuit( builder, - false, + true, filter, n_plus_one, - out1_channel, + in2_channel, lv, yield_constr, ); - constrain_channel_ext_circuit(builder, false, filter, zero, out2_channel, lv, yield_constr); + + // Constrain nv.stack_len. + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // TODO: Constrain unused channels? } pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let n = lv.opcode_bits[0] @@ -233,13 +267,14 @@ pub fn eval_packed( + lv.opcode_bits[2] * P::Scalar::from_canonical_u64(4) + lv.opcode_bits[3] * P::Scalar::from_canonical_u64(8); - eval_packed_dup(n, lv, yield_constr); - eval_packed_swap(n, lv, yield_constr); + eval_packed_dup(n, lv, nv, yield_constr); + eval_packed_swap(n, lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let n = lv.opcode_bits[..4].iter().enumerate().fold( @@ -249,6 +284,6 @@ pub fn eval_ext_circuit, const D: usize>( }, ); - eval_ext_circuit_dup(builder, n, lv, yield_constr); - eval_ext_circuit_swap(builder, n, lv, yield_constr); + eval_ext_circuit_dup(builder, n, lv, nv, yield_constr); + eval_ext_circuit_swap(builder, n, lv, nv, yield_constr); } diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 694fb0f47e..1434efd93d 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -36,7 +36,8 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { push: G_VERYLOW, dup: G_VERYLOW, swap: G_VERYLOW, - context_op: KERNEL_ONLY_INSTR, + get_context: KERNEL_ONLY_INSTR, + set_context: KERNEL_ONLY_INSTR, mstore_32bytes: KERNEL_ONLY_INSTR, mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 1829177384..0c03e2d178 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -7,7 +7,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; -use crate::cpu::stack; use crate::memory::segments::Segment; pub fn eval_packed_exit_kernel( @@ -74,8 +73,26 @@ pub fn eval_packed_jump_jumpi( let is_jumpi = filter * lv.opcode_bits[0]; // Stack constraints. - stack::eval_packed_one(lv, nv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); - stack::eval_packed_one(lv, nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); + // If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)... + let len_diff = lv.stack_len - P::ONES - lv.opcode_bits[0]; + let new_filter = len_diff * filter; + // Read an extra element. + let channel = nv.mem_channels[0]; + yield_constr.constraint_transition(new_filter * (channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (channel.addr_virtual - addr_virtual)); + // Constrain `stack_inv_aux`. + yield_constr.constraint( + filter * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Disable channel if stack_len == N. + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint_transition(empty_stack_filter * channel.used); // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -123,6 +140,12 @@ pub fn eval_packed_jump_jumpi( // Channel 1 is unused by the `JUMP` instruction. yield_constr.constraint(is_jump * lv.mem_channels[1].used); + // Update stack length. + yield_constr.constraint_transition(is_jump * (nv.stack_len - lv.stack_len + P::ONES)); + yield_constr.constraint_transition( + is_jumpi * (nv.stack_len - lv.stack_len + P::Scalar::from_canonical_u64(2)), + ); + // Finally, set the next program counter. let fallthrough_dst = lv.program_counter + P::ONES; let jump_dest = dst[0]; @@ -150,22 +173,55 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]); // Stack constraints. - stack::eval_ext_circuit_one( - builder, - lv, - nv, - is_jump, - stack::JUMP_OP.unwrap(), - yield_constr, - ); - stack::eval_ext_circuit_one( - builder, - lv, - nv, - is_jumpi, - stack::JUMPI_OP.unwrap(), - yield_constr, - ); + // If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)... + let len_diff = builder.sub_extension(lv.stack_len, one_extension); + let len_diff = builder.sub_extension(len_diff, lv.opcode_bits[0]); + let new_filter = builder.mul_extension(len_diff, filter); + // Read an extra element. + let channel = nv.mem_channels[0]; + + { + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + new_filter, + channel.addr_segment, + new_filter, + ); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, nv.stack_len); + let constr = builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); + yield_constr.constraint_transition(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let prod = builder.mul_extension(len_diff, lv.general.stack().stack_inv); + let diff = builder.sub_extension(prod, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Disable channel if stack_len == N. + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); + yield_constr.constraint_transition(builder, constr); + } // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -267,6 +323,19 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr.constraint(builder, constr); } + // Update stack length. + { + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_add_extension(is_jump, diff, is_jump); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let diff = builder.add_const_extension(diff, F::TWO); + let constr = builder.mul_extension(is_jumpi, diff); + yield_constr.constraint_transition(builder, constr); + } + // Finally, set the next program counter. let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE); let jump_dest = dst[0]; diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 8f19a0728f..315e93f16d 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -126,8 +126,14 @@ impl<'a> Interpreter<'a> { opcode_count: [0; 0x100], }; result.generation_state.registers.program_counter = initial_offset; - result.generation_state.registers.stack_len = initial_stack.len(); - *result.stack_mut() = initial_stack; + let initial_stack_len = initial_stack.len(); + result.generation_state.registers.stack_len = initial_stack_len; + if !initial_stack.is_empty() { + result.generation_state.registers.stack_top = initial_stack[initial_stack_len - 1]; + *result.stack_segment_mut() = initial_stack; + result.stack_segment_mut().truncate(initial_stack_len - 1); + } + result } @@ -262,12 +268,18 @@ impl<'a> Interpreter<'a> { self.generation_state.registers.program_counter += n; } - pub(crate) fn stack(&self) -> &[U256] { - &self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + pub(crate) fn stack(&self) -> Vec { + let mut stack = self.generation_state.memory.contexts[self.context].segments + [Segment::Stack as usize] .content + .clone(); + if self.stack_len() > 0 { + stack.push(self.stack_top()); + } + stack } - fn stack_mut(&mut self) -> &mut Vec { + fn stack_segment_mut(&mut self) -> &mut Vec { &mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] .content } @@ -285,7 +297,11 @@ impl<'a> Interpreter<'a> { } pub(crate) fn push(&mut self, x: U256) { - self.stack_mut().push(x); + if self.stack_len() > 0 { + let top = self.stack_top(); + self.stack_segment_mut().push(top); + } + self.generation_state.registers.stack_top = x; self.generation_state.registers.stack_len += 1; } @@ -295,9 +311,17 @@ impl<'a> Interpreter<'a> { pub(crate) fn pop(&mut self) -> U256 { let result = stack_peek(&self.generation_state, 0); + if self.stack_len() > 1 { + let top = stack_peek(&self.generation_state, 1).unwrap(); + self.generation_state.registers.stack_top = top; + } self.generation_state.registers.stack_len -= 1; let new_len = self.stack_len(); - self.stack_mut().truncate(new_len); + if new_len > 0 { + self.stack_segment_mut().truncate(new_len - 1); + } else { + self.stack_segment_mut().truncate(0); + } result.expect("Empty stack") } @@ -1007,13 +1031,19 @@ impl<'a> Interpreter<'a> { } fn run_dup(&mut self, n: u8) { - self.push(self.stack()[self.stack_len() - n as usize]); + if n == 0 { + self.push(self.stack_top()); + } else { + self.push(stack_peek(&self.generation_state, n as usize - 1).unwrap()); + } } fn run_swap(&mut self, n: u8) -> anyhow::Result<()> { let len = self.stack_len(); ensure!(len > n as usize); - self.stack_mut().swap(len - 1, len - n as usize - 1); + let to_swap = stack_peek(&self.generation_state, n as usize).unwrap(); + self.stack_segment_mut()[len - n as usize - 1] = self.stack_top(); + self.generation_state.registers.stack_top = to_swap; Ok(()) } @@ -1084,9 +1114,13 @@ impl<'a> Interpreter<'a> { } } - fn stack_len(&self) -> usize { + pub(crate) fn stack_len(&self) -> usize { self.generation_state.registers.stack_len } + + pub(crate) fn stack_top(&self) -> U256 { + self.generation_state.registers.stack_top + } } // Computes the two's complement of the given integer. diff --git a/evm/src/cpu/kernel/tests/signed_syscalls.rs b/evm/src/cpu/kernel/tests/signed_syscalls.rs index 728d5565f7..93391cf635 100644 --- a/evm/src/cpu/kernel/tests/signed_syscalls.rs +++ b/evm/src/cpu/kernel/tests/signed_syscalls.rs @@ -119,8 +119,8 @@ fn run_test(fn_label: &str, expected_fn: fn(U256, U256) -> U256, opname: &str) { let stack = vec![retdest, y, x]; let mut interpreter = Interpreter::new_with_kernel(fn_label, stack); interpreter.run().unwrap(); - assert_eq!(interpreter.stack().len(), 1usize, "unexpected stack size"); - let output = interpreter.stack()[0]; + assert_eq!(interpreter.stack_len(), 1usize, "unexpected stack size"); + let output = interpreter.stack_top(); let expected_output = expected_fn(x, y); assert_eq!( output, expected_output, diff --git a/evm/src/cpu/memio.rs b/evm/src/cpu/memio.rs index aa3749cab2..f70f3fdb67 100644 --- a/evm/src/cpu/memio.rs +++ b/evm/src/cpu/memio.rs @@ -1,6 +1,7 @@ use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -8,6 +9,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::stack; +use crate::memory::segments::Segment; fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { let addr_context = lv.mem_channels[0].value[0]; @@ -27,18 +29,14 @@ fn eval_packed_load( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); let load_channel = lv.mem_channels[3]; - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; yield_constr.constraint(filter * (load_channel.used - P::ONES)); yield_constr.constraint(filter * (load_channel.is_read - P::ONES)); yield_constr.constraint(filter * (load_channel.addr_context - addr_context)); yield_constr.constraint(filter * (load_channel.addr_segment - addr_segment)); yield_constr.constraint(filter * (load_channel.addr_virtual - addr_virtual)); - for (load_limb, push_limb) in izip!(load_channel.value, push_channel.value) { - yield_constr.constraint(filter * (load_limb - push_limb)); - } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS - 1] { + for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { yield_constr.constraint(filter * channel.used); } @@ -64,7 +62,6 @@ fn eval_ext_circuit_load, const D: usize>( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); let load_channel = lv.mem_channels[3]; - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; { let constr = builder.mul_sub_extension(filter, load_channel.used, filter); yield_constr.constraint(builder, constr); @@ -85,14 +82,9 @@ fn eval_ext_circuit_load, const D: usize>( let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for (load_limb, push_limb) in izip!(load_channel.value, push_channel.value) { - let diff = builder.sub_extension(load_limb, push_limb); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS - 1] { + for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } @@ -113,7 +105,7 @@ fn eval_packed_store( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.m_op_general * (P::ONES - lv.opcode_bits[0]); + let filter = lv.op.m_op_general * (lv.opcode_bits[0] - P::ONES); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -133,14 +125,50 @@ fn eval_packed_store( yield_constr.constraint(filter * channel.used); } - // Stack constraints - stack::eval_packed_one( - lv, - nv, - filter, - stack::MSTORE_GENERAL_OP.unwrap(), - yield_constr, + // Stack constraints. + // Pops. + for i in 1..4 { + let channel = lv.mem_channels[i]; + + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); + + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } + // Constrain `stack_inv_aux`. + let len_diff = lv.stack_len - P::Scalar::from_canonical_usize(4); + yield_constr.constraint( + lv.op.m_op_general + * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), ); + // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = lv.general.stack().stack_inv_aux * (P::ONES - lv.opcode_bits[0]); + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + yield_constr + .constraint(lv.op.m_op_general * (lv.general.stack().stack_inv_aux_2 - is_top_read)); + let new_filter = lv.op.m_op_general * lv.general.stack().stack_inv_aux_2; + yield_constr.constraint_transition(new_filter * (top_read_channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter + * (top_read_channel.addr_segment + - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_virtual - addr_virtual)); + // If stack_len == 4 or MLOAD, disable the channel. + yield_constr.constraint( + lv.op.m_op_general * (lv.general.stack().stack_inv_aux - P::ONES) * top_read_channel.used, + ); + yield_constr.constraint(lv.op.m_op_general * lv.opcode_bits[0] * top_read_channel.used); } fn eval_ext_circuit_store, const D: usize>( @@ -149,10 +177,8 @@ fn eval_ext_circuit_store, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.m_op_general; - let one = builder.one_extension(); - let minus = builder.sub_extension(one, lv.opcode_bits[0]); - filter = builder.mul_extension(filter, minus); + let filter = + builder.mul_sub_extension(lv.op.m_op_general, lv.opcode_bits[0], lv.op.m_op_general); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -191,14 +217,102 @@ fn eval_ext_circuit_store, const D: usize>( } // Stack constraints - stack::eval_ext_circuit_one( - builder, - lv, - nv, - filter, - stack::MSTORE_GENERAL_OP.unwrap(), - yield_constr, - ); + // Pops. + for i in 1..4 { + let channel = lv.mem_channels[i]; + + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.add_const_extension( + channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = + builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(i + 1)); + let diff = builder.sub_extension(channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let len_diff = builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(4)); + let diff = builder.mul_sub_extension( + len_diff, + lv.general.stack().stack_inv, + lv.general.stack().stack_inv_aux, + ); + let constr = builder.mul_extension(lv.op.m_op_general, diff); + yield_constr.constraint(builder, constr); + } + // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); + let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + { + let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); + let constr = builder.mul_extension(lv.op.m_op_general, diff); + yield_constr.constraint(builder, constr); + } + let new_filter = builder.mul_extension(lv.op.m_op_general, lv.general.stack().stack_inv_aux_2); + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(top_read_channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.add_const_extension( + top_read_channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let addr_virtual = builder.add_const_extension(nv.stack_len, -F::ONE); + let diff = builder.sub_extension(top_read_channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + // If stack_len == 4 or MLOAD, disable the channel. + { + let diff = builder.mul_sub_extension( + lv.op.m_op_general, + lv.general.stack().stack_inv_aux, + lv.op.m_op_general, + ); + let constr = builder.mul_extension(diff, top_read_channel.used); + yield_constr.constraint(builder, constr); + } + { + let mul = builder.mul_extension(lv.op.m_op_general, lv.opcode_bits[0]); + let constr = builder.mul_extension(mul, top_read_channel.used); + yield_constr.constraint(builder, constr); + } } pub fn eval_packed( diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index b7312147b4..0885f644bb 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -16,6 +16,6 @@ mod pc; mod push0; mod shift; pub(crate) mod simple_logic; -mod stack; +pub(crate) mod stack; pub(crate) mod stack_bounds; mod syscalls_exceptions; diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs index 86f08052ef..eed497f5d3 100644 --- a/evm/src/cpu/modfp254.rs +++ b/evm/src/cpu/modfp254.rs @@ -22,7 +22,7 @@ pub fn eval_packed( let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read - // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // the modulus from the stack. We simply constrain `mem_channels[1]` to be our prime (that's // where the modulus goes in the generalized operations). let channel_val = lv.mem_channels[2].value; for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { @@ -39,7 +39,7 @@ pub fn eval_ext_circuit, const D: usize>( let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read - // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // the modulus from the stack. We simply constrain `mem_channels[1]` to be our prime (that's // where the modulus goes in the generalized operations). let channel_val = lv.mem_channels[2].value; for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { diff --git a/evm/src/cpu/pc.rs b/evm/src/cpu/pc.rs index 26731c92c2..5271ad81aa 100644 --- a/evm/src/cpu/pc.rs +++ b/evm/src/cpu/pc.rs @@ -5,16 +5,16 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::membus::NUM_GP_CHANNELS; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.pc; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - yield_constr.constraint(filter * (push_value[0] - lv.program_counter)); - for &limb in &push_value[1..] { + let new_stack_top = nv.mem_channels[0].value; + yield_constr.constraint(filter * (new_stack_top[0] - lv.program_counter)); + for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } } @@ -22,16 +22,17 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = lv.op.pc; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + let new_stack_top = nv.mem_channels[0].value; { - let diff = builder.sub_extension(push_value[0], lv.program_counter); + let diff = builder.sub_extension(new_stack_top[0], lv.program_counter); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for &limb in &push_value[1..] { + for &limb in &new_stack_top[1..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/push0.rs b/evm/src/cpu/push0.rs index 30f6d0ae0f..d49446cc23 100644 --- a/evm/src/cpu/push0.rs +++ b/evm/src/cpu/push0.rs @@ -5,15 +5,14 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::membus::NUM_GP_CHANNELS; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.push0; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - for limb in push_value { + for limb in nv.mem_channels[0].value { yield_constr.constraint(filter * limb); } } @@ -21,11 +20,11 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = lv.op.push0; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - for limb in push_value { + for limb in nv.mem_channels[0].value { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 28abf077cb..31d0405cea 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -1,3 +1,5 @@ +use std::cmp::max; + use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -13,46 +15,41 @@ use crate::memory::segments::Segment; #[derive(Clone, Copy)] pub(crate) struct StackBehavior { - num_pops: usize, - pushes: bool, + pub(crate) num_pops: usize, + pub(crate) pushes: bool, + new_top_stack_channel: Option, disable_other_channels: bool, } -const BASIC_UNARY_OP: Option = Some(StackBehavior { - num_pops: 1, - pushes: true, - disable_other_channels: true, -}); const BASIC_BINARY_OP: Option = Some(StackBehavior { num_pops: 2, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); const BASIC_TERNARY_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); pub(crate) const JUMP_OP: Option = Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); pub(crate) const JUMPI_OP: Option = Some(StackBehavior { num_pops: 2, pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); pub(crate) const MLOAD_GENERAL_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, - disable_other_channels: false, -}); - -pub(crate) const MSTORE_GENERAL_OP: Option = Some(StackBehavior { - num_pops: 4, - pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); @@ -61,79 +58,111 @@ pub(crate) const MSTORE_GENERAL_OP: Option = Some(StackBehavior { // propertly constrained. The same applies when `disable_other_channels` is set to `false`, // except the first `num_pops` and the last `pushes as usize` channels have their read flag and // address constrained automatically in this file. -const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { +pub(crate) const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { binary_op: BASIC_BINARY_OP, ternary_op: BASIC_TERNARY_OP, fp254_op: BASIC_BINARY_OP, eq_iszero: None, // EQ is binary, IS_ZERO is unary. logic_op: BASIC_BINARY_OP, - not: BASIC_UNARY_OP, + not: Some(StackBehavior { + num_pops: 1, + pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), + disable_other_channels: true, + }), shift: Some(StackBehavior { num_pops: 2, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: false, }), keccak_general: Some(StackBehavior { num_pops: 4, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }), prover_input: None, // TODO pop: Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), jumps: None, // Depends on whether it's a JUMP or a JUMPI. pc: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: true, }), jumpdest: Some(StackBehavior { num_pops: 0, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), push0: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: true, }), push: None, // TODO dup: None, swap: None, - context_op: None, // SET_CONTEXT is special since it involves the old and the new stack. - mstore_32bytes: Some(StackBehavior { - num_pops: 5, - pushes: false, - disable_other_channels: false, + get_context: Some(StackBehavior { + num_pops: 0, + pushes: true, + new_top_stack_channel: None, + disable_other_channels: true, }), + set_context: None, // SET_CONTEXT is special since it involves the old and the new stack. mload_32bytes: Some(StackBehavior { num_pops: 4, pushes: true, + new_top_stack_channel: Some(4), + disable_other_channels: false, + }), + mstore_32bytes: Some(StackBehavior { + num_pops: 5, + pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }), exit_kernel: Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), m_op_general: None, syscall: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: false, }), exception: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: false, }), }; -pub(crate) const EQ_STACK_BEHAVIOR: Option = BASIC_BINARY_OP; -pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = BASIC_UNARY_OP; +pub(crate) const EQ_STACK_BEHAVIOR: Option = Some(StackBehavior { + num_pops: 2, + pushes: true, + new_top_stack_channel: Some(2), + disable_other_channels: true, +}); +pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = Some(StackBehavior { + num_pops: 1, + pushes: true, + new_top_stack_channel: Some(2), + disable_other_channels: true, +}); pub(crate) fn eval_packed_one( lv: &CpuColumnsView

, @@ -142,43 +171,109 @@ pub(crate) fn eval_packed_one( stack_behavior: StackBehavior, yield_constr: &mut ConstraintConsumer

, ) { - let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); - assert!(num_operands <= NUM_GP_CHANNELS); + // If you have pops. + if stack_behavior.num_pops > 0 { + for i in 1..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; - // Pops - for i in 0..stack_behavior.num_pops { - let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * (channel.is_read - P::ONES)); + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } - yield_constr.constraint(filter * (channel.addr_context - lv.context)); - yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - // E.g. if `stack_len == 1` and `i == 0`, we want `add_virtual == 0`. - let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + // If you also push, you don't need to read the new top of the stack. + // If you don't: + // - if the stack isn't empty after the pops, you read the new top from an extra pop. + // - if not, the extra read is disabled. + // These are transition constraints: they don't apply to the last row. + if !stack_behavior.pushes { + // If stack_len != N... + let len_diff = lv.stack_len - P::Scalar::from_canonical_usize(stack_behavior.num_pops); + let new_filter = len_diff * filter; + // Read an extra element. + let channel = nv.mem_channels[0]; + yield_constr.constraint_transition(new_filter * (channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (channel.addr_virtual - addr_virtual)); + // Constrain `stack_inv_aux`. + yield_constr.constraint( + filter + * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Disable channel if stack_len == N. + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint_transition(empty_stack_filter * channel.used); + } } - - // Pushes - if stack_behavior.pushes { + // If the op only pushes, you only need to constrain the top of the stack if the stack isn't empty. + else if stack_behavior.pushes { + // If len > 0... + let new_filter = lv.stack_len * filter; + // You write the previous top of the stack in memory, in the last channel. let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * channel.is_read); - - yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint(new_filter * (channel.used - P::ONES)); + yield_constr.constraint(new_filter * channel.is_read); + yield_constr.constraint(new_filter * (channel.addr_context - lv.context)); yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + new_filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), ); - let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(stack_behavior.num_pops); - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + let addr_virtual = lv.stack_len - P::ONES; + yield_constr.constraint(new_filter * (channel.addr_virtual - addr_virtual)); + for (limb_ch, limb_top) in channel.value.iter().zip(lv.mem_channels[0].value.iter()) { + yield_constr.constraint(new_filter * (*limb_ch - *limb_top)); + } + // Else you disable the channel. + yield_constr.constraint( + filter + * (lv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint(empty_stack_filter * channel.used); + } + // If the op doesn't pop nor push, the top of the stack must not change. + else { + yield_constr.constraint(filter * nv.mem_channels[0].used); + for (limb_old, limb_new) in lv.mem_channels[0] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + yield_constr.constraint(filter * (*limb_old - *limb_new)); + } + } + + // Maybe constrain next stack_top. + // These are transition constraints: they don't apply to the last row. + if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { + for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + yield_constr.constraint_transition(filter * (*limb_ch - *limb_top)); + } } // Unused channels if stack_behavior.disable_other_channels { - for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + // The first channel contains (or not) the top od the stack and is constrained elsewhere. + for i in max(1, stack_behavior.num_pops)..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) + { let channel = lv.mem_channels[i]; yield_constr.constraint(filter * channel.used); } @@ -210,94 +305,199 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> stack_behavior: StackBehavior, yield_constr: &mut RecursiveConstraintConsumer, ) { - let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); - assert!(num_operands <= NUM_GP_CHANNELS); + // If you have pops. + if stack_behavior.num_pops > 0 { + for i in 1..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; + + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + filter, + channel.addr_segment, + filter, + ); + yield_constr.constraint(builder, constr); + } + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + { + let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); + let constr = builder.arithmetic_extension( + F::ONE, + F::from_canonical_usize(i + 1), + filter, + diff, + filter, + ); + yield_constr.constraint(builder, constr); + } + } - // Pops - for i in 0..stack_behavior.num_pops { - let channel = lv.mem_channels[i]; + // If you also push, you don't need to read the new top of the stack. + // If you don't: + // - if the stack isn't empty after the pops, you read the new top from an extra pop. + // - if not, the extra read is disabled. + // These are transition constraints: they don't apply to the last row. + if !stack_behavior.pushes { + // If stack_len != N... + let target_num_pops = + builder.constant_extension(F::from_canonical_usize(stack_behavior.num_pops).into()); + let len_diff = builder.sub_extension(lv.stack_len, target_num_pops); + let new_filter = builder.mul_extension(filter, len_diff); + // Read an extra element. + let channel = nv.mem_channels[0]; + { + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + new_filter, + channel.addr_segment, + new_filter, + ); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, nv.stack_len); + let constr = + builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); + yield_constr.constraint_transition(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let prod = builder.mul_extension(len_diff, lv.general.stack().stack_inv); + let diff = builder.sub_extension(prod, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Disable channel if stack_len == N. + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); + yield_constr.constraint_transition(builder, constr); + } + } + } + // If the op only pushes, you only need to constrain the top of the stack if the stack isn't empty. + else if stack_behavior.pushes { + // If len > 0... + let new_filter = builder.mul_extension(lv.stack_len, filter); + // You write the previous top of the stack in memory, in the last channel. + let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; { - let constr = builder.mul_sub_extension(filter, channel.used, filter); + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); yield_constr.constraint(builder, constr); } { - let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + let constr = builder.mul_extension(new_filter, channel.is_read); yield_constr.constraint(builder, constr); } { let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(new_filter, diff); yield_constr.constraint(builder, constr); } { let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_u64(Segment::Stack as u64), - filter, + new_filter, channel.addr_segment, - filter, + new_filter, ); yield_constr.constraint(builder, constr); } { let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension( - F::ONE, - F::from_canonical_usize(i + 1), - filter, - diff, - filter, - ); + let constr = builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); yield_constr.constraint(builder, constr); } - } - - // Pushes - if stack_behavior.pushes { - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - - { - let constr = builder.mul_sub_extension(filter, channel.used, filter); + for (limb_ch, limb_top) in channel.value.iter().zip(lv.mem_channels[0].value.iter()) { + let diff = builder.sub_extension(*limb_ch, *limb_top); + let constr = builder.mul_extension(new_filter, diff); yield_constr.constraint(builder, constr); } + // Else you disable the channel. { - let constr = builder.mul_extension(filter, channel.is_read); + let diff = builder.mul_extension(lv.stack_len, lv.general.stack().stack_inv); + let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - { - let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); yield_constr.constraint(builder, constr); } + } + // If the op doesn't pop nor push, the top of the stack must not change. + else { { - let constr = builder.arithmetic_extension( - F::ONE, - -F::from_canonical_u64(Segment::Stack as u64), - filter, - channel.addr_segment, - filter, - ); + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); yield_constr.constraint(builder, constr); } { - let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension( - F::ONE, - F::from_canonical_usize(stack_behavior.num_pops), - filter, - diff, - filter, - ); - yield_constr.constraint(builder, constr); + for (limb_old, limb_new) in lv.mem_channels[0] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + let diff = builder.sub_extension(*limb_old, *limb_new); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + } + } + + // Maybe constrain next stack_top. + // These are transition constraints: they don't apply to the last row. + if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { + for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + let diff = builder.sub_extension(*limb_ch, *limb_top); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint_transition(builder, constr); } } // Unused channels if stack_behavior.disable_other_channels { - for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + // The first channel contains (or not) the top od the stack and is constrained elsewhere. + for i in max(1, stack_behavior.num_pops)..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) + { let channel = lv.mem_channels[i]; let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); diff --git a/evm/src/cpu/syscalls_exceptions.rs b/evm/src/cpu/syscalls_exceptions.rs index f9ea9a0a9f..1437fba02b 100644 --- a/evm/src/cpu/syscalls_exceptions.rs +++ b/evm/src/cpu/syscalls_exceptions.rs @@ -64,7 +64,7 @@ pub fn eval_packed( let exc_handler_addr_start = exc_jumptable_start + exc_code * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { yield_constr.constraint(total_filter * (channel.used - P::ONES)); yield_constr.constraint(total_filter * (channel.is_read - P::ONES)); @@ -81,13 +81,13 @@ pub fn eval_packed( } // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + for channel in &lv.mem_channels[BYTES_PER_OFFSET + 1..NUM_GP_CHANNELS - 1] { yield_constr.constraint(total_filter * channel.used); } // Set program counter to the handler address // The addresses are big-endian in memory - let target = lv.mem_channels[0..BYTES_PER_OFFSET] + let target = lv.mem_channels[1..BYTES_PER_OFFSET + 1] .iter() .map(|channel| channel.value[0]) .fold(P::ZEROS, |cumul, limb| { @@ -102,9 +102,8 @@ pub fn eval_packed( yield_constr.constraint_transition(total_filter * nv.gas[0]); yield_constr.constraint_transition(total_filter * nv.gas[1]); - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). + let output = nv.mem_channels[0].value; + // New top of the stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). yield_constr.constraint(filter_syscall * (output[0] - (lv.program_counter + P::ONES))); yield_constr.constraint(filter_exception * (output[0] - lv.program_counter)); // Check the kernel mode, for syscalls only @@ -182,7 +181,7 @@ pub fn eval_ext_circuit, const D: usize>( exc_jumptable_start, ); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { { let constr = builder.mul_sub_extension(total_filter, channel.used, total_filter); yield_constr.constraint(builder, constr); @@ -235,7 +234,7 @@ pub fn eval_ext_circuit, const D: usize>( } // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + for channel in &lv.mem_channels[BYTES_PER_OFFSET + 1..NUM_GP_CHANNELS - 1] { let constr = builder.mul_extension(total_filter, channel.used); yield_constr.constraint(builder, constr); } @@ -243,7 +242,7 @@ pub fn eval_ext_circuit, const D: usize>( // Set program counter to the handler address // The addresses are big-endian in memory { - let target = lv.mem_channels[0..BYTES_PER_OFFSET] + let target = lv.mem_channels[1..BYTES_PER_OFFSET + 1] .iter() .map(|channel| channel.value[0]) .fold(builder.zero_extension(), |cumul, limb| { @@ -272,8 +271,8 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint_transition(builder, constr); } - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + // New top of the stack. + let output = nv.mem_channels[0].value; // Push to stack (syscall): current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). { let pc_plus_1 = builder.add_const_extension(lv.program_counter, F::ONE); diff --git a/evm/src/witness/memory.rs b/evm/src/witness/memory.rs index 3b62c94553..5d589934a0 100644 --- a/evm/src/witness/memory.rs +++ b/evm/src/witness/memory.rs @@ -88,6 +88,18 @@ pub struct MemoryOp { pub value: U256, } +pub static DUMMY_MEMOP: MemoryOp = MemoryOp { + filter: false, + timestamp: 0, + address: MemoryAddress { + context: 0, + segment: 0, + virt: 0, + }, + kind: MemoryOpKind::Read, + value: U256::zero(), +}; + impl MemoryOp { pub fn new( channel: MemoryChannel, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index f4dc03e806..a503ab496c 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use keccak_hash::keccak; use plonky2::field::types::Field; -use super::util::{byte_packing_log, byte_unpacking_log}; +use super::util::{byte_packing_log, byte_unpacking_log, push_no_write, push_with_write}; use crate::arithmetic::BinaryOperator; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; @@ -20,9 +20,10 @@ use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, Virt use crate::witness::errors::ProgramError; use crate::witness::errors::ProgramError::MemoryError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; +use crate::witness::operation::MemoryChannel::GeneralPurpose; use crate::witness::util::{ keccak_sponge_log, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill, - stack_pop_with_log_and_fill, stack_push_log_and_fill, + stack_pop_with_log_and_fill, }; use crate::{arithmetic, logic}; @@ -59,14 +60,13 @@ pub(crate) fn generate_binary_logic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(in0, _), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = logic::Operation::new(op, in0, in1); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result)?; + + push_no_write(state, &mut row, operation.result, Some(NUM_GP_CHANNELS - 1)); state.traces.push_logic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -76,10 +76,8 @@ pub(crate) fn generate_binary_arithmetic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = arithmetic::Operation::binary(operator, input0, input1); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; if operator == arithmetic::BinaryOperator::AddFp254 || operator == arithmetic::BinaryOperator::MulFp254 @@ -94,10 +92,15 @@ pub(crate) fn generate_binary_arithmetic_op( } } + push_no_write( + state, + &mut row, + operation.result(), + Some(NUM_GP_CHANNELS - 1), + ); + state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -107,16 +110,20 @@ pub(crate) fn generate_ternary_arithmetic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1), (input2, log_in2)] = + let [(input0, _), (input1, log_in1), (input2, log_in2)] = stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; let operation = arithmetic::Operation::ternary(operator, input0, input1, input2); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; + + push_no_write( + state, + &mut row, + operation.result(), + Some(NUM_GP_CHANNELS - 1), + ); state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -126,7 +133,7 @@ pub(crate) fn generate_keccak_general( mut row: CpuColumnsView, ) -> Result<(), ProgramError> { row.is_keccak_sponge = F::ONE; - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -144,15 +151,13 @@ pub(crate) fn generate_keccak_general( log::debug!("Hashing {:?}", input); let hash = keccak(&input); - let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?; + push_no_write(state, &mut row, hash.into_uint(), Some(NUM_GP_CHANNELS - 1)); keccak_sponge_log(state, base_address, input); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); - state.traces.push_memory(log_push); state.traces.push_cpu(row); Ok(()) } @@ -164,9 +169,7 @@ pub(crate) fn generate_prover_input( let pc = state.registers.program_counter; let input_fn = &KERNEL.prover_inputs[&pc]; let input = state.prover_input(input_fn)?; - let write = stack_push_log_and_fill(state, &mut row, input)?; - - state.traces.push_memory(write); + push_with_write(state, &mut row, input)?; state.traces.push_cpu(row); Ok(()) } @@ -175,10 +178,10 @@ pub(crate) fn generate_pop( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(_, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(_, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; - state.traces.push_memory(log_in); state.traces.push_cpu(row); + Ok(()) } @@ -186,7 +189,8 @@ pub(crate) fn generate_jump( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(dst, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let dst: u32 = dst .try_into() .map_err(|_| ProgramError::InvalidJumpDestination)?; @@ -216,7 +220,15 @@ pub(crate) fn generate_jump( row.general.jumps_mut().should_jump = F::ONE; row.general.jumps_mut().cond_sum_pinv = F::ONE; - state.traces.push_memory(log_in0); + let diff = row.stack_len - F::ONE; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.traces.push_cpu(row); state.jump_to(dst as usize)?; Ok(()) @@ -226,7 +238,7 @@ pub(crate) fn generate_jumpi( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(dst, _), (cond, log_cond)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let should_jump = !cond.is_zero(); if should_jump { @@ -271,8 +283,16 @@ pub(crate) fn generate_jumpi( state.traces.push_memory(jumpdest_bit_log); } - state.traces.push_memory(log_in0); - state.traces.push_memory(log_in1); + let diff = row.stack_len - F::TWO; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + state.traces.push_memory(log_cond); state.traces.push_cpu(row); Ok(()) } @@ -281,8 +301,7 @@ pub(crate) fn generate_pc( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let write = stack_push_log_and_fill(state, &mut row, state.registers.program_counter.into())?; - state.traces.push_memory(write); + push_with_write(state, &mut row, state.registers.program_counter.into())?; state.traces.push_cpu(row); Ok(()) } @@ -299,9 +318,7 @@ pub(crate) fn generate_get_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let ctx = state.registers.context.into(); - let write = stack_push_log_and_fill(state, &mut row, ctx)?; - state.traces.push_memory(write); + push_with_write(state, &mut row, state.registers.context.into())?; state.traces.push_cpu(row); Ok(()) } @@ -310,8 +327,10 @@ pub(crate) fn generate_set_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(ctx, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(ctx, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let sp_to_save = state.registers.stack_len.into(); + let old_ctx = state.registers.context; let new_ctx = u256_to_usize(ctx)?; @@ -347,10 +366,31 @@ pub(crate) fn generate_set_context( mem_read_gp_with_log_and_fill(2, new_sp_addr, state, &mut row) }; + // If the new stack isn't empty, read stack_top from memory. + let new_sp = new_sp.as_usize(); + if new_sp > 0 { + // Set up columns to disable the channel if it *is* empty. + let new_sp_field = F::from_canonical_usize(new_sp); + if let Some(inv) = new_sp_field.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + let new_top_addr = MemoryAddress::new(new_ctx, Segment::Stack, new_sp - 1); + let (new_top, log_read_new_top) = + mem_read_gp_with_log_and_fill(3, new_top_addr, state, &mut row); + state.registers.stack_top = new_top; + state.traces.push_memory(log_read_new_top); + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.registers.context = new_ctx; - let new_sp = u256_to_usize(new_sp)?; state.registers.stack_len = new_sp; - state.traces.push_memory(log_in); state.traces.push_memory(log_write_old_sp); state.traces.push_memory(log_read_new_sp); state.traces.push_cpu(row); @@ -386,31 +426,76 @@ pub(crate) fn generate_push( .collect_vec(); let val = U256::from_big_endian(&bytes); - let write = stack_push_log_and_fill(state, &mut row, val)?; - - state.traces.push_memory(write); + push_with_write(state, &mut row, val)?; state.traces.push_cpu(row); Ok(()) } +// This instruction is special. The order of the operations are: +// - Write `stack_top` at `stack[stack_len - 1]` +// - Read `val` at `stack[stack_len - 1 - n]` +// - Update `stack_top` with `val` and add 1 to `stack_len` +// Since the write must happen before the read, the normal way of assigning +// GP channels doesn't work and we must handle them manually. pub(crate) fn generate_dup( n: u8, state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let other_addr_lo = state - .registers - .stack_len - .checked_sub(1 + (n as usize)) - .ok_or(ProgramError::StackUnderflow)?; - let other_addr = MemoryAddress::new(state.registers.context, Segment::Stack, other_addr_lo); + // Same logic as in `push_with_write`, but we use the channel GP(0) instead. + if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { + return Err(ProgramError::StackOverflow); + } + if n as usize >= state.registers.stack_len { + return Err(ProgramError::StackUnderflow); + } + let stack_top = state.registers.stack_top; + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let log_push = mem_write_gp_log_and_fill(1, address, state, &mut row, stack_top); + state.traces.push_memory(log_push); + + let other_addr = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - n as usize, + ); + + // If n = 0, we read a value that hasn't been written to memory: the corresponding write + // is buffered in the mem_ops queue, but hasn't been applied yet. + let (val, log_read) = if n == 0 { + let op = MemoryOp::new( + MemoryChannel::GeneralPurpose(2), + state.traces.clock(), + other_addr, + MemoryOpKind::Read, + stack_top, + ); + + let channel = &mut row.mem_channels[2]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(other_addr.context); + channel.addr_segment = F::from_canonical_usize(other_addr.segment); + channel.addr_virtual = F::from_canonical_usize(other_addr.virt); + let val_limbs: [u64; 4] = state.registers.stack_top.0; + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } - let (val, log_in) = mem_read_gp_with_log_and_fill(0, other_addr, state, &mut row); - let log_out = stack_push_log_and_fill(state, &mut row, val)?; + (stack_top, op) + } else { + mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row) + }; + push_no_write(state, &mut row, val, None); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); + state.traces.push_memory(log_read); state.traces.push_cpu(row); Ok(()) } @@ -427,15 +512,13 @@ pub(crate) fn generate_swap( .ok_or(ProgramError::StackUnderflow)?; let other_addr = MemoryAddress::new(state.registers.context, Segment::Stack, other_addr_lo); - let [(in0, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(in0, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); - let log_out0 = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, state, &mut row, in0); - let log_out1 = stack_push_log_and_fill(state, &mut row, in1)?; + let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); + push_no_write(state, &mut row, in1, None); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_out0); - state.traces.push_memory(log_out1); state.traces.push_cpu(row); Ok(()) } @@ -444,12 +527,10 @@ pub(crate) fn generate_not( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(x, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let result = !x; - let log_out = stack_push_log_and_fill(state, &mut row, result)?; + push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -458,18 +539,16 @@ pub(crate) fn generate_iszero( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(x, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let is_zero = x.is_zero(); let result = { let t: u64 = is_zero.into(); t.into() }; - let log_out = stack_push_log_and_fill(state, &mut row, result)?; generate_pinv_diff(x, U256::zero(), &mut row); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); + push_no_write(state, &mut row, result, None); state.traces.push_cpu(row); Ok(()) } @@ -480,12 +559,9 @@ fn append_shift( is_shl: bool, input0: U256, input1: U256, - log_in0: MemoryOp, log_in1: MemoryOp, result: U256, ) -> Result<(), ProgramError> { - let log_out = stack_push_log_and_fill(state, &mut row, result)?; - const LOOKUP_CHANNEL: usize = 2; let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); if input0.bits() <= 32 { @@ -511,9 +587,8 @@ fn append_shift( let operation = arithmetic::Operation::binary(operator, input0, input1); state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); + push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -522,30 +597,28 @@ pub(crate) fn generate_shl( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let result = if input0 > U256::from(255u64) { U256::zero() } else { input1 << input0 }; - append_shift(state, row, true, input0, input1, log_in0, log_in1, result) + append_shift(state, row, true, input0, input1, log_in1, result) } pub(crate) fn generate_shr( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let result = if input0 > U256::from(255u64) { U256::zero() } else { input1 >> input0 }; - append_shift(state, row, false, input0, input1, log_in0, log_in1, result) + append_shift(state, row, false, input0, input1, log_in1, result) } pub(crate) fn generate_syscall( @@ -574,19 +647,19 @@ pub(crate) fn generate_syscall( handler_jumptable_addr + (opcode as usize) * (BYTES_PER_OFFSET as usize); assert_eq!(BYTES_PER_OFFSET, 3, "Code below assumes 3 bytes per offset"); let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill( - 0, + 1, MemoryAddress::new(0, Segment::Code, handler_addr_addr), state, &mut row, ); let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill( - 1, + 2, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1), state, &mut row, ); let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill( - 2, + 3, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2), state, &mut row, @@ -606,14 +679,13 @@ pub(crate) fn generate_syscall( state.registers.is_kernel = true; state.registers.gas_used = 0; - let log_out = stack_push_log_and_fill(state, &mut row, syscall_info)?; + push_with_write(state, &mut row, syscall_info)?; log::debug!("Syscall to {}", KERNEL.offset_name(new_program_counter)); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) @@ -623,16 +695,14 @@ pub(crate) fn generate_eq( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(in0, _), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let eq = in0 == in1; let result = U256::from(u64::from(eq)); - let log_out = stack_push_log_and_fill(state, &mut row, result)?; generate_pinv_diff(in0, in1, &mut row); - state.traces.push_memory(log_in0); + push_no_write(state, &mut row, result, None); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -641,7 +711,7 @@ pub(crate) fn generate_exit_kernel( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(kexit_info, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(kexit_info, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let kexit_info_u64 = kexit_info.0[0]; let program_counter = kexit_info_u64 as u32 as usize; let is_kernel_mode_val = (kexit_info_u64 >> 32) as u32; @@ -661,7 +731,6 @@ pub(crate) fn generate_exit_kernel( is_kernel_mode ); - state.traces.push_memory(log_in); state.traces.push_cpu(row); Ok(()) @@ -671,7 +740,7 @@ pub(crate) fn generate_mload_general( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (virt, log_in2)] = + let [(context, _), (segment, log_in1), (virt, log_in2)] = stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; let (val, log_read) = mem_read_gp_with_log_and_fill( @@ -680,14 +749,20 @@ pub(crate) fn generate_mload_general( state, &mut row, ); + push_no_write(state, &mut row, val, None); - let log_out = stack_push_log_and_fill(state, &mut row, val)?; + let diff = row.stack_len - F::from_canonical_usize(4); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_read); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -696,7 +771,7 @@ pub(crate) fn generate_mload_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let len = u256_to_usize(len)?; if len > 32 { @@ -722,15 +797,13 @@ pub(crate) fn generate_mload_32bytes( .collect_vec(); let packed_int = U256::from_big_endian(&bytes); - let log_out = stack_push_log_and_fill(state, &mut row, packed_int)?; + push_no_write(state, &mut row, packed_int, Some(4)); byte_packing_log(state, base_address, bytes); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -739,7 +812,7 @@ pub(crate) fn generate_mstore_general( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (virt, log_in2), (val, log_in3)] = + let [(context, _), (segment, log_in1), (virt, log_in2), (val, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let address = MemoryAddress { @@ -755,12 +828,23 @@ pub(crate) fn generate_mstore_general( }; let log_write = mem_write_gp_log_and_fill(4, address, state, &mut row, val); - state.traces.push_memory(log_in0); + let diff = row.stack_len - F::from_canonical_usize(4); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + row.general.stack_mut().stack_inv_aux_2 = F::ONE; + state.registers.is_stack_top_read = true; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); state.traces.push_memory(log_write); state.traces.push_cpu(row); + Ok(()) } @@ -768,7 +852,7 @@ pub(crate) fn generate_mstore_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -776,7 +860,6 @@ pub(crate) fn generate_mstore_32bytes( byte_unpacking_log(state, base_address, val, len); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); @@ -805,6 +888,36 @@ pub(crate) fn generate_exception( return Err(ProgramError::InterpreterError); } + if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + if state.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(state.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack as usize); + channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); + + let address = MemoryAddress { + context: state.registers.context, + segment: Segment::Stack as usize, + virt: state.registers.stack_len - 1, + }; + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + state.traces.clock(), + address, + MemoryOpKind::Read, + state.registers.stack_top, + ); + state.traces.push_memory(mem_op); + state.registers.is_stack_top_read = false; + } + row.general.exception_mut().exc_code_bits = [ F::from_bool(exc_code & 1 != 0), F::from_bool(exc_code & 2 != 0), @@ -816,19 +929,19 @@ pub(crate) fn generate_exception( handler_jumptable_addr + (exc_code as usize) * (BYTES_PER_OFFSET as usize); assert_eq!(BYTES_PER_OFFSET, 3, "Code below assumes 3 bytes per offset"); let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill( - 0, + 1, MemoryAddress::new(0, Segment::Code, handler_addr_addr), state, &mut row, ); let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill( - 1, + 2, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1), state, &mut row, ); let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill( - 2, + 3, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2), state, &mut row, @@ -847,14 +960,13 @@ pub(crate) fn generate_exception( state.registers.is_kernel = true; state.registers.gas_used = 0; - let log_out = stack_push_log_and_fill(state, &mut row, exc_info)?; + push_with_write(state, &mut row, exc_info)?; log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter)); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) diff --git a/evm/src/witness/state.rs b/evm/src/witness/state.rs index 3b37b01e02..406ae8567f 100644 --- a/evm/src/witness/state.rs +++ b/evm/src/witness/state.rs @@ -1,3 +1,5 @@ +use ethereum_types::U256; + use crate::cpu::kernel::aggregator::KERNEL; const KERNEL_CONTEXT: usize = 0; @@ -7,6 +9,9 @@ pub struct RegistersState { pub program_counter: usize, pub is_kernel: bool, pub stack_len: usize, + pub stack_top: U256, + // Indicates if you read the new stack_top from memory to set the channel accordingly. + pub is_stack_top_read: bool, pub context: usize, pub gas_used: u64, } @@ -27,6 +32,8 @@ impl Default for RegistersState { program_counter: KERNEL.global_labels["main"], is_kernel: true, stack_len: 0, + stack_top: U256::zero(), + is_stack_top_read: false, context: 0, gas_used: 0, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 2a710f4b94..00030110dd 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -2,14 +2,20 @@ use anyhow::bail; use log::log_enabled; use plonky2::field::types::Field; +use super::memory::{MemoryOp, MemoryOpKind}; +use super::util::fill_channel_with_value; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::stack::{ + EQ_STACK_BEHAVIOR, IS_ZERO_STACK_BEHAVIOR, JUMPI_OP, JUMP_OP, STACK_BEHAVIORS, +}; use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::gas::gas_to_charge; use crate::witness::memory::MemoryAddress; +use crate::witness::memory::MemoryChannel::GeneralPurpose; use crate::witness::operation::*; use crate::witness::state::RegistersState; use crate::witness::util::mem_read_code_with_log_and_fill; @@ -175,7 +181,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Jump | Operation::Jumpi => &mut flags.jumps, Operation::Pc => &mut flags.pc, Operation::Jumpdest => &mut flags.jumpdest, - Operation::GetContext | Operation::SetContext => &mut flags.context_op, + Operation::GetContext => &mut flags.get_context, + Operation::SetContext => &mut flags.set_context, Operation::Mload32Bytes => &mut flags.mload_32bytes, Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, @@ -183,6 +190,52 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { } = F::ONE; } +// Equal to the number of pops if an operation pops without pushing, and `None` otherwise. +fn get_op_special_length(op: Operation) -> Option { + let behavior_opt = match op { + Operation::Push(0) => STACK_BEHAVIORS.push0, + Operation::Push(1..) => STACK_BEHAVIORS.push, + Operation::Dup(_) => STACK_BEHAVIORS.dup, + Operation::Swap(_) => STACK_BEHAVIORS.swap, + Operation::Iszero => IS_ZERO_STACK_BEHAVIOR, + Operation::Not => STACK_BEHAVIORS.not, + Operation::Syscall(_, _, _) => STACK_BEHAVIORS.syscall, + Operation::Eq => EQ_STACK_BEHAVIOR, + Operation::BinaryLogic(_) => STACK_BEHAVIORS.logic_op, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => { + STACK_BEHAVIORS.fp254_op + } + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => STACK_BEHAVIORS.shift, + Operation::BinaryArithmetic(_) => STACK_BEHAVIORS.binary_op, + Operation::TernaryArithmetic(_) => STACK_BEHAVIORS.ternary_op, + Operation::KeccakGeneral => STACK_BEHAVIORS.keccak_general, + Operation::ProverInput => STACK_BEHAVIORS.prover_input, + Operation::Pop => STACK_BEHAVIORS.pop, + Operation::Jump => JUMP_OP, + Operation::Jumpi => JUMPI_OP, + Operation::Pc => STACK_BEHAVIORS.pc, + Operation::Jumpdest => STACK_BEHAVIORS.jumpdest, + Operation::GetContext => STACK_BEHAVIORS.get_context, + Operation::SetContext => None, + Operation::Mload32Bytes => STACK_BEHAVIORS.mload_32bytes, + Operation::Mstore32Bytes => STACK_BEHAVIORS.mstore_32bytes, + Operation::ExitKernel => STACK_BEHAVIORS.exit_kernel, + Operation::MloadGeneral | Operation::MstoreGeneral => STACK_BEHAVIORS.m_op_general, + }; + if let Some(behavior) = behavior_opt { + if behavior.num_pops > 0 && !behavior.pushes { + Some(behavior.num_pops) + } else { + None + } + } else { + None + } +} + fn perform_op( state: &mut GenerationState, op: Operation, @@ -247,6 +300,7 @@ fn base_row(state: &mut GenerationState) -> (CpuColumnsView, u8) F::from_canonical_u32((state.registers.gas_used >> 32) as u32), ]; row.stack_len = F::from_canonical_usize(state.registers.stack_len); + fill_channel_with_value(&mut row, 0, state.registers.stack_top); let opcode = read_code_memory(state, &mut row); (row, opcode) @@ -264,6 +318,31 @@ fn try_perform_instruction(state: &mut GenerationState) -> Result<( fill_op_flag(op, &mut row); + if state.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(state.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack as usize); + channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); + + let address = MemoryAddress { + context: state.registers.context, + segment: Segment::Stack as usize, + virt: state.registers.stack_len - 1, + }; + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + state.traces.clock(), + address, + MemoryOpKind::Read, + state.registers.stack_top, + ); + state.traces.push_memory(mem_op); + state.registers.is_stack_top_read = false; + } + if state.registers.is_kernel { row.stack_len_bounds_aux = F::ZERO; } else { @@ -277,6 +356,21 @@ fn try_perform_instruction(state: &mut GenerationState) -> Result<( } } + // Might write in general CPU columns when it shouldn't, but the correct values will + // overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + state.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + perform_op(state, op, row) } diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index dbe4c0ede5..249703614b 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -1,6 +1,7 @@ use ethereum_types::U256; use plonky2::field::types::Field; +use super::memory::DUMMY_MEMOP; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::keccak_util::keccakf_u8s; @@ -36,6 +37,10 @@ pub(crate) fn stack_peek( if i >= state.registers.stack_len { return Err(ProgramError::StackUnderflow); } + if i == 0 { + return Ok(state.registers.stack_top); + } + Ok(state.memory.get(MemoryAddress::new( state.registers.context, Segment::Stack, @@ -53,6 +58,77 @@ pub(crate) fn current_context_peek( state.memory.get(MemoryAddress::new(context, segment, virt)) } +pub(crate) fn fill_channel_with_value(row: &mut CpuColumnsView, n: usize, val: U256) { + let channel = &mut row.mem_channels[n]; + let val_limbs: [u64; 4] = val.0; + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } +} + +/// Pushes without writing in memory. This happens in opcodes where a push immediately follows a pop. +/// The pushed value may be loaded in a memory channel, without creating a memory operation. +pub(crate) fn push_no_write( + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, + channel_opt: Option, +) { + state.registers.stack_top = val; + state.registers.stack_len += 1; + + if let Some(channel) = channel_opt { + let val_limbs: [u64; 4] = val.0; + + let channel = &mut row.mem_channels[channel]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ZERO; + channel.is_read = F::ZERO; + channel.addr_context = F::from_canonical_usize(0); + channel.addr_segment = F::from_canonical_usize(0); + channel.addr_virtual = F::from_canonical_usize(0); + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + } +} + +/// Pushes and (maybe) writes the previous stack top in memory. This happens in opcodes which only push. +pub(crate) fn push_with_write( + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, +) -> Result<(), ProgramError> { + if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { + return Err(ProgramError::StackOverflow); + } + + let write = if state.registers.stack_len == 0 { + None + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let res = mem_write_gp_log_and_fill( + NUM_GP_CHANNELS - 1, + address, + state, + row, + state.registers.stack_top, + ); + Some(res) + }; + push_no_write(state, row, val, None); + if let Some(log) = write { + state.traces.push_memory(log); + } + Ok(()) +} + pub(crate) fn mem_read_with_log( channel: MemoryChannel, address: MemoryAddress, @@ -146,6 +222,9 @@ pub(crate) fn mem_write_gp_log_and_fill( op } +// Channel 0 already contains the top of the stack. You only need to read +// from the second popped element. +// If the resulting stack isn't empty, update `stack_top`. pub(crate) fn stack_pop_with_log_and_fill( state: &mut GenerationState, row: &mut CpuColumnsView, @@ -154,39 +233,33 @@ pub(crate) fn stack_pop_with_log_and_fill( return Err(ProgramError::StackUnderflow); } + let new_stack_top = if state.registers.stack_len == N { + None + } else { + Some(stack_peek(state, N)?) + }; + let result = core::array::from_fn(|i| { - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1 - i, - ); - mem_read_gp_with_log_and_fill(i, address, state, row) + if i == 0 { + (state.registers.stack_top, DUMMY_MEMOP) + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - i, + ); + + mem_read_gp_with_log_and_fill(i, address, state, row) + } }); state.registers.stack_len -= N; - Ok(result) -} - -pub(crate) fn stack_push_log_and_fill( - state: &mut GenerationState, - row: &mut CpuColumnsView, - val: U256, -) -> Result { - if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { - return Err(ProgramError::StackOverflow); + if let Some(val) = new_stack_top { + state.registers.stack_top = val; } - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len, - ); - let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, state, row, val); - - state.registers.stack_len += 1; - - Ok(res) + Ok(result) } fn xor_into_sponge(