diff --git a/evm/Cargo.toml b/evm/Cargo.toml index e328aa0c97..3d217131f8 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -37,6 +37,7 @@ static_assertions = "1.1.0" hashbrown = { version = "0.14.0" } tiny-keccak = "2.0.2" serde_json = "1.0" +smt_utils = { git = "https://github.com/0xPolygonZero/smt_utils" } [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/evm/spec/tables.tex b/evm/spec/tables.tex index 92ee1d2a54..43b45eb584 100644 --- a/evm/spec/tables.tex +++ b/evm/spec/tables.tex @@ -3,6 +3,7 @@ \section{Tables} \input{tables/cpu} \input{tables/arithmetic} +\input{tables/byte-packing} \input{tables/logic} \input{tables/memory} \input{tables/keccak-f} diff --git a/evm/spec/tables/byte-packing.tex b/evm/spec/tables/byte-packing.tex new file mode 100644 index 0000000000..4b78eea118 --- /dev/null +++ b/evm/spec/tables/byte-packing.tex @@ -0,0 +1,4 @@ +\subsection{Byte Packing} +\label{byte-packing} + +TODO diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf index f181eba624..b6f3d90599 100644 Binary files a/evm/spec/zkevm.pdf and b/evm/spec/zkevm.pdf differ diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 079ff114c4..50d3268883 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -23,6 +23,7 @@ use crate::memory::memory_stark; use crate::memory::memory_stark::MemoryStark; use crate::stark::Stark; +/// Structure containing all STARKs and the cross-table lookups. #[derive(Clone)] pub struct AllStark, const D: usize> { pub arithmetic_stark: ArithmeticStark, @@ -36,6 +37,7 @@ pub struct AllStark, const D: usize> { } impl, const D: usize> Default for AllStark { + /// Returns an `AllStark` containing all the STARKs initialized with default values. fn default() -> Self { Self { arithmetic_stark: ArithmeticStark::default(), @@ -64,6 +66,7 @@ impl, const D: usize> AllStark { } } +/// Associates STARK tables with a unique index. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { Arithmetic = 0, @@ -75,9 +78,11 @@ pub enum Table { Memory = 6, } +/// Number of STARK tables. pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; impl Table { + /// Returns all STARK table indices. pub(crate) fn all() -> [Self; NUM_TABLES] { [ Self::Arithmetic, @@ -91,6 +96,7 @@ impl Table { } } +/// Returns all the `CrossTableLookups` used for proving the EVM. pub(crate) fn all_cross_table_lookups() -> Vec> { vec![ ctl_arithmetic(), @@ -103,6 +109,7 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ] } +/// `CrossTableLookup` for `ArithmeticStark`, to connect it with the `Cpu` module. fn ctl_arithmetic() -> CrossTableLookup { CrossTableLookup::new( vec![cpu_stark::ctl_arithmetic_base_rows()], @@ -110,6 +117,7 @@ fn ctl_arithmetic() -> CrossTableLookup { ) } +/// `CrossTableLookup` for `BytePackingStark`, to connect it with the `Cpu` module. fn ctl_byte_packing() -> CrossTableLookup { let cpu_packing_looking = TableWithColumns::new( Table::Cpu, @@ -132,9 +140,9 @@ fn ctl_byte_packing() -> CrossTableLookup { ) } -// We now need two different looked tables for `KeccakStark`: -// one for the inputs and one for the outputs. -// They are linked with the timestamp. +/// `CrossTableLookup` for `KeccakStark` inputs, to connect it with the `KeccakSponge` module. +/// `KeccakStarkSponge` looks into `KeccakStark` to give the inputs of the sponge. +/// Its consistency with the 'output' CTL is ensured through a timestamp column on the `KeccakStark` side. fn ctl_keccak_inputs() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, @@ -149,6 +157,8 @@ fn ctl_keccak_inputs() -> CrossTableLookup { CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) } +/// `CrossTableLookup` for `KeccakStark` outputs, to connect it with the `KeccakSponge` module. +/// `KeccakStarkSponge` looks into `KeccakStark` to give the outputs of the sponge. fn ctl_keccak_outputs() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, @@ -163,6 +173,7 @@ fn ctl_keccak_outputs() -> CrossTableLookup { CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) } +/// `CrossTableLookup` for `KeccakSpongeStark` to connect it with the `Cpu` module. fn ctl_keccak_sponge() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, @@ -177,6 +188,7 @@ fn ctl_keccak_sponge() -> CrossTableLookup { CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked) } +/// `CrossTableLookup` for `LogicStark` to connect it with the `Cpu` and `KeccakSponge` modules. fn ctl_logic() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, @@ -197,6 +209,7 @@ fn ctl_logic() -> CrossTableLookup { CrossTableLookup::new(all_lookers, logic_looked) } +/// `CrossTableLookup` for `MemoryStark` to connect it with all the modules which need memory accesses. fn ctl_memory() -> CrossTableLookup { let cpu_memory_code_read = TableWithColumns::new( Table::Cpu, diff --git a/evm/src/arithmetic/addcy.rs b/evm/src/arithmetic/addcy.rs index 3366e432ae..4f343b45d5 100644 --- a/evm/src/arithmetic/addcy.rs +++ b/evm/src/arithmetic/addcy.rs @@ -149,7 +149,7 @@ pub(crate) fn eval_packed_generic_addcy( } } -pub fn eval_packed_generic( +pub(crate) fn eval_packed_generic( lv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { @@ -236,7 +236,7 @@ pub(crate) fn eval_ext_circuit_addcy, const D: usiz } } -pub fn eval_ext_circuit, const D: usize>( +pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 3d281c868c..21dcf91985 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -22,8 +22,8 @@ use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::lookup::Lookup; use crate::stark::Stark; -/// Link the 16-bit columns of the arithmetic table, split into groups -/// of N_LIMBS at a time in `regs`, with the corresponding 32-bit +/// Creates a vector of `Columns` to link the 16-bit columns of the arithmetic table, +/// split into groups of N_LIMBS at a time in `regs`, with the corresponding 32-bit /// columns of the CPU table. Does this for all ops in `ops`. /// /// This is done by taking pairs of columns (x, y) of the arithmetic @@ -57,7 +57,8 @@ fn cpu_arith_data_link( res } -pub fn ctl_arithmetic_rows() -> TableWithColumns { +/// Returns the `TableWithColumns` for `ArithmeticStark` rows where one of the arithmetic operations has been called. +pub(crate) fn ctl_arithmetic_rows() -> TableWithColumns { // We scale each filter flag with the associated opcode value. // If an arithmetic operation is happening on the CPU side, // the CTL will enforce that the reconstructed opcode value @@ -102,6 +103,7 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { ) } +/// Structure representing the `Arithmetic` STARK, which carries out all the arithmetic operations. #[derive(Copy, Clone, Default)] pub struct ArithmeticStark { pub f: PhantomData, @@ -204,11 +206,17 @@ impl, const D: usize> Stark for ArithmeticSta let range_max = P::Scalar::from_canonical_u64((RANGE_MAX - 1) as u64); yield_constr.constraint_last_row(rc1 - range_max); + // Evaluate constraints for the MUL operation. mul::eval_packed_generic(lv, yield_constr); + // Evaluate constraints for ADD, SUB, LT and GT operations. addcy::eval_packed_generic(lv, yield_constr); + // Evaluate constraints for DIV and MOD operations. divmod::eval_packed(lv, nv, yield_constr); + // Evaluate constraints for ADDMOD, SUBMOD, MULMOD and for FP254 modular operations. modular::eval_packed(lv, nv, yield_constr); + // Evaluate constraints for the BYTE operation. byte::eval_packed(lv, yield_constr); + // Evaluate constraints for SHL and SHR operations. shift::eval_packed_generic(lv, nv, yield_constr); } @@ -223,6 +231,9 @@ impl, const D: usize> Stark for ArithmeticSta let nv: &[ExtensionTarget; NUM_ARITH_COLUMNS] = vars.get_next_values().try_into().unwrap(); + // Check the range column: First value must be 0, last row + // must be 2^16-1, and intermediate rows must increment by 0 + // or 1. let rc1 = lv[columns::RANGE_COUNTER]; let rc2 = nv[columns::RANGE_COUNTER]; yield_constr.constraint_first_row(builder, rc1); @@ -234,11 +245,17 @@ impl, const D: usize> Stark for ArithmeticSta let t = builder.sub_extension(rc1, range_max); yield_constr.constraint_last_row(builder, t); + // Evaluate constraints for the MUL operation. mul::eval_ext_circuit(builder, lv, yield_constr); + // Evaluate constraints for ADD, SUB, LT and GT operations. addcy::eval_ext_circuit(builder, lv, yield_constr); + // Evaluate constraints for DIV and MOD operations. divmod::eval_ext_circuit(builder, lv, nv, yield_constr); + // Evaluate constraints for ADDMOD, SUBMOD, MULMOD and for FP254 modular operations. modular::eval_ext_circuit(builder, lv, nv, yield_constr); + // Evaluate constraints for the BYTE operation. byte::eval_ext_circuit(builder, lv, yield_constr); + // Evaluate constraints for SHL and SHR operations. shift::eval_ext_circuit(builder, lv, nv, yield_constr); } diff --git a/evm/src/arithmetic/byte.rs b/evm/src/arithmetic/byte.rs index bb8cd12122..b7381ae0fa 100644 --- a/evm/src/arithmetic/byte.rs +++ b/evm/src/arithmetic/byte.rs @@ -197,7 +197,7 @@ pub(crate) fn generate(lv: &mut [F], idx: U256, val: U256) { ); } -pub fn eval_packed( +pub(crate) fn eval_packed( lv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { @@ -293,7 +293,7 @@ pub fn eval_packed( } } -pub fn eval_ext_circuit, const D: usize>( +pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, @@ -306,6 +306,7 @@ pub fn eval_ext_circuit, const D: usize>( let idx_decomp = &lv[AUX_INPUT_REGISTER_0]; let tree = &lv[AUX_INPUT_REGISTER_1]; + // low 5 bits of the first limb of idx: let mut idx0_lo5 = builder.zero_extension(); for i in 0..5 { let bit = idx_decomp[i]; @@ -316,6 +317,9 @@ pub fn eval_ext_circuit, const D: usize>( let scale = builder.constant_extension(scale); idx0_lo5 = builder.mul_add_extension(bit, scale, idx0_lo5); } + // Verify that idx0_hi is the high (11) bits of the first limb of + // idx (in particular idx0_hi is at most 11 bits, since idx[0] is + // at most 16 bits). let t = F::Extension::from(F::from_canonical_u64(32)); let t = builder.constant_extension(t); let t = builder.mul_add_extension(idx_decomp[5], t, idx0_lo5); @@ -323,6 +327,9 @@ pub fn eval_ext_circuit, const D: usize>( let t = builder.mul_extension(is_byte, t); yield_constr.constraint(builder, t); + // Verify the layers of the tree + // NB: Each of the bit values is negated in place to account for + // the reversed indexing. let one = builder.one_extension(); let bit = idx_decomp[4]; for i in 0..8 { @@ -362,6 +369,8 @@ pub fn eval_ext_circuit, const D: usize>( let t = builder.mul_extension(is_byte, t); yield_constr.constraint(builder, t); + // Check byte decomposition of last limb: + let base8 = F::Extension::from(F::from_canonical_u64(1 << 8)); let base8 = builder.constant_extension(base8); let lo_byte = lv[BYTE_LAST_LIMB_LO]; @@ -380,19 +389,29 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint(builder, t); let expected_out_byte = tree[15]; + // Sum all higher limbs; sum will be non-zero iff idx >= 32. let mut hi_limb_sum = lv[BYTE_IDX_DECOMP_HI]; for i in 1..N_LIMBS { hi_limb_sum = builder.add_extension(hi_limb_sum, idx[i]); } + // idx_is_large is 0 or 1 let idx_is_large = lv[BYTE_IDX_IS_LARGE]; let t = builder.mul_sub_extension(idx_is_large, idx_is_large, idx_is_large); let t = builder.mul_extension(is_byte, t); yield_constr.constraint(builder, t); + // If hi_limb_sum is nonzero, then idx_is_large must be one. let t = builder.sub_extension(idx_is_large, one); let t = builder.mul_many_extension([is_byte, hi_limb_sum, t]); yield_constr.constraint(builder, t); + // If idx_is_large is 1, then hi_limb_sum_inv must be the inverse + // of hi_limb_sum, hence hi_limb_sum is non-zero, hence idx is + // indeed "large". + // + // Otherwise, if idx_is_large is 0, then hi_limb_sum * hi_limb_sum_inv + // is zero, which is only possible if hi_limb_sum is zero, since + // hi_limb_sum_inv is non-zero. let base16 = F::from_canonical_u64(1 << 16); let hi_limb_sum_inv = builder.mul_const_add_extension( base16, @@ -414,6 +433,7 @@ pub fn eval_ext_circuit, const D: usize>( let t = builder.mul_extension(is_byte, check); yield_constr.constraint(builder, t); + // Check that the rest of the output limbs are zero for i in 1..N_LIMBS { let t = builder.mul_extension(is_byte, out[i]); yield_constr.constraint(builder, t); diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index df2d12476b..aa36b3ab71 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -109,4 +109,5 @@ pub(crate) const RANGE_COUNTER: usize = START_SHARED_COLS + NUM_SHARED_COLS; /// The frequencies column used in logUp. pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; +/// Number of columns in `ArithmeticStark`. pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS + 2; diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs index e143ded6dd..b284c829c5 100644 --- a/evm/src/arithmetic/divmod.rs +++ b/evm/src/arithmetic/divmod.rs @@ -1,3 +1,7 @@ +//! Support for EVM instructions DIV and MOD. +//! +//! The logic for verifying them is detailed in the `modular` submodule. + use std::ops::Range; use ethereum_types::U256; diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 7763e98a06..2699ee51c4 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -15,6 +15,9 @@ mod utils; pub mod arithmetic_stark; pub(crate) mod columns; +/// An enum representing different binary operations. +/// +/// `Shl` and `Shr` are handled differently, by leveraging `Mul` and `Div` respectively. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum BinaryOperator { Add, @@ -33,6 +36,7 @@ pub(crate) enum BinaryOperator { } impl BinaryOperator { + /// Computes the result of a binary arithmetic operation given two inputs. pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { BinaryOperator::Add => input0.overflowing_add(input1).0, @@ -81,6 +85,7 @@ impl BinaryOperator { } } + /// Maps a binary arithmetic operation to its associated flag column in the trace. pub(crate) fn row_filter(&self) -> usize { match self { BinaryOperator::Add => columns::IS_ADD, @@ -100,6 +105,7 @@ impl BinaryOperator { } } +/// An enum representing different ternary operations. #[allow(clippy::enum_variant_names)] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum TernaryOperator { @@ -109,6 +115,7 @@ pub(crate) enum TernaryOperator { } impl TernaryOperator { + /// Computes the result of a ternary arithmetic operation given three inputs. pub(crate) fn result(&self, input0: U256, input1: U256, input2: U256) -> U256 { match self { TernaryOperator::AddMod => addmod(input0, input1, input2), @@ -117,6 +124,7 @@ impl TernaryOperator { } } + /// Maps a ternary arithmetic operation to its associated flag column in the trace. pub(crate) fn row_filter(&self) -> usize { match self { TernaryOperator::AddMod => columns::IS_ADDMOD, @@ -145,7 +153,7 @@ pub(crate) enum Operation { } impl Operation { - /// Create a binary operator with given inputs. + /// Creates a binary operator with given inputs. /// /// NB: This works as you would expect, EXCEPT for SHL and SHR, /// whose inputs need a small amount of preprocessing. Specifically, @@ -170,6 +178,7 @@ impl Operation { } } + /// Creates a ternary operator with given inputs. pub(crate) fn ternary( operator: TernaryOperator, input0: U256, @@ -186,6 +195,7 @@ impl Operation { } } + /// Gets the result of an arithmetic operation. pub(crate) fn result(&self) -> U256 { match self { Operation::BinaryOperation { result, .. } => *result, @@ -222,6 +232,7 @@ impl Operation { } } +/// Converts a ternary arithmetic operation to one or two rows of the `ArithmeticStark` table. fn ternary_op_to_rows( row_filter: usize, input0: U256, @@ -239,6 +250,7 @@ fn ternary_op_to_rows( (row1, Some(row2)) } +/// Converts a binary arithmetic operation to one or two rows of the `ArithmeticStark` table. fn binary_op_to_rows( op: BinaryOperator, input0: U256, diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 4e6e21a632..05f9cf0da7 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -1,5 +1,5 @@ -//! Support for the EVM modular instructions ADDMOD, MULMOD and MOD, -//! as well as DIV. +//! Support for the EVM modular instructions ADDMOD, SUBMOD, MULMOD and MOD, +//! as well as DIV and FP254 related modular instructions. //! //! This crate verifies an EVM modular instruction, which takes three //! 256-bit inputs A, B and M, and produces a 256-bit output C satisfying @@ -478,7 +478,7 @@ pub(crate) fn modular_constr_poly( let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); let offset = P::Scalar::from_canonical_u64(AUX_COEFF_ABS_MAX as u64); - // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) + // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x)c let mut aux = [P::ZEROS; 2 * N_LIMBS]; for (c, i) in aux.iter_mut().zip(MODULAR_AUX_INPUT_LO) { // MODULAR_AUX_INPUT elements were offset by 2^20 in @@ -625,10 +625,13 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons ) -> [ExtensionTarget; 2 * N_LIMBS] { let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; + // Check that mod_is_zero is zero or one let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); let t = builder.mul_extension(filter, t); yield_constr.constraint_transition(builder, t); + // Check that mod_is_zero is zero if modulus is not zero (they + // could both be zero) let limb_sum = builder.add_many_extension(modulus); let t = builder.mul_extension(limb_sum, mod_is_zero); let t = builder.mul_extension(filter, t); @@ -636,13 +639,19 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons modulus[0] = builder.add_extension(modulus[0], mod_is_zero); + // Is 1 iff the operation is DIV or SHR and the denominator is zero. let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero); let t = builder.mul_extension(filter, t); yield_constr.constraint_transition(builder, t); + + // Needed to compensate for adding mod_is_zero to modulus above, + // since the call eval_packed_generic_addcy() below subtracts modulus + // to verify in the case of a DIV or SHR. output[0] = builder.add_extension(output[0], div_denom_is_zero); + // Verify that the output is reduced, i.e. output < modulus. let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; let one = builder.one_extension(); let zero = builder.zero_extension(); @@ -660,24 +669,31 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons &is_less_than, true, ); + // restore output[0] output[0] = builder.sub_extension(output[0], div_denom_is_zero); + // prod = q(x) * m(x) let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); + // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { let t = builder.mul_extension(filter, x); yield_constr.constraint_transition(builder, t); } + // constr_poly = c(x) + q(x) * m(x) let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); let offset = builder.constant_extension(F::Extension::from_canonical_u64(AUX_COEFF_ABS_MAX as u64)); let zero = builder.zero_extension(); + + // constr_poly = c(x) + q(x) * m(x) let mut aux = [zero; 2 * N_LIMBS]; for (c, i) in aux.iter_mut().zip(MODULAR_AUX_INPUT_LO) { *c = builder.sub_extension(nv[i], offset); } + // add high 16-bits of aux input let base = F::from_canonical_u64(1u64 << LIMB_BITS); for (c, j) in aux.iter_mut().zip(MODULAR_AUX_INPUT_HI) { *c = builder.mul_const_add_extension(base, nv[j], *c); @@ -700,10 +716,13 @@ pub(crate) fn submod_constr_poly_ext_circuit, const modulus: [ExtensionTarget; N_LIMBS], mut quot: [ExtensionTarget; 2 * N_LIMBS], ) -> [ExtensionTarget; 2 * N_LIMBS] { + // quot was offset by 2^16 - 1 if it was negative; we undo that + // offset here: let (lo, hi) = quot.split_at_mut(N_LIMBS); let sign = hi[0]; let t = builder.mul_sub_extension(sign, sign, sign); let t = builder.mul_extension(filter, t); + // sign must be 1 (negative) or 0 (positive) yield_constr.constraint(builder, t); let offset = F::from_canonical_u16(u16::max_value()); for c in lo { @@ -712,6 +731,7 @@ pub(crate) fn submod_constr_poly_ext_circuit, const } hi[0] = builder.zero_extension(); for d in hi { + // All higher limbs must be zero let t = builder.mul_extension(filter, *d); yield_constr.constraint(builder, t); } @@ -737,8 +757,12 @@ pub(crate) fn eval_ext_circuit, const D: usize>( bn254_filter, ]); + // Ensure that this operation is not the last row of the table; + // needed because we access the next row of the table in nv. yield_constr.constraint_last_row(builder, filter); + // Verify that the modulus is the BN254 modulus for the + // {ADD,MUL,SUB}FP254 operations. let modulus = read_value::(lv, MODULAR_MODULUS); for (&mi, bi) in modulus.iter().zip(bn254_modulus_limbs()) { // bn254_filter * (mi - bi) @@ -760,6 +784,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let mul_filter = builder.add_extension(lv[columns::IS_MULMOD], lv[columns::IS_MULFP254]); let addmul_filter = builder.add_extension(add_filter, mul_filter); + // constr_poly has 2*N_LIMBS limbs let submod_constr_poly = submod_constr_poly_ext_circuit( lv, nv, diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index c09c39d8dc..01c9d5c1c0 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -107,7 +107,7 @@ pub(crate) fn generate_mul(lv: &mut [F], left_in: [i64; 16], ri .copy_from_slice(&aux_limbs.map(|c| F::from_canonical_u16((c >> 16) as u16))); } -pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { +pub(crate) fn generate(lv: &mut [F], left_in: U256, right_in: U256) { // TODO: It would probably be clearer/cleaner to read the U256 // into an [i64;N] and then copy that to the lv table. u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); @@ -173,7 +173,7 @@ pub(crate) fn eval_packed_generic_mul( } } -pub fn eval_packed_generic( +pub(crate) fn eval_packed_generic( lv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { @@ -195,6 +195,8 @@ pub(crate) fn eval_ext_mul_circuit, const D: usize> let output_limbs = read_value::(lv, OUTPUT_REGISTER); let aux_limbs = { + // MUL_AUX_INPUT was offset by 2^20 in generation, so we undo + // that here let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); let offset = builder.constant_extension(F::Extension::from_canonical_u64(AUX_COEFF_ABS_MAX as u64)); @@ -211,17 +213,22 @@ pub(crate) fn eval_ext_mul_circuit, const D: usize> let mut constr_poly = pol_mul_lo_ext_circuit(builder, left_in_limbs, right_in_limbs); pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); + // This subtracts (x - β) * s(x) from constr_poly. let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); let rhs = pol_adjoin_root_ext_circuit(builder, aux_limbs, base); pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs); + // At this point constr_poly holds the coefficients of the + // polynomial a(x)b(x) - c(x) - (x - β)*s(x). The + // multiplication is valid if and only if all of those + // coefficients are zero. for &c in &constr_poly { let filter = builder.mul_extension(filter, c); yield_constr.constraint(builder, filter); } } -pub fn eval_ext_circuit, const D: usize>( +pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, diff --git a/evm/src/arithmetic/shift.rs b/evm/src/arithmetic/shift.rs index 6600c01e54..bb83798495 100644 --- a/evm/src/arithmetic/shift.rs +++ b/evm/src/arithmetic/shift.rs @@ -38,7 +38,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer /// NB: if `shift >= 256`, then the third register holds 0. /// We leverage the functions in mul.rs and divmod.rs to carry out /// the computation. -pub fn generate( +pub(crate) fn generate( lv: &mut [F], nv: &mut [F], is_shl: bool, @@ -117,7 +117,7 @@ fn eval_packed_shr( ); } -pub fn eval_packed_generic( +pub(crate) fn eval_packed_generic( lv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, @@ -168,7 +168,7 @@ fn eval_ext_circuit_shr, const D: usize>( ); } -pub fn eval_ext_circuit, const D: usize>( +pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index 6ea375fef3..7fadb8f14d 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -319,6 +319,7 @@ pub(crate) fn read_value_i64_limbs( } #[inline] +/// Turn a 64-bit integer into 4 16-bit limbs and convert them to field elements. fn u64_to_array(out: &mut [F], x: u64) { const_assert!(LIMB_BITS == 16); debug_assert!(out.len() == 4); @@ -329,6 +330,7 @@ fn u64_to_array(out: &mut [F], x: u64) { out[3] = F::from_canonical_u16((x >> 48) as u16); } +/// Turn a 256-bit integer into 16 16-bit limbs and convert them to field elements. // TODO: Refactor/replace u256_limbs in evm/src/util.rs pub(crate) fn u256_to_array(out: &mut [F], x: U256) { const_assert!(N_LIMBS == 16); diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs index c28b055a81..ad485c591e 100644 --- a/evm/src/byte_packing/byte_packing_stark.rs +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -59,6 +59,8 @@ use crate::witness::memory::MemoryAddress; /// Strict upper bound for the individual bytes range-check. const BYTE_RANGE_MAX: usize = 1usize << 8; +/// Creates the vector of `Columns` for `BytePackingStark` corresponding to the final packed limbs being read/written. +/// `CpuStark` will look into these columns, as the CPU needs the output of byte packing. pub(crate) fn ctl_looked_data() -> Vec> { // Reconstruct the u32 limbs composing the final `U256` word // being read/written from the underlying byte values. For each, @@ -88,12 +90,14 @@ pub(crate) fn ctl_looked_data() -> Vec> { .collect() } +/// CTL filter for the `BytePackingStark` looked table. pub fn ctl_looked_filter() -> Column { // The CPU table is only interested in our sequence end rows, // since those contain the final limbs of our packed int. Column::single(SEQUENCE_END) } +/// Column linear combination for the `BytePackingStark` table reading/writing the `i`th byte sequence from `MemoryStark`. pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); @@ -212,6 +216,8 @@ impl, const D: usize> BytePackingStark { row[index_bytes(i)] = F::ONE; rows.push(row); + + // Update those fields for the next row row[index_bytes(i)] = F::ZERO; row[ADDR_VIRTUAL] -= F::ONE; } diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs index 4eff0df8f5..16bc4be5fa 100644 --- a/evm/src/byte_packing/columns.rs +++ b/evm/src/byte_packing/columns.rs @@ -11,6 +11,8 @@ pub(crate) const IS_READ: usize = 0; pub(crate) const SEQUENCE_END: usize = IS_READ + 1; pub(super) const BYTES_INDICES_START: usize = SEQUENCE_END + 1; +// There are `NUM_BYTES` columns used to represent the index of the active byte +// for a given row of a byte (un)packing operation. pub(crate) const fn index_bytes(i: usize) -> usize { debug_assert!(i < NUM_BYTES); BYTES_INDICES_START + i @@ -28,6 +30,9 @@ pub(crate) const TIMESTAMP: usize = ADDR_VIRTUAL + 1; // 32 byte limbs hold a total of 256 bits. const BYTES_VALUES_START: usize = TIMESTAMP + 1; +// There are `NUM_BYTES` columns used to store the values of the bytes +// that are beeing read/written for an (un)packing operation. +// If `index_bytes(i) == 1`, then all `value_bytes(j) for j <= i` may be non-zero. pub(crate) const fn value_bytes(i: usize) -> usize { debug_assert!(i < NUM_BYTES); BYTES_VALUES_START + i @@ -38,4 +43,5 @@ pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; /// The frequencies column used in logUp. pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; +/// Number of columns in `BytePackingStark`. pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + 2; diff --git a/evm/src/byte_packing/mod.rs b/evm/src/byte_packing/mod.rs index 7cc93374ca..3767b21ed6 100644 --- a/evm/src/byte_packing/mod.rs +++ b/evm/src/byte_packing/mod.rs @@ -6,4 +6,5 @@ pub mod byte_packing_stark; pub mod columns; +/// Maximum number of bytes being processed by a byte (un)packing operation. pub(crate) const NUM_BYTES: usize = 32; diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 759c852aae..b04ff379dd 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -1,6 +1,7 @@ //! The initial phase of execution, where the kernel code is hashed while being written to memory. //! The hash is then checked against a precomputed kernel hash. +use ethereum_types::U256; use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -18,6 +19,7 @@ use crate::memory::segments::Segment; use crate::witness::memory::MemoryAddress; use crate::witness::util::{keccak_sponge_log, mem_write_gp_log_and_fill}; +/// Generates the rows to bootstrap the kernel. pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState) { // Iterate through chunks of the code, such that we can write one chunk to memory per row. for chunk in &KERNEL.code.iter().enumerate().chunks(NUM_GP_CHANNELS) { @@ -52,10 +54,18 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState MemoryAddress::new(0, Segment::Code, 0), KERNEL.code.clone(), ); + state.registers.stack_top = KERNEL + .code_hash + .iter() + .enumerate() + .fold(0.into(), |acc, (i, &elt)| { + acc + (U256::from(elt) << (224 - 32 * i)) + }); state.traces.push_cpu(final_cpu_row); log::info!("Bootstrapping took {} cycles", state.traces.clock()); } +/// Evaluates the constraints for kernel bootstrapping. pub(crate) fn eval_bootstrap_kernel_packed>( local_values: &CpuColumnsView

, next_values: &CpuColumnsView

, @@ -99,6 +109,8 @@ pub(crate) fn eval_bootstrap_kernel_packed> } } +/// Circuit version of `eval_bootstrap_kernel_packed`. +/// Evaluates the constraints for kernel bootstrapping. pub(crate) fn eval_bootstrap_kernel_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, local_values: &CpuColumnsView>, diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index d4f3447380..cfc24f51c6 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -14,52 +14,62 @@ pub(crate) union CpuGeneralColumnsView { } impl CpuGeneralColumnsView { - // SAFETY: Each view is a valid interpretation of the underlying array. + /// View of the columns used for exceptions: they are the exception code bits. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn exception(&self) -> &CpuExceptionView { unsafe { &self.exception } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// Mutable view of the column required for exceptions: they are the exception code bits. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn exception_mut(&mut self) -> &mut CpuExceptionView { unsafe { &mut self.exception } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// View of the columns required for logic operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn logic(&self) -> &CpuLogicView { unsafe { &self.logic } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// Mutable view of the columns required for logic operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn logic_mut(&mut self) -> &mut CpuLogicView { unsafe { &mut self.logic } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// View of the columns required for jump operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn jumps(&self) -> &CpuJumpsView { unsafe { &self.jumps } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// Mutable view of the columns required for jump operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView { unsafe { &mut self.jumps } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// View of the columns required for shift operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn shift(&self) -> &CpuShiftView { unsafe { &self.shift } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// Mutable view of the columns required for shift operations. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView { unsafe { &mut self.shift } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// View of the columns required for the stack top. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn stack(&self) -> &CpuStackView { unsafe { &self.stack } } - // SAFETY: Each view is a valid interpretation of the underlying array. + /// Mutable view of the columns required for the stack top. + /// SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn stack_mut(&mut self) -> &mut CpuStackView { unsafe { &mut self.stack } } @@ -94,41 +104,51 @@ impl BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView { } } +/// View of the first three `CpuGeneralColumns` containing exception code bits. #[derive(Copy, Clone)] pub(crate) struct CpuExceptionView { - // Exception code as little-endian bits. + /// Exception code as little-endian bits. pub(crate) exc_code_bits: [T; 3], } +/// View of the `CpuGeneralColumns` storing pseudo-inverses used to prove logic operations. #[derive(Copy, Clone)] pub(crate) struct CpuLogicView { - // Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs. + /// Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs. pub(crate) diff_pinv: [T; 8], } +/// View of the first two `CpuGeneralColumns` storing a flag and a pseudoinverse used to prove jumps. #[derive(Copy, Clone)] pub(crate) struct CpuJumpsView { - // A flag. + /// A flag indicating whether a jump should occur. pub(crate) should_jump: T, - // Pseudoinverse of `cond.iter().sum()`. Used to check `should_jump`. + /// Pseudoinverse of `cond.iter().sum()`. Used to check `should_jump`. pub(crate) cond_sum_pinv: T, } +/// View of the first `CpuGeneralColumns` storing a pseudoinverse used to prove shift operations. #[derive(Copy, Clone)] pub(crate) struct CpuShiftView { - // For a shift amount of displacement: [T], this is the inverse of - // sum(displacement[1..]) or zero if the sum is zero. + /// For a shift amount of displacement: [T], this is the inverse of + /// sum(displacement[1..]) or zero if the sum is zero. pub(crate) high_limb_sum_inv: T, } +/// View of the last three `CpuGeneralColumns` storing the stack length pseudoinverse `stack_inv`, +/// stack_len * stack_inv and filter * stack_inv_aux when needed. #[derive(Copy, Clone)] pub(crate) struct CpuStackView { // Used for conditionally enabling and disabling channels when reading the next `stack_top`. _unused: [T; 5], + /// Pseudoinverse of the stack len. pub(crate) stack_inv: T, + /// stack_inv * stack_len. pub(crate) stack_inv_aux: T, + /// Holds filter * stack_inv_aux when necessary, to reduce the degree of stack constraints. pub(crate) stack_inv_aux_2: T, } -// `u8` is guaranteed to have a `size_of` of 1. +/// Number of columns shared by all the views of `CpuGeneralColumnsView`. +/// `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index b7b4f780e0..f83147c2dc 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -12,23 +12,32 @@ use crate::memory; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; mod general; +/// Cpu operation flags. pub(crate) mod ops; +/// 32-bit limbs of the value stored in the current memory channel. pub type MemValue = [T; memory::VALUE_LIMBS]; +/// View of the columns required for one memory channel. #[repr(C)] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct MemoryChannelView { /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise /// 0. pub used: T, + /// 1 if a read is performed on the `i`th channel of the memory bus, otherwise 0. pub is_read: T, + /// Context of the memory operation in the `i`th channel of the memory bus. pub addr_context: T, + /// Segment of the memory operation in the `ith` channel of the memory bus. pub addr_segment: T, + /// Virtual address of the memory operation in the `ith` channel of the memory bus. pub addr_virtual: T, + /// Value, subdivided into 32-bit limbs, stored in the `ith` channel of the memory bus. pub value: MemValue, } +/// View of all the columns in `CpuStark`. #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct CpuColumnsView { @@ -36,7 +45,6 @@ pub struct CpuColumnsView { pub is_bootstrap_kernel: T, /// If CPU cycle: Current context. - // TODO: this is currently unconstrained pub context: T, /// If CPU cycle: Context for code memory channel. @@ -68,13 +76,18 @@ pub struct CpuColumnsView { /// Filter. 1 iff a Keccak sponge lookup is performed on this row. pub is_keccak_sponge: T, + /// Columns shared by various operations. pub(crate) general: CpuGeneralColumnsView, + /// CPU clock. pub(crate) clock: T, + + /// Memory bus channels in the CPU. Each channel is comprised of 13 columns. pub mem_channels: [MemoryChannelView; NUM_GP_CHANNELS], } -// `u8` is guaranteed to have a `size_of` of 1. +/// Total number of columns in `CpuStark`. +/// `u8` is guaranteed to have a `size_of` of 1. pub const NUM_CPU_COLUMNS: usize = size_of::>(); impl Default for CpuColumnsView { @@ -146,4 +159,5 @@ const fn make_col_map() -> CpuColumnsView { unsafe { transmute::<[usize; NUM_CPU_COLUMNS], CpuColumnsView>(indices_arr) } } +/// Mapping between [0..NUM_CPU_COLUMNS-1] and the CPU columns. pub const COL_MAP: CpuColumnsView = make_col_map(); diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 270b0ab871..282fa393e2 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -4,37 +4,55 @@ use std::ops::{Deref, DerefMut}; use crate::util::transmute_no_compile_time_size_checks; +/// Structure representing the flags for the various opcodes. #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct OpsColumnsView { - pub binary_op: T, // Combines ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags. - pub ternary_op: T, // Combines ADDMOD, MULMOD and SUBMOD flags. - pub fp254_op: T, // Combines ADD_FP254, MUL_FP254 and SUB_FP254 flags. - pub eq_iszero: T, // Combines EQ and ISZERO flags. - pub logic_op: T, // Combines AND, OR and XOR flags. - pub not: T, - pub shift: T, // Combines SHL and SHR flags. - pub keccak_general: T, + /// Combines ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags. + pub binary_op: T, + /// Combines ADDMOD, MULMOD and SUBMOD flags. + pub ternary_op: T, + /// Combines ADD_FP254, MUL_FP254 and SUB_FP254 flags. + pub fp254_op: T, + /// Combines EQ and ISZERO flags. + pub eq_iszero: T, + /// Combines AND, OR and XOR flags. + pub logic_op: T, + /// Combines NOT and POP flags. + pub not_pop: T, + /// Combines SHL and SHR flags. + pub shift: T, + /// Combines JUMPDEST and KECCAK_GENERAL flags. + pub jumpdest_keccak_general: T, + /// Flag for PROVER_INPUT. pub prover_input: T, - pub pop: T, - pub jumps: T, // Combines JUMP and JUMPI flags. - pub pc: T, - pub jumpdest: T, - pub push0: T, + /// Combines JUMP and JUMPI flags. + pub jumps: T, + /// Flag for PUSH. pub push: T, + /// Combines DUP and SWAP flags. pub dup_swap: T, - pub get_context: T, - pub set_context: T, + /// Combines GET_CONTEXT and SET_CONTEXT flags. + pub context_op: T, + /// Flag for MSTORE_32BYTES. pub mstore_32bytes: T, + /// Flag for MLOAD_32BYTES. pub mload_32bytes: T, + /// Flag for EXIT_KERNEL. pub exit_kernel: T, + /// Combines MSTORE_GENERAL and MLOAD_GENERAL flags. pub m_op_general: T, + /// Combines PC and PUSH0 + pub pc_push0: T, + /// Flag for syscalls. pub syscall: T, + /// Flag for exceptions. pub exception: T, } -// `u8` is guaranteed to have a `size_of` of 1. +/// Number of columns in Cpu Stark. +/// `u8` is guaranteed to have a `size_of` of 1. pub const NUM_OPS_COLUMNS: usize = size_of::>(); impl From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView { diff --git a/evm/src/cpu/contextops.rs b/evm/src/cpu/contextops.rs index 1683c30e56..8e25d5a168 100644 --- a/evm/src/cpu/contextops.rs +++ b/evm/src/cpu/contextops.rs @@ -1,3 +1,4 @@ +use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::field::types::Field; @@ -5,31 +6,114 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; +use super::columns::ops::OpsColumnsView; +use super::membus::NUM_GP_CHANNELS; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::memory::segments::Segment; +// If true, the instruction will keep the current context for the next row. +// If false, next row's context is handled manually. +const KEEPS_CONTEXT: OpsColumnsView = OpsColumnsView { + binary_op: true, + ternary_op: true, + fp254_op: true, + eq_iszero: true, + logic_op: true, + not_pop: true, + shift: true, + jumpdest_keccak_general: true, + prover_input: true, + jumps: true, + pc_push0: true, + push: true, + dup_swap: true, + context_op: false, + mstore_32bytes: true, + mload_32bytes: true, + exit_kernel: true, + m_op_general: true, + syscall: true, + exception: true, +}; + +fn eval_packed_keep( + lv: &CpuColumnsView

, + nv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + for (op, keeps_context) in izip!(lv.op.into_iter(), KEEPS_CONTEXT.into_iter()) { + if keeps_context { + yield_constr.constraint_transition(op * (nv.context - lv.context)); + } + } + + // context_op is hybrid; we evaluate it separately. + let is_get_context = lv.op.context_op * (lv.opcode_bits[0] - P::ONES); + yield_constr.constraint_transition(is_get_context * (nv.context - lv.context)); +} + +fn eval_ext_circuit_keep, const D: usize>( + builder: &mut CircuitBuilder, + lv: &CpuColumnsView>, + nv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + for (op, keeps_context) in izip!(lv.op.into_iter(), KEEPS_CONTEXT.into_iter()) { + if keeps_context { + let diff = builder.sub_extension(nv.context, lv.context); + let constr = builder.mul_extension(op, diff); + yield_constr.constraint_transition(builder, constr); + } + } + + // context_op is hybrid; we evaluate it separately. + let is_get_context = + builder.mul_sub_extension(lv.op.context_op, lv.opcode_bits[0], lv.op.context_op); + let diff = builder.sub_extension(nv.context, lv.context); + let constr = builder.mul_extension(is_get_context, diff); + yield_constr.constraint_transition(builder, constr); +} + +/// Evaluates constraints for GET_CONTEXT. fn eval_packed_get( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.get_context; + // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0. + let filter = lv.op.context_op * (P::ONES - lv.opcode_bits[0]); let new_stack_top = nv.mem_channels[0].value; yield_constr.constraint(filter * (new_stack_top[0] - lv.context)); for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } + + // Constrain new stack length. + yield_constr.constraint(filter * (nv.stack_len - (lv.stack_len + P::ONES))); + + // Unused channels. + for i in 1..NUM_GP_CHANNELS { + if i != 3 { + let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * channel.used); + } + } + yield_constr.constraint(filter * nv.mem_channels[0].used); } +/// Circuit version of `eval_packed_get`. +/// Evalutes constraints for GET_CONTEXT. fn eval_ext_circuit_get, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.get_context; + // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0. + let prod = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]); + let filter = builder.sub_extension(lv.op.context_op, prod); let new_stack_top = nv.mem_channels[0].value; { let diff = builder.sub_extension(new_stack_top[0], lv.context); @@ -40,14 +124,36 @@ fn eval_ext_circuit_get, const D: usize>( let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } + + // Constrain new stack length. + { + let new_len = builder.add_const_extension(lv.stack_len, F::ONE); + let diff = builder.sub_extension(nv.stack_len, new_len); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + + // Unused channels. + for i in 1..NUM_GP_CHANNELS { + if i != 3 { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + } + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } } +/// Evaluates constraints for `SET_CONTEXT`. fn eval_packed_set( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.set_context; + let filter = lv.op.context_op * lv.opcode_bits[0]; let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; @@ -77,43 +183,40 @@ fn eval_packed_set( yield_constr.constraint(filter * (read_new_sp_channel.addr_segment - ctx_metadata_segment)); yield_constr.constraint(filter * (read_new_sp_channel.addr_virtual - stack_size_field)); - // The next row's stack top is loaded from memory (if the stack isn't empty). - yield_constr.constraint(filter * nv.mem_channels[0].used); - - let read_new_stack_top_channel = lv.mem_channels[3]; - let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); - let new_filter = filter * nv.stack_len; - - for (limb_channel, limb_top) in read_new_stack_top_channel - .value - .iter() - .zip(nv.mem_channels[0].value) - { - yield_constr.constraint(new_filter * (*limb_channel - limb_top)); - } - yield_constr.constraint(new_filter * (read_new_stack_top_channel.used - P::ONES)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.is_read - P::ONES)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_context - nv.context)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_segment - stack_segment)); + // Constrain stack_inv_aux_2. + let new_top_channel = nv.mem_channels[0]; yield_constr.constraint( - new_filter * (read_new_stack_top_channel.addr_virtual - (nv.stack_len - P::ONES)), + lv.op.context_op + * (lv.general.stack().stack_inv_aux * lv.opcode_bits[0] + - lv.general.stack().stack_inv_aux_2), ); - - // If the new stack is empty, disable the channel read. + // The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed). yield_constr.constraint( - filter * (nv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + lv.op.context_op + * lv.general.stack().stack_inv_aux_2 + * (lv.mem_channels[3].value[0] - new_top_channel.value[0]), ); - let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); - yield_constr.constraint(empty_stack_filter * read_new_stack_top_channel.used); + for &limb in &new_top_channel.value[1..] { + yield_constr.constraint(lv.op.context_op * lv.general.stack().stack_inv_aux_2 * limb); + } + + // Unused channels. + for i in 4..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * channel.used); + } + yield_constr.constraint(filter * new_top_channel.used); } +/// Circuit version of `eval_packed_set`. +/// Evaluates constraints for SET_CONTEXT. fn eval_ext_circuit_set, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.set_context; + let filter = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]); let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; @@ -197,85 +300,144 @@ fn eval_ext_circuit_set, const D: usize>( yield_constr.constraint(builder, constr); } - // The next row's stack top is loaded from memory (if the stack isn't empty). - { - let constr = builder.mul_extension(filter, nv.mem_channels[0].used); - yield_constr.constraint(builder, constr); - } - - let read_new_stack_top_channel = lv.mem_channels[3]; - let stack_segment = - builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); - - let new_filter = builder.mul_extension(filter, nv.stack_len); - - for (limb_channel, limb_top) in read_new_stack_top_channel - .value - .iter() - .zip(nv.mem_channels[0].value) - { - let diff = builder.sub_extension(*limb_channel, limb_top); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = - builder.mul_sub_extension(new_filter, read_new_stack_top_channel.used, new_filter); - yield_constr.constraint(builder, constr); - } - { - let constr = - builder.mul_sub_extension(new_filter, read_new_stack_top_channel.is_read, new_filter); - yield_constr.constraint(builder, constr); - } + // Constrain stack_inv_aux_2. + let new_top_channel = nv.mem_channels[0]; { - let diff = builder.sub_extension(read_new_stack_top_channel.addr_context, nv.context); - let constr = builder.mul_extension(new_filter, diff); + let diff = builder.mul_sub_extension( + lv.general.stack().stack_inv_aux, + lv.opcode_bits[0], + lv.general.stack().stack_inv_aux_2, + ); + let constr = builder.mul_extension(lv.op.context_op, diff); yield_constr.constraint(builder, constr); } + // The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed). { - let diff = builder.sub_extension(read_new_stack_top_channel.addr_segment, stack_segment); - let constr = builder.mul_extension(new_filter, diff); + let diff = builder.sub_extension(lv.mem_channels[3].value[0], new_top_channel.value[0]); + let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, diff); + let constr = builder.mul_extension(lv.op.context_op, prod); yield_constr.constraint(builder, constr); } - { - let diff = builder.sub_extension(nv.stack_len, one); - let diff = builder.sub_extension(read_new_stack_top_channel.addr_virtual, diff); - let constr = builder.mul_extension(new_filter, diff); + for &limb in &new_top_channel.value[1..] { + let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, limb); + let constr = builder.mul_extension(lv.op.context_op, prod); yield_constr.constraint(builder, constr); } - // If the new stack is empty, disable the channel read. - { - let diff = builder.mul_extension(nv.stack_len, lv.general.stack().stack_inv); - let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); - let constr = builder.mul_extension(filter, diff); + // Unused channels. + for i in 4..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } - { - let empty_stack_filter = - builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); - let constr = builder.mul_extension(empty_stack_filter, read_new_stack_top_channel.used); + let constr = builder.mul_extension(filter, new_top_channel.used); yield_constr.constraint(builder, constr); } } +/// Evaluates the constraints for the GET and SET opcodes. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { + eval_packed_keep(lv, nv, yield_constr); eval_packed_get(lv, nv, yield_constr); eval_packed_set(lv, nv, yield_constr); + + // Stack constraints. + // Both operations use memory channel 3. The operations are similar enough that + // we can constrain both at the same time. + let filter = lv.op.context_op; + let channel = lv.mem_channels[3]; + // For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0. + // However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes. + let stack_len = nv.stack_len - (P::ONES - lv.opcode_bits[0]); + // Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise. + yield_constr.constraint( + filter * (stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Enable or disable the channel. + yield_constr.constraint(filter * (lv.general.stack().stack_inv_aux - channel.used)); + let new_filter = filter * lv.general.stack().stack_inv_aux; + // It's a write for get_context, a read for set_context. + yield_constr.constraint(new_filter * (channel.is_read - lv.opcode_bits[0])); + // In both cases, next row's context works. + yield_constr.constraint(new_filter * (channel.addr_context - nv.context)); + // Same segment for both. + yield_constr.constraint( + new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // The address is one less than stack_len. + let addr_virtual = stack_len - P::ONES; + yield_constr.constraint(new_filter * (channel.addr_virtual - addr_virtual)); } +/// Circuit version of èval_packed`. +/// Evaluates the constraints for the GET and SET opcodes. pub fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + eval_ext_circuit_keep(builder, lv, nv, yield_constr); eval_ext_circuit_get(builder, lv, nv, yield_constr); eval_ext_circuit_set(builder, lv, nv, yield_constr); + + // Stack constraints. + // Both operations use memory channel 3. The operations are similar enough that + // we can constrain both at the same time. + let filter = lv.op.context_op; + let channel = lv.mem_channels[3]; + // For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0. + // However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes. + let diff = builder.add_const_extension(lv.opcode_bits[0], -F::ONE); + let stack_len = builder.add_extension(nv.stack_len, diff); + // Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise. + { + let diff = builder.mul_sub_extension( + stack_len, + lv.general.stack().stack_inv, + lv.general.stack().stack_inv_aux, + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Enable or disable the channel. + { + let diff = builder.sub_extension(lv.general.stack().stack_inv_aux, channel.used); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + let new_filter = builder.mul_extension(filter, lv.general.stack().stack_inv_aux); + // It's a write for get_context, a read for set_context. + { + let diff = builder.sub_extension(channel.is_read, lv.opcode_bits[0]); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // In both cases, next row's context works. + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // Same segment for both. + { + let diff = builder.add_const_extension( + channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // The address is one less than stack_len. + { + let addr_virtual = builder.add_const_extension(stack_len, -F::ONE); + let diff = builder.sub_extension(channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 2f496b514a..adaee51123 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,42 +8,41 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 17] = [ +const NATIVE_INSTRUCTIONS: [usize; 13] = [ COL_MAP.op.binary_op, COL_MAP.op.ternary_op, COL_MAP.op.fp254_op, COL_MAP.op.eq_iszero, COL_MAP.op.logic_op, - COL_MAP.op.not, + COL_MAP.op.not_pop, COL_MAP.op.shift, - COL_MAP.op.keccak_general, + COL_MAP.op.jumpdest_keccak_general, COL_MAP.op.prover_input, - COL_MAP.op.pop, // not JUMPS (possible need to jump) - COL_MAP.op.pc, - COL_MAP.op.jumpdest, - COL_MAP.op.push0, + COL_MAP.op.pc_push0, // not PUSH (need to increment by more than 1) COL_MAP.op.dup_swap, - COL_MAP.op.get_context, - COL_MAP.op.set_context, + COL_MAP.op.context_op, // not EXIT_KERNEL (performs a jump) COL_MAP.op.m_op_general, // not SYSCALL (performs a jump) // not exceptions (also jump) ]; +/// Returns `halt`'s program counter. pub(crate) fn get_halt_pc() -> F { let halt_pc = KERNEL.global_labels["halt"]; F::from_canonical_usize(halt_pc) } +/// Returns `main`'s program counter. pub(crate) fn get_start_pc() -> F { let start_pc = KERNEL.global_labels["main"]; F::from_canonical_usize(start_pc) } +/// Evaluates the constraints related to the flow of instructions. pub fn eval_packed_generic( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -82,6 +81,8 @@ pub fn eval_packed_generic( yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len); } +/// Circuit version of `eval_packed`. +/// Evaluates the constraints related to the flow of instructions. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 64a2db9c36..926cd7485c 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -14,7 +14,6 @@ use super::halt; use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{COL_MAP, NUM_CPU_COLUMNS}; -use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::{ bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio, modfp254, pc, push0, shift, simple_logic, stack, stack_bounds, syscalls_exceptions, @@ -25,6 +24,8 @@ use crate::memory::segments::Segment; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; +/// Creates the vector of `Columns` corresponding to the General Purpose channels when calling the Keccak sponge: +/// the CPU reads the output of the sponge directly from the `KeccakSpongeStark` table. pub fn ctl_data_keccak_sponge() -> Vec> { // When executing KECCAK_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context @@ -41,26 +42,25 @@ pub fn ctl_data_keccak_sponge() -> Vec> { let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); let mut cols = vec![context, segment, virt, len, timestamp]; - cols.extend(COL_MAP.mem_channels[4].value.map(Column::single)); + cols.extend(Column::singles_next_row(COL_MAP.mem_channels[0].value)); cols } +/// CTL filter for a call to the Keccak sponge. pub fn ctl_filter_keccak_sponge() -> Column { Column::single(COL_MAP.is_keccak_sponge) } -/// Create the vector of Columns corresponding to the two inputs and +/// Creates the vector of `Columns` corresponding to the two inputs and /// one output of a binary operation. fn ctl_data_binops() -> Vec> { let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); - res.extend(Column::singles( - COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, - )); + res.extend(Column::singles_next_row(COL_MAP.mem_channels[0].value)); res } -/// Create the vector of Columns corresponding to the three inputs and +/// Creates the vector of `Columns` corresponding to the three inputs and /// one output of a ternary operation. By default, ternary operations use /// the first three memory channels, and the last one for the result (binary /// operations do not use the third inputs). @@ -68,12 +68,11 @@ fn ctl_data_ternops() -> Vec> { let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles(COL_MAP.mem_channels[2].value)); - res.extend(Column::singles( - COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, - )); + res.extend(Column::singles_next_row(COL_MAP.mem_channels[0].value)); res } +/// Creates the vector of columns corresponding to the opcode, the two inputs and the output of the logic operation. pub fn ctl_data_logic() -> Vec> { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut res = vec![Column::le_bits(COL_MAP.opcode_bits)]; @@ -81,10 +80,12 @@ pub fn ctl_data_logic() -> Vec> { res } +/// CTL filter for logic operations. pub fn ctl_filter_logic() -> Column { Column::single(COL_MAP.op.logic_op) } +/// Returns the `TableWithColumns` for the CPU rows calling arithmetic operations. pub fn ctl_arithmetic_base_rows() -> TableWithColumns { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; @@ -106,14 +107,18 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { ) } +/// Creates the vector of `Columns` corresponding to the contents of General Purpose channels when calling byte packing. +/// We use `ctl_data_keccak_sponge` because the `Columns` are the same as the ones computed for `KeccakSpongeStark`. pub fn ctl_data_byte_packing() -> Vec> { ctl_data_keccak_sponge() } +/// CTL filter for the `MLOAD_32BYTES` operation. pub fn ctl_filter_byte_packing() -> Column { Column::single(COL_MAP.op.mload_32bytes) } +/// Creates the vector of `Columns` corresponding to the contents of General Purpose channels when calling byte unpacking. pub fn ctl_data_byte_unpacking() -> Vec> { // When executing MSTORE_32BYTES, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context @@ -136,11 +141,14 @@ pub fn ctl_data_byte_unpacking() -> Vec> { res } +/// CTL filter for the `MSTORE_32BYTES` operation. pub fn ctl_filter_byte_unpacking() -> Column { Column::single(COL_MAP.op.mstore_32bytes) } +/// Index of the memory channel storing code. pub const MEM_CODE_CHANNEL_IDX: usize = 0; +/// Index of the first general purpose memory channel. pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1; /// Make the time/channel column for memory lookups. @@ -150,6 +158,7 @@ fn mem_time_and_channel(channel: usize) -> Column { Column::linear_combination_with_constant([(COL_MAP.clock, scalar)], addend) } +/// Creates the vector of `Columns` corresponding to the contents of the code channel when reading code values. pub fn ctl_data_code_memory() -> Vec> { let mut cols = vec![ Column::constant(F::ONE), // is_read @@ -169,6 +178,7 @@ pub fn ctl_data_code_memory() -> Vec> { cols } +/// Creates the vector of `Columns` corresponding to the contents of General Purpose channels. pub fn ctl_data_gp_memory(channel: usize) -> Vec> { let channel_map = COL_MAP.mem_channels[channel]; let mut cols: Vec<_> = Column::singles([ @@ -186,14 +196,17 @@ pub fn ctl_data_gp_memory(channel: usize) -> Vec> { cols } +/// CTL filter for code read and write operations. pub fn ctl_filter_code_memory() -> Column { Column::sum(COL_MAP.op.iter()) } +/// CTL filter for General Purpose memory read and write operations. pub fn ctl_filter_gp_memory(channel: usize) -> Column { Column::single(COL_MAP.mem_channels[channel].used) } +/// Structure representing the CPU Stark. #[derive(Copy, Clone, Default)] pub struct CpuStark { pub f: PhantomData, @@ -207,6 +220,7 @@ impl, const D: usize> Stark for CpuStark, NUM_CPU_COLUMNS>; + /// Evaluates all CPU constraints. fn eval_packed_generic( &self, vars: &Self::EvaluationFrame, @@ -240,6 +254,8 @@ impl, const D: usize> Stark for CpuStark, diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index a4756684a2..77246aea8d 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -23,27 +23,22 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; /// behavior. /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. -const OPCODES: [(u8, usize, bool, usize); 16] = [ +const OPCODES: [(u8, usize, bool, usize); 9] = [ // (start index of block, number of top bits to check (log2), kernel-only, flag column) // ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags are handled partly manually here, and partly through the Arithmetic table CTL. // ADDMOD, MULMOD and SUBMOD flags are handled partly manually here, and partly through the Arithmetic table CTL. // FP254 operation flags are handled partly manually here, and partly through the Arithmetic table CTL. (0x14, 1, false, COL_MAP.op.eq_iszero), // AND, OR and XOR flags are handled partly manually here, and partly through the Logic table CTL. - (0x19, 0, false, COL_MAP.op.not), + // NOT and POP are handled manually here. // SHL and SHR flags are handled partly manually here, and partly through the Logic table CTL. - (0x21, 0, true, COL_MAP.op.keccak_general), + // JUMPDEST and KECCAK_GENERAL are handled manually here. (0x49, 0, true, COL_MAP.op.prover_input), - (0x50, 0, false, COL_MAP.op.pop), - (0x56, 1, false, COL_MAP.op.jumps), // 0x56-0x57 - (0x58, 0, false, COL_MAP.op.pc), - (0x5b, 0, false, COL_MAP.op.jumpdest), - (0x5f, 0, false, COL_MAP.op.push0), + (0x56, 1, false, COL_MAP.op.jumps), // 0x56-0x57 (0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f (0x80, 5, false, COL_MAP.op.dup_swap), // 0x80-0x9f (0xee, 0, true, COL_MAP.op.mstore_32bytes), - (0xf6, 0, true, COL_MAP.op.get_context), - (0xf7, 0, true, COL_MAP.op.set_context), + (0xf6, 1, true, COL_MAP.op.context_op), //0xf6-0xf7 (0xf8, 0, true, COL_MAP.op.mload_32bytes), (0xf9, 0, true, COL_MAP.op.exit_kernel), // MLOAD_GENERAL and MSTORE_GENERAL flags are handled manually here. @@ -52,13 +47,16 @@ const OPCODES: [(u8, usize, bool, usize); 16] = [ /// List of combined opcodes requiring a special handling. /// Each index in the list corresponds to an arbitrary combination /// of opcodes defined in evm/src/cpu/columns/ops.rs. -const COMBINED_OPCODES: [usize; 6] = [ +const COMBINED_OPCODES: [usize; 9] = [ COL_MAP.op.logic_op, COL_MAP.op.fp254_op, COL_MAP.op.binary_op, COL_MAP.op.ternary_op, COL_MAP.op.shift, COL_MAP.op.m_op_general, + COL_MAP.op.jumpdest_keccak_general, + COL_MAP.op.not_pop, + COL_MAP.op.pc_push0, ]; /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -75,6 +73,7 @@ const fn bits_from_opcode(opcode: u8) -> [bool; 8] { ] } +/// Evaluates the constraints for opcode decoding. pub fn eval_packed_generic( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -147,8 +146,41 @@ pub fn eval_packed_generic( * (opcode - P::Scalar::from_canonical_usize(0xfc_usize)) * lv.op.m_op_general; yield_constr.constraint(m_op_constr); + + // Manually check lv.op.jumpdest_keccak_general. + // KECCAK_GENERAL is a kernel-only instruction, but not JUMPDEST. + // JUMPDEST is differentiated from KECCAK_GENERAL by its second bit set to 1. + yield_constr.constraint( + (P::ONES - kernel_mode) * lv.op.jumpdest_keccak_general * (P::ONES - lv.opcode_bits[1]), + ); + + // Check the JUMPDEST and KERNEL_GENERAL opcodes. + let jumpdest_opcode = P::Scalar::from_canonical_usize(0x5b); + let keccak_general_opcode = P::Scalar::from_canonical_usize(0x21); + let jumpdest_keccak_general_constr = (opcode - keccak_general_opcode) + * (opcode - jumpdest_opcode) + * lv.op.jumpdest_keccak_general; + yield_constr.constraint(jumpdest_keccak_general_constr); + + // Manually check lv.op.pc_push0. + // Both PC and PUSH0 can be called outside of the kernel mode: + // there is no need to constrain them in that regard. + let pc_push0_constr = (opcode - P::Scalar::from_canonical_usize(0x58_usize)) + * (opcode - P::Scalar::from_canonical_usize(0x5f_usize)) + * lv.op.pc_push0; + yield_constr.constraint(pc_push0_constr); + + // Manually check lv.op.not_pop. + // Both NOT and POP can be called outside of the kernel mode: + // there is no need to constrain them in that regard. + let not_pop_op = (opcode - P::Scalar::from_canonical_usize(0x19_usize)) + * (opcode - P::Scalar::from_canonical_usize(0x50_usize)) + * lv.op.not_pop; + yield_constr.constraint(not_pop_op); } +/// Circuit version of `eval_packed_generic`. +/// Evaluates the constraints for opcode decoding. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -249,4 +281,55 @@ pub fn eval_ext_circuit, const D: usize>( m_op_constr = builder.mul_extension(m_op_constr, lv.op.m_op_general); yield_constr.constraint(builder, m_op_constr); + + // Manually check lv.op.jumpdest_keccak_general. + // KECCAK_GENERAL is a kernel-only instruction, but not JUMPDEST. + // JUMPDEST is differentiated from KECCAK_GENERAL by its second bit set to 1. + let jumpdest_opcode = + builder.constant_extension(F::Extension::from_canonical_usize(0x5b_usize)); + let keccak_general_opcode = + builder.constant_extension(F::Extension::from_canonical_usize(0x21_usize)); + + // Check that KECCAK_GENERAL is kernel-only. + let mut kernel_general_filter = builder.sub_extension(one, lv.opcode_bits[1]); + kernel_general_filter = + builder.mul_extension(lv.op.jumpdest_keccak_general, kernel_general_filter); + let constr = builder.mul_extension(is_not_kernel_mode, kernel_general_filter); + yield_constr.constraint(builder, constr); + + // Check the JUMPDEST and KERNEL_GENERAL opcodes. + let jumpdest_constr = builder.sub_extension(opcode, jumpdest_opcode); + let keccak_general_constr = builder.sub_extension(opcode, keccak_general_opcode); + let mut jumpdest_keccak_general_constr = + builder.mul_extension(jumpdest_constr, keccak_general_constr); + jumpdest_keccak_general_constr = builder.mul_extension( + jumpdest_keccak_general_constr, + lv.op.jumpdest_keccak_general, + ); + + yield_constr.constraint(builder, jumpdest_keccak_general_constr); + + // Manually check lv.op.pc_push0. + // Both PC and PUSH0 can be called outside of the kernel mode: + // there is no need to constrain them in that regard. + let pc_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0x58_usize)); + let push0_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0x5f_usize)); + let pc_constr = builder.sub_extension(opcode, pc_opcode); + let push0_constr = builder.sub_extension(opcode, push0_opcode); + let mut pc_push0_constr = builder.mul_extension(pc_constr, push0_constr); + pc_push0_constr = builder.mul_extension(pc_push0_constr, lv.op.pc_push0); + yield_constr.constraint(builder, pc_push0_constr); + + // Manually check lv.op.not_pop. + // Both NOT and POP can be called outside of the kernel mode: + // there is no need to constrain them in that regard. + let not_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0x19_usize)); + let pop_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0x50_usize)); + + let not_constr = builder.sub_extension(opcode, not_opcode); + let pop_constr = builder.sub_extension(opcode, pop_opcode); + + let mut not_pop_constr = builder.mul_extension(not_constr, pop_constr); + not_pop_constr = builder.mul_extension(lv.op.not_pop, not_pop_constr); + yield_constr.constraint(builder, not_pop_constr); } diff --git a/evm/src/cpu/dup_swap.rs b/evm/src/cpu/dup_swap.rs index 0cc6c67c8f..fd623c7ee6 100644 --- a/evm/src/cpu/dup_swap.rs +++ b/evm/src/cpu/dup_swap.rs @@ -6,6 +6,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; +use super::membus::NUM_GP_CHANNELS; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, MemoryChannelView}; use crate::memory::segments::Segment; @@ -100,6 +101,7 @@ fn constrain_channel_ext_circuit, const D: usize>( ); yield_constr.constraint(builder, constr); } + // Top of the stack is at `addr = lv.stack_len - 1`. { let constr = builder.add_extension(channel.addr_virtual, offset); let constr = builder.sub_extension(constr, lv.stack_len); @@ -108,6 +110,7 @@ fn constrain_channel_ext_circuit, const D: usize>( } } +/// Evaluates constraints for DUP. fn eval_packed_dup( n: P, lv: &CpuColumnsView

, @@ -120,18 +123,30 @@ fn eval_packed_dup( let write_channel = &lv.mem_channels[1]; let read_channel = &lv.mem_channels[2]; + // Constrain the input and top of the stack channels to have the same value. channels_equal_packed(filter, write_channel, &lv.mem_channels[0], yield_constr); + // Constrain the output channel's addresses, `is_read` and `used` fields. constrain_channel_packed(false, filter, P::ZEROS, write_channel, lv, yield_constr); + // Constrain the output and top of the stack channels to have the same value. channels_equal_packed(filter, read_channel, &nv.mem_channels[0], yield_constr); + // Constrain the input channel's addresses, `is_read` and `used` fields. constrain_channel_packed(true, filter, n, read_channel, lv, yield_constr); // Constrain nv.stack_len. yield_constr.constraint_transition(filter * (nv.stack_len - lv.stack_len - P::ONES)); - // TODO: Constrain unused channels? + // Disable next top. + yield_constr.constraint(filter * nv.mem_channels[0].used); + + // Constrain unused channels. + for i in 3..NUM_GP_CHANNELS { + yield_constr.constraint(filter * lv.mem_channels[i].used); + } } +/// Circuit version of `eval_packed_dup`. +/// Evaluates constraints for DUP. fn eval_ext_circuit_dup, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, @@ -148,6 +163,7 @@ fn eval_ext_circuit_dup, const D: usize>( let write_channel = &lv.mem_channels[1]; let read_channel = &lv.mem_channels[2]; + // Constrain the input and top of the stack channels to have the same value. channels_equal_ext_circuit( builder, filter, @@ -155,6 +171,7 @@ fn eval_ext_circuit_dup, const D: usize>( &lv.mem_channels[0], yield_constr, ); + // Constrain the output channel's addresses, `is_read` and `used` fields. constrain_channel_ext_circuit( builder, false, @@ -165,6 +182,7 @@ fn eval_ext_circuit_dup, const D: usize>( yield_constr, ); + // Constrain the output and top of the stack channels to have the same value. channels_equal_ext_circuit( builder, filter, @@ -172,16 +190,30 @@ fn eval_ext_circuit_dup, const D: usize>( &nv.mem_channels[0], yield_constr, ); + // Constrain the input channel's addresses, `is_read` and `used` fields. constrain_channel_ext_circuit(builder, true, filter, n, read_channel, lv, yield_constr); // Constrain nv.stack_len. - let diff = builder.sub_extension(nv.stack_len, lv.stack_len); - let constr = builder.mul_sub_extension(filter, diff, filter); - yield_constr.constraint_transition(builder, constr); + { + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_sub_extension(filter, diff, filter); + yield_constr.constraint_transition(builder, constr); + } - // TODO: Constrain unused channels? + // Disable next top. + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } + + // Constrain unused channels. + for i in 3..NUM_GP_CHANNELS { + let constr = builder.mul_extension(filter, lv.mem_channels[i].used); + yield_constr.constraint(builder, constr); + } } +/// Evaluates constraints for SWAP. fn eval_packed_swap( n: P, lv: &CpuColumnsView

, @@ -197,18 +229,31 @@ fn eval_packed_swap( let in2_channel = &lv.mem_channels[1]; let out_channel = &lv.mem_channels[2]; + // Constrain the first input channel value to be equal to the output channel value. channels_equal_packed(filter, in1_channel, out_channel, yield_constr); + // We set `is_read`, `used` and the address for the first input. The first input is + // read from the top of the stack, and is therefore not a memory read. constrain_channel_packed(false, filter, n_plus_one, out_channel, lv, yield_constr); + // Constrain the second input channel value to be equal to the new top of the stack. channels_equal_packed(filter, in2_channel, &nv.mem_channels[0], yield_constr); + // We set `is_read`, `used` and the address for the second input. constrain_channel_packed(true, filter, n_plus_one, in2_channel, lv, yield_constr); - // Constrain nv.stack_len; + // Constrain nv.stack_len. yield_constr.constraint(filter * (nv.stack_len - lv.stack_len)); - // TODO: Constrain unused channels? + // Disable next top. + yield_constr.constraint(filter * nv.mem_channels[0].used); + + // Constrain unused channels. + for i in 3..NUM_GP_CHANNELS { + yield_constr.constraint(filter * lv.mem_channels[i].used); + } } +/// Circuit version of `eval_packed_swap`. +/// Evaluates constraints for SWAP. fn eval_ext_circuit_swap, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, @@ -226,7 +271,10 @@ fn eval_ext_circuit_swap, const D: usize>( let in2_channel = &lv.mem_channels[1]; let out_channel = &lv.mem_channels[2]; + // Constrain the first input channel value to be equal to the output channel value. channels_equal_ext_circuit(builder, filter, in1_channel, out_channel, yield_constr); + // We set `is_read`, `used` and the address for the first input. The first input is + // read from the top of the stack, and is therefore not a memory read. constrain_channel_ext_circuit( builder, false, @@ -237,6 +285,7 @@ fn eval_ext_circuit_swap, const D: usize>( yield_constr, ); + // Constrain the second input channel value to be equal to the new top of the stack. channels_equal_ext_circuit( builder, filter, @@ -244,6 +293,7 @@ fn eval_ext_circuit_swap, const D: usize>( &nv.mem_channels[0], yield_constr, ); + // We set `is_read`, `used` and the address for the second input. constrain_channel_ext_circuit( builder, true, @@ -259,9 +309,20 @@ fn eval_ext_circuit_swap, const D: usize>( let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); - // TODO: Constrain unused channels? + // Disable next top. + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } + + // Constrain unused channels. + for i in 3..NUM_GP_CHANNELS { + let constr = builder.mul_extension(filter, lv.mem_channels[i].used); + yield_constr.constraint(builder, constr); + } } +/// Evaluates the constraints for the DUP and SWAP opcodes. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -276,6 +337,8 @@ pub fn eval_packed( eval_packed_swap(n, lv, nv, yield_constr); } +/// Circuit version of `eval_packed`. +/// Evaluates the constraints for the DUP and SWAP opcodes. pub fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 1a908d6df4..ed0f33b487 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -24,19 +24,15 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { fp254_op: KERNEL_ONLY_INSTR, eq_iszero: G_VERYLOW, logic_op: G_VERYLOW, - not: G_VERYLOW, + not_pop: None, // This is handled manually below shift: G_VERYLOW, - keccak_general: KERNEL_ONLY_INSTR, + jumpdest_keccak_general: None, // This is handled manually below. prover_input: KERNEL_ONLY_INSTR, - pop: G_BASE, jumps: None, // Combined flag handled separately. - pc: G_BASE, - jumpdest: G_JUMPDEST, - push0: G_BASE, + pc_push0: G_BASE, push: G_VERYLOW, dup_swap: G_VERYLOW, - get_context: KERNEL_ONLY_INSTR, - set_context: KERNEL_ONLY_INSTR, + context_op: KERNEL_ONLY_INSTR, mstore_32bytes: KERNEL_ONLY_INSTR, mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, @@ -105,6 +101,22 @@ fn eval_packed_accumulate( let ternary_op_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) - lv.opcode_bits[1] * P::Scalar::from_canonical_u32(G_MID.unwrap()); yield_constr.constraint_transition(lv.op.ternary_op * (gas_diff - ternary_op_cost)); + + // For NOT and POP. + // NOT is differentiated from POP by its first bit set to 1. + let not_pop_cost = (P::ONES - lv.opcode_bits[0]) + * P::Scalar::from_canonical_u32(G_BASE.unwrap()) + + lv.opcode_bits[0] * P::Scalar::from_canonical_u32(G_VERYLOW.unwrap()); + yield_constr.constraint_transition(lv.op.not_pop * (gas_diff - not_pop_cost)); + + // For JUMPDEST and KECCAK_GENERAL. + // JUMPDEST is differentiated from KECCAK_GENERAL by its second bit set to 1. + let jumpdest_keccak_general_gas_cost = lv.opcode_bits[1] + * P::Scalar::from_canonical_u32(G_JUMPDEST.unwrap()) + + (P::ONES - lv.opcode_bits[1]) * P::Scalar::from_canonical_u32(KERNEL_ONLY_INSTR.unwrap()); + yield_constr.constraint_transition( + lv.op.jumpdest_keccak_general * (gas_diff - jumpdest_keccak_general_gas_cost), + ); } fn eval_packed_init( @@ -121,6 +133,7 @@ fn eval_packed_init( yield_constr.constraint_transition(filter * nv.gas[1]); } +/// Evaluate the gas constraints for the opcodes that cost a constant gas. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -233,6 +246,39 @@ fn eval_ext_circuit_accumulate, const D: usize>( let gas_diff = builder.sub_extension(nv_lv_diff, ternary_op_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); + + // For NOT and POP. + // NOT is differentiated from POP by its first bit set to 1. + let filter = lv.op.not_pop; + let one = builder.one_extension(); + let mut not_pop_cost = + builder.mul_const_extension(F::from_canonical_u32(G_VERYLOW.unwrap()), lv.opcode_bits[0]); + let mut pop_cost = builder.sub_extension(one, lv.opcode_bits[0]); + pop_cost = builder.mul_const_extension(F::from_canonical_u32(G_BASE.unwrap()), pop_cost); + not_pop_cost = builder.add_extension(not_pop_cost, pop_cost); + + let not_pop_gas_diff = builder.sub_extension(nv_lv_diff, not_pop_cost); + let not_pop_constr = builder.mul_extension(filter, not_pop_gas_diff); + yield_constr.constraint_transition(builder, not_pop_constr); + + // For JUMPDEST and KECCAK_GENERAL. + // JUMPDEST is differentiated from KECCAK_GENERAL by its second bit set to 1. + let one = builder.one_extension(); + let filter = lv.op.jumpdest_keccak_general; + + let jumpdest_keccak_general_gas_cost = builder.arithmetic_extension( + F::from_canonical_u32(G_JUMPDEST.unwrap()) + - F::from_canonical_u32(KERNEL_ONLY_INSTR.unwrap()), + F::from_canonical_u32(KERNEL_ONLY_INSTR.unwrap()), + lv.opcode_bits[1], + one, + one, + ); + + let gas_diff = builder.sub_extension(nv_lv_diff, jumpdest_keccak_general_gas_cost); + let constr = builder.mul_extension(filter, gas_diff); + + yield_constr.constraint_transition(builder, constr); } fn eval_ext_circuit_init, const D: usize>( @@ -252,12 +298,16 @@ fn eval_ext_circuit_init, const D: usize>( yield_constr.constraint_transition(builder, constr); } +/// Circuit version of `eval_packed`. +/// Evaluate the gas constraints for the opcodes that cost a constant gas. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + // Evaluates the transition gas constraints. eval_ext_circuit_accumulate(builder, lv, nv, yield_constr); + // Evaluates the initial gas constraints. eval_ext_circuit_init(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/halt.rs b/evm/src/cpu/halt.rs index 9ad34344ea..8ed9aa5bcd 100644 --- a/evm/src/cpu/halt.rs +++ b/evm/src/cpu/halt.rs @@ -11,6 +11,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::membus::NUM_GP_CHANNELS; +/// Evaluates constraints for the `halt` flag. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -45,6 +46,8 @@ pub fn eval_packed( yield_constr.constraint(halt_state * (lv.program_counter - halt_pc)); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for the `halt` flag. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 0c03e2d178..b75ef494e8 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -9,6 +9,7 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; +/// Evaluates constraints for EXIT_KERNEL. pub fn eval_packed_exit_kernel( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -26,6 +27,8 @@ pub fn eval_packed_exit_kernel( yield_constr.constraint_transition(filter * (input[7] - nv.gas[1])); } +/// Circuit version of `eval_packed_exit_kernel`. +/// Evaluates constraints for EXIT_KERNEL. pub fn eval_ext_circuit_exit_kernel, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -59,6 +62,7 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize } } +/// Evaluates constraints jump operations: JUMP and JUMPI. pub fn eval_packed_jump_jumpi( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -156,6 +160,8 @@ pub fn eval_packed_jump_jumpi( .constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - jump_dest)); } +/// Circuit version of `eval_packed_jumpi_jumpi`. +/// Evaluates constraints jump operations: JUMP and JUMPI. pub fn eval_ext_circuit_jump_jumpi, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -353,6 +359,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> } } +/// Evaluates constraints for EXIT_KERNEL, JUMP and JUMPI. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -362,6 +369,8 @@ pub fn eval_packed( eval_packed_jump_jumpi(lv, nv, yield_constr); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for EXIT_KERNEL, JUMP and JUMPI. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index bda2ab610e..522da72a25 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -110,7 +110,12 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/memory/packing.asm"), include_str!("asm/memory/syscalls.asm"), include_str!("asm/memory/txn_fields.asm"), - include_str!("asm/mpt/accounts.asm"), + include_str!("asm/smt/load.asm"), + include_str!("asm/smt/hash.asm"), + include_str!("asm/smt/insert.asm"), + include_str!("asm/smt/read.asm"), + include_str!("asm/smt/utils.asm"), + include_str!("asm/smt/accounts.asm"), include_str!("asm/mpt/delete/delete.asm"), include_str!("asm/mpt/delete/delete_branch.asm"), include_str!("asm/mpt/delete/delete_extension.asm"), diff --git a/evm/src/cpu/kernel/asm/account_code.asm b/evm/src/cpu/kernel/asm/account_code.asm index ee19819837..65b2962672 100644 --- a/evm/src/cpu/kernel/asm/account_code.asm +++ b/evm/src/cpu/kernel/asm/account_code.asm @@ -24,7 +24,7 @@ extcodehash_dead: global extcodehash: // stack: address, retdest - %mpt_read_state_trie + %smt_read_state // stack: account_ptr, retdest DUP1 ISZERO %jumpi(retzero) %add_const(3) @@ -80,110 +80,6 @@ global extcodesize: // stack: extcodesize(address), retdest SWAP1 JUMP -%macro extcodecopy - // stack: address, dest_offset, offset, size - %stack (address, dest_offset, offset, size) -> (address, dest_offset, offset, size, %%after) - %jump(extcodecopy) -%%after: -%endmacro - -// Pre stack: kexit_info, address, dest_offset, offset, size -// Post stack: (empty) -global sys_extcodecopy: - %stack (kexit_info, address, dest_offset, offset, size) - -> (address, dest_offset, offset, size, kexit_info) - %u256_to_addr DUP1 %insert_accessed_addresses - // stack: cold_access, address, dest_offset, offset, size, kexit_info - PUSH @GAS_COLDACCOUNTACCESS_MINUS_WARMACCESS - MUL - PUSH @GAS_WARMACCESS - ADD - // stack: Gaccess, address, dest_offset, offset, size, kexit_info - - DUP5 - // stack: size, Gaccess, address, dest_offset, offset, size, kexit_info - ISZERO %jumpi(sys_extcodecopy_empty) - - // stack: Gaccess, address, dest_offset, offset, size, kexit_info - DUP5 %num_bytes_to_num_words %mul_const(@GAS_COPY) ADD - %stack (gas, address, dest_offset, offset, size, kexit_info) -> (gas, kexit_info, address, dest_offset, offset, size) - %charge_gas - - %stack (kexit_info, address, dest_offset, offset, size) -> (dest_offset, size, kexit_info, address, dest_offset, offset, size) - %add_or_fault - // stack: expanded_num_bytes, kexit_info, address, dest_offset, offset, size - DUP1 %ensure_reasonable_offset - %update_mem_bytes - - %stack (kexit_info, address, dest_offset, offset, size) -> (address, dest_offset, offset, size, kexit_info) - %extcodecopy - // stack: kexit_info - EXIT_KERNEL - -sys_extcodecopy_empty: - %stack (Gaccess, address, dest_offset, offset, size, kexit_info) -> (Gaccess, kexit_info) - %charge_gas - EXIT_KERNEL - - -// Pre stack: address, dest_offset, offset, size, retdest -// Post stack: (empty) -global extcodecopy: - // stack: address, dest_offset, offset, size, retdest - %stack (address, dest_offset, offset, size, retdest) - -> (address, 0, @SEGMENT_KERNEL_ACCOUNT_CODE, extcodecopy_contd, size, offset, dest_offset, retdest) - %jump(load_code) - -extcodecopy_contd: - // stack: code_size, size, offset, dest_offset, retdest - DUP1 DUP4 - // stack: offset, code_size, code_size, size, offset, dest_offset, retdest - GT %jumpi(extcodecopy_large_offset) - - // stack: code_size, size, offset, dest_offset, retdest - DUP3 DUP3 ADD - // stack: offset + size, code_size, size, offset, dest_offset, retdest - DUP2 GT %jumpi(extcodecopy_within_bounds) - - // stack: code_size, size, offset, dest_offset, retdest - DUP3 DUP3 ADD - // stack: offset + size, code_size, size, offset, dest_offset, retdest - SUB - // stack: extra_size = offset + size - code_size, size, offset, dest_offset, retdest - DUP1 DUP3 SUB - // stack: copy_size = size - extra_size, extra_size, size, offset, dest_offset, retdest - - // Compute the new dest_offset after actual copies, at which we will start padding with zeroes. - DUP1 DUP6 ADD - // stack: new_dest_offset, copy_size, extra_size, size, offset, dest_offset, retdest - - GET_CONTEXT - %stack (context, new_dest_offset, copy_size, extra_size, size, offset, dest_offset, retdest) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, 0, @SEGMENT_KERNEL_ACCOUNT_CODE, offset, copy_size, extcodecopy_end, new_dest_offset, extra_size, retdest) - %jump(memcpy_bytes) - -extcodecopy_within_bounds: - // stack: code_size, size, offset, dest_offset, retdest - GET_CONTEXT - %stack (context, code_size, size, offset, dest_offset, retdest) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, 0, @SEGMENT_KERNEL_ACCOUNT_CODE, offset, size, retdest) - %jump(memcpy_bytes) - -// Same as extcodecopy_large_offset, but without `offset` in the stack. -extcodecopy_end: - // stack: dest_offset, size, retdest - GET_CONTEXT - %stack (context, dest_offset, size, retdest) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, size, retdest) - %jump(memset) - -extcodecopy_large_offset: - // offset is larger than the code size. So we just have to write zeros. - // stack: code_size, size, offset, dest_offset, retdest - GET_CONTEXT - %stack (context, code_size, size, offset, dest_offset, retdest) -> (context, @SEGMENT_MAIN_MEMORY, dest_offset, size, retdest) - %jump(memset) - // Loads the code at `address` into memory, at the given context and segment, starting at offset 0. // Checks that the hash of the loaded code corresponds to the `codehash` in the state trie. // Pre stack: address, ctx, segment, retdest diff --git a/evm/src/cpu/kernel/asm/balance.asm b/evm/src/cpu/kernel/asm/balance.asm index f175d027c9..1dc480b5f7 100644 --- a/evm/src/cpu/kernel/asm/balance.asm +++ b/evm/src/cpu/kernel/asm/balance.asm @@ -27,9 +27,9 @@ global sys_balance: global balance: // stack: address, retdest - %mpt_read_state_trie + %smt_read_state // stack: account_ptr, retdest - DUP1 ISZERO %jumpi(retzero) // If the account pointer is null, return 0. + // No need to consider the case where `account_ptr=0` because `trie_data[1]=0`. %add_const(1) // stack: balance_ptr, retdest %mload_trie_data diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index af5ab3196c..25e331cb95 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -388,7 +388,7 @@ call_too_deep: args_size, %%after, // count, retdest new_ctx, args_size ) - %jump(memcpy) + %jump(memcpy_bytes) %%after: %stack (new_ctx, args_size) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALLDATA_SIZE, args_size) @@ -410,7 +410,7 @@ call_too_deep: n, %%after, // count, retdest kexit_info, success ) - %jump(memcpy) + %jump(memcpy_bytes) %%after: %endmacro diff --git a/evm/src/cpu/kernel/asm/core/create.asm b/evm/src/cpu/kernel/asm/core/create.asm index ddaf96de03..4756b40702 100644 --- a/evm/src/cpu/kernel/asm/core/create.asm +++ b/evm/src/cpu/kernel/asm/core/create.asm @@ -104,7 +104,7 @@ global create_common: code_len, run_constructor, new_ctx, value, address) - %jump(memcpy) + %jump(memcpy_bytes) run_constructor: // stack: new_ctx, value, address, kexit_info diff --git a/evm/src/cpu/kernel/asm/core/create_receipt.asm b/evm/src/cpu/kernel/asm/core/create_receipt.asm index ec9b1fbd21..9478b19031 100644 --- a/evm/src/cpu/kernel/asm/core/create_receipt.asm +++ b/evm/src/cpu/kernel/asm/core/create_receipt.asm @@ -82,7 +82,7 @@ process_receipt_after_type: PUSH 0 PUSH @SEGMENT_TXN_BLOOM PUSH 0 // Bloom memory address. %get_trie_data_size PUSH @SEGMENT_TRIE_DATA PUSH 0 // MPT dest address. // stack: DST, SRC, 256, receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest - %memcpy + %memcpy_bytes // stack: receipt_ptr, txn_nb, new_cum_gas, txn_nb, num_nibbles, retdest // Update trie data size. %get_trie_data_size diff --git a/evm/src/cpu/kernel/asm/core/nonce.asm b/evm/src/cpu/kernel/asm/core/nonce.asm index 48486be9e2..2a4796645c 100644 --- a/evm/src/cpu/kernel/asm/core/nonce.asm +++ b/evm/src/cpu/kernel/asm/core/nonce.asm @@ -3,7 +3,7 @@ // Post stack: (empty) global nonce: // stack: address, retdest - %mpt_read_state_trie + %smt_read_state // stack: account_ptr, retdest // The nonce is the first account field, so we deref the account pointer itself. // Note: We don't need to handle account_ptr=0, as trie_data[0] = 0, @@ -23,7 +23,7 @@ global nonce: global increment_nonce: // stack: address, retdest DUP1 - %mpt_read_state_trie + %smt_read_state // stack: account_ptr, address, retdest DUP1 ISZERO %jumpi(increment_nonce_no_such_account) // stack: nonce_ptr, address, retdest diff --git a/evm/src/cpu/kernel/asm/core/precompiles/id.asm b/evm/src/cpu/kernel/asm/core/precompiles/id.asm index 0aa0894fd0..83cee0d042 100644 --- a/evm/src/cpu/kernel/asm/core/precompiles/id.asm +++ b/evm/src/cpu/kernel/asm/core/precompiles/id.asm @@ -32,7 +32,7 @@ global precompile_id: ctx, @SEGMENT_CALLDATA, 0, // SRC size, id_contd // count, retdest ) - %jump(memcpy) + %jump(memcpy_bytes) id_contd: // stack: kexit_info diff --git a/evm/src/cpu/kernel/asm/core/precompiles/main.asm b/evm/src/cpu/kernel/asm/core/precompiles/main.asm index b45b46cb0d..d6cb100bdc 100644 --- a/evm/src/cpu/kernel/asm/core/precompiles/main.asm +++ b/evm/src/cpu/kernel/asm/core/precompiles/main.asm @@ -59,7 +59,7 @@ global handle_precompiles_from_eoa: %stack (calldata_size, new_ctx) -> (calldata_size, new_ctx, calldata_size) %set_new_ctx_calldata_size %stack (new_ctx, calldata_size) -> (new_ctx, @SEGMENT_CALLDATA, 0, 0, @SEGMENT_TXN_DATA, 0, calldata_size, handle_precompiles_from_eoa_finish, new_ctx) - %jump(memcpy) + %jump(memcpy_bytes) handle_precompiles_from_eoa_finish: %stack (new_ctx, addr, retdest) -> (addr, new_ctx, retdest) diff --git a/evm/src/cpu/kernel/asm/core/precompiles/rip160.asm b/evm/src/cpu/kernel/asm/core/precompiles/rip160.asm index 20ea42cb58..0e5aee8cb1 100644 --- a/evm/src/cpu/kernel/asm/core/precompiles/rip160.asm +++ b/evm/src/cpu/kernel/asm/core/precompiles/rip160.asm @@ -47,7 +47,7 @@ global precompile_rip160: PUSH @SEGMENT_KERNEL_GENERAL DUP3 - %jump(memcpy) + %jump(memcpy_bytes) rip160_contd: // stack: hash, kexit_info diff --git a/evm/src/cpu/kernel/asm/core/precompiles/sha256.asm b/evm/src/cpu/kernel/asm/core/precompiles/sha256.asm index 97cf0f026f..6dad0745ba 100644 --- a/evm/src/cpu/kernel/asm/core/precompiles/sha256.asm +++ b/evm/src/cpu/kernel/asm/core/precompiles/sha256.asm @@ -49,7 +49,7 @@ global precompile_sha256: PUSH @SEGMENT_KERNEL_GENERAL DUP3 - %jump(memcpy) + %jump(memcpy_bytes) sha256_contd: // stack: hash, kexit_info diff --git a/evm/src/cpu/kernel/asm/core/process_txn.asm b/evm/src/cpu/kernel/asm/core/process_txn.asm index 779acab1da..95a36831b4 100644 --- a/evm/src/cpu/kernel/asm/core/process_txn.asm +++ b/evm/src/cpu/kernel/asm/core/process_txn.asm @@ -163,7 +163,7 @@ global process_contract_creation_txn: PUSH 0 // DST.offset PUSH @SEGMENT_CODE // DST.segment DUP8 // DST.context = new_ctx - %jump(memcpy) + %jump(memcpy_bytes) global process_contract_creation_txn_after_code_loaded: // stack: new_ctx, address, retdest @@ -294,7 +294,7 @@ global process_message_txn_code_loaded: %stack (calldata_size, new_ctx, retdest) -> (calldata_size, new_ctx, calldata_size, retdest) %set_new_ctx_calldata_size %stack (new_ctx, calldata_size, retdest) -> (new_ctx, @SEGMENT_CALLDATA, 0, 0, @SEGMENT_TXN_DATA, 0, calldata_size, process_message_txn_code_loaded_finish, new_ctx, retdest) - %jump(memcpy) + %jump(memcpy_bytes) process_message_txn_code_loaded_finish: %enter_new_ctx diff --git a/evm/src/cpu/kernel/asm/core/terminate.asm b/evm/src/cpu/kernel/asm/core/terminate.asm index 7bb7842fe3..bdbd3e5886 100644 --- a/evm/src/cpu/kernel/asm/core/terminate.asm +++ b/evm/src/cpu/kernel/asm/core/terminate.asm @@ -45,7 +45,7 @@ return_after_gas: ctx, @SEGMENT_MAIN_MEMORY, offset, // SRC size, sys_return_finish, kexit_info // count, retdest, ... ) - %jump(memcpy) + %jump(memcpy_bytes) sys_return_finish: // stack: kexit_info @@ -145,7 +145,7 @@ revert_after_gas: ctx, @SEGMENT_MAIN_MEMORY, offset, // SRC size, sys_revert_finish, kexit_info // count, retdest, ... ) - %jump(memcpy) + %jump(memcpy_bytes) sys_revert_finish: %leftover_gas diff --git a/evm/src/cpu/kernel/asm/core/transfer.asm b/evm/src/cpu/kernel/asm/core/transfer.asm index 0517cf3a8f..05127061af 100644 --- a/evm/src/cpu/kernel/asm/core/transfer.asm +++ b/evm/src/cpu/kernel/asm/core/transfer.asm @@ -29,7 +29,7 @@ global transfer_eth_failure: global deduct_eth: // stack: addr, amount, retdest DUP1 %insert_touched_addresses - %mpt_read_state_trie + %smt_read_state // stack: account_ptr, amount, retdest DUP1 ISZERO %jumpi(deduct_eth_no_such_account) // If the account pointer is null, return 1. %add_const(1) @@ -65,7 +65,7 @@ global deduct_eth_insufficient_balance: global add_eth: // stack: addr, amount, retdest DUP1 %insert_touched_addresses - DUP1 %mpt_read_state_trie + DUP1 %smt_read_state // stack: account_ptr, addr, amount, retdest DUP1 ISZERO %jumpi(add_eth_new_account) // If the account pointer is null, we need to create the account. %add_const(1) @@ -90,6 +90,7 @@ global add_eth_new_account: // stack: new_account_ptr, addr, amount, retdest SWAP2 // stack: amount, addr, new_account_ptr, retdest + PUSH 0 %append_to_trie_data // key PUSH 0 %append_to_trie_data // nonce %append_to_trie_data // balance // stack: addr, new_account_ptr, retdest @@ -98,7 +99,7 @@ global add_eth_new_account: // stack: addr, new_account_ptr, retdest %addr_to_state_key // stack: key, new_account_ptr, retdest - %jump(mpt_insert_state_trie) + %jump(smt_insert_state) add_eth_new_account_zero: // stack: addr, amount, retdest diff --git a/evm/src/cpu/kernel/asm/core/util.asm b/evm/src/cpu/kernel/asm/core/util.asm index ee33ff26ca..a186520cd7 100644 --- a/evm/src/cpu/kernel/asm/core/util.asm +++ b/evm/src/cpu/kernel/asm/core/util.asm @@ -40,7 +40,7 @@ // Returns 1 if the account is empty, 0 otherwise. %macro is_empty // stack: addr - %mpt_read_state_trie + %smt_read_state // stack: account_ptr DUP1 ISZERO %jumpi(%%false) // stack: account_ptr diff --git a/evm/src/cpu/kernel/asm/main.asm b/evm/src/cpu/kernel/asm/main.asm index bd555218be..24ed0a43d9 100644 --- a/evm/src/cpu/kernel/asm/main.asm +++ b/evm/src/cpu/kernel/asm/main.asm @@ -10,7 +10,7 @@ global main: %jump(load_all_mpts) global hash_initial_tries: - %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) %assert_eq + %smt_hash_state %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) %assert_eq %mpt_hash_txn_trie %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE) %assert_eq %mpt_hash_receipt_trie %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_BEFORE) %assert_eq @@ -55,7 +55,7 @@ global hash_final_tries: DUP3 %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_AFTER) %assert_eq %pop3 %check_metadata_block_bloom - %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) %assert_eq + %smt_hash_state %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) %assert_eq %mpt_hash_txn_trie %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_AFTER) %assert_eq %mpt_hash_receipt_trie %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER) %assert_eq %jump(halt) diff --git a/evm/src/cpu/kernel/asm/memory/memcpy.asm b/evm/src/cpu/kernel/asm/memory/memcpy.asm index e737dc33ca..99532dcb1b 100644 --- a/evm/src/cpu/kernel/asm/memory/memcpy.asm +++ b/evm/src/cpu/kernel/asm/memory/memcpy.asm @@ -42,12 +42,6 @@ global memcpy: // Continue the loop. %jump(memcpy) -memcpy_finish: - // stack: DST, SRC, count, retdest - %pop7 - // stack: retdest - JUMP - %macro memcpy %stack (dst: 3, src: 3, count) -> (dst, src, count, %%after) %jump(memcpy) @@ -58,15 +52,6 @@ memcpy_finish: global memcpy_bytes: // stack: DST, SRC, count, retdest - // Handle empty case - DUP7 - // stack: count, DST, SRC, count, retdest - ISZERO - // stack: count == 0, DST, SRC, count, retdest - %jumpi(memcpy_bytes_empty) - - // stack: DST, SRC, count, retdest - // Handle small case DUP7 // stack: count, DST, SRC, count, retdest @@ -110,6 +95,15 @@ global memcpy_bytes: memcpy_bytes_finish: // stack: DST, SRC, count, retdest + // Handle empty case + DUP7 + // stack: count, DST, SRC, count, retdest + ISZERO + // stack: count == 0, DST, SRC, count, retdest + %jumpi(memcpy_finish) + + // stack: DST, SRC, count, retdest + // Copy the last chunk of `count` bytes. DUP7 DUP1 @@ -126,12 +120,8 @@ memcpy_bytes_finish: MSTORE_32BYTES // stack: DST, SRC, count, retdest - %pop7 - // stack: retdest - JUMP - -memcpy_bytes_empty: - // stack: DST, SRC, 0, retdest +memcpy_finish: + // stack: DST, SRC, count, retdest %pop7 // stack: retdest JUMP diff --git a/evm/src/cpu/kernel/asm/memory/syscalls.asm b/evm/src/cpu/kernel/asm/memory/syscalls.asm index 1820056715..3a8c16184d 100644 --- a/evm/src/cpu/kernel/asm/memory/syscalls.asm +++ b/evm/src/cpu/kernel/asm/memory/syscalls.asm @@ -70,15 +70,10 @@ calldataload_large_offset: %stack (kexit_info, i) -> (kexit_info, 0) EXIT_KERNEL -// Macro for {CALLDATA,CODE,RETURNDATA}COPY (W_copy in Yellow Paper). +// Macro for {CALLDATA, RETURNDATA}COPY (W_copy in Yellow Paper). %macro wcopy(segment, context_metadata_size) // stack: kexit_info, dest_offset, offset, size - PUSH @GAS_VERYLOW - DUP5 - // stack: size, Gverylow, kexit_info, dest_offset, offset, size - ISZERO %jumpi(wcopy_empty) - // stack: Gverylow, kexit_info, dest_offset, offset, size - DUP5 %num_bytes_to_num_words %mul_const(@GAS_COPY) ADD %charge_gas + %wcopy_charge_gas %stack (kexit_info, dest_offset, offset, size) -> (dest_offset, size, kexit_info, dest_offset, offset, size) %add_or_fault @@ -92,54 +87,44 @@ calldataload_large_offset: // stack: offset, total_size, kexit_info, dest_offset, offset, size GT %jumpi(wcopy_large_offset) + // stack: kexit_info, dest_offset, offset, size + GET_CONTEXT PUSH $segment - %mload_context_metadata($context_metadata_size) - // stack: total_size, segment, kexit_info, dest_offset, offset, size - DUP6 DUP6 ADD - // stack: offset + size, total_size, segment, kexit_info, dest_offset, offset, size - LT %jumpi(wcopy_within_bounds) - - %mload_context_metadata($context_metadata_size) - // stack: total_size, segment, kexit_info, dest_offset, offset, size - DUP6 DUP6 ADD - // stack: offset + size, total_size, segment, kexit_info, dest_offset, offset, size - SUB // extra_size = offset + size - total_size - // stack: extra_size, segment, kexit_info, dest_offset, offset, size - DUP1 DUP7 SUB - // stack: copy_size = size - extra_size, extra_size, segment, kexit_info, dest_offset, offset, size - - // Compute the new dest_offset after actual copies, at which we will start padding with zeroes. - DUP1 DUP6 ADD - // stack: new_dest_offset, copy_size, extra_size, segment, kexit_info, dest_offset, offset, size + // stack: segment, context, kexit_info, dest_offset, offset, size + %jump(wcopy_within_bounds) +%endmacro - GET_CONTEXT - %stack (context, new_dest_offset, copy_size, extra_size, segment, kexit_info, dest_offset, offset, size) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, context, segment, offset, copy_size, wcopy_over_range, new_dest_offset, extra_size, kexit_info) - %jump(memcpy_bytes) +%macro wcopy_charge_gas + // stack: kexit_info, dest_offset, offset, size + PUSH @GAS_VERYLOW + DUP5 + // stack: size, Gverylow, kexit_info, dest_offset, offset, size + ISZERO %jumpi(wcopy_empty) + // stack: Gverylow, kexit_info, dest_offset, offset, size + DUP5 %num_bytes_to_num_words %mul_const(@GAS_COPY) ADD %charge_gas %endmacro + +codecopy_within_bounds: + // stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size + POP wcopy_within_bounds: - // stack: segment, kexit_info, dest_offset, offset, size + // stack: segment, src_ctx, kexit_info, dest_offset, offset, size GET_CONTEXT - %stack (context, segment, kexit_info, dest_offset, offset, size) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, context, segment, offset, size, wcopy_after, kexit_info) + %stack (context, segment, src_ctx, kexit_info, dest_offset, offset, size) -> + (context, @SEGMENT_MAIN_MEMORY, dest_offset, src_ctx, segment, offset, size, wcopy_after, kexit_info) %jump(memcpy_bytes) - -// Same as wcopy_large_offset, but without `offset` in the stack. -wcopy_over_range: - // stack: dest_offset, size, kexit_info - GET_CONTEXT - %stack (context, dest_offset, size, kexit_info) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, size, wcopy_after, kexit_info) - %jump(memset) - wcopy_empty: // stack: Gverylow, kexit_info, dest_offset, offset, size %charge_gas %stack (kexit_info, dest_offset, offset, size) -> (kexit_info) EXIT_KERNEL + +codecopy_large_offset: + // stack: total_size, src_ctx, kexit_info, dest_offset, offset, size + %pop2 wcopy_large_offset: // offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros. // stack: kexit_info, dest_offset, offset, size @@ -152,64 +137,107 @@ wcopy_after: // stack: kexit_info EXIT_KERNEL +// Pre stack: kexit_info, dest_offset, offset, size +// Post stack: (empty) global sys_calldatacopy: %wcopy(@SEGMENT_CALLDATA, @CTX_METADATA_CALLDATA_SIZE) -global sys_codecopy: - %wcopy(@SEGMENT_CODE, @CTX_METADATA_CODE_SIZE) - -// Same as %wcopy but with overflow checks. +// Pre stack: kexit_info, dest_offset, offset, size +// Post stack: (empty) global sys_returndatacopy: + DUP4 DUP4 %add_or_fault // Overflow check + %mload_context_metadata(@CTX_METADATA_RETURNDATA_SIZE) LT %jumpi(fault_exception) // Data len check + + %wcopy(@SEGMENT_RETURNDATA, @CTX_METADATA_RETURNDATA_SIZE) + +// Pre stack: kexit_info, dest_offset, offset, size +// Post stack: (empty) +global sys_codecopy: // stack: kexit_info, dest_offset, offset, size - PUSH @GAS_VERYLOW - // stack: Gverylow, kexit_info, dest_offset, offset, size - DUP5 %num_bytes_to_num_words %mul_const(@GAS_COPY) ADD %charge_gas + %wcopy_charge_gas %stack (kexit_info, dest_offset, offset, size) -> (dest_offset, size, kexit_info, dest_offset, offset, size) %add_or_fault // stack: expanded_num_bytes, kexit_info, dest_offset, offset, size, kexit_info DUP1 %ensure_reasonable_offset %update_mem_bytes - // stack: kexit_info, dest_offset, offset, size, kexit_info - DUP4 DUP4 %add_or_fault // Overflow check - %mload_context_metadata(@CTX_METADATA_RETURNDATA_SIZE) LT %jumpi(fault_exception) // Data len check - // stack: kexit_info, dest_offset, offset, size - DUP4 - // stack: size, kexit_info, dest_offset, offset, size - ISZERO %jumpi(returndatacopy_empty) + GET_CONTEXT + %mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // stack: code_size, ctx, kexit_info, dest_offset, offset, size + %codecopy_after_checks(@SEGMENT_CODE) + + +// Pre stack: kexit_info, address, dest_offset, offset, size +// Post stack: (empty) +global sys_extcodecopy: + %stack (kexit_info, address, dest_offset, offset, size) + -> (address, dest_offset, offset, size, kexit_info) + %u256_to_addr DUP1 %insert_accessed_addresses + // stack: cold_access, address, dest_offset, offset, size, kexit_info + PUSH @GAS_COLDACCOUNTACCESS_MINUS_WARMACCESS + MUL + PUSH @GAS_WARMACCESS + ADD + // stack: Gaccess, address, dest_offset, offset, size, kexit_info - %mload_context_metadata(@CTX_METADATA_RETURNDATA_SIZE) - // stack: total_size, kexit_info, dest_offset, offset, size - DUP4 - // stack: offset, total_size, kexit_info, dest_offset, offset, size - GT %jumpi(wcopy_large_offset) + DUP5 + // stack: size, Gaccess, address, dest_offset, offset, size, kexit_info + ISZERO %jumpi(sys_extcodecopy_empty) + + // stack: Gaccess, address, dest_offset, offset, size, kexit_info + DUP5 %num_bytes_to_num_words %mul_const(@GAS_COPY) ADD + %stack (gas, address, dest_offset, offset, size, kexit_info) -> (gas, kexit_info, address, dest_offset, offset, size) + %charge_gas + + %stack (kexit_info, address, dest_offset, offset, size) -> (dest_offset, size, kexit_info, address, dest_offset, offset, size) + %add_or_fault + // stack: expanded_num_bytes, kexit_info, address, dest_offset, offset, size + DUP1 %ensure_reasonable_offset + %update_mem_bytes + + %stack (kexit_info, address, dest_offset, offset, size) -> + (address, 0, @SEGMENT_KERNEL_ACCOUNT_CODE, extcodecopy_contd, 0, kexit_info, dest_offset, offset, size) + %jump(load_code) + +sys_extcodecopy_empty: + %stack (Gaccess, address, dest_offset, offset, size, kexit_info) -> (Gaccess, kexit_info) + %charge_gas + EXIT_KERNEL - PUSH @SEGMENT_RETURNDATA - %mload_context_metadata(@CTX_METADATA_RETURNDATA_SIZE) - // stack: total_size, returndata_segment, kexit_info, dest_offset, offset, size - DUP6 DUP6 ADD - // stack: offset + size, total_size, returndata_segment, kexit_info, dest_offset, offset, size - LT %jumpi(wcopy_within_bounds) - - %mload_context_metadata(@CTX_METADATA_RETURNDATA_SIZE) - // stack: total_size, returndata_segment, kexit_info, dest_offset, offset, size - DUP6 DUP6 ADD - // stack: offset + size, total_size, returndata_segment, kexit_info, dest_offset, offset, size +extcodecopy_contd: + // stack: code_size, src_ctx, kexit_info, dest_offset, offset, size + %codecopy_after_checks(@SEGMENT_KERNEL_ACCOUNT_CODE) + + +// The internal logic is similar to wcopy, but handles range overflow differently. +// It is used for both CODECOPY and EXTCODECOPY. +%macro codecopy_after_checks(segment) + // stack: total_size, src_ctx, kexit_info, dest_offset, offset, size + DUP1 DUP6 + // stack: offset, total_size, total_size, src_ctx, kexit_info, dest_offset, offset, size + GT %jumpi(codecopy_large_offset) + + PUSH $segment SWAP1 + // stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size + DUP1 DUP8 DUP8 ADD + // stack: offset + size, total_size, total_size, segment, src_ctx, kexit_info, dest_offset, offset, size + LT %jumpi(codecopy_within_bounds) + + // stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size + DUP7 DUP7 ADD + // stack: offset + size, total_size, segment, src_ctx, kexit_info, dest_offset, offset, size SUB // extra_size = offset + size - total_size - // stack: extra_size, returndata_segment, kexit_info, dest_offset, offset, size - DUP1 DUP7 SUB - // stack: copy_size = size - extra_size, extra_size, returndata_segment, kexit_info, dest_offset, offset, size + // stack: extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size + DUP1 DUP8 SUB + // stack: copy_size = size - extra_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size // Compute the new dest_offset after actual copies, at which we will start padding with zeroes. - DUP1 DUP6 ADD - // stack: new_dest_offset, copy_size, extra_size, returndata_segment, kexit_info, dest_offset, offset, size + DUP1 DUP7 ADD + // stack: new_dest_offset, copy_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size GET_CONTEXT - %stack (context, new_dest_offset, copy_size, extra_size, returndata_segment, kexit_info, dest_offset, offset, size) -> - (context, @SEGMENT_MAIN_MEMORY, dest_offset, context, returndata_segment, offset, copy_size, wcopy_over_range, new_dest_offset, extra_size, kexit_info) + %stack (context, new_dest_offset, copy_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size) -> + (context, @SEGMENT_MAIN_MEMORY, dest_offset, src_ctx, segment, offset, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size) %jump(memcpy_bytes) - -returndatacopy_empty: - %stack (kexit_info, dest_offset, offset, size) -> (kexit_info) - EXIT_KERNEL +%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/load/load.asm b/evm/src/cpu/kernel/asm/mpt/load/load.asm index d787074b4f..fd378685e4 100644 --- a/evm/src/cpu/kernel/asm/mpt/load/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load/load.asm @@ -1,12 +1,8 @@ // Load all partial trie data from prover inputs. global load_all_mpts: // stack: retdest - // First set @GLOBAL_METADATA_TRIE_DATA_SIZE = 1. - // We don't want it to start at 0, as we use 0 as a null pointer. - PUSH 1 - %set_trie_data_size - %load_mpt(mpt_load_state_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %load_state_smt %load_mpt(mpt_load_txn_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) %load_mpt(mpt_load_receipt_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) diff --git a/evm/src/cpu/kernel/asm/mpt/storage/storage_read.asm b/evm/src/cpu/kernel/asm/mpt/storage/storage_read.asm index 84d8d0efc9..7b889120a0 100644 --- a/evm/src/cpu/kernel/asm/mpt/storage/storage_read.asm +++ b/evm/src/cpu/kernel/asm/mpt/storage/storage_read.asm @@ -8,10 +8,9 @@ global sload_current: %stack (slot) -> (slot, after_storage_read) %slot_to_storage_key // stack: storage_key, after_storage_read - PUSH 64 // storage_key has 64 nibbles - %current_storage_trie - // stack: storage_root_ptr, 64, storage_key, after_storage_read - %jump(mpt_read) + %current_storage_smt + // stack: storage_root_ptr, storage_key, after_storage_read + %jump(smt_read) global after_storage_read: // stack: value_ptr, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm index 08270dfa9e..5caafb3b5f 100644 --- a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm +++ b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm @@ -1,4 +1,4 @@ -// Write a word to the current account's storage trie. +// Write a word to the current account's storage SMT. // // Pre stack: kexit_info, slot, value // Post stack: (empty) @@ -92,26 +92,27 @@ sstore_after_refund: DUP2 %address %journal_add_storage_change // stack: slot, value, kexit_info - // If the value is zero, delete the slot from the storage trie. + // If the value is zero, delete the slot from the storage SMT. // stack: slot, value, kexit_info DUP2 ISZERO %jumpi(sstore_delete) - // First we write the value to MPT data, and get a pointer to it. + // First we write the value to SMT data, and get a pointer to it. %get_trie_data_size // stack: value_ptr, slot, value, kexit_info + PUSH 0 %append_to_trie_data // For the key. + // stack: value_ptr, slot, value, kexit_info SWAP2 // stack: value, slot, value_ptr, kexit_info %append_to_trie_data // stack: slot, value_ptr, kexit_info - // Next, call mpt_insert on the current account's storage root. + // Next, call smt_insert on the current account's storage root. %stack (slot, value_ptr) -> (slot, value_ptr, after_storage_insert) %slot_to_storage_key // stack: storage_key, value_ptr, after_storage_insert, kexit_info - PUSH 64 // storage_key has 64 nibbles - %current_storage_trie - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_storage_insert, kexit_info - %jump(mpt_insert) + %current_storage_smt + // stack: storage_root_ptr, storage_key, value_ptr, after_storage_insert, kexit_info + %jump(smt_insert) after_storage_insert: // stack: new_storage_root_ptr, kexit_info @@ -130,8 +131,9 @@ sstore_noop: %pop3 EXIT_KERNEL -// Delete the slot from the storage trie. +// Delete the slot from the storage SMT. sstore_delete: + PANIC // TODO: Not implemented for SMT. // stack: slot, value, kexit_info SWAP1 POP PUSH after_storage_insert SWAP1 @@ -139,6 +141,6 @@ sstore_delete: %slot_to_storage_key // stack: storage_key, after_storage_insert, kexit_info PUSH 64 // storage_key has 64 nibbles - %current_storage_trie + %current_storage_smt // stack: storage_root_ptr, 64, storage_key, after_storage_insert, kexit_info %jump(mpt_delete) diff --git a/evm/src/cpu/kernel/asm/rlp/encode_rlp_string.asm b/evm/src/cpu/kernel/asm/rlp/encode_rlp_string.asm index 1065c61209..1c8bec9673 100644 --- a/evm/src/cpu/kernel/asm/rlp/encode_rlp_string.asm +++ b/evm/src/cpu/kernel/asm/rlp/encode_rlp_string.asm @@ -33,7 +33,7 @@ global encode_rlp_string_small: // stack: pos'', pos', ADDR: 3, len, retdest %stack (pos2, pos1, ADDR: 3, len, retdest) -> (0, @SEGMENT_RLP_RAW, pos1, ADDR, len, retdest, pos2) - %jump(memcpy) + %jump(memcpy_bytes) global encode_rlp_string_small_single_byte: // stack: pos, ADDR: 3, len, retdest @@ -71,7 +71,7 @@ global encode_rlp_string_large_after_writing_len: // stack: pos''', pos'', ADDR: 3, len, retdest %stack (pos3, pos2, ADDR: 3, len, retdest) -> (0, @SEGMENT_RLP_RAW, pos2, ADDR, len, retdest, pos3) - %jump(memcpy) + %jump(memcpy_bytes) %macro encode_rlp_string %stack (pos, ADDR: 3, len) -> (pos, ADDR, len, %%after) diff --git a/evm/src/cpu/kernel/asm/mpt/accounts.asm b/evm/src/cpu/kernel/asm/smt/accounts.asm similarity index 91% rename from evm/src/cpu/kernel/asm/mpt/accounts.asm rename to evm/src/cpu/kernel/asm/smt/accounts.asm index 0ee987b4c1..439e18783b 100644 --- a/evm/src/cpu/kernel/asm/mpt/accounts.asm +++ b/evm/src/cpu/kernel/asm/smt/accounts.asm @@ -1,6 +1,6 @@ // Return a pointer to the current account's data in the state trie. %macro current_account_data - %address %mpt_read_state_trie + %address %smt_read_state // stack: account_ptr // account_ptr should be non-null as long as the prover provided the proper // Merkle data. But a bad prover may not have, and we don't want return a @@ -10,7 +10,7 @@ %endmacro // Returns a pointer to the root of the storage trie associated with the current account. -%macro current_storage_trie +%macro current_storage_smt // stack: (empty) %current_account_data // stack: account_ptr diff --git a/evm/src/cpu/kernel/asm/smt/hash.asm b/evm/src/cpu/kernel/asm/smt/hash.asm new file mode 100644 index 0000000000..5ac4fe3c6e --- /dev/null +++ b/evm/src/cpu/kernel/asm/smt/hash.asm @@ -0,0 +1,172 @@ +%macro smt_hash_state + PUSH %%after %jump(smt_hash_state) +%%after: +%endmacro + +// Root hash of the state SMT. +global smt_hash_state: + // stack: retdest + PUSH 0 %mstore_kernel_general(@SMT_IS_STORAGE) // is_storage flag. + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + +// Root hash of SMT stored at `trie_data[ptr]`. +// Pseudocode: +// ``` +// hash( HashNode { h } ) = h +// hash( InternalNode { left, right } ) = keccak(1 || hash(left) || hash(right)) // TODO: Domain separation in capacity when using Poseidon. See https://github.com/0xPolygonZero/plonky2/pull/1315#discussion_r1374780333. +// hash( Leaf { key, val_hash } ) = keccak(0 || key || val_hash) // TODO: Domain separation in capacity when using Poseidon. +// ``` +// where `val_hash` is `keccak(nonce || balance || storage_root || code_hash)` for accounts and +// `val` for a storage value. +global smt_hash: + // stack: ptr, retdest + DUP1 + %mload_trie_data + // stack: node, node_ptr, retdest + DUP1 %eq_const(@SMT_NODE_HASH) %jumpi(smt_hash_hash) + DUP1 %eq_const(@SMT_NODE_INTERNAL) %jumpi(smt_hash_internal) + DUP1 %eq_const(@SMT_NODE_LEAF) %jumpi(smt_hash_leaf) +smt_hash_unknown_node_type: + PANIC + +smt_hash_hash: + // stack: node, node_ptr, retdest + POP + // stack: node_ptr, retdest + %increment + // stack: node_ptr+1, retdest + %mload_trie_data + // stack: hash, retdest + SWAP1 JUMP + +smt_hash_internal: + // stack: node, node_ptr, retdest + POP + // stack: node_ptr, retdest + %increment + // stack: node_ptr+1, retdest + DUP1 + %mload_trie_data + %stack (left_child_ptr, node_ptr_plus_1, retdest) -> (left_child_ptr, smt_hash_internal_after_left, node_ptr_plus_1, retdest) + %jump(smt_hash) +smt_hash_internal_after_left: + // stack: left_hash, node_ptr+1, retdest + SWAP1 %increment + // stack: node_ptr+2, left_hash, retdest + %mload_trie_data + %stack (right_child_ptr, left_hash, retdest) -> (right_child_ptr, smt_hash_internal_after_right, left_hash, retdest) + %jump(smt_hash) +smt_hash_internal_after_right: + // stack: right_hash, left_hash, retdest + %stack (right_hash) -> (0, @SEGMENT_KERNEL_GENERAL, 33, right_hash, 32) + %mstore_unpacking POP + %stack (left_hash) -> (0, @SEGMENT_KERNEL_GENERAL, 1, left_hash, 32) + %mstore_unpacking POP + // stack: retdest + // Internal node flag. + PUSH 1 %mstore_kernel_general(0) + %stack () -> (0, @SEGMENT_KERNEL_GENERAL, 0, 65) + KECCAK_GENERAL + // stack: hash, retdest + SWAP1 JUMP + +smt_hash_leaf: + // stack: node, node_ptr, retdest + POP + // stack: node_ptr, retdest + %increment + // stack: node_ptr+1, retdest + %mload_trie_data + // stack: payload_ptr, retdest + %mload_kernel_general(@SMT_IS_STORAGE) + // stack: is_value, payload_ptr, retdest + %jumpi(smt_hash_leaf_value) +smt_hash_leaf_account: + // stack: payload_ptr, retdest + DUP1 %mload_trie_data + // stack: key, payload_ptr, retdest + SWAP1 %increment + // stack: payload_ptr+1, key, retdest + DUP1 %mload_trie_data + // stack: nonce, payload_ptr+1, key, retdest + SWAP1 + // stack: payload_ptr+1, nonce, key, retdest + %increment + // stack: payload_ptr+2, nonce, key, retdest + DUP1 %mload_trie_data + // stack: balance, payload_ptr+2, nonce, key, retdest + SWAP1 + // stack: payload_ptr+2, balance, nonce, key, retdest + %increment + // stack: payload_ptr+3, balance, nonce, key, retdest + DUP1 %mload_trie_data + // stack: storage_root, payload_ptr+3, balance, nonce, key, retdest + PUSH 1 %mstore_kernel_general(@SMT_IS_STORAGE) + %stack (storage_root) -> (storage_root, smt_hash_leaf_account_after_storage) + %jump(smt_hash) +smt_hash_leaf_account_after_storage: + PUSH 0 %mstore_kernel_general(@SMT_IS_STORAGE) + // stack: storage_root_hash, payload_ptr+3, balance, nonce, key, retdest + SWAP1 + // stack: payload_ptr+3, storage_root_hash, balance, nonce, key, retdest + %increment + // stack: payload_ptr+4, storage_root_hash, balance, nonce, key, retdest + %mload_trie_data + // stack: code_hash, storage_root_hash, balance, nonce, key, retdest + + // 0----7 | 8----39 | 40--------71 | 72----103 + // nonce | balance | storage_root | code_hash + + // TODO: The way we do the `mstore_unpacking`s could be optimized. See https://github.com/0xPolygonZero/plonky2/pull/1315#discussion_r1378207927. + %stack (code_hash) -> (0, @SEGMENT_KERNEL_GENERAL, 72, code_hash, 32) + %mstore_unpacking POP + + %stack (storage_root) -> (0, @SEGMENT_KERNEL_GENERAL, 40, storage_root, 32) + %mstore_unpacking POP + + %stack (balance) -> (0, @SEGMENT_KERNEL_GENERAL, 8, balance, 32) + %mstore_unpacking POP + + %stack (nonce) -> (0, @SEGMENT_KERNEL_GENERAL, 0, nonce) + %mstore_unpacking_u64_LE + + // stack: key, retdest + %stack () -> (0, @SEGMENT_KERNEL_GENERAL, 0, 104) + KECCAK_GENERAL + // stack: hash, key, retdest + + // Leaf flag + PUSH 0 %mstore_kernel_general(0) + + %stack (hash) -> (0, @SEGMENT_KERNEL_GENERAL, 33, hash, 32) + %mstore_unpacking POP + + %stack (key) -> (0, @SEGMENT_KERNEL_GENERAL, 1, key, 32) + %mstore_unpacking POP + + %stack () -> (0, @SEGMENT_KERNEL_GENERAL, 0, 65) + KECCAK_GENERAL + + SWAP1 JUMP + +smt_hash_leaf_value: + // stack: payload_ptr, retdest + DUP1 %mload_trie_data + // stack: key, payload_ptr, retdest + SWAP1 + // stack: payload_ptr, key, retdest + %increment + // stack: payload_ptr+1, key, retdest + %mload_trie_data + // stack: value, key, retdest + PUSH 0 %mstore_kernel_general(0) + %stack (value) -> (0, @SEGMENT_KERNEL_GENERAL, 33, value, 32) + %mstore_unpacking POP + // stack: key, retdest + %stack (key) -> (0, @SEGMENT_KERNEL_GENERAL, 1, key, 32) + %mstore_unpacking POP + // stack: retdest + %stack () -> (0, @SEGMENT_KERNEL_GENERAL, 0, 65) + KECCAK_GENERAL + // stack: hash, retdest + SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/smt/insert.asm b/evm/src/cpu/kernel/asm/smt/insert.asm new file mode 100644 index 0000000000..32b248c66a --- /dev/null +++ b/evm/src/cpu/kernel/asm/smt/insert.asm @@ -0,0 +1,109 @@ +// Insert a key-value pair in the state SMT. +global smt_insert_state: + // stack: key, new_account_ptr, retdest + %stack (key, new_account_ptr) -> (key, new_account_ptr, smt_insert_state_set_root) + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: root_ptr, key, new_account_ptr, smt_insert_state_set_root, retdest + %jump(smt_insert) + +smt_insert_state_set_root: + // stack: root_ptr, retdest + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: retdest + JUMP + +// Insert a key-value pair in the SMT at `trie_data[node_ptr]`. +// `value_ptr` should point to a an empty slot reserved for `rem_key`, followed by the actual value. +// Pseudocode: +// ``` +// insert( HashNode { h }, key, value_ptr ) = if h == 0 then Leaf { key, value_ptr } else PANIC +// insert( InternalNode { left, right }, key, value_ptr ) = if key&1 { insert( right, key>>1, value_ptr ) } else { insert( left, key>>1, value_ptr ) } +// insert( Leaf { key', value_ptr' }, key, value_ptr ) = { +// let internal = new InternalNode; +// insert(internal, key', value_ptr'); +// insert(internal, key, value_ptr); +// return internal;} +// ``` +global smt_insert: + // stack: node_ptr, key, value_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, key, value_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, node_payload_ptr, key, value_ptr, retdest + + DUP1 %eq_const(@SMT_NODE_HASH) %jumpi(smt_insert_hash) + DUP1 %eq_const(@SMT_NODE_INTERNAL) %jumpi(smt_insert_internal) + DUP1 %eq_const(@SMT_NODE_LEAF) %jumpi(smt_insert_leaf) + PANIC + +smt_insert_hash: + // stack: node_type, node_payload_ptr, key, value_ptr, retdest + POP + // stack: node_payload_ptr, key, value_ptr, retdest + %mload_trie_data + // stack: hash, key, value_ptr, retdest + ISZERO %jumpi(smt_insert_empty) + PANIC // Trying to insert in a non-empty hash node. +smt_insert_empty: + // stack: key, value_ptr, retdest + %get_trie_data_size + // stack: index, key, value_ptr, retdest + PUSH @SMT_NODE_LEAF %append_to_trie_data + %stack (index, key, value_ptr) -> (value_ptr, key, value_ptr, index) + %mstore_trie_data + // stack: value_ptr, index, retdest + %append_to_trie_data + // stack: index, retdest + SWAP1 JUMP + +smt_insert_internal: + // stack: node_type, node_payload_ptr, key, value_ptr, retdest + POP + // stack: node_payload_ptr, key, value_ptr, retdest + SWAP1 + // stack: key, node_payload_ptr, value_ptr, retdest + %pop_bit + %stack (bit, key, node_payload_ptr, value_ptr, retdest) -> (bit, node_payload_ptr, node_payload_ptr, key, value_ptr, smt_insert_internal_after, retdest) + ADD + // stack: child_ptr_ptr, node_payload_ptr, key, value_ptr, smt_insert_internal_after, retdest + DUP1 %mload_trie_data + %stack (child_ptr, child_ptr_ptr, node_payload_ptr, key, value_ptr, smt_insert_internal_after) -> (child_ptr, key, value_ptr, smt_insert_internal_after, child_ptr_ptr, node_payload_ptr) + %jump(smt_insert) + +smt_insert_internal_after: + // stack: new_node_ptr, child_ptr_ptr, node_payload_ptr, retdest + SWAP1 %mstore_trie_data + // stack: node_payload_ptr retdest + %decrement + SWAP1 JUMP + +smt_insert_leaf: + // stack: node_type, node_payload_ptr_ptr, key, value_ptr, retdest + POP + %stack (node_payload_ptr_ptr, key) -> (node_payload_ptr_ptr, key, node_payload_ptr_ptr, key) + %mload_trie_data %mload_trie_data EQ %jumpi(smt_insert_leaf_same_key) + // stack: node_payload_ptr_ptr, key, value_ptr, retdest + // We create an internal node with two empty children, and then we insert the two leaves. + %get_trie_data_size + // stack: index, node_payload_ptr_ptr, key, value_ptr, retdest + PUSH @SMT_NODE_INTERNAL %append_to_trie_data + PUSH 0 %append_to_trie_data // Empty hash node + PUSH 0 %append_to_trie_data // Empty hash node + %stack (index, node_payload_ptr_ptr, key, value_ptr) -> (index, key, value_ptr, after_first_leaf, node_payload_ptr_ptr) + %jump(smt_insert) +after_first_leaf: + // stack: internal_ptr, node_payload_ptr_ptr, retdest + SWAP1 + // stack: node_payload_ptr_ptr, internal_ptr, retdest + %mload_trie_data DUP1 %mload_trie_data + %stack (key, node_payload_ptr, internal_ptr) -> (internal_ptr, key, node_payload_ptr, after_second_leaf) + %jump(smt_insert) +after_second_leaf: + // stack: internal_ptr, retdest + SWAP1 JUMP + + +smt_insert_leaf_same_key: + // stack: node_payload_ptr, key, value_ptr, retdest + PANIC // Not sure if this should happen. diff --git a/evm/src/cpu/kernel/asm/smt/load.asm b/evm/src/cpu/kernel/asm/smt/load.asm new file mode 100644 index 0000000000..72a469f139 --- /dev/null +++ b/evm/src/cpu/kernel/asm/smt/load.asm @@ -0,0 +1,35 @@ +%macro load_state_smt + PUSH %%after %jump(load_state_smt) +%%after: +%endmacro + +// Simply copy the serialized state SMT to `TrieData`. +// First entry is the length of the serialized data. +global load_state_smt: + // stack: retdest + PROVER_INPUT(smt::state) + // stack: len, retdest + %get_trie_data_size + // stack: i, len, retdest + DUP2 %mstore_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + // stack: i, len, retdest + DUP1 %add_const(2) // First two entries are [0,0] for an empty hash node. + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: i, len, retdest + %stack (i, len) -> (len, i, i) + ADD SWAP1 + // stack: i, len, retdest +loop: + // stack: i, len, retdest + DUP2 DUP2 EQ %jumpi(loop_end) + // stack: i, len, retdest + PROVER_INPUT(smt::state) + DUP2 + // stack: i, x, i, len, retdest + %mstore_trie_data + // stack: i, len, retdest + %increment + %jump(loop) +loop_end: + // stack: i, len, retdest + %pop2 JUMP diff --git a/evm/src/cpu/kernel/asm/smt/read.asm b/evm/src/cpu/kernel/asm/smt/read.asm new file mode 100644 index 0000000000..58ae5690cb --- /dev/null +++ b/evm/src/cpu/kernel/asm/smt/read.asm @@ -0,0 +1,81 @@ +// Given an address, return a pointer to the associated account data, which +// consists of four words (nonce, balance, storage_root_ptr, code_hash), in the +// state SMT. Returns null if the address is not found. +global smt_read_state: + // stack: addr, retdest + %addr_to_state_key + // stack: key, retdest + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) // node_ptr + // stack: node_ptr, key, retdest + %jump(smt_read) + +// Convenience macro to call smt_read_state and return where we left off. +%macro smt_read_state + %stack (addr) -> (addr, %%after) + %jump(smt_read_state) +%%after: +%endmacro + +// Return the data at the given key in the SMT at `trie_data[node_ptr]`. +// Pseudocode: +// ``` +// read( HashNode { h }, key ) = if h == 0 then 0 else PANIC +// read( InternalNode { left, right }, key ) = if key&1 { read( right, key>>1 ) } else { read( left, key>>1 ) } +// read( Leaf { key', value_ptr }, key ) = if key == key' then value_ptr' else 0 +// ``` +global smt_read: + // stack: node_ptr, key, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, key, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, node_payload_ptr, key, retdest + + DUP1 %eq_const(@SMT_NODE_HASH) %jumpi(smt_read_hash) + DUP1 %eq_const(@SMT_NODE_INTERNAL) %jumpi(smt_read_internal) + DUP1 %eq_const(@SMT_NODE_LEAF) %jumpi(smt_read_leaf) + PANIC + +smt_read_hash: + // stack: node_type, node_payload_ptr, key, retdest + POP + // stack: node_payload_ptr, key, retdest + %mload_trie_data + // stack: hash, key, retdest + ISZERO %jumpi(smt_read_empty) + PANIC // Trying to read a non-empty hash node. Should never happen. + +smt_read_empty: + %stack (key, retdest) -> (retdest, 0) + JUMP + +smt_read_internal: + // stack: node_type, node_payload_ptr, key, retdest + POP + // stack: node_payload_ptr, key, retdest + SWAP1 + // stack: key, node_payload_ptr, retdest + %pop_bit + %stack (bit, key, node_payload_ptr) -> (bit, node_payload_ptr, key) + ADD + // stack: child_ptr_ptr, key, retdest + %mload_trie_data + %jump(smt_read) + +smt_read_leaf: + // stack: node_type, node_payload_ptr_ptr, key, retdest + POP + // stack: node_payload_ptr_ptr, key, retdest + %mload_trie_data + %stack (node_payload_ptr, key) -> (node_payload_ptr, key, node_payload_ptr) + %mload_trie_data EQ %jumpi(smt_read_existing_leaf) // Checking if the key exists +smt_read_non_existing_leaf: + %stack (node_payload_ptr_ptr, retdest) -> (retdest, 0) + JUMP + +smt_read_existing_leaf: + // stack: node_payload_ptr, retdest + %increment // We want to point to the account values, not the key. + SWAP1 JUMP + + diff --git a/evm/src/cpu/kernel/asm/smt/utils.asm b/evm/src/cpu/kernel/asm/smt/utils.asm new file mode 100644 index 0000000000..145f61310a --- /dev/null +++ b/evm/src/cpu/kernel/asm/smt/utils.asm @@ -0,0 +1,7 @@ +%macro pop_bit + // stack: key + DUP1 %shr_const(1) + // stack: key>>1, key + SWAP1 %and_const(1) + // stack: key&1, key>>1 +%endmacro diff --git a/evm/src/cpu/kernel/asm/transactions/common_decoding.asm b/evm/src/cpu/kernel/asm/transactions/common_decoding.asm index 9b12d9c931..d4df7a6e57 100644 --- a/evm/src/cpu/kernel/asm/transactions/common_decoding.asm +++ b/evm/src/cpu/kernel/asm/transactions/common_decoding.asm @@ -115,7 +115,7 @@ PUSH @SEGMENT_TXN_DATA GET_CONTEXT // stack: DST, SRC, data_len, %%after, new_pos - %jump(memcpy) + %jump(memcpy_bytes) %%after: // stack: new_pos diff --git a/evm/src/cpu/kernel/asm/transactions/type_1.asm b/evm/src/cpu/kernel/asm/transactions/type_1.asm index f8396e50a4..68d998aeab 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_1.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_1.asm @@ -94,7 +94,7 @@ after_serializing_txn_data: al_len, after_serializing_access_list, rlp_pos, rlp_start, retdest) - %jump(memcpy) + %jump(memcpy_bytes) after_serializing_access_list: // stack: rlp_pos, rlp_start, retdest %mload_global_metadata(@GLOBAL_METADATA_ACCESS_LIST_RLP_LEN) ADD diff --git a/evm/src/cpu/kernel/asm/transactions/type_2.asm b/evm/src/cpu/kernel/asm/transactions/type_2.asm index 38f1980fae..c57621617f 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_2.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_2.asm @@ -101,7 +101,7 @@ after_serializing_txn_data: al_len, after_serializing_access_list, rlp_pos, rlp_start, retdest) - %jump(memcpy) + %jump(memcpy_bytes) after_serializing_access_list: // stack: rlp_pos, rlp_start, retdest %mload_global_metadata(@GLOBAL_METADATA_ACCESS_LIST_RLP_LEN) ADD diff --git a/evm/src/cpu/kernel/constants/mod.rs b/evm/src/cpu/kernel/constants/mod.rs index 77abde994f..99e6404440 100644 --- a/evm/src/cpu/kernel/constants/mod.rs +++ b/evm/src/cpu/kernel/constants/mod.rs @@ -6,6 +6,7 @@ use hex_literal::hex; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::journal_entry::JournalEntry; +use crate::cpu::kernel::constants::smt_type::PartialSmtType; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField; use crate::memory::segments::Segment; @@ -14,6 +15,7 @@ pub(crate) mod context_metadata; mod exc_bitfields; pub(crate) mod global_metadata; pub(crate) mod journal_entry; +pub(crate) mod smt_type; pub(crate) mod trie_type; pub(crate) mod txn_fields; @@ -57,6 +59,8 @@ pub fn evm_constants() -> HashMap { c.insert(MAX_NONCE.0.into(), U256::from(MAX_NONCE.1)); c.insert(CALL_STACK_LIMIT.0.into(), U256::from(CALL_STACK_LIMIT.1)); + c.insert(SMT_IS_STORAGE.0.into(), U256::from(SMT_IS_STORAGE.1)); + for segment in Segment::all() { c.insert(segment.var_name().into(), (segment as u32).into()); } @@ -72,6 +76,9 @@ pub fn evm_constants() -> HashMap { for trie_type in PartialTrieType::all() { c.insert(trie_type.var_name().into(), (trie_type as u32).into()); } + for trie_type in PartialSmtType::all() { + c.insert(trie_type.var_name().into(), (trie_type as u32).into()); + } for entry in JournalEntry::all() { c.insert(entry.var_name().into(), (entry as u32).into()); } @@ -270,3 +277,6 @@ const CODE_SIZE_LIMIT: [(&str, u64); 3] = [ const MAX_NONCE: (&str, u64) = ("MAX_NONCE", 0xffffffffffffffff); const CALL_STACK_LIMIT: (&str, u64) = ("CALL_STACK_LIMIT", 1024); + +// Holds a flag that is set to 1 when hashing storage SMTs. Used in `smt_hash`. +const SMT_IS_STORAGE: (&str, u64) = ("SMT_IS_STORAGE", 13371337); diff --git a/evm/src/cpu/kernel/constants/smt_type.rs b/evm/src/cpu/kernel/constants/smt_type.rs new file mode 100644 index 0000000000..b134598bc6 --- /dev/null +++ b/evm/src/cpu/kernel/constants/smt_type.rs @@ -0,0 +1,23 @@ +#[derive(Copy, Clone, Debug)] +pub(crate) enum PartialSmtType { + Hash = 0, + Internal = 1, + Leaf = 2, +} + +impl PartialSmtType { + pub(crate) const COUNT: usize = 3; + + pub(crate) fn all() -> [Self; Self::COUNT] { + [Self::Hash, Self::Internal, Self::Leaf] + } + + /// The variable name that gets passed into kernel assembly code. + pub(crate) fn var_name(&self) -> &'static str { + match self { + Self::Hash => "SMT_NODE_HASH", + Self::Internal => "SMT_NODE_INTERNAL", + Self::Leaf => "SMT_NODE_LEAF", + } + } +} diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 4ba6e9dcfe..96e2810c2d 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -188,6 +188,7 @@ impl<'a> Interpreter<'a> { .set(field as usize, value) } + #[allow(unused)] pub(crate) fn get_trie_data(&self) -> &[U256] { &self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content } @@ -426,7 +427,7 @@ impl<'a> Interpreter<'a> { 0xf6 => self.run_get_context(), // "GET_CONTEXT", 0xf7 => self.run_set_context(), // "SET_CONTEXT", 0xf8 => self.run_mload_32bytes(), // "MLOAD_32BYTES", - 0xf9 => todo!(), // "EXIT_KERNEL", + 0xf9 => self.run_exit_kernel(), // "EXIT_KERNEL", 0xfa => todo!(), // "STATICCALL", 0xfb => self.run_mload_general(), // "MLOAD_GENERAL", 0xfc => self.run_mstore_general(), // "MSTORE_GENERAL", @@ -1126,6 +1127,24 @@ impl<'a> Interpreter<'a> { } } + fn run_exit_kernel(&mut self) { + let kexit_info = self.pop(); + + let kexit_info_u64 = kexit_info.0[0]; + let program_counter = kexit_info_u64 as u32 as usize; + let is_kernel_mode_val = (kexit_info_u64 >> 32) as u32; + assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1); + let is_kernel_mode = is_kernel_mode_val != 0; + let gas_used_val = kexit_info.0[3]; + if TryInto::::try_into(gas_used_val).is_err() { + panic!("Gas overflow"); + } + + self.generation_state.registers.program_counter = program_counter; + self.generation_state.registers.is_kernel = is_kernel_mode; + self.generation_state.registers.gas_used = gas_used_val; + } + pub(crate) fn stack_len(&self) -> usize { self.generation_state.registers.stack_len } diff --git a/evm/src/cpu/kernel/tests/account_code.rs b/evm/src/cpu/kernel/tests/account_code.rs index f4c18fe603..9283af77e7 100644 --- a/evm/src/cpu/kernel/tests/account_code.rs +++ b/evm/src/cpu/kernel/tests/account_code.rs @@ -1,25 +1,25 @@ use std::collections::HashMap; use anyhow::{anyhow, Result}; -use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; -use rand::{thread_rng, Rng}; +use rand::{random, thread_rng, Rng}; +use smt_utils::account::Account; +use smt_utils::smt::Smt; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata::GasLimit; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::nibbles_64; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, state_smt_prover_inputs_reversed}; use crate::memory::segments::Segment; -use crate::Node; // Test account with a given code hash. -fn test_account(code: &[u8]) -> AccountRlp { - AccountRlp { - nonce: U256::from(1111), +fn test_account(code: &[u8]) -> Account { + Account { + nonce: 1111, balance: U256::from(2222), - storage_root: HashedPartialTrie::from(Node::Empty).hash(), + storage_smt: Smt::empty(), code_hash: keccak(code), } } @@ -35,36 +35,32 @@ fn random_code() -> Vec { fn prepare_interpreter( interpreter: &mut Interpreter, address: Address, - account: &AccountRlp, + account: Account, ) -> Result<()> { let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - let mut state_trie: HashedPartialTrie = Default::default(); + let smt_insert_state = KERNEL.global_labels["smt_insert_state"]; + let smt_hash_state = KERNEL.global_labels["smt_hash_state"]; + let mut state_smt = Smt::empty(); let trie_inputs = Default::default(); interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); + interpreter.generation_state.state_smt_prover_inputs = + state_smt_prover_inputs_reversed(&trie_inputs); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs) .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - let k = nibbles_64(U256::from_big_endian( - keccak(address.to_fixed_bytes()).as_bytes(), - )); - // Next, execute mpt_insert_state_trie. - interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; + let k = keccak(address.to_fixed_bytes()); + // Next, execute smt_insert_state. + interpreter.generation_state.registers.program_counter = smt_insert_state; let trie_data = interpreter.get_trie_data_mut(); - if trie_data.is_empty() { - // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. - // Since we don't explicitly set it to 0, we need to do so here. - trie_data.push(0.into()); - } let value_ptr = trie_data.len(); - trie_data.push(account.nonce); + trie_data.push(U256::zero()); // For key. + trie_data.push(account.nonce.into()); trie_data.push(account.balance); // In memory, storage_root gets interpreted as a pointer to a storage trie, // so we have to ensure the pointer is valid. It's easiest to set it to 0, @@ -75,7 +71,7 @@ fn prepare_interpreter( interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); interpreter.push(0xDEADBEEFu32.into()); interpreter.push(value_ptr.into()); // value_ptr - interpreter.push(k.try_into_u256().unwrap()); // key + interpreter.push(k.into_uint()); // key interpreter.run()?; assert_eq!( @@ -86,7 +82,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = smt_hash_state; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -98,8 +94,8 @@ fn prepare_interpreter( ); let hash = H256::from_uint(&interpreter.stack()[0]); - state_trie.insert(k, rlp::encode(account).to_vec()); - let expected_state_trie_hash = state_trie.hash(); + state_smt.insert(k.into(), account.into()).unwrap(); + let expected_state_trie_hash = state_smt.root; assert_eq!(hash, expected_state_trie_hash); Ok(()) @@ -111,9 +107,9 @@ fn test_extcodesize() -> Result<()> { let account = test_account(&code); let mut interpreter = Interpreter::new_with_kernel(0, vec![]); - let address: Address = thread_rng().gen(); + let address: Address = random(); // Prepare the interpreter by inserting the account in the state trie. - prepare_interpreter(&mut interpreter, address, &account)?; + prepare_interpreter(&mut interpreter, address, account)?; let extcodesize = KERNEL.global_labels["extcodesize"]; @@ -140,9 +136,13 @@ fn test_extcodecopy() -> Result<()> { let mut interpreter = Interpreter::new_with_kernel(0, vec![]); let address: Address = thread_rng().gen(); // Prepare the interpreter by inserting the account in the state trie. - prepare_interpreter(&mut interpreter, address, &account)?; + prepare_interpreter(&mut interpreter, address, account)?; - let extcodecopy = KERNEL.global_labels["extcodecopy"]; + interpreter.generation_state.memory.contexts[interpreter.context].segments + [Segment::ContextMetadata as usize] + .set(GasLimit as usize, U256::from(1000000000000u64) << 192); + + let extcodecopy = KERNEL.global_labels["sys_extcodecopy"]; // Put random data in main memory and the `KernelAccountCode` segment for realism. let mut rng = thread_rng(); @@ -164,11 +164,11 @@ fn test_extcodecopy() -> Result<()> { interpreter.generation_state.registers.program_counter = extcodecopy; interpreter.pop(); assert!(interpreter.stack().is_empty()); - interpreter.push(0xDEADBEEFu32.into()); interpreter.push(size.into()); interpreter.push(offset.into()); interpreter.push(dest_offset.into()); interpreter.push(U256::from_big_endian(address.as_bytes())); + interpreter.push(0xDEADBEEFu32.into()); // kexit_info interpreter.generation_state.inputs.contract_code = HashMap::from([(keccak(&code), code.clone())]); interpreter.run()?; diff --git a/evm/src/cpu/kernel/tests/balance.rs b/evm/src/cpu/kernel/tests/balance.rs index 40214405c7..f83a6eb1f3 100644 --- a/evm/src/cpu/kernel/tests/balance.rs +++ b/evm/src/cpu/kernel/tests/balance.rs @@ -1,22 +1,21 @@ use anyhow::{anyhow, Result}; -use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; use rand::{thread_rng, Rng}; +use smt_utils::account::Account; +use smt_utils::smt::Smt; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::nibbles_64; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; -use crate::Node; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; // Test account with a given code hash. -fn test_account(balance: U256) -> AccountRlp { - AccountRlp { - nonce: U256::from(1111), +fn test_account(balance: U256) -> Account { + Account { + nonce: 1111, balance, - storage_root: HashedPartialTrie::from(Node::Empty).hash(), + storage_smt: Smt::empty(), code_hash: H256::from_uint(&U256::from(8888)), } } @@ -26,12 +25,12 @@ fn test_account(balance: U256) -> AccountRlp { fn prepare_interpreter( interpreter: &mut Interpreter, address: Address, - account: &AccountRlp, + account: Account, ) -> Result<()> { let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - let mut state_trie: HashedPartialTrie = Default::default(); + let smt_insert_state = KERNEL.global_labels["smt_insert_state"]; + let smt_hash_state = KERNEL.global_labels["smt_hash_state"]; + let mut state_smt = Smt::empty(); let trie_inputs = Default::default(); interpreter.generation_state.registers.program_counter = load_all_mpts; @@ -43,19 +42,18 @@ fn prepare_interpreter( interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - let k = nibbles_64(U256::from_big_endian( - keccak(address.to_fixed_bytes()).as_bytes(), - )); - // Next, execute mpt_insert_state_trie. - interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; + let k = keccak(address.to_fixed_bytes()); + // Next, execute smt_insert_state. + interpreter.generation_state.registers.program_counter = smt_insert_state; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { - // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. + // In the assembly we skip over 0, knowing trie_data[0:4] = 0 by default. // Since we don't explicitly set it to 0, we need to do so here. - trie_data.push(0.into()); + trie_data.extend((0..4).map(|_| U256::zero())); } let value_ptr = trie_data.len(); - trie_data.push(account.nonce); + trie_data.push(U256::zero()); // For key. + trie_data.push(account.nonce.into()); trie_data.push(account.balance); // In memory, storage_root gets interpreted as a pointer to a storage trie, // so we have to ensure the pointer is valid. It's easiest to set it to 0, @@ -66,7 +64,7 @@ fn prepare_interpreter( interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); interpreter.push(0xDEADBEEFu32.into()); interpreter.push(value_ptr.into()); // value_ptr - interpreter.push(k.try_into_u256().unwrap()); // key + interpreter.push(k.into_uint()); interpreter.run()?; assert_eq!( @@ -77,7 +75,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = smt_hash_state; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -89,8 +87,8 @@ fn prepare_interpreter( ); let hash = H256::from_uint(&interpreter.stack()[0]); - state_trie.insert(k, rlp::encode(account).to_vec()); - let expected_state_trie_hash = state_trie.hash(); + state_smt.insert(k.into(), account.into()).unwrap(); + let expected_state_trie_hash = state_smt.root; assert_eq!(hash, expected_state_trie_hash); Ok(()) @@ -105,7 +103,7 @@ fn test_balance() -> Result<()> { let mut interpreter = Interpreter::new_with_kernel(0, vec![]); let address: Address = rng.gen(); // Prepare the interpreter by inserting the account in the state trie. - prepare_interpreter(&mut interpreter, address, &account)?; + prepare_interpreter(&mut interpreter, address, account)?; // Test `balance` interpreter.generation_state.registers.program_counter = KERNEL.global_labels["balance"]; diff --git a/evm/src/cpu/kernel/tests/mod.rs b/evm/src/cpu/kernel/tests/mod.rs index b66c016266..d49a22c7d8 100644 --- a/evm/src/cpu/kernel/tests/mod.rs +++ b/evm/src/cpu/kernel/tests/mod.rs @@ -10,11 +10,11 @@ mod ecc; mod exp; mod hash; mod log; -mod mpt; mod packing; mod receipt; mod rlp; mod signed_syscalls; +mod smt; mod transaction_parsing; use std::str::FromStr; diff --git a/evm/src/cpu/kernel/tests/mpt/delete.rs b/evm/src/cpu/kernel/tests/mpt/delete.rs deleted file mode 100644 index 074eea26ef..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/delete.rs +++ /dev/null @@ -1,130 +0,0 @@ -use anyhow::{anyhow, Result}; -use eth_trie_utils::nibbles::Nibbles; -use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{BigEndianHash, H256}; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{nibbles_64, test_account_1_rlp, test_account_2}; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; -use crate::generation::TrieInputs; -use crate::Node; - -#[test] -fn mpt_delete_empty() -> Result<()> { - test_state_trie(Default::default(), nibbles_64(0xABC), test_account_2()) -} - -#[test] -fn mpt_delete_leaf_nonoverlapping_keys() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: nibbles_64(0xABC), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0x123), test_account_2()) -} - -#[test] -fn mpt_delete_leaf_overlapping_keys() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: nibbles_64(0xABC), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xADE), test_account_2()) -} - -#[test] -fn mpt_delete_branch_into_hash() -> Result<()> { - let hash = Node::Hash(H256::random()); - let state_trie = Node::Extension { - nibbles: nibbles_64(0xADF), - child: hash.into(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xADE), test_account_2()) -} - -/// Note: The account's storage_root is ignored, as we can't insert a new storage_root without the -/// accompanying trie data. An empty trie's storage_root is used instead. -fn test_state_trie( - state_trie: HashedPartialTrie, - k: Nibbles, - mut account: AccountRlp, -) -> Result<()> { - assert_eq!(k.count, 64); - - // Ignore any storage_root; see documentation note. - account.storage_root = HashedPartialTrie::from(Node::Empty).hash(); - - let trie_inputs = TrieInputs { - state_trie: state_trie.clone(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_delete = KERNEL.global_labels["mpt_delete"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - // Next, execute mpt_insert_state_trie. - interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; - let trie_data = interpreter.get_trie_data_mut(); - if trie_data.is_empty() { - // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. - // Since we don't explicitly set it to 0, we need to do so here. - trie_data.push(0.into()); - } - let value_ptr = trie_data.len(); - trie_data.push(account.nonce); - trie_data.push(account.balance); - // In memory, storage_root gets interpreted as a pointer to a storage trie, - // so we have to ensure the pointer is valid. It's easiest to set it to 0, - // which works as an empty node, since trie_data[0] = 0 = MPT_TYPE_EMPTY. - trie_data.push(H256::zero().into_uint()); - trie_data.push(account.code_hash.into_uint()); - let trie_data_len = trie_data.len().into(); - interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); - interpreter.push(0xDEADBEEFu32.into()); - interpreter.push(value_ptr.into()); // value_ptr - interpreter.push(k.try_into_u256().unwrap()); // key - interpreter.run()?; - assert_eq!( - interpreter.stack().len(), - 0, - "Expected empty stack after insert, found {:?}", - interpreter.stack() - ); - - // Next, execute mpt_delete, deleting the account we just inserted. - let state_trie_ptr = interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot); - interpreter.generation_state.registers.program_counter = mpt_delete; - interpreter.push(0xDEADBEEFu32.into()); - interpreter.push(k.try_into_u256().unwrap()); - interpreter.push(64.into()); - interpreter.push(state_trie_ptr); - interpreter.run()?; - let state_trie_ptr = interpreter.pop(); - interpreter.set_global_metadata_field(GlobalMetadata::StateTrieRoot, state_trie_ptr); - - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; - interpreter.push(0xDEADBEEFu32.into()); - interpreter.run()?; - - let state_trie_hash = H256::from_uint(&interpreter.pop()); - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(state_trie_hash, expected_state_trie_hash); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs deleted file mode 100644 index 05077a94da..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ /dev/null @@ -1,137 +0,0 @@ -use anyhow::{anyhow, Result}; -use eth_trie_utils::partial_trie::PartialTrie; -use ethereum_types::{BigEndianHash, H256}; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1_rlp, test_account_2_rlp}; -use crate::generation::mpt::all_mpt_prover_inputs_reversed; -use crate::generation::TrieInputs; -use crate::Node; - -// TODO: Test with short leaf. Might need to be a storage trie. - -#[test] -fn mpt_hash_empty() -> Result<()> { - let trie_inputs = TrieInputs { - state_trie: Default::default(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - test_state_trie(trie_inputs) -} - -#[test] -fn mpt_hash_empty_branch() -> Result<()> { - let children = core::array::from_fn(|_| Node::Empty.into()); - let state_trie = Node::Branch { - children, - value: vec![], - } - .into(); - let trie_inputs = TrieInputs { - state_trie, - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - test_state_trie(trie_inputs) -} - -#[test] -fn mpt_hash_hash() -> Result<()> { - let hash = H256::random(); - let trie_inputs = TrieInputs { - state_trie: Node::Hash(hash).into(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - test_state_trie(trie_inputs) -} - -#[test] -fn mpt_hash_leaf() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: 0xABC_u64.into(), - value: test_account_1_rlp(), - } - .into(); - let trie_inputs = TrieInputs { - state_trie, - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - test_state_trie(trie_inputs) -} - -#[test] -fn mpt_hash_extension_to_leaf() -> Result<()> { - let state_trie = extension_to_leaf(test_account_1_rlp()); - let trie_inputs = TrieInputs { - state_trie, - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - test_state_trie(trie_inputs) -} - -#[test] -fn mpt_hash_branch_to_leaf() -> Result<()> { - let leaf = Node::Leaf { - nibbles: 0xABC_u64.into(), - value: test_account_2_rlp(), - } - .into(); - - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[3] = leaf; - let state_trie = Node::Branch { - children, - value: vec![], - } - .into(); - - let trie_inputs = TrieInputs { - state_trie, - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - test_state_trie(trie_inputs) -} - -fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; - interpreter.push(0xDEADBEEFu32.into()); - interpreter.run()?; - - assert_eq!( - interpreter.stack().len(), - 1, - "Expected 1 item on stack, found {:?}", - interpreter.stack() - ); - let hash = H256::from_uint(&interpreter.stack()[0]); - let expected_state_trie_hash = trie_inputs.state_trie.hash(); - assert_eq!(hash, expected_state_trie_hash); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs b/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs deleted file mode 100644 index c13b812220..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/hex_prefix.rs +++ /dev/null @@ -1,87 +0,0 @@ -use anyhow::Result; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::interpreter::Interpreter; - -#[test] -fn hex_prefix_even_nonterminated() -> Result<()> { - let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; - - let retdest = 0xDEADBEEFu32.into(); - let terminated = 0.into(); - let packed_nibbles = 0xABCDEF.into(); - let num_nibbles = 6.into(); - let rlp_pos = 0.into(); - let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![5.into()]); - - assert_eq!( - interpreter.get_rlp_memory(), - vec![ - 0x80 + 4, // prefix - 0, // neither flag is set - 0xAB, - 0xCD, - 0xEF - ] - ); - - Ok(()) -} - -#[test] -fn hex_prefix_odd_terminated() -> Result<()> { - let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; - - let retdest = 0xDEADBEEFu32.into(); - let terminated = 1.into(); - let packed_nibbles = 0xABCDE.into(); - let num_nibbles = 5.into(); - let rlp_pos = 0.into(); - let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![4.into()]); - - assert_eq!( - interpreter.get_rlp_memory(), - vec![ - 0x80 + 3, // prefix - (2 + 1) * 16 + 0xA, - 0xBC, - 0xDE, - ] - ); - - Ok(()) -} - -#[test] -fn hex_prefix_odd_terminated_tiny() -> Result<()> { - let hex_prefix = KERNEL.global_labels["hex_prefix_rlp"]; - - let retdest = 0xDEADBEEFu32.into(); - let terminated = 1.into(); - let packed_nibbles = 0xA.into(); - let num_nibbles = 1.into(); - let rlp_pos = 2.into(); - let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter = Interpreter::new_with_kernel(hex_prefix, initial_stack); - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![3.into()]); - - assert_eq!( - interpreter.get_rlp_memory(), - vec![ - // Since rlp_pos = 2, we skipped over the first two bytes. - 0, - 0, - // No length prefix; this tiny string is its own RLP encoding. - (2 + 1) * 16 + 0xA, - ] - ); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs deleted file mode 100644 index 6fd95a30b9..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ /dev/null @@ -1,230 +0,0 @@ -use anyhow::{anyhow, Result}; -use eth_trie_utils::nibbles::Nibbles; -use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{BigEndianHash, H256}; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{ - nibbles_64, nibbles_count, test_account_1_rlp, test_account_2, -}; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; -use crate::generation::TrieInputs; -use crate::Node; - -#[test] -fn mpt_insert_empty() -> Result<()> { - test_state_trie(Default::default(), nibbles_64(0xABC), test_account_2()) -} - -#[test] -fn mpt_insert_leaf_identical_keys() -> Result<()> { - let key = nibbles_64(0xABC); - let state_trie = Node::Leaf { - nibbles: key, - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, key, test_account_2()) -} - -#[test] -fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: nibbles_64(0xABC), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0x123), test_account_2()) -} - -#[test] -fn mpt_insert_leaf_overlapping_keys() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: nibbles_64(0xABC), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xADE), test_account_2()) -} - -#[test] -#[ignore] // TODO: Not valid for state trie, all keys have same len. -fn mpt_insert_leaf_insert_key_extends_leaf_key() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: 0xABC_u64.into(), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xABCDE), test_account_2()) -} - -#[test] -#[ignore] // TODO: Not valid for state trie, all keys have same len. -fn mpt_insert_leaf_leaf_key_extends_insert_key() -> Result<()> { - let state_trie = Node::Leaf { - nibbles: 0xABCDE_u64.into(), - value: test_account_1_rlp(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xABC), test_account_2()) -} - -#[test] -fn mpt_insert_branch_replacing_empty_child() -> Result<()> { - let children = core::array::from_fn(|_| Node::Empty.into()); - let state_trie = Node::Branch { - children, - value: vec![], - } - .into(); - - test_state_trie(state_trie, nibbles_64(0xABC), test_account_2()) -} - -#[test] -// TODO: Not a valid test because branches state trie cannot have branch values. -// We should change it to use a different trie. -#[ignore] -fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { - // Existing keys are 0xABC, 0xABCDEF; inserted key is 0x12345. - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[0xD] = Node::Leaf { - nibbles: 0xEF_u64.into(), - value: test_account_1_rlp(), - } - .into(); - let state_trie = Node::Extension { - nibbles: 0xABC_u64.into(), - child: Node::Branch { - children, - value: test_account_1_rlp(), - } - .into(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0x12345), test_account_2()) -} - -#[test] -// TODO: Not a valid test because branches state trie cannot have branch values. -// We should change it to use a different trie. -#[ignore] -fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { - // Existing keys are 0xA, 0xABCD; inserted key is 0xABCDEF. - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[0xB] = Node::Leaf { - nibbles: 0xCD_u64.into(), - value: test_account_1_rlp(), - } - .into(); - let state_trie = Node::Extension { - nibbles: 0xA_u64.into(), - child: Node::Branch { - children, - value: test_account_1_rlp(), - } - .into(), - } - .into(); - test_state_trie(state_trie, nibbles_64(0xABCDEF), test_account_2()) -} - -#[test] -fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { - let leaf = Node::Leaf { - nibbles: nibbles_count(0xBCD, 63), - value: test_account_1_rlp(), - } - .into(); - - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[0] = leaf; - let state_trie = Node::Branch { - children, - value: vec![], - } - .into(); - - test_state_trie(state_trie, nibbles_64(0xABCD), test_account_2()) -} - -/// Note: The account's storage_root is ignored, as we can't insert a new storage_root without the -/// accompanying trie data. An empty trie's storage_root is used instead. -fn test_state_trie( - mut state_trie: HashedPartialTrie, - k: Nibbles, - mut account: AccountRlp, -) -> Result<()> { - assert_eq!(k.count, 64); - - // Ignore any storage_root; see documentation note. - account.storage_root = HashedPartialTrie::from(Node::Empty).hash(); - - let trie_inputs = TrieInputs { - state_trie: state_trie.clone(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - // Next, execute mpt_insert_state_trie. - interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; - let trie_data = interpreter.get_trie_data_mut(); - if trie_data.is_empty() { - // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. - // Since we don't explicitly set it to 0, we need to do so here. - trie_data.push(0.into()); - } - let value_ptr = trie_data.len(); - trie_data.push(account.nonce); - trie_data.push(account.balance); - // In memory, storage_root gets interpreted as a pointer to a storage trie, - // so we have to ensure the pointer is valid. It's easiest to set it to 0, - // which works as an empty node, since trie_data[0] = 0 = MPT_TYPE_EMPTY. - trie_data.push(H256::zero().into_uint()); - trie_data.push(account.code_hash.into_uint()); - let trie_data_len = trie_data.len().into(); - interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); - interpreter.push(0xDEADBEEFu32.into()); - interpreter.push(value_ptr.into()); // value_ptr - interpreter.push(k.try_into_u256().unwrap()); // key - - interpreter.run()?; - assert_eq!( - interpreter.stack().len(), - 0, - "Expected empty stack after insert, found {:?}", - interpreter.stack() - ); - - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; - interpreter.push(0xDEADBEEFu32.into()); - interpreter.run()?; - - assert_eq!( - interpreter.stack().len(), - 1, - "Expected 1 item on stack after hashing, found {:?}", - interpreter.stack() - ); - let hash = H256::from_uint(&interpreter.stack()[0]); - - state_trie.insert(k, rlp::encode(&account).to_vec()); - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(hash, expected_state_trie_hash); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs deleted file mode 100644 index ae0bfa3bc8..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ /dev/null @@ -1,294 +0,0 @@ -use std::str::FromStr; - -use anyhow::{anyhow, Result}; -use eth_trie_utils::nibbles::Nibbles; -use eth_trie_utils::partial_trie::HashedPartialTrie; -use ethereum_types::{BigEndianHash, H256, U256}; -use hex_literal::hex; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::constants::trie_type::PartialTrieType; -use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; -use crate::generation::mpt::all_mpt_prover_inputs_reversed; -use crate::generation::TrieInputs; -use crate::Node; - -#[test] -fn load_all_mpts_empty() -> Result<()> { - let trie_inputs = TrieInputs { - state_trie: Default::default(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - assert_eq!(interpreter.get_trie_data(), vec![]); - - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot), - 0.into() - ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), - 0.into() - ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), - 0.into() - ); - - Ok(()) -} - -#[test] -fn load_all_mpts_leaf() -> Result<()> { - let trie_inputs = TrieInputs { - state_trie: Node::Leaf { - nibbles: 0xABC_u64.into(), - value: test_account_1_rlp(), - } - .into(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - let type_leaf = U256::from(PartialTrieType::Leaf as u32); - assert_eq!( - interpreter.get_trie_data(), - vec![ - 0.into(), - type_leaf, - 3.into(), - 0xABC.into(), - 5.into(), // value ptr - test_account_1().nonce, - test_account_1().balance, - 9.into(), // pointer to storage trie root - test_account_1().code_hash.into_uint(), - // These last two elements encode the storage trie, which is a hash node. - (PartialTrieType::Hash as u32).into(), - test_account_1().storage_root.into_uint(), - ] - ); - - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), - 0.into() - ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), - 0.into() - ); - - Ok(()) -} - -#[test] -fn load_all_mpts_hash() -> Result<()> { - let hash = H256::random(); - let trie_inputs = TrieInputs { - state_trie: Node::Hash(hash).into(), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - let type_hash = U256::from(PartialTrieType::Hash as u32); - assert_eq!( - interpreter.get_trie_data(), - vec![0.into(), type_hash, hash.into_uint(),] - ); - - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), - 0.into() - ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), - 0.into() - ); - - Ok(()) -} - -#[test] -fn load_all_mpts_empty_branch() -> Result<()> { - let children = core::array::from_fn(|_| Node::Empty.into()); - let state_trie = Node::Branch { - children, - value: vec![], - } - .into(); - let trie_inputs = TrieInputs { - state_trie, - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - let type_branch = U256::from(PartialTrieType::Branch as u32); - assert_eq!( - interpreter.get_trie_data(), - vec![ - 0.into(), // First address is unused, so that 0 can be treated as a null pointer. - type_branch, - 0.into(), // child 0 - 0.into(), // ... - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), - 0.into(), // child 16 - 0.into(), // value_ptr - ] - ); - - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), - 0.into() - ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), - 0.into() - ); - - Ok(()) -} - -#[test] -fn load_all_mpts_ext_to_leaf() -> Result<()> { - let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(test_account_1_rlp()), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - let type_extension = U256::from(PartialTrieType::Extension as u32); - let type_leaf = U256::from(PartialTrieType::Leaf as u32); - assert_eq!( - interpreter.get_trie_data(), - vec![ - 0.into(), // First address is unused, so that 0 can be treated as a null pointer. - type_extension, - 3.into(), // 3 nibbles - 0xABC.into(), // key part - 5.into(), // Pointer to the leaf node immediately below. - type_leaf, - 3.into(), // 3 nibbles - 0xDEF.into(), // key part - 9.into(), // value pointer - test_account_1().nonce, - test_account_1().balance, - 13.into(), // pointer to storage trie root - test_account_1().code_hash.into_uint(), - // These last two elements encode the storage trie, which is a hash node. - (PartialTrieType::Hash as u32).into(), - test_account_1().storage_root.into_uint(), - ] - ); - - Ok(()) -} - -#[test] -fn load_mpt_txn_trie() -> Result<()> { - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - - let txn = hex!("f860010a830186a094095e7baea6a6c7c4c2dfeb977efac326af552e89808025a04a223955b0bd3827e3740a9a427d0ea43beb5bafa44a0204bf0a3306c8219f7ba0502c32d78f233e9e7ce9f5df3b576556d5d49731e0678fd5a068cdf359557b5b").to_vec(); - - let trie_inputs = TrieInputs { - state_trie: Default::default(), - transactions_trie: HashedPartialTrie::from(Node::Leaf { - nibbles: Nibbles::from_str("0x80").unwrap(), - value: txn.clone(), - }), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let initial_stack = vec![0xDEADBEEFu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - let mut expected_trie_data = vec![ - 0.into(), - U256::from(PartialTrieType::Leaf as u32), - 2.into(), - 128.into(), // Nibble - 5.into(), // value_ptr - txn.len().into(), - ]; - expected_trie_data.extend(txn.into_iter().map(U256::from)); - let trie_data = interpreter.get_trie_data(); - - assert_eq!(trie_data, expected_trie_data); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs deleted file mode 100644 index 292d064af1..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/mod.rs +++ /dev/null @@ -1,71 +0,0 @@ -use eth_trie_utils::nibbles::Nibbles; -use eth_trie_utils::partial_trie::HashedPartialTrie; -use ethereum_types::{BigEndianHash, H256, U256}; - -use crate::generation::mpt::AccountRlp; -use crate::Node; - -mod delete; -mod hash; -mod hex_prefix; -mod insert; -mod load; -mod read; - -pub(crate) fn nibbles_64>(v: T) -> Nibbles { - let packed: U256 = v.into(); - Nibbles { - count: 64, - packed: packed.into(), - } -} - -pub(crate) fn nibbles_count>(v: T, count: usize) -> Nibbles { - let packed: U256 = v.into(); - Nibbles { - count, - packed: packed.into(), - } -} - -pub(crate) fn test_account_1() -> AccountRlp { - AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - } -} - -pub(crate) fn test_account_1_rlp() -> Vec { - rlp::encode(&test_account_1()).to_vec() -} - -pub(crate) fn test_account_2() -> AccountRlp { - AccountRlp { - nonce: U256::from(5555), - balance: U256::from(6666), - storage_root: H256::from_uint(&U256::from(7777)), - code_hash: H256::from_uint(&U256::from(8888)), - } -} - -pub(crate) fn test_account_2_rlp() -> Vec { - rlp::encode(&test_account_2()).to_vec() -} - -/// A `PartialTrie` where an extension node leads to a leaf node containing an account. -pub(crate) fn extension_to_leaf(value: Vec) -> HashedPartialTrie { - Node::Extension { - nibbles: 0xABC_u64.into(), - child: Node::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xDEF.into(), - }, - value, - } - .into(), - } - .into() -} diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs deleted file mode 100644 index f9ae94f03b..0000000000 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ /dev/null @@ -1,49 +0,0 @@ -use anyhow::{anyhow, Result}; -use ethereum_types::BigEndianHash; - -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; -use crate::generation::mpt::all_mpt_prover_inputs_reversed; -use crate::generation::TrieInputs; - -#[test] -fn mpt_read() -> Result<()> { - let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(test_account_1_rlp()), - transactions_trie: Default::default(), - receipts_trie: Default::default(), - storage_tries: vec![], - }; - - let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let mpt_read = KERNEL.global_labels["mpt_read"]; - - let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = - all_mpt_prover_inputs_reversed(&trie_inputs) - .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - - // Now, execute mpt_read on the state trie. - interpreter.generation_state.registers.program_counter = mpt_read; - interpreter.push(0xdeadbeefu32.into()); - interpreter.push(0xABCDEFu64.into()); - interpreter.push(6.into()); - interpreter.push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)); - interpreter.run()?; - - assert_eq!(interpreter.stack().len(), 1); - let result_ptr = interpreter.stack()[0].as_usize(); - let result = &interpreter.get_trie_data()[result_ptr..][..4]; - assert_eq!(result[0], test_account_1().nonce); - assert_eq!(result[1], test_account_1().balance); - // result[2] is the storage root pointer. We won't check that it matches a - // particular address, since that seems like over-specifying. - assert_eq!(result[3], test_account_1().code_hash.into_uint()); - - Ok(()) -} diff --git a/evm/src/cpu/kernel/tests/smt/hash.rs b/evm/src/cpu/kernel/tests/smt/hash.rs new file mode 100644 index 0000000000..b38caecc59 --- /dev/null +++ b/evm/src/cpu/kernel/tests/smt/hash.rs @@ -0,0 +1,65 @@ +use anyhow::{anyhow, Result}; +use ethereum_types::{BigEndianHash, H256, U256}; +use rand::{thread_rng, Rng}; +use smt_utils::account::Account; +use smt_utils::smt::Smt; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, state_smt_prover_inputs_reversed}; +use crate::generation::TrieInputs; + +// TODO: Test with short leaf. Might need to be a storage trie. + +#[test] +fn smt_hash_empty() -> Result<()> { + let smt = Smt::empty(); + test_state_smt(smt) +} + +#[test] +fn smt_hash() -> Result<()> { + let n = 100; + let mut rng = thread_rng(); + let rand_node = |_| (U256(rng.gen()).into(), Account::rand(10).into()); + let smt = Smt::new((0..n).map(rand_node)).unwrap(); + + test_state_smt(smt) +} + +fn test_state_smt(state_smt: Smt) -> Result<()> { + let trie_inputs = TrieInputs { + state_smt: state_smt.serialize(), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + let smt_hash_state = KERNEL.global_labels["smt_hash_state"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; + interpreter.generation_state.state_smt_prover_inputs = + state_smt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + // Now, execute smt_hash_state. + interpreter.generation_state.registers.program_counter = smt_hash_state; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack, found {:?}", + interpreter.stack() + ); + let hash = H256::from_uint(&interpreter.stack()[0]); + let expected_state_trie_hash = state_smt.root; + assert_eq!(hash, expected_state_trie_hash); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/smt/insert.rs b/evm/src/cpu/kernel/tests/smt/insert.rs new file mode 100644 index 0000000000..a785e2c08c --- /dev/null +++ b/evm/src/cpu/kernel/tests/smt/insert.rs @@ -0,0 +1,181 @@ +use anyhow::{anyhow, Result}; +use ethereum_types::{BigEndianHash, U256}; +use rand::{thread_rng, Rng}; +use smt_utils::account::Account; +use smt_utils::smt::{AccountOrValue, Smt, ValOrHash}; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::evm_constants; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, state_smt_prover_inputs_reversed}; +use crate::generation::TrieInputs; +use crate::memory::segments::Segment; + +#[test] +fn smt_insert_state() -> Result<()> { + let n = 100; + let mut rng = thread_rng(); + let rand_node = |_| { + ( + U256(rng.gen()).into(), + ValOrHash::Val(AccountOrValue::Account(Account::rand(10))), + ) + }; + let smt = Smt::new((0..n).map(rand_node)).unwrap(); + let new_key = U256(rng.gen()); + let new_account = Account::rand(0); + + test_state_smt(smt, new_key, new_account) +} + +fn test_state_smt(mut state_smt: Smt, new_key: U256, new_account: Account) -> Result<()> { + let trie_inputs = TrieInputs { + state_smt: state_smt.serialize(), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; + interpreter.generation_state.state_smt_prover_inputs = + state_smt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let state_root = interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot); + let trie_data_size = interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize); + let trie_data = interpreter.get_trie_data_mut(); + trie_data.push(U256::zero()); // For key + let mut packed_account = new_account.pack_u256(); + packed_account[2] = U256::zero(); // No storage SMT. + for (i, x) in (trie_data_size.as_usize() + 1..).zip(packed_account) { + if i < trie_data.len() { + trie_data[i] = x; + } else { + trie_data.push(x); + } + } + let len = trie_data.len(); + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, len.into()); + + let smt_insert = KERNEL.global_labels["smt_insert"]; + // Now, execute smt_insert. + interpreter.generation_state.registers.program_counter = smt_insert; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(trie_data_size); + interpreter.push(new_key); + interpreter.push(state_root); + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack, found {:?}", + interpreter.stack() + ); + let smt_hash = KERNEL.global_labels["smt_hash"]; + interpreter.generation_state.registers.program_counter = smt_hash; + let ptr = interpreter.stack()[0]; + interpreter.pop(); + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(ptr); + interpreter.run()?; + let hash = interpreter.pop(); + + state_smt + .insert( + new_key.into(), + ValOrHash::Val(AccountOrValue::Account(new_account)), + ) + .unwrap(); + let expected_hash = state_smt.root; + + assert_eq!(hash, expected_hash.into_uint()); + + Ok(()) +} + +#[test] +fn smt_insert_storage() -> Result<()> { + let n = 100; + let mut rng = thread_rng(); + let rand_node = |_| { + ( + U256(rng.gen()).into(), + ValOrHash::Val(AccountOrValue::Value(U256(rng.gen()))), + ) + }; + let smt = Smt::new((0..n).map(rand_node)).unwrap(); + let new_key = U256(rng.gen()); + let new_val = U256(rng.gen()); + + test_storage_smt(smt, new_key, new_val) +} + +fn test_storage_smt(mut storage_smt: Smt, new_key: U256, new_val: U256) -> Result<()> { + let initial_stack = vec![0xDEADBEEFu32.into()]; + let smt_insert = KERNEL.global_labels["smt_insert"]; + let mut interpreter = Interpreter::new_with_kernel(smt_insert, initial_stack); + let trie_data = storage_smt.serialize(); + let len = trie_data.len(); + interpreter.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content = + trie_data; + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, len.into()); + interpreter.set_global_metadata_field(GlobalMetadata::StateTrieRoot, 2.into()); + + let state_root = interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot); + let trie_data_size = interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize); + let trie_data = &mut interpreter.generation_state.memory.contexts[0].segments + [Segment::TrieData as usize] + .content; + trie_data.push(U256::zero()); // For key + trie_data.push(new_val); + let len = trie_data.len(); + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, len.into()); + + let smt_insert = KERNEL.global_labels["smt_insert"]; + // Now, execute smt_insert. + interpreter.generation_state.registers.program_counter = smt_insert; + interpreter.push(trie_data_size); + interpreter.push(new_key); + interpreter.push(state_root); + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack, found {:?}", + interpreter.stack() + ); + + let smt_hash = KERNEL.global_labels["smt_hash"]; + interpreter.generation_state.registers.program_counter = smt_hash; + interpreter.generation_state.memory.contexts[0].segments[Segment::KernelGeneral as usize] + .content + .resize(13371338, U256::zero()); + interpreter.generation_state.memory.contexts[0].segments[Segment::KernelGeneral as usize] + .content[evm_constants()["SMT_IS_STORAGE"].as_usize()] = U256::one(); // To hash storage trie. + let ptr = interpreter.stack()[0]; + interpreter.pop(); + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(ptr); + interpreter.run()?; + let hash = interpreter.pop(); + + storage_smt + .insert( + new_key.into(), + ValOrHash::Val(AccountOrValue::Value(new_val)), + ) + .unwrap(); + let expected_hash = storage_smt.root; + + assert_eq!(hash, expected_hash.into_uint()); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/smt/mod.rs b/evm/src/cpu/kernel/tests/smt/mod.rs new file mode 100644 index 0000000000..2d7504cdc5 --- /dev/null +++ b/evm/src/cpu/kernel/tests/smt/mod.rs @@ -0,0 +1,2 @@ +mod hash; +mod insert; diff --git a/evm/src/cpu/membus.rs b/evm/src/cpu/membus.rs index 10dc25a4ca..fe967a2e3e 100644 --- a/evm/src/cpu/membus.rs +++ b/evm/src/cpu/membus.rs @@ -9,6 +9,7 @@ use crate::cpu::columns::CpuColumnsView; /// General-purpose memory channels; they can read and write to all contexts/segments/addresses. pub const NUM_GP_CHANNELS: usize = 5; +/// Indices for code and general purpose memory channels. pub mod channel_indices { use std::ops::Range; @@ -31,6 +32,7 @@ pub mod channel_indices { /// These limitations save us numerous columns in the CPU table. pub const NUM_CHANNELS: usize = channel_indices::GP.end; +/// Evaluates constraints regarding the membus. pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -47,6 +49,8 @@ pub fn eval_packed( } } +/// Circuit version of `eval_packed`. +/// Evaluates constraints regarding the membus. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/memio.rs b/evm/src/cpu/memio.rs index f70f3fdb67..71960f5ffc 100644 --- a/evm/src/cpu/memio.rs +++ b/evm/src/cpu/memio.rs @@ -18,16 +18,18 @@ fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { (addr_context, addr_segment, addr_virtual) } +/// Evaluates constraints for MLOAD_GENERAL. fn eval_packed_load( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // The opcode for MLOAD_GENERAL is 0xfb. If the operation is MLOAD_GENERAL, lv.opcode_bits[0] = 1 + // The opcode for MLOAD_GENERAL is 0xfb. If the operation is MLOAD_GENERAL, lv.opcode_bits[0] = 1. let filter = lv.op.m_op_general * lv.opcode_bits[0]; let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check that we are loading the correct value from the correct address. let load_channel = lv.mem_channels[3]; yield_constr.constraint(filter * (load_channel.used - P::ONES)); yield_constr.constraint(filter * (load_channel.is_read - P::ONES)); @@ -50,17 +52,21 @@ fn eval_packed_load( ); } +/// Circuit version for `eval_packed_load`. +/// Evaluates constraints for MLOAD_GENERAL. fn eval_ext_circuit_load, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + // The opcode for MLOAD_GENERAL is 0xfb. If the operation is MLOAD_GENERAL, lv.opcode_bits[0] = 1. let mut filter = lv.op.m_op_general; filter = builder.mul_extension(filter, lv.opcode_bits[0]); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check that we are loading the correct value from the correct channel. let load_channel = lv.mem_channels[3]; { let constr = builder.mul_sub_extension(filter, load_channel.used, filter); @@ -100,6 +106,7 @@ fn eval_ext_circuit_load, const D: usize>( ); } +/// Evaluates constraints for MSTORE_GENERAL. fn eval_packed_store( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -109,6 +116,7 @@ fn eval_packed_store( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check that we are storing the correct value at the correct address. let value_channel = lv.mem_channels[3]; let store_channel = lv.mem_channels[4]; yield_constr.constraint(filter * (store_channel.used - P::ONES)); @@ -171,6 +179,8 @@ fn eval_packed_store( yield_constr.constraint(lv.op.m_op_general * lv.opcode_bits[0] * top_read_channel.used); } +/// Circuit version of `eval_packed_store`. +/// /// Evaluates constraints for MSTORE_GENERAL. fn eval_ext_circuit_store, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -182,6 +192,7 @@ fn eval_ext_circuit_store, const D: usize>( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check that we are storing the correct value at the correct address. let value_channel = lv.mem_channels[3]; let store_channel = lv.mem_channels[4]; { @@ -264,7 +275,7 @@ fn eval_ext_circuit_store, const D: usize>( let top_read_channel = nv.mem_channels[0]; let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); - // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * (1 - opcode_bits[0])`. { let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); let constr = builder.mul_extension(lv.op.m_op_general, diff); @@ -315,6 +326,7 @@ fn eval_ext_circuit_store, const D: usize>( } } +/// Evaluates constraints for MLOAD_GENERAL and MSTORE_GENERAL. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -324,6 +336,8 @@ pub fn eval_packed( eval_packed_store(lv, nv, yield_constr); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for MLOAD_GENERAL and MSTORE_GENERAL. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs index eed497f5d3..c5b130fa04 100644 --- a/evm/src/cpu/modfp254.rs +++ b/evm/src/cpu/modfp254.rs @@ -15,6 +15,7 @@ const P_LIMBS: [u32; 8] = [ 0xd87cfd47, 0x3c208c16, 0x6871ca8d, 0x97816a91, 0x8181585d, 0xb85045b6, 0xe131a029, 0x30644e72, ]; +/// Evaluates constriants to check the modulus in mem_channel[2]. pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -31,6 +32,8 @@ pub fn eval_packed( } } +/// Circuit version of `eval_packed`. +/// Evaluates constriants to check the modulus in mem_channel[2]. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/pc.rs b/evm/src/cpu/pc.rs index 5271ad81aa..16915a53e7 100644 --- a/evm/src/cpu/pc.rs +++ b/evm/src/cpu/pc.rs @@ -6,12 +6,14 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; +/// Evaluates constraints to check that we are storing the correct PC. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.pc; + // `PUSH0`'s opcode is odd, while `PC`'s opcode is even. + let filter = lv.op.pc_push0 * (P::ONES - lv.opcode_bits[0]); let new_stack_top = nv.mem_channels[0].value; yield_constr.constraint(filter * (new_stack_top[0] - lv.program_counter)); for &limb in &new_stack_top[1..] { @@ -19,13 +21,18 @@ pub fn eval_packed( } } +/// Circuit version if `eval_packed`. +/// Evaluates constraints to check that we are storing the correct PC. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.pc; + // `PUSH0`'s opcode is odd, while `PC`'s opcode is even. + let one = builder.one_extension(); + let mut filter = builder.sub_extension(one, lv.opcode_bits[0]); + filter = builder.mul_extension(lv.op.pc_push0, filter); let new_stack_top = nv.mem_channels[0].value; { let diff = builder.sub_extension(new_stack_top[0], lv.program_counter); diff --git a/evm/src/cpu/push0.rs b/evm/src/cpu/push0.rs index d49446cc23..e48e77f375 100644 --- a/evm/src/cpu/push0.rs +++ b/evm/src/cpu/push0.rs @@ -6,24 +6,29 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; +/// Evaluates constraints to check that we are not pushing anything. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.push0; + // `PUSH0`'s opcode is odd, while `PC`'s opcode is even. + let filter = lv.op.pc_push0 * lv.opcode_bits[0]; for limb in nv.mem_channels[0].value { yield_constr.constraint(filter * limb); } } +/// Circuit version of `eval_packed`. +/// Evaluates constraints to check that we are not pushing anything. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.push0; + // `PUSH0`'s opcode is odd, while `PC`'s opcode is even. + let filter = builder.mul_extension(lv.op.pc_push0, lv.opcode_bits[0]); for limb in nv.mem_channels[0].value { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index 0f92cbd20d..e77762fc5a 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -9,6 +9,8 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; +/// Evaluates constraints for shift operations on the CPU side: +/// the shifting factor is read from memory when displacement < 2^32. pub(crate) fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -59,6 +61,9 @@ pub(crate) fn eval_packed( // last -> last (output is the same) } +/// Circuit version. +/// Evaluates constraints for shift operations on the CPU side: +/// the shifting factor is read from memory when displacement < 2^32. pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 7be021caa6..7692bb9065 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -19,7 +19,11 @@ fn limbs(x: U256) -> [u32; 8] { } res } - +/// Form `diff_pinv`. +/// Let `diff = val0 - val1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. +/// Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set +/// `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have +/// `diff @ diff_pinv = 1 - equal` as desired. pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsView) { let val0_limbs = limbs(val0).map(F::from_canonical_u32); let val1_limbs = limbs(val1).map(F::from_canonical_u32); @@ -27,19 +31,8 @@ pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsV let num_unequal_limbs = izip!(val0_limbs, val1_limbs) .map(|(limb0, limb1)| (limb0 != limb1) as usize) .sum(); - let equal = num_unequal_limbs == 0; - - let output = &mut lv.mem_channels[2].value; - output[0] = F::from_bool(equal); - for limb in &mut output[1..] { - *limb = F::ZERO; - } // Form `diff_pinv`. - // Let `diff = val0 - val1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. - // Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set - // `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have - // `diff @ diff_pinv = 1 - equal` as desired. let logic = lv.general.logic_mut(); let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs) .try_inverse() @@ -49,6 +42,7 @@ pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsV } } +/// Evaluates the constraints for EQ and ISZERO. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -57,7 +51,7 @@ pub fn eval_packed( let logic = lv.general.logic(); let input0 = lv.mem_channels[0].value; let input1 = lv.mem_channels[1].value; - let output = lv.mem_channels[2].value; + let output = nv.mem_channels[0].value; // EQ (0x14) and ISZERO (0x15) are differentiated by their first opcode bit. let eq_filter = lv.op.eq_iszero * (P::ONES - lv.opcode_bits[0]); @@ -105,6 +99,8 @@ pub fn eval_packed( ); } +/// Circuit version of `eval_packed`. +/// Evaluates the constraints for EQ and ISZERO. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -117,7 +113,7 @@ pub fn eval_ext_circuit, const D: usize>( let logic = lv.general.logic(); let input0 = lv.mem_channels[0].value; let input1 = lv.mem_channels[1].value; - let output = lv.mem_channels[2].value; + let output = nv.mem_channels[0].value; // EQ (0x14) and ISZERO (0x15) are differentiated by their first opcode bit. let eq_filter = builder.mul_extension(lv.op.eq_iszero, lv.opcode_bits[0]); diff --git a/evm/src/cpu/simple_logic/mod.rs b/evm/src/cpu/simple_logic/mod.rs index 9b4e60b016..dfe396fd48 100644 --- a/evm/src/cpu/simple_logic/mod.rs +++ b/evm/src/cpu/simple_logic/mod.rs @@ -9,21 +9,24 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; +/// Evaluates constraints for NOT, EQ and ISZERO. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - not::eval_packed(lv, yield_constr); + not::eval_packed(lv, nv, yield_constr); eq_iszero::eval_packed(lv, nv, yield_constr); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for NOT, EQ and ISZERO. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - not::eval_ext_circuit(builder, lv, yield_constr); + not::eval_ext_circuit(builder, lv, nv, yield_constr); eq_iszero::eval_ext_circuit(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index 0bfaa0b71a..727e840c8c 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -6,34 +6,42 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::stack; const LIMB_SIZE: usize = 32; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; +/// Evaluates constraints for NOT. pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { // This is simple: just do output = 0xffffffff - input. let input = lv.mem_channels[0].value; - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - let filter = lv.op.not; + let output = nv.mem_channels[0].value; + let filter = lv.op.not_pop * lv.opcode_bits[0]; for (input_limb, output_limb) in input.into_iter().zip(output) { yield_constr.constraint( filter * (output_limb + input_limb - P::Scalar::from_canonical_u64(ALL_1_LIMB)), ); } + + // Stack constraints. + stack::eval_packed_one(lv, nv, filter, stack::BASIC_UNARY_OP.unwrap(), yield_constr); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for NOT. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let input = lv.mem_channels[0].value; - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - let filter = lv.op.not; + let output = nv.mem_channels[0].value; + let filter = builder.mul_extension(lv.op.not_pop, lv.opcode_bits[0]); for (input_limb, output_limb) in input.into_iter().zip(output) { let constr = builder.add_extension(output_limb, input_limb); let constr = builder.arithmetic_extension( @@ -45,4 +53,14 @@ pub fn eval_ext_circuit, const D: usize>( ); yield_constr.constraint(builder, constr); } + + // Stack constraints. + stack::eval_ext_circuit_one( + builder, + lv, + nv, + filter, + stack::BASIC_UNARY_OP.unwrap(), + yield_constr, + ); } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index db0c480d3d..302d967e8d 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -13,46 +13,66 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; +/// Structure to represent opcodes stack behaviours: +/// - number of pops +/// - whether the opcode(s) push +/// - whether unused channels should be disabled. #[derive(Clone, Copy)] pub(crate) struct StackBehavior { pub(crate) num_pops: usize, pub(crate) pushes: bool, - new_top_stack_channel: Option, disable_other_channels: bool, } +/// `StackBehavior` for unary operations. +pub(crate) const BASIC_UNARY_OP: Option = Some(StackBehavior { + num_pops: 1, + pushes: true, + disable_other_channels: true, +}); +/// `StackBehavior` for binary operations. const BASIC_BINARY_OP: Option = Some(StackBehavior { num_pops: 2, pushes: true, - new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); +/// `StackBehavior` for ternary operations. const BASIC_TERNARY_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, - new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); +/// `StackBehavior` for JUMP. pub(crate) const JUMP_OP: Option = Some(StackBehavior { num_pops: 1, pushes: false, - new_top_stack_channel: None, disable_other_channels: false, }); +/// `StackBehavior` for JUMPI. pub(crate) const JUMPI_OP: Option = Some(StackBehavior { num_pops: 2, pushes: false, - new_top_stack_channel: None, disable_other_channels: false, }); - +/// `StackBehavior` for MLOAD_GENERAL. pub(crate) const MLOAD_GENERAL_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, - new_top_stack_channel: None, disable_other_channels: false, }); +pub(crate) const KECCAK_GENERAL_OP: StackBehavior = StackBehavior { + num_pops: 4, + pushes: true, + disable_other_channels: true, +}; + +pub(crate) const JUMPDEST_OP: StackBehavior = StackBehavior { + num_pops: 0, + pushes: false, + disable_other_channels: true, +}; + // AUDITORS: If the value below is `None`, then the operation must be manually checked to ensure // that every general-purpose memory channel is either disabled or has its read flag and address // propertly constrained. The same applies when `disable_other_channels` is set to `false`, @@ -64,105 +84,65 @@ pub(crate) const STACK_BEHAVIORS: OpsColumnsView> = OpsCol fp254_op: BASIC_BINARY_OP, eq_iszero: None, // EQ is binary, IS_ZERO is unary. logic_op: BASIC_BINARY_OP, - not: Some(StackBehavior { - num_pops: 1, - pushes: true, - new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), - disable_other_channels: true, - }), + not_pop: None, shift: Some(StackBehavior { num_pops: 2, pushes: true, - new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: false, }), - keccak_general: Some(StackBehavior { - num_pops: 4, - pushes: true, - new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), - disable_other_channels: true, - }), + jumpdest_keccak_general: None, prover_input: None, // TODO - pop: Some(StackBehavior { - num_pops: 1, - pushes: false, - new_top_stack_channel: None, - disable_other_channels: true, - }), - jumps: None, // Depends on whether it's a JUMP or a JUMPI. - pc: Some(StackBehavior { - num_pops: 0, - pushes: true, - new_top_stack_channel: None, - disable_other_channels: true, - }), - jumpdest: Some(StackBehavior { - num_pops: 0, - pushes: false, - new_top_stack_channel: None, - disable_other_channels: true, - }), - push0: Some(StackBehavior { + jumps: None, // Depends on whether it's a JUMP or a JUMPI. + pc_push0: Some(StackBehavior { num_pops: 0, pushes: true, - new_top_stack_channel: None, disable_other_channels: true, }), push: None, // TODO dup_swap: None, - get_context: Some(StackBehavior { - num_pops: 0, - pushes: true, - new_top_stack_channel: None, - disable_other_channels: true, - }), - set_context: None, // SET_CONTEXT is special since it involves the old and the new stack. + context_op: None, mload_32bytes: Some(StackBehavior { num_pops: 4, pushes: true, - new_top_stack_channel: Some(4), disable_other_channels: false, }), mstore_32bytes: Some(StackBehavior { num_pops: 5, pushes: false, - new_top_stack_channel: None, disable_other_channels: false, }), exit_kernel: Some(StackBehavior { num_pops: 1, pushes: false, - new_top_stack_channel: None, disable_other_channels: true, }), m_op_general: None, syscall: Some(StackBehavior { num_pops: 0, pushes: true, - new_top_stack_channel: None, disable_other_channels: false, }), exception: Some(StackBehavior { num_pops: 0, pushes: true, - new_top_stack_channel: None, disable_other_channels: false, }), }; +/// Stack behavior for EQ. pub(crate) const EQ_STACK_BEHAVIOR: Option = Some(StackBehavior { num_pops: 2, pushes: true, - new_top_stack_channel: Some(2), disable_other_channels: true, }); +/// Stack behavior for ISZERO. pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = Some(StackBehavior { num_pops: 1, pushes: true, - new_top_stack_channel: Some(2), disable_other_channels: true, }); +/// Evaluates constraints for one `StackBehavior`. pub(crate) fn eval_packed_one( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -256,18 +236,6 @@ pub(crate) fn eval_packed_one( } } - // Maybe constrain next stack_top. - // These are transition constraints: they don't apply to the last row. - if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { - for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] - .value - .iter() - .zip(nv.mem_channels[0].value.iter()) - { - yield_constr.constraint_transition(filter * (*limb_ch - *limb_top)); - } - } - // Unused channels if stack_behavior.disable_other_channels { // The first channel contains (or not) the top od the stack and is constrained elsewhere. @@ -284,6 +252,7 @@ pub(crate) fn eval_packed_one( yield_constr.constraint_transition(filter * (nv.stack_len - (lv.stack_len - num_pops + push))); } +/// Evaluates constraints for all opcodes' `StackBehavior`s. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -294,8 +263,68 @@ pub fn eval_packed( eval_packed_one(lv, nv, op, stack_behavior, yield_constr); } } + + // Constrain stack for JUMPDEST. + let jumpdest_filter = lv.op.jumpdest_keccak_general * lv.opcode_bits[1]; + eval_packed_one(lv, nv, jumpdest_filter, JUMPDEST_OP, yield_constr); + + // Constrain stack for KECCAK_GENERAL. + let keccak_general_filter = lv.op.jumpdest_keccak_general * (P::ONES - lv.opcode_bits[1]); + eval_packed_one( + lv, + nv, + keccak_general_filter, + KECCAK_GENERAL_OP, + yield_constr, + ); + + // Stack constraints for POP. + // The only constraints POP has are stack constraints. + // Since POP and NOT are combined into one flag and they have + // different stack behaviors, POP needs special stack constraints. + // Constrain `stack_inv_aux`. + let len_diff = lv.stack_len - P::Scalar::ONES; + yield_constr.constraint( + lv.op.not_pop + * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + + // If stack_len != 1 and POP, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = lv.general.stack().stack_inv_aux * (P::ONES - lv.opcode_bits[0]); + + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * (1 - opcode_bits[0])`. + yield_constr.constraint(lv.op.not_pop * (lv.general.stack().stack_inv_aux_2 - is_top_read)); + let new_filter = lv.op.not_pop * lv.general.stack().stack_inv_aux_2; + yield_constr.constraint_transition(new_filter * (top_read_channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter + * (top_read_channel.addr_segment + - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_virtual - addr_virtual)); + // If stack_len == 1 or NOT, disable the channel. + // If NOT or (len==1 and POP), then `stack_inv_aux_2` = 0. + yield_constr.constraint( + lv.op.not_pop * (lv.general.stack().stack_inv_aux_2 - P::ONES) * top_read_channel.used, + ); + + // Disable remaining memory channels. + for &channel in &lv.mem_channels[1..] { + yield_constr.constraint(lv.op.not_pop * (lv.opcode_bits[0] - P::ONES) * channel.used); + } + + // Constrain the new stack length for POP. + yield_constr.constraint_transition( + lv.op.not_pop * (lv.opcode_bits[0] - P::ONES) * (nv.stack_len - lv.stack_len + P::ONES), + ); } +/// Circuit version of `eval_packed_one`. +/// Evaluates constraints for one `StackBehavior`. pub(crate) fn eval_ext_circuit_one, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -478,20 +507,6 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> } } - // Maybe constrain next stack_top. - // These are transition constraints: they don't apply to the last row. - if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { - for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] - .value - .iter() - .zip(nv.mem_channels[0].value.iter()) - { - let diff = builder.sub_extension(*limb_ch, *limb_top); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint_transition(builder, constr); - } - } - // Unused channels if stack_behavior.disable_other_channels { // The first channel contains (or not) the top od the stack and is constrained elsewhere. @@ -514,6 +529,8 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> yield_constr.constraint_transition(builder, constr); } +/// Circuti version of `eval_packed`. +/// Evaluates constraints for all opcodes' `StackBehavior`s. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -525,4 +542,99 @@ pub fn eval_ext_circuit, const D: usize>( eval_ext_circuit_one(builder, lv, nv, op, stack_behavior, yield_constr); } } + + // Constrain stack for JUMPDEST. + let jumpdest_filter = builder.mul_extension(lv.op.jumpdest_keccak_general, lv.opcode_bits[1]); + eval_ext_circuit_one(builder, lv, nv, jumpdest_filter, JUMPDEST_OP, yield_constr); + + // Constrain stack for KECCAK_GENERAL. + let one = builder.one_extension(); + let mut keccak_general_filter = builder.sub_extension(one, lv.opcode_bits[1]); + keccak_general_filter = + builder.mul_extension(lv.op.jumpdest_keccak_general, keccak_general_filter); + eval_ext_circuit_one( + builder, + lv, + nv, + keccak_general_filter, + KECCAK_GENERAL_OP, + yield_constr, + ); + + // Stack constraints for POP. + // The only constraints POP has are stack constraints. + // Since POP and NOT are combined into one flag and they have + // different stack behaviors, POP needs special stack constraints. + // Constrain `stack_inv_aux`. + { + let len_diff = builder.add_const_extension(lv.stack_len, F::NEG_ONE); + let diff = builder.mul_sub_extension( + len_diff, + lv.general.stack().stack_inv, + lv.general.stack().stack_inv_aux, + ); + let constr = builder.mul_extension(lv.op.not_pop, diff); + yield_constr.constraint(builder, constr); + } + // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); + let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + { + let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); + let constr = builder.mul_extension(lv.op.not_pop, diff); + yield_constr.constraint(builder, constr); + } + let new_filter = builder.mul_extension(lv.op.not_pop, lv.general.stack().stack_inv_aux_2); + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(top_read_channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.add_const_extension( + top_read_channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let addr_virtual = builder.add_const_extension(nv.stack_len, -F::ONE); + let diff = builder.sub_extension(top_read_channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + // If stack_len == 1 or NOT, disable the channel. + { + let diff = builder.mul_sub_extension( + lv.op.not_pop, + lv.general.stack().stack_inv_aux_2, + lv.op.not_pop, + ); + let constr = builder.mul_extension(diff, top_read_channel.used); + yield_constr.constraint(builder, constr); + } + + // Disable remaining memory channels. + let filter = builder.mul_sub_extension(lv.op.not_pop, lv.opcode_bits[0], lv.op.not_pop); + for &channel in &lv.mem_channels[1..] { + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + + // Constrain the new stack length for POP. + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let mut constr = builder.add_const_extension(diff, F::ONES); + constr = builder.mul_extension(filter, constr); + yield_constr.constraint_transition(builder, constr); } diff --git a/evm/src/cpu/stack_bounds.rs b/evm/src/cpu/stack_bounds.rs index e66e6686b5..cc8c6aedfe 100644 --- a/evm/src/cpu/stack_bounds.rs +++ b/evm/src/cpu/stack_bounds.rs @@ -20,6 +20,7 @@ use crate::cpu::columns::CpuColumnsView; pub const MAX_USER_STACK_SIZE: usize = 1024; +/// Evaluates constraints to check for stack overflows. pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -36,6 +37,8 @@ pub fn eval_packed( yield_constr.constraint(filter * (lhs - rhs)); } +/// Circuit version of `eval_packed`. +/// Evaluates constraints to check for stack overflows. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, diff --git a/evm/src/cpu/syscalls_exceptions.rs b/evm/src/cpu/syscalls_exceptions.rs index 1437fba02b..78f1e0c1ae 100644 --- a/evm/src/cpu/syscalls_exceptions.rs +++ b/evm/src/cpu/syscalls_exceptions.rs @@ -19,6 +19,7 @@ use crate::memory::segments::Segment; const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize; const_assert!(BYTES_PER_OFFSET < NUM_GP_CHANNELS); // Reserve one channel for stack push +/// Evaluates constraints for syscalls and exceptions. pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -65,6 +66,7 @@ pub fn eval_packed( exc_jumptable_start + exc_code * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { + // Set `used` and `is_read`. yield_constr.constraint(total_filter * (channel.used - P::ONES)); yield_constr.constraint(total_filter * (channel.is_read - P::ONES)); @@ -120,6 +122,8 @@ pub fn eval_packed( } } +/// Circuit version of `eval_packed`. +/// Evaluates constraints for syscalls and exceptions. pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, @@ -182,6 +186,7 @@ pub fn eval_ext_circuit, const D: usize>( ); for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { + // Set `used` and `is_read`. { let constr = builder.mul_sub_extension(total_filter, channel.used, total_filter); yield_constr.constraint(builder, constr); diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 621403f912..d5d8b6b36b 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,3 +1,32 @@ +//! This crate provides support for cross-table lookups. +//! +//! If a STARK S_1 calls an operation that is carried out by another STARK S_2, +//! S_1 provides the inputs to S_2 and reads the output from S_1. To ensure that +//! the operation was correctly carried out, we must check that the provided inputs +//! and outputs are correctly read. Cross-table lookups carry out that check. +//! +//! To achieve this, smaller CTL tables are created on both sides: looking and looked tables. +//! In our example, we create a table S_1' comprised of columns -- or linear combinations +//! of columns -- of S_1, and rows that call operations carried out in S_2. We also create a +//! table S_2' comprised of columns -- or linear combinations od columns -- of S_2 and rows +//! that carry out the operations needed by other STARKs. Then, S_1' is a looking table for +//! the looked S_2', since we want to check that the operation outputs in S_1' are indeeed in S_2'. +//! Furthermore, the concatenation of all tables looking into S_2' must be equal to S_2'. +//! +//! To achieve this, we construct, for each table, a permutation polynomial Z(x). +//! Z(x) is computed as the product of all its column combinations. +//! To check it was correctly constructed, we check: +//! - Z(gw) = Z(w) * combine(w) where combine(w) is the column combination at point w. +//! - Z(g^(n-1)) = combine(1). +//! - The verifier also checks that the product of looking table Z polynomials is equal +//! to the associated looked table Z polynomial. +//! Note that the first two checks are written that way because Z polynomials are computed +//! upside down for convenience. +//! +//! Additionally, we support cross-table lookups over two rows. The permutation principle +//! is similar, but we provide not only `local_values` but also `next_values` -- corresponding to +//! the current and next row values -- when computing the linear combinations. + use std::borrow::Borrow; use std::fmt::Debug; use std::iter::repeat; @@ -26,7 +55,10 @@ use crate::evaluation_frame::StarkEvaluationFrame; use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; -/// Represent a linear combination of columns. +/// Represent two linear combination of columns, corresponding to the current and next row values. +/// Each linear combination is represented as: +/// - a vector of `(usize, F)` corresponding to the column number and the associated multiplicand +/// - the constant of the linear combination. #[derive(Clone, Debug)] pub struct Column { linear_combination: Vec<(usize, F)>, @@ -35,6 +67,7 @@ pub struct Column { } impl Column { + /// Returns the representation of a single column in the current row. pub fn single(c: usize) -> Self { Self { linear_combination: vec![(c, F::ONE)], @@ -43,12 +76,14 @@ impl Column { } } + /// Returns multiple single columns in the current row. pub fn singles>>( cs: I, ) -> impl Iterator { cs.into_iter().map(|c| Self::single(*c.borrow())) } + /// Returns the representation of a single column in the next row. pub fn single_next_row(c: usize) -> Self { Self { linear_combination: vec![], @@ -57,12 +92,14 @@ impl Column { } } + /// Returns multiple single columns for the next row. pub fn singles_next_row>>( cs: I, ) -> impl Iterator { cs.into_iter().map(|c| Self::single_next_row(*c.borrow())) } + /// Returns a linear combination corresponding to a constant. pub fn constant(constant: F) -> Self { Self { linear_combination: vec![], @@ -71,14 +108,17 @@ impl Column { } } + /// Returns a linear combination corresponding to 0. pub fn zero() -> Self { Self::constant(F::ZERO) } + /// Returns a linear combination corresponding to 1. pub fn one() -> Self { Self::constant(F::ONE) } + /// Given an iterator of `(usize, F)` and a constant, returns the association linear combination of columns for the current row. pub fn linear_combination_with_constant>( iter: I, constant: F, @@ -97,6 +137,7 @@ impl Column { } } + /// Given an iterator of `(usize, F)` and a constant, returns the associated linear combination of columns for the current and the next rows. pub fn linear_combination_and_next_row_with_constant>( iter: I, next_row_iter: I, @@ -124,14 +165,19 @@ impl Column { } } + /// Returns a linear combination of columns, with no additional constant. pub fn linear_combination>(iter: I) -> Self { Self::linear_combination_with_constant(iter, F::ZERO) } + /// Given an iterator of columns (c_0, ..., c_n) containing bits in little endian order: + /// returns the representation of c_0 + 2 * c_1 + ... + 2^n * c_n. pub fn le_bits>>(cs: I) -> Self { Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(F::TWO.powers())) } + /// Given an iterator of columns (c_0, ..., c_n) containing bytes in little endian order: + /// returns the representation of c_0 + 256 * c_1 + ... + 256^n * c_n. pub fn le_bytes>>(cs: I) -> Self { Self::linear_combination( cs.into_iter() @@ -140,10 +186,12 @@ impl Column { ) } + /// Given an iterator of columns, returns the representation of their sum. pub fn sum>>(cs: I) -> Self { Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(repeat(F::ONE))) } + /// Given the column values for the current row, returns the evaluation of the linear combination. pub fn eval(&self, v: &[P]) -> P where FE: FieldExtension, @@ -156,6 +204,7 @@ impl Column { + FE::from_basefield(self.constant) } + /// Given the column values for the current and next rows, evaluates the current and next linear combinations and returns their sum. pub fn eval_with_next(&self, v: &[P], next_v: &[P]) -> P where FE: FieldExtension, @@ -173,7 +222,7 @@ impl Column { + FE::from_basefield(self.constant) } - /// Evaluate on an row of a table given in column-major form. + /// Evaluate on a row of a table given in column-major form. pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { let mut res = self .linear_combination @@ -195,6 +244,7 @@ impl Column { res } + /// Circuit version of `eval`: Given a row's targets, returns their linear combination. pub fn eval_circuit( &self, builder: &mut CircuitBuilder, @@ -217,6 +267,8 @@ impl Column { builder.inner_product_extension(F::ONE, constant, pairs) } + /// Circuit version of `eval_with_next`: + /// Given the targets of the current and next row, returns the sum of their linear combinations. pub fn eval_with_next_circuit( &self, builder: &mut CircuitBuilder, @@ -248,6 +300,9 @@ impl Column { } } +/// A `Table` with a linear combination of columns and a filter. +/// `filter_column` is used to determine the rows to select in `Table`. +/// `columns` represents linear combinations of the columns of `Table`. #[derive(Clone, Debug)] pub struct TableWithColumns { table: Table, @@ -256,6 +311,7 @@ pub struct TableWithColumns { } impl TableWithColumns { + /// Generates a new `TableWithColumns` given a `Table`, a linear combination of columns `columns` and a `filter_column`. pub fn new(table: Table, columns: Vec>, filter_column: Option>) -> Self { Self { table, @@ -265,13 +321,19 @@ impl TableWithColumns { } } +/// Cross-table lookup data consisting in the lookup table (`looked_table`) and all the tables that look into `looked_table` (`looking_tables`). +/// Each `looking_table` corresponds to a STARK's table whose rows have been filtered out and whose columns have been through a linear combination (see `eval_table`). The concatenation of those smaller tables should result in the `looked_table`. #[derive(Clone)] pub struct CrossTableLookup { + /// Column linear combinations for all tables that are looking into the current table. pub(crate) looking_tables: Vec>, + /// Column linear combination for the current table. pub(crate) looked_table: TableWithColumns, } impl CrossTableLookup { + /// Creates a new `CrossTableLookup` given some looking tables and a looked table. + /// All tables should have the same width. pub fn new( looking_tables: Vec>, looked_table: TableWithColumns, @@ -285,6 +347,8 @@ impl CrossTableLookup { } } + /// Given a `Table` t and the number of challenges, returns the number of Cross-table lookup polynomials associated to t, + /// i.e. the number of looking and looked tables among all CTLs whose columns are taken from t. pub(crate) fn num_ctl_zs(ctls: &[Self], table: Table, num_challenges: usize) -> usize { let mut num_ctls = 0; for ctl in ctls { @@ -298,27 +362,35 @@ impl CrossTableLookup { /// Cross-table lookup data for one table. #[derive(Clone, Default)] pub struct CtlData { + /// Data associated with all Z(x) polynomials for one table. pub(crate) zs_columns: Vec>, } /// Cross-table lookup data associated with one Z(x) polynomial. #[derive(Clone)] pub(crate) struct CtlZData { + /// Z polynomial values. pub(crate) z: PolynomialValues, + /// Cross-table lookup challenge. pub(crate) challenge: GrandProductChallenge, + /// Column linear combination for the current table. pub(crate) columns: Vec>, + /// Filter column for the current table. It evaluates to either 1 or 0. pub(crate) filter_column: Option>, } impl CtlData { + /// Returns the number of cross-table lookup polynomials. pub fn len(&self) -> usize { self.zs_columns.len() } + /// Returns whether there are no cross-table lookups. pub fn is_empty(&self) -> bool { self.zs_columns.is_empty() } + /// Returns all the cross-table lookup polynomials. pub fn z_polys(&self) -> Vec> { self.zs_columns .iter() @@ -449,6 +521,11 @@ pub(crate) fn get_grand_product_challenge_set_target< GrandProductChallengeSet { challenges } } +/// Generates all the cross-table lookup data, for all tables. +/// - `trace_poly_values` corresponds to the trace values for all tables. +/// - `cross_table_lookups` corresponds to all the cross-table lookups, i.e. the looked and looking tables, as described in `CrossTableLookup`. +/// - `ctl_challenges` corresponds to the challenges used for CTLs. +/// For each `CrossTableLookup`, and each looking/looked table, the partial products for the CTL are computed, and added to the said table's `CtlZData`. pub(crate) fn cross_table_lookup_data( trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], @@ -499,6 +576,14 @@ pub(crate) fn cross_table_lookup_data( ctl_data_per_table } +/// Computes the cross-table lookup partial products for one table and given column linear combinations. +/// `trace` represents the trace values for the given table. +/// `columns` are all the column linear combinations to evaluate. +/// `filter_column` is a column linear combination used to determine whether a row should be selected. +/// `challenge` is a cross-table lookup challenge. +/// The initial product `p` is 1. +/// For each row, if the `filter_column` evaluates to 1, then the rows is selected. All the column linear combinations are evaluated at said row. All those evaluations are combined using the challenge to get a value `v`. +/// The product is updated: `p *= v`, and is pushed to the vector of partial products. fn partial_products( trace: &[PolynomialValues], columns: &[Column], @@ -529,6 +614,7 @@ fn partial_products( res.into() } +/// Data necessary to check the cross-table lookups of a given table. #[derive(Clone)] pub struct CtlCheckVars<'a, F, FE, P, const D2: usize> where @@ -536,22 +622,29 @@ where FE: FieldExtension, P: PackedField, { + /// Evaluation of the trace polynomials at point `zeta`. pub(crate) local_z: P, + /// Evaluation of the trace polynomials at point `g * zeta` pub(crate) next_z: P, + /// Cross-table lookup challenges. pub(crate) challenges: GrandProductChallenge, + /// Column linear combinations of the `CrossTableLookup`s. pub(crate) columns: &'a [Column], + /// Column linear combination that evaluates to either 1 or 0. pub(crate) filter_column: &'a Option>, } impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { + /// Extracts the `CtlCheckVars` for each STARK. pub(crate) fn from_proofs>( proofs: &[StarkProofWithMetadata; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, num_lookup_columns: &[usize; NUM_TABLES], ) -> [Vec; NUM_TABLES] { + // Get all cross-table lookup polynomial openings for each STARK proof. let mut ctl_zs = proofs .iter() .zip(num_lookup_columns) @@ -563,6 +656,7 @@ impl<'a, F: RichField + Extendable, const D: usize> }) .collect::>(); + // Put each cross-table lookup polynomial into the correct table data: if a CTL polynomial is extracted from looking/looked table t, then we add it to the `CtlCheckVars` of table t. let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); for CrossTableLookup { looking_tables, @@ -595,7 +689,10 @@ impl<'a, F: RichField + Extendable, const D: usize> } } -/// CTL Z partial products are upside down: the complete product is on the first row, and +/// Checks the cross-table lookup Z polynomials for each table: +/// - Checks that the CTL `Z` partial products are correctly updated. +/// - Checks that the final value of the CTL product is the combination of all STARKs' CTL polynomials. +/// CTL `Z` partial products are upside down: the complete product is on the first row, and /// the first term is on the last row. This allows the transition constraint to be: /// Z(w) = Z(gw) * combine(w) where combine is called on the local row /// and not the next. This enables CTLs across two rows. @@ -621,6 +718,7 @@ pub(crate) fn eval_cross_table_lookup_checks { + /// Evaluation of the trace polynomials at point `zeta`. pub(crate) local_z: ExtensionTarget, + /// Evaluation of the trace polynomials at point `g * zeta`. pub(crate) next_z: ExtensionTarget, + /// Cross-table lookup challenges. pub(crate) challenges: GrandProductChallenge, + /// Column linear combinations of the `CrossTableLookup`s. pub(crate) columns: &'a [Column], + /// Column linear combination that evaluates to either 1 or 0. pub(crate) filter_column: &'a Option>, } impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { + /// Circuit version of `from_proofs`. Extracts the `CtlCheckVarsTarget` for each STARK. pub(crate) fn from_proof( table: Table, proof: &StarkProofTarget, @@ -657,6 +763,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { ctl_challenges: &'a GrandProductChallengeSet, num_lookup_columns: usize, ) -> Vec { + // Get all cross-table lookup polynomial openings for each STARK proof. let mut ctl_zs = { let openings = &proof.openings; let ctl_zs = openings.auxiliary_polys.iter().skip(num_lookup_columns); @@ -667,6 +774,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { ctl_zs.zip(ctl_zs_next) }; + // Put each cross-table lookup polynomial into the correct table data: if a CTL polynomial is extracted from looking/looked table t, then we add it to the `CtlCheckVars` of table t. let mut ctl_vars = vec![]; for CrossTableLookup { looking_tables, @@ -704,6 +812,13 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { } } +/// Circuit version of `eval_cross_table_lookup_checks`. Checks the cross-table lookups for each table: +/// - Checks that the CTL `Z` partial products are correctly updated. +/// - Checks that the final value of the CTL product is the combination of all STARKs' CTL polynomials. +/// CTL `Z` partial products are upside down: the complete product is on the first row, and +/// the first term is on the last row. This allows the transition constraint to be: +/// Z(w) = Z(gw) * combine(w) where combine is called on the local row +/// and not the next. This enables CTLs across two rows. pub(crate) fn eval_cross_table_lookup_checks_circuit< S: Stark, F: RichField + Extendable, @@ -742,12 +857,14 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< builder.mul_add_extension(filter, x, tmp) // filter * x + 1 - filter } + // Compute all linear combinations on the current table, and combine them using the challenge. let evals = columns .iter() .map(|c| c.eval_with_next_circuit(builder, local_values, next_values)) .collect::>(); let combined = challenges.combine_circuit(builder, &evals); + // If the filter evaluates to 1, then the previously computed combination is used. let select = select(builder, local_filter, combined); // Check value of `Z(g^(n-1))` @@ -759,6 +876,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } } +/// Verifies all cross-table lookups. pub(crate) fn verify_cross_table_lookups, const D: usize>( cross_table_lookups: &[CrossTableLookup], ctl_zs_first: [Vec; NUM_TABLES], @@ -774,15 +892,19 @@ pub(crate) fn verify_cross_table_lookups, const D: }, ) in cross_table_lookups.iter().enumerate() { + // Get elements looking into `looked_table` that are not associated to any STARK. let extra_product_vec = &ctl_extra_looking_products[looked_table.table as usize]; for c in 0..config.num_challenges { + // Compute the combination of all looking table CTL polynomial openings. let looking_zs_prod = looking_tables .iter() .map(|table| *ctl_zs_openings[table.table as usize].next().unwrap()) .product::() * extra_product_vec[c]; + // Get the looked table CTL polynomial opening. let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); + // Ensure that the combination of looking table openings is equal to the looked table opening. ensure!( looking_zs_prod == looked_z, "Cross-table lookup {:?} verification failed.", @@ -795,6 +917,7 @@ pub(crate) fn verify_cross_table_lookups, const D: Ok(()) } +/// Circuit version of `verify_cross_table_lookups`. Verifies all cross-table lookups. pub(crate) fn verify_cross_table_lookups_circuit, const D: usize>( builder: &mut CircuitBuilder, cross_table_lookups: Vec>, @@ -808,8 +931,10 @@ pub(crate) fn verify_cross_table_lookups_circuit, c looked_table, } in cross_table_lookups.into_iter() { + // Get elements looking into `looked_table` that are not associated to any STARK. let extra_product_vec = &ctl_extra_looking_products[looked_table.table as usize]; for c in 0..inner_config.num_challenges { + // Compute the combination of all looking table CTL polynomial openings. let mut looking_zs_prod = builder.mul_many( looking_tables .iter() @@ -818,7 +943,9 @@ pub(crate) fn verify_cross_table_lookups_circuit, c looking_zs_prod = builder.mul(looking_zs_prod, extra_product_vec[c]); + // Get the looked table CTL polynomial opening. let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); + // Verify that the combination of looking table openings is equal to the looked table opening. builder.connect(looked_z, looking_zs_prod); } } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 62182cd254..c4d19dca01 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -9,6 +9,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; use serde::{Deserialize, Serialize}; +use smt_utils::smt::hash_serialize_state; use GlobalMetadata::{ ReceiptTrieRootDigestAfter, ReceiptTrieRootDigestBefore, StateTrieRootDigestAfter, StateTrieRootDigestBefore, TransactionTrieRootDigestAfter, TransactionTrieRootDigestBefore, @@ -69,11 +70,11 @@ pub struct GenerationInputs { pub addresses: Vec

, } -#[derive(Clone, Debug, Deserialize, Serialize, Default)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct TrieInputs { - /// A partial version of the state trie prior to these transactions. It should include all nodes + /// A serialized partial version of the state SMT prior to these transactions. It should include all nodes /// that will be accessed by these transactions. - pub state_trie: HashedPartialTrie, + pub state_smt: Vec, /// A partial version of the transaction trie prior to these transactions. It should include all /// nodes that will be accessed by these transactions. @@ -88,6 +89,18 @@ pub struct TrieInputs { pub storage_tries: Vec<(H256, HashedPartialTrie)>, } +impl Default for TrieInputs { + fn default() -> Self { + Self { + // First 2 zeros are for the default empty node. The next 2 are for the current empty state trie root. + state_smt: vec![U256::zero(); 4], + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + } + } +} + fn apply_metadata_and_tries_memops, const D: usize>( state: &mut GenerationState, inputs: &GenerationInputs, @@ -124,7 +137,7 @@ fn apply_metadata_and_tries_memops, const D: usize> ), ( GlobalMetadata::StateTrieRootDigestBefore, - h2u(tries.state_trie.hash()), + h2u(hash_serialize_state(&tries.state_smt)), ), ( GlobalMetadata::TransactionTrieRootDigestBefore, diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index 20e8b30b60..726c20225b 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -1,17 +1,15 @@ -use std::collections::HashMap; use std::ops::Deref; use bytes::Bytes; -use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{Address, BigEndianHash, H256, U256, U512}; +use ethereum_types::{Address, H256, U256}; use keccak_hash::keccak; use rlp::{Decodable, DecoderError, Encodable, PayloadInfo, Rlp, RlpStream}; use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; -use crate::witness::errors::{ProgramError, ProverInputError}; +use crate::witness::errors::ProgramError; use crate::Node; #[derive(RlpEncodable, RlpDecodable, Debug)] @@ -48,6 +46,12 @@ pub struct LegacyReceiptRlp { pub logs: Vec, } +pub(crate) fn state_smt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { + let mut inputs = state_smt_prover_inputs(trie_inputs); + inputs.reverse(); + inputs +} + pub(crate) fn all_mpt_prover_inputs_reversed( trie_inputs: &TrieInputs, ) -> Result, ProgramError> { @@ -86,26 +90,18 @@ pub(crate) fn parse_receipts(rlp: &[u8]) -> Result, ProgramError> { Ok(parsed_receipt) } + +pub(crate) fn state_smt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { + let len = trie_inputs.state_smt.len(); + let mut v = vec![len.into()]; + v.extend(trie_inputs.state_smt.iter()); + v +} + /// Generate prover inputs for the initial MPT data, in the format expected by `mpt/load.asm`. pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Result, ProgramError> { let mut prover_inputs = vec![]; - let storage_tries_by_state_key = trie_inputs - .storage_tries - .iter() - .map(|(hashed_address, storage_trie)| { - let key = Nibbles::from_bytes_be(hashed_address.as_bytes()).unwrap(); - (key, storage_trie) - }) - .collect(); - - mpt_prover_inputs_state_trie( - &trie_inputs.state_trie, - empty_nibbles(), - &mut prover_inputs, - &storage_tries_by_state_key, - )?; - mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { let mut parsed_txn = vec![U256::from(rlp.len())]; parsed_txn.extend(rlp.iter().copied().map(U256::from)); @@ -179,106 +175,6 @@ where } } -/// Like `mpt_prover_inputs`, but for the state trie, which is a bit unique since each value -/// leads to a storage trie which we recursively traverse. -pub(crate) fn mpt_prover_inputs_state_trie( - trie: &HashedPartialTrie, - key: Nibbles, - prover_inputs: &mut Vec, - storage_tries_by_state_key: &HashMap, -) -> Result<(), ProgramError> { - prover_inputs.push((PartialTrieType::of(trie) as u32).into()); - match trie.deref() { - Node::Empty => Ok(()), - Node::Hash(h) => { - prover_inputs.push(U256::from_big_endian(h.as_bytes())); - Ok(()) - } - Node::Branch { children, value } => { - if !value.is_empty() { - return Err(ProgramError::ProverInputError( - ProverInputError::InvalidMptInput, - )); - } - prover_inputs.push(U256::zero()); // value_present = 0 - - for (i, child) in children.iter().enumerate() { - let extended_key = key.merge_nibbles(&Nibbles { - count: 1, - packed: i.into(), - }); - mpt_prover_inputs_state_trie( - child, - extended_key, - prover_inputs, - storage_tries_by_state_key, - )?; - } - - Ok(()) - } - Node::Extension { nibbles, child } => { - prover_inputs.push(nibbles.count.into()); - prover_inputs.push( - nibbles - .try_into_u256() - .map_err(|_| ProgramError::IntegerTooLarge)?, - ); - let extended_key = key.merge_nibbles(nibbles); - mpt_prover_inputs_state_trie( - child, - extended_key, - prover_inputs, - storage_tries_by_state_key, - ) - } - Node::Leaf { nibbles, value } => { - let account: AccountRlp = rlp::decode(value).map_err(|_| ProgramError::InvalidRlp)?; - let AccountRlp { - nonce, - balance, - storage_root, - code_hash, - } = account; - - let storage_hash_only = HashedPartialTrie::new(Node::Hash(storage_root)); - let merged_key = key.merge_nibbles(nibbles); - let storage_trie: &HashedPartialTrie = storage_tries_by_state_key - .get(&merged_key) - .copied() - .unwrap_or(&storage_hash_only); - - assert_eq!(storage_trie.hash(), storage_root, - "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); - - prover_inputs.push(nibbles.count.into()); - prover_inputs.push( - nibbles - .try_into_u256() - .map_err(|_| ProgramError::IntegerTooLarge)?, - ); - prover_inputs.push(nonce); - prover_inputs.push(balance); - mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value)?; - prover_inputs.push(code_hash.into_uint()); - - Ok(()) - } - } -} - -fn parse_storage_value(value_rlp: &[u8]) -> Result, ProgramError> { - let value: U256 = rlp::decode(value_rlp).map_err(|_| ProgramError::InvalidRlp)?; - Ok(vec![value]) -} - -fn empty_nibbles() -> Nibbles { - Nibbles { - count: 0, - packed: U512::zero(), - } -} - pub mod transaction_testing { use super::*; diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 205dff7c66..cf618d1717 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -40,6 +40,7 @@ impl GenerationState { "sf" => self.run_sf(input_fn), "ffe" => self.run_ffe(input_fn), "mpt" => self.run_mpt(), + "smt" => self.run_smt(input_fn), "rlp" => self.run_rlp(), "current_hash" => self.run_current_hash(), "account_code" => self.run_account_code(input_fn), @@ -120,6 +121,19 @@ impl GenerationState { .ok_or(ProgramError::ProverInputError(OutOfMptData)) } + /// SMT data. + fn run_smt(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "state" => self + .state_smt_prover_inputs + .pop() + .ok_or(ProgramError::ProverInputError(OutOfSmtData)), + "transactions" => todo!(), + "receipts" => todo!(), + _ => panic!("Invalid SMT"), + } + } + /// RLP data. fn run_rlp(&mut self) -> Result { self.rlp_prover_inputs diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index aec01e1b71..79e0cfdb85 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -6,7 +6,7 @@ use plonky2::field::types::Field; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::generation::mpt::all_mpt_prover_inputs_reversed; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, state_smt_prover_inputs_reversed}; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; @@ -35,6 +35,10 @@ pub(crate) struct GenerationState { /// via `pop()`. pub(crate) mpt_prover_inputs: Vec, + /// Prover inputs containing SMT data, in reverse order so that the next input can be obtained + /// via `pop()`. + pub(crate) state_smt_prover_inputs: Vec, + /// Prover inputs containing RLP data, in reverse order so that the next input can be obtained /// via `pop()`. pub(crate) rlp_prover_inputs: Vec, @@ -53,7 +57,7 @@ pub(crate) struct GenerationState { impl GenerationState { pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Result { log::debug!("Input signed_txns: {:?}", &inputs.signed_txns); - log::debug!("Input state_trie: {:?}", &inputs.tries.state_trie); + log::debug!("Input state_trie: {:?}", &inputs.tries.state_smt); log::debug!( "Input transactions_trie: {:?}", &inputs.tries.transactions_trie @@ -61,6 +65,7 @@ impl GenerationState { log::debug!("Input receipts_trie: {:?}", &inputs.tries.receipts_trie); log::debug!("Input storage_tries: {:?}", &inputs.tries.storage_tries); log::debug!("Input contract_code: {:?}", &inputs.contract_code); + let state_smt_prover_inputs = state_smt_prover_inputs_reversed(&inputs.tries); let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries)?; let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let bignum_modmul_result_limbs = Vec::new(); @@ -72,6 +77,7 @@ impl GenerationState { traces: Traces::default(), next_txn_index: 0, mpt_prover_inputs, + state_smt_prover_inputs, rlp_prover_inputs, state_key_to_address: HashMap::new(), bignum_modmul_result_limbs, diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 2745d03302..19524e2d8b 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -33,22 +33,26 @@ pub(crate) const NUM_ROUNDS: usize = 24; /// Number of 64-bit elements in the Keccak permutation input. pub(crate) const NUM_INPUTS: usize = 25; +/// Create vector of `Columns` corresponding to the permutation input limbs. pub fn ctl_data_inputs() -> Vec> { let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect(); res.push(Column::single(TIMESTAMP)); res } +/// Create vector of `Columns` corresponding to the permutation output limbs. pub fn ctl_data_outputs() -> Vec> { let mut res: Vec<_> = Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)).collect(); res.push(Column::single(TIMESTAMP)); res } +/// CTL filter for the first round of the Keccak permutation. pub fn ctl_filter_inputs() -> Column { Column::single(reg_step(0)) } +/// CTL filter for the final round of the Keccak permutation. pub fn ctl_filter_outputs() -> Column { Column::single(reg_step(NUM_ROUNDS - 1)) } diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs index 431c09e092..f10dfbfd9a 100644 --- a/evm/src/keccak_sponge/columns.rs +++ b/evm/src/keccak_sponge/columns.rs @@ -3,17 +3,27 @@ use std::mem::{size_of, transmute}; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +/// Total number of sponge bytes: number of rate bytes + number of capacity bytes. pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; +/// Total number of 32-bit limbs in the sponge. pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +/// Number of non-digest bytes. pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize = (KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4; +/// Number of rate bytes. pub(crate) const KECCAK_RATE_BYTES: usize = 136; +/// Number of 32-bit rate limbs. pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; +/// Number of capacity bytes. pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; +/// Number of 32-bit capacity limbs. pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; +/// Number of output digest bytes used during the squeezing phase. pub(crate) const KECCAK_DIGEST_BYTES: usize = 32; +/// Number of 32-bit digest limbs. pub(crate) const KECCAK_DIGEST_U32S: usize = KECCAK_DIGEST_BYTES / 4; +/// A view of `KeccakSpongeStark`'s columns. #[repr(C)] #[derive(Eq, PartialEq, Debug)] pub(crate) struct KeccakSpongeColumnsView { @@ -21,9 +31,11 @@ pub(crate) struct KeccakSpongeColumnsView { /// not a padding byte; 0 otherwise. pub is_full_input_block: T, - // The base address at which we will read the input block. + /// The context of the base addresss at which we will read the input block. pub context: T, + /// The segment of the base address at which we will read the input block. pub segment: T, + /// The virtual address at which we will read the input block. pub virt: T, /// The timestamp at which inputs should be read from memory. @@ -66,6 +78,7 @@ pub(crate) struct KeccakSpongeColumnsView { } // `u8` is guaranteed to have a `size_of` of 1. +/// Number of columns in `KeccakSpongeStark`. pub const NUM_KECCAK_SPONGE_COLUMNS: usize = size_of::>(); impl From<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { @@ -117,4 +130,5 @@ const fn make_col_map() -> KeccakSpongeColumnsView { } } +/// Map between the `KeccakSponge` columns and (0..`NUM_KECCAK_SPONGE_COLUMNS`) pub(crate) const KECCAK_SPONGE_COL_MAP: KeccakSpongeColumnsView = make_col_map(); diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index e491252ba8..03c1f811a0 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -23,6 +23,11 @@ use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; +/// Creates the vector of `Columns` corresponding to: +/// - the address in memory of the inputs, +/// - the length of the inputs, +/// - the timestamp at which the inputs are read from memory, +/// - the output limbs of the Keccak sponge. pub(crate) fn ctl_looked_data() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; let mut outputs = Vec::with_capacity(8); @@ -47,6 +52,9 @@ pub(crate) fn ctl_looked_data() -> Vec> { .collect() } +/// Creates the vector of `Columns` corresponding to the inputs of the Keccak sponge. +/// This is used to check that the inputs of the sponge correspond to the inputs +/// given by `KeccakStark`. pub(crate) fn ctl_looking_keccak_inputs() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; let mut res: Vec<_> = Column::singles( @@ -62,6 +70,9 @@ pub(crate) fn ctl_looking_keccak_inputs() -> Vec> { res } +/// Creates the vector of `Columns` corresponding to the outputs of the Keccak sponge. +/// This is used to check that the outputs of the sponge correspond to the outputs +/// given by `KeccakStark`. pub(crate) fn ctl_looking_keccak_outputs() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; @@ -83,6 +94,7 @@ pub(crate) fn ctl_looking_keccak_outputs() -> Vec> { res } +/// Creates the vector of `Columns` corresponding to the address and value of the byte being read from memory. pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; @@ -111,12 +123,16 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { res } +/// Returns the number of `KeccakSponge` tables looking into the `LogicStark`. pub(crate) fn num_logic_ctls() -> usize { const U8S_PER_CTL: usize = 32; ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL) } -/// CTL for performing the `i`th logic CTL. Since we need to do 136 byte XORs, and the logic CTL can +/// Creates the vector of `Columns` required to perform the `i`th logic CTL. +/// It is comprised of the ÌS_XOR` flag, the two inputs and the output +/// of the XOR operation. +/// Since we need to do 136 byte XORs, and the logic CTL can /// XOR 32 bytes per CTL, there are 5 such CTLs. pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { const U32S_PER_CTL: usize = 8; @@ -156,6 +172,7 @@ pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { res } +/// CTL filter for the final block rows of the `KeccakSponge` table. pub(crate) fn ctl_looked_filter() -> Column { // The CPU table is only interested in our final-block rows, since those contain the final // sponge output. @@ -181,6 +198,7 @@ pub(crate) fn ctl_looking_logic_filter() -> Column { Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len)) } +/// CTL filter for looking at the input and output in the Keccak table. pub(crate) fn ctl_looking_keccak_filter() -> Column { let cols = KECCAK_SPONGE_COL_MAP; Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len)) @@ -199,12 +217,14 @@ pub(crate) struct KeccakSpongeOp { pub(crate) input: Vec, } +/// Structure representing the `KeccakSponge` STARK, which carries out the sponge permutation. #[derive(Copy, Clone, Default)] pub struct KeccakSpongeStark { f: PhantomData, } impl, const D: usize> KeccakSpongeStark { + /// Generates the trace polynomial values for the `KeccakSponge`STARK. pub(crate) fn generate_trace( &self, operations: Vec, @@ -227,6 +247,8 @@ impl, const D: usize> KeccakSpongeStark { trace_polys } + /// Generates the trace rows given the vector of `KeccakSponge` operations. + /// The trace is padded to a power of two with all-zero rows. fn generate_trace_rows( &self, operations: Vec, @@ -237,9 +259,11 @@ impl, const D: usize> KeccakSpongeStark { .map(|op| op.input.len() / KECCAK_RATE_BYTES + 1) .sum(); let mut rows = Vec::with_capacity(base_len.max(min_rows).next_power_of_two()); + // Generate active rows. for op in operations { rows.extend(self.generate_rows_for_op(op)); } + // Pad the trace. let padded_rows = rows.len().max(min_rows).next_power_of_two(); for _ in rows.len()..padded_rows { rows.push(self.generate_padding_row()); @@ -247,6 +271,9 @@ impl, const D: usize> KeccakSpongeStark { rows } + /// Generates the rows associated to a given operation: + /// Performs a Keccak sponge permutation and fills the STARK's rows accordingly. + /// The number of rows is the number of input chunks of size `KECCAK_RATE_BYTES`. fn generate_rows_for_op(&self, op: KeccakSpongeOp) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { let mut rows = Vec::with_capacity(op.input.len() / KECCAK_RATE_BYTES + 1); @@ -255,6 +282,7 @@ impl, const D: usize> KeccakSpongeStark { let mut input_blocks = op.input.chunks_exact(KECCAK_RATE_BYTES); let mut already_absorbed_bytes = 0; for block in input_blocks.by_ref() { + // We compute the updated state of the sponge. let row = self.generate_full_input_row( &op, already_absorbed_bytes, @@ -262,6 +290,9 @@ impl, const D: usize> KeccakSpongeStark { block.try_into().unwrap(), ); + // We update the state limbs for the next block absorption. + // The first `KECCAK_DIGEST_U32s` limbs are stored as bytes after the computation, + // so we recompute the corresponding `u32` and update the first state limbs. sponge_state[..KECCAK_DIGEST_U32S] .iter_mut() .zip(row.updated_digest_state_bytes.chunks_exact(4)) @@ -273,6 +304,8 @@ impl, const D: usize> KeccakSpongeStark { .sum(); }); + // The rest of the bytes are already stored in the expected form, so we can directly + // update the state with the stored values. sponge_state[KECCAK_DIGEST_U32S..] .iter_mut() .zip(row.partial_updated_state_u32s) @@ -295,6 +328,8 @@ impl, const D: usize> KeccakSpongeStark { rows } + /// Generates a row where all bytes are input bytes, not padding bytes. + /// This includes updating the state sponge with a single absorption. fn generate_full_input_row( &self, op: &KeccakSpongeOp, @@ -313,6 +348,10 @@ impl, const D: usize> KeccakSpongeStark { row } + /// Generates a row containing the last input bytes. + /// On top of computing one absorption and padding the input, + /// we indicate the last non-padding input byte by setting + /// `row.is_final_input_len[final_inputs.len()]` to 1. fn generate_final_row( &self, op: &KeccakSpongeOp, @@ -345,6 +384,9 @@ impl, const D: usize> KeccakSpongeStark { /// Generate fields that are common to both full-input-block rows and final-block rows. /// Also updates the sponge state with a single absorption. + /// Given a state S = R || C and a block input B, + /// - R is updated with R XOR B, + /// - S is replaced by keccakf_u32s(S). fn generate_common_fields( row: &mut KeccakSpongeColumnsView, op: &KeccakSpongeOp, diff --git a/evm/src/logic.rs b/evm/src/logic.rs index 319dfab2d0..2382897b1a 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -19,28 +19,34 @@ use crate::logic::columns::NUM_COLUMNS; use crate::stark::Stark; use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive, trace_rows_to_poly_values}; -// Total number of bits per input/output. +/// Total number of bits per input/output. const VAL_BITS: usize = 256; -// Number of bits stored per field element. Ensure that this fits; it is not checked. +/// Number of bits stored per field element. Ensure that this fits; it is not checked. pub(crate) const PACKED_LIMB_BITS: usize = 32; -// Number of field elements needed to store each input/output at the specified packing. +/// Number of field elements needed to store each input/output at the specified packing. const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS); +/// `LogicStark` columns. pub(crate) mod columns { use std::cmp::min; use std::ops::Range; use super::{PACKED_LEN, PACKED_LIMB_BITS, VAL_BITS}; + /// 1 if this is an AND operation, 0 otherwise. pub const IS_AND: usize = 0; + /// 1 if this is an OR operation, 0 otherwise. pub const IS_OR: usize = IS_AND + 1; + /// 1 if this is a XOR operation, 0 otherwise. pub const IS_XOR: usize = IS_OR + 1; - // The inputs are decomposed into bits. + /// First input, decomposed into bits. pub const INPUT0: Range = (IS_XOR + 1)..(IS_XOR + 1) + VAL_BITS; + /// Second input, decomposed into bits. pub const INPUT1: Range = INPUT0.end..INPUT0.end + VAL_BITS; - // The result is packed in limbs of `PACKED_LIMB_BITS` bits. + /// The result is packed in limbs of `PACKED_LIMB_BITS` bits. pub const RESULT: Range = INPUT1.end..INPUT1.end + PACKED_LEN; + /// Returns the column range for each 32 bit chunk in the input. pub fn limb_bit_cols_for_input(input_bits: Range) -> impl Iterator> { (0..PACKED_LEN).map(move |i| { let start = input_bits.start + i * PACKED_LIMB_BITS; @@ -49,9 +55,11 @@ pub(crate) mod columns { }) } + /// Number of columns in `LogicStark`. pub const NUM_COLUMNS: usize = RESULT.end; } +/// Creates the vector of `Columns` corresponding to the opcode, the two inputs and the output of the logic operation. pub fn ctl_data() -> Vec> { // We scale each filter flag with the associated opcode value. // If a logic operation is happening on the CPU side, the CTL @@ -68,15 +76,18 @@ pub fn ctl_data() -> Vec> { res } +/// CTL filter for logic operations. pub fn ctl_filter() -> Column { Column::sum([columns::IS_AND, columns::IS_OR, columns::IS_XOR]) } +/// Structure representing the Logic STARK, which computes all logic operations. #[derive(Copy, Clone, Default)] pub struct LogicStark { pub f: PhantomData, } +/// Logic operations. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub(crate) enum Op { And, @@ -85,6 +96,7 @@ pub(crate) enum Op { } impl Op { + /// Returns the output of the current Logic operation. pub(crate) fn result(&self, a: U256, b: U256) -> U256 { match self { Op::And => a & b, @@ -94,6 +106,8 @@ impl Op { } } +/// A logic operation over `U256`` words. It contains an operator, +/// either `AND`, `OR` or `XOR`, two inputs and its expected result. #[derive(Debug)] pub(crate) struct Operation { operator: Op, @@ -103,6 +117,8 @@ pub(crate) struct Operation { } impl Operation { + /// Computes the expected result of an operator with the two provided inputs, + /// and returns the associated logic `Operation`. pub(crate) fn new(operator: Op, input0: U256, input1: U256) -> Self { let result = operator.result(input0, input1); Operation { @@ -113,6 +129,7 @@ impl Operation { } } + /// Given an `Operation`, fills a row with the corresponding flag, inputs and output. fn into_row(self) -> [F; NUM_COLUMNS] { let Operation { operator, @@ -140,17 +157,20 @@ impl Operation { } impl LogicStark { + /// Generates the trace polynomials for `LogicStark`. pub(crate) fn generate_trace( &self, operations: Vec, min_rows: usize, timing: &mut TimingTree, ) -> Vec> { + // First, turn all provided operations into rows in `LogicStark`, and pad if necessary. let trace_rows = timed!( timing, "generate trace rows", self.generate_trace_rows(operations, min_rows) ); + // Generate the trace polynomials from the trace values. let trace_polys = timed!( timing, "convert to PolynomialValues", @@ -159,6 +179,8 @@ impl LogicStark { trace_polys } + /// Generate the `LogicStark` traces based on the provided vector of operations. + /// The trace is padded to a power of two with all-zero rows. fn generate_trace_rows( &self, operations: Vec, diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 9a41323200..b77a799cf8 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -5,10 +5,18 @@ use crate::memory::VALUE_LIMBS; // Columns for memory operations, ordered by (addr, timestamp). /// 1 if this is an actual memory operation, or 0 if it's a padding row. pub(crate) const FILTER: usize = 0; +/// Each memory operation is associated to a unique timestamp. +/// For a given memory operation `op_i`, its timestamp is computed as `C * N + i` +/// where `C` is the CPU clock at that time, `N` is the number of general memory channels, +/// and `i` is the index of the memory channel at which the memory operation is performed. pub(crate) const TIMESTAMP: usize = FILTER + 1; +/// 1 if this is a read operation, 0 if it is a write one. pub(crate) const IS_READ: usize = TIMESTAMP + 1; +/// The execution context of this address. pub(crate) const ADDR_CONTEXT: usize = IS_READ + 1; +/// The segment section of this address. pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; +/// The virtual address within the given context and segment. pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; // Eight 32-bit limbs hold a total of 256 bits. diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 4a63f50a7a..74aa8787d7 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -27,6 +27,11 @@ use crate::stark::Stark; use crate::witness::memory::MemoryOpKind::Read; use crate::witness::memory::{MemoryAddress, MemoryOp}; +/// Creates the vector of `Columns` corresponding to: +/// - the memory operation type, +/// - the address in memory of the element being read/written, +/// - the value being read/written, +/// - the timestamp at which the element is read/written. pub fn ctl_data() -> Vec> { let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); @@ -35,6 +40,7 @@ pub fn ctl_data() -> Vec> { res } +/// CTL filter for memory operations. pub fn ctl_filter() -> Column { Column::single(FILTER) } diff --git a/evm/src/memory/mod.rs b/evm/src/memory/mod.rs index 4cdfd1be5a..c61119530f 100644 --- a/evm/src/memory/mod.rs +++ b/evm/src/memory/mod.rs @@ -1,7 +1,13 @@ +//! The Memory STARK is used to handle all memory read and write operations happening when +//! executing the EVM. Each non-dummy row of the table correspond to a single operation, +//! and rows are ordered by the timestamp associated to each memory operation. + pub mod columns; pub mod memory_stark; pub mod segments; // TODO: Move to CPU module, now that channels have been removed from the memory table. pub(crate) const NUM_CHANNELS: usize = crate::cpu::membus::NUM_CHANNELS; +/// The number of limbs holding the value at a memory address. +/// Eight limbs of 32 bits can hold a `U256`. pub(crate) const VALUE_LIMBS: usize = 8; diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 43561e8c8b..c8e7214e88 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -23,36 +23,52 @@ use crate::cross_table_lookup::GrandProductChallengeSet; /// A STARK proof for each table, plus some metadata used to create recursive wrapper proofs. #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { + /// Proofs for all the different STARK modules. pub stark_proofs: [StarkProofWithMetadata; NUM_TABLES], + /// Cross-table lookup challenges. pub(crate) ctl_challenges: GrandProductChallengeSet, + /// Public memory values used for the recursive proofs. pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { + /// Returns the degree (i.e. the trace length) of each STARK. pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { core::array::from_fn(|i| self.stark_proofs[i].proof.recover_degree_bits(config)) } } +/// Randomness for all STARKs. pub(crate) struct AllProofChallenges, const D: usize> { + /// Randomness used in each STARK proof. pub stark_challenges: [StarkProofChallenges; NUM_TABLES], + /// Randomness used for cross-table lookups. It is shared by all STARKs. pub ctl_challenges: GrandProductChallengeSet, } /// Memory values which are public. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct PublicValues { + /// Trie hashes before the execution of the local state transition pub trie_roots_before: TrieRoots, + /// Trie hashes after the execution of the local state transition. pub trie_roots_after: TrieRoots, + /// Block metadata: it remains unchanged withing a block. pub block_metadata: BlockMetadata, + /// 256 previous block hashes and current block's hash. pub block_hashes: BlockHashes, + /// Extra block data that is specific to the current proof. pub extra_block_data: ExtraBlockData, } +/// Trie hashes. #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct TrieRoots { + /// State trie hash. pub state_root: H256, + /// Transaction trie hash. pub transactions_root: H256, + /// Receipts trie hash. pub receipts_root: H256, } @@ -141,14 +157,20 @@ pub struct ExtraBlockData { /// Note: All the larger integers are encoded with 32-bit limbs in little-endian order. #[derive(Eq, PartialEq, Debug)] pub struct PublicValuesTarget { + /// Trie hashes before the execution of the local state transition. pub trie_roots_before: TrieRootsTarget, + /// Trie hashes after the execution of the local state transition. pub trie_roots_after: TrieRootsTarget, + /// Block metadata: it remains unchanged withing a block. pub block_metadata: BlockMetadataTarget, + /// 256 previous block hashes and current block's hash. pub block_hashes: BlockHashesTarget, + /// Extra block data that is specific to the current proof. pub extra_block_data: ExtraBlockDataTarget, } impl PublicValuesTarget { + /// Serializes public value targets. pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { let TrieRootsTarget { state_root: state_root_before, @@ -221,6 +243,7 @@ impl PublicValuesTarget { Ok(()) } + /// Deserializes public value targets. pub fn from_buffer(buffer: &mut Buffer) -> IoResult { let trie_roots_before = TrieRootsTarget { state_root: buffer.read_target_array()?, @@ -271,6 +294,9 @@ impl PublicValuesTarget { }) } + /// Extracts public value `Target`s from the given public input `Target`s. + /// Public values are always the first public inputs added to the circuit, + /// so we can start extracting at index 0. pub fn from_public_inputs(pis: &[Target]) -> Self { assert!( pis.len() @@ -308,6 +334,7 @@ impl PublicValuesTarget { } } + /// Returns the public values in `pv0` or `pv1` depening on `condition`. pub fn select, const D: usize>( builder: &mut CircuitBuilder, condition: BoolTarget, @@ -349,16 +376,24 @@ impl PublicValuesTarget { } } +/// Circuit version of `TrieRoots`. +/// `Target`s for trie hashes. Since a `Target` holds a 32-bit limb, each hash requires 8 `Target`s. #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct TrieRootsTarget { + /// Targets for the state trie hash. pub state_root: [Target; 8], + /// Targets for the transactions trie hash. pub transactions_root: [Target; 8], + /// Targets for the receipts trie hash. pub receipts_root: [Target; 8], } impl TrieRootsTarget { + /// Number of `Target`s required for all trie hashes. pub const SIZE: usize = 24; + /// Extracts trie hash `Target`s for all tries from the provided public input `Target`s. + /// The provided `pis` should start with the trie hashes. pub fn from_public_inputs(pis: &[Target]) -> Self { let state_root = pis[0..8].try_into().unwrap(); let transactions_root = pis[8..16].try_into().unwrap(); @@ -371,6 +406,8 @@ impl TrieRootsTarget { } } + /// If `condition`, returns the trie hashes in `tr0`, + /// otherwise returns the trie hashes in `tr1`. pub fn select, const D: usize>( builder: &mut CircuitBuilder, condition: BoolTarget, @@ -394,6 +431,7 @@ impl TrieRootsTarget { } } + /// Connects the trie hashes in `tr0` and in `tr1`. pub fn connect, const D: usize>( builder: &mut CircuitBuilder, tr0: Self, @@ -407,23 +445,39 @@ impl TrieRootsTarget { } } +/// Circuit version of `BlockMetadata`. +/// Metadata contained in a block header. Those are identical between +/// all state transition proofs within the same block. #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct BlockMetadataTarget { + /// `Target`s for the address of this block's producer. pub block_beneficiary: [Target; 5], + /// `Target` for the timestamp of this block. pub block_timestamp: Target, + /// `Target` for the index of this block. pub block_number: Target, + /// `Target` for the difficulty (before PoS transition) of this block. pub block_difficulty: Target, + /// `Target`s for the `mix_hash` value of this block. pub block_random: [Target; 8], + /// `Target`s for the gas limit of this block. pub block_gaslimit: [Target; 2], + /// `Target` for the chain id of this block. pub block_chain_id: Target, + /// `Target`s for the base fee of this block. pub block_base_fee: [Target; 2], + /// `Target`s for the gas used of this block. pub block_gas_used: [Target; 2], + /// `Target`s for the block bloom of this block. pub block_bloom: [Target; 64], } impl BlockMetadataTarget { + /// Number of `Target`s required for the block metadata. pub const SIZE: usize = 87; + /// Extracts block metadata `Target`s from the provided public input `Target`s. + /// The provided `pis` should start with the block metadata. pub fn from_public_inputs(pis: &[Target]) -> Self { let block_beneficiary = pis[0..5].try_into().unwrap(); let block_timestamp = pis[5]; @@ -450,6 +504,8 @@ impl BlockMetadataTarget { } } + /// If `condition`, returns the block metadata in `bm0`, + /// otherwise returns the block metadata in `bm1`. pub fn select, const D: usize>( builder: &mut CircuitBuilder, condition: BoolTarget, @@ -486,6 +542,7 @@ impl BlockMetadataTarget { } } + /// Connects the block metadata in `bm0` to the block metadata in `bm1`. pub fn connect, const D: usize>( builder: &mut CircuitBuilder, bm0: Self, @@ -516,14 +573,29 @@ impl BlockMetadataTarget { } } +/// Circuit version of `BlockHashes`. +/// `Target`s for the user-provided previous 256 block hashes and current block hash. +/// Each block hash requires 8 `Target`s. +/// The proofs across consecutive blocks ensure that these values +/// are consistent (i.e. shifted by eight `Target`s to the left). +/// +/// When the block number is less than 256, dummy values, i.e. `H256::default()`, +/// should be used for the additional block hashes. #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct BlockHashesTarget { + /// `Target`s for the previous 256 hashes to the current block. The leftmost hash, i.e. `prev_hashes[0..8]`, + /// is the oldest, and the rightmost, i.e. `prev_hashes[255 * 7..255 * 8]` is the hash of the parent block. pub prev_hashes: [Target; 2048], + // `Target` for the hash of the current block. pub cur_hash: [Target; 8], } impl BlockHashesTarget { + /// Number of `Target`s required for previous and current block hashes. pub const BLOCK_HASHES_SIZE: usize = 2056; + + /// Extracts the previous and current block hash `Target`s from the public input `Target`s. + /// The provided `pis` should start with the block hashes. pub fn from_public_inputs(pis: &[Target]) -> Self { Self { prev_hashes: pis[0..2048].try_into().unwrap(), @@ -531,6 +603,8 @@ impl BlockHashesTarget { } } + /// If `condition`, returns the block hashes in `bm0`, + /// otherwise returns the block hashes in `bm1`. pub fn select, const D: usize>( builder: &mut CircuitBuilder, condition: BoolTarget, @@ -547,6 +621,7 @@ impl BlockHashesTarget { } } + /// Connects the block hashes in `bm0` to the block hashes in `bm1`. pub fn connect, const D: usize>( builder: &mut CircuitBuilder, bm0: Self, @@ -561,20 +636,38 @@ impl BlockHashesTarget { } } +/// Circuit version of `ExtraBlockData`. +/// Additional block data that are specific to the local transaction being proven, +/// unlike `BlockMetadata`. #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct ExtraBlockDataTarget { + /// `Target`s for the state trie digest of the genesis block. pub genesis_state_trie_root: [Target; 8], + /// `Target` for the transaction count prior execution of the local state transition, starting + /// at 0 for the initial trnasaction of a block. pub txn_number_before: Target, + /// `Target` for the transaction count after execution of the local state transition. pub txn_number_after: Target, + /// `Target` for the accumulated gas used prior execution of the local state transition, starting + /// at 0 for the initial transaction of a block. pub gas_used_before: [Target; 2], + /// `Target` for the accumulated gas used after execution of the local state transition. It should + /// match the `block_gas_used` value after execution of the last transaction in a block. pub gas_used_after: [Target; 2], + /// `Target`s for the accumulated bloom filter of this block prior execution of the local state transition, + /// starting with all zeros for the initial transaction of a block. pub block_bloom_before: [Target; 64], + /// `Target`s for the accumulated bloom filter after execution of the local state transition. It should + /// match the `block_bloom` value after execution of the last transaction in a block. pub block_bloom_after: [Target; 64], } impl ExtraBlockDataTarget { + /// Number of `Target`s required for the extra block data. const SIZE: usize = 142; + /// Extracts the extra block data `Target`s from the public input `Target`s. + /// The provided `pis` should start with the extra vblock data. pub fn from_public_inputs(pis: &[Target]) -> Self { let genesis_state_trie_root = pis[0..8].try_into().unwrap(); let txn_number_before = pis[8]; @@ -595,6 +688,8 @@ impl ExtraBlockDataTarget { } } + /// If `condition`, returns the extra block data in `ed0`, + /// otherwise returns the extra block data in `ed1`. pub fn select, const D: usize>( builder: &mut CircuitBuilder, condition: BoolTarget, @@ -638,6 +733,7 @@ impl ExtraBlockDataTarget { } } + /// Connects the extra block data in `ed0` with the extra block data in `ed1`. pub fn connect, const D: usize>( builder: &mut CircuitBuilder, ed0: Self, @@ -666,6 +762,7 @@ impl ExtraBlockDataTarget { } } +/// Merkle caps and openings that form the proof of a single STARK. #[derive(Debug, Clone)] pub struct StarkProof, C: GenericConfig, const D: usize> { /// Merkle cap of LDEs of trace values. @@ -688,7 +785,9 @@ where F: RichField + Extendable, C: GenericConfig, { + /// Initial Fiat-Shamir state. pub(crate) init_challenger_state: >::Permutation, + /// Proof for a single STARK. pub(crate) proof: StarkProof, } @@ -703,21 +802,30 @@ impl, C: GenericConfig, const D: usize> S lde_bits - config.fri_config.rate_bits } + /// Returns the number of cross-table lookup polynomials computed for the current STARK. pub fn num_ctl_zs(&self) -> usize { self.openings.ctl_zs_first.len() } } +/// Circuit version of `StarkProof`. +/// Merkle caps and openings that form the proof of a single STARK. #[derive(Eq, PartialEq, Debug)] pub struct StarkProofTarget { + /// `Target` for the Merkle cap if LDEs of trace values. pub trace_cap: MerkleCapTarget, + /// `Target` for the Merkle cap of LDEs of lookup helper and CTL columns. pub auxiliary_polys_cap: MerkleCapTarget, + /// `Target` for the Merkle cap of LDEs of quotient polynomial evaluations. pub quotient_polys_cap: MerkleCapTarget, + /// `Target`s for the purported values of each polynomial at the challenge point. pub openings: StarkOpeningSetTarget, + /// `Target`s for the batch FRI argument for all openings. pub opening_proof: FriProofTarget, } impl StarkProofTarget { + /// Serializes a STARK proof. pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { buffer.write_target_merkle_cap(&self.trace_cap)?; buffer.write_target_merkle_cap(&self.auxiliary_polys_cap)?; @@ -727,6 +835,7 @@ impl StarkProofTarget { Ok(()) } + /// Deserializes a STARK proof. pub fn from_buffer(buffer: &mut Buffer) -> IoResult { let trace_cap = buffer.read_target_merkle_cap()?; let auxiliary_polys_cap = buffer.read_target_merkle_cap()?; @@ -754,6 +863,7 @@ impl StarkProofTarget { } } +/// Randomness used for a STARK proof. pub(crate) struct StarkProofChallenges, const D: usize> { /// Random values used to combine STARK constraints. pub stark_alphas: Vec, @@ -761,12 +871,17 @@ pub(crate) struct StarkProofChallenges, const D: us /// Point at which the STARK polynomials are opened. pub stark_zeta: F::Extension, + /// Randomness used in FRI. pub fri_challenges: FriChallenges, } +/// Circuit version of `StarkProofChallenges`. pub(crate) struct StarkProofChallengesTarget { + /// `Target`s for the random values used to combine STARK constraints. pub stark_alphas: Vec, + /// `ExtensionTarget` for the point at which the STARK polynomials are opened. pub stark_zeta: ExtensionTarget, + /// `Target`s for the randomness used in FRI. pub fri_challenges: FriChallengesTarget, } @@ -788,6 +903,9 @@ pub struct StarkOpeningSet, const D: usize> { } impl, const D: usize> StarkOpeningSet { + /// Returns a `StarkOpeningSet` given all the polynomial commitments, the number of permutation `Z`polynomials, + /// the evaluation point and a generator `g`. + /// Polynomials are evaluated at point `zeta` and, if necessary, at `g * zeta`. pub fn new>( zeta: F::Extension, g: F, @@ -796,18 +914,21 @@ impl, const D: usize> StarkOpeningSet { quotient_commitment: &PolynomialBatch, num_lookup_columns: usize, ) -> Self { + // Batch evaluates polynomials on the LDE, at a point `z`. let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { c.polynomials .par_iter() .map(|p| p.to_extension().eval(z)) .collect::>() }; + // Batch evaluates polynomials at a base field point `z`. let eval_commitment_base = |z: F, c: &PolynomialBatch| { c.polynomials .par_iter() .map(|p| p.eval(z)) .collect::>() }; + // `g * zeta`. let zeta_next = zeta.scalar_mul(g); Self { local_values: eval_commitment(zeta, trace_commitment), @@ -821,6 +942,8 @@ impl, const D: usize> StarkOpeningSet { } } + /// Constructs the openings required by FRI. + /// All openings but `ctl_zs_first` are grouped together. pub(crate) fn to_fri_openings(&self) -> FriOpenings { let zeta_batch = FriOpeningBatch { values: self @@ -855,17 +978,26 @@ impl, const D: usize> StarkOpeningSet { } } +/// Circuit version of `StarkOpeningSet`. +/// `Target`s for the purported values of each polynomial at the challenge point. #[derive(Eq, PartialEq, Debug)] pub struct StarkOpeningSetTarget { + /// `ExtensionTarget`s for the openings of trace polynomials at `zeta`. pub local_values: Vec>, + /// `ExtensionTarget`s for the opening of trace polynomials at `g * zeta`. pub next_values: Vec>, + /// `ExtensionTarget`s for the opening of lookups and cross-table lookups `Z` polynomials at `zeta`. pub auxiliary_polys: Vec>, + /// `ExtensionTarget`s for the opening of lookups and cross-table lookups `Z` polynomials at `g * zeta`. pub auxiliary_polys_next: Vec>, + /// /// `ExtensionTarget`s for the opening of lookups and cross-table lookups `Z` polynomials at 1. pub ctl_zs_first: Vec, + /// `ExtensionTarget`s for the opening of quotient polynomials at `zeta`. pub quotient_polys: Vec>, } impl StarkOpeningSetTarget { + /// Serializes a STARK's opening set. pub fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { buffer.write_target_ext_vec(&self.local_values)?; buffer.write_target_ext_vec(&self.next_values)?; @@ -876,6 +1008,7 @@ impl StarkOpeningSetTarget { Ok(()) } + /// Deserializes a STARK's opening set. pub fn from_buffer(buffer: &mut Buffer) -> IoResult { let local_values = buffer.read_target_ext_vec::()?; let next_values = buffer.read_target_ext_vec::()?; @@ -894,6 +1027,9 @@ impl StarkOpeningSetTarget { }) } + /// Circuit version of `to_fri_openings`for `FriOpenings`. + /// Constructs the `Target`s the circuit version of FRI. + /// All openings but `ctl_zs_first` are grouped together. pub(crate) fn to_fri_openings(&self, zero: Target) -> FriOpeningsTarget { let zeta_batch = FriOpeningBatchTarget { values: self diff --git a/evm/src/prover.rs b/evm/src/prover.rs index c5729a573f..fcf59738cc 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -90,6 +90,7 @@ where let rate_bits = config.fri_config.rate_bits; let cap_height = config.fri_config.cap_height; + // For each STARK, we compute the polynomial commitments for the polynomials interpolating its trace. let trace_commitments = timed!( timing, "compute all trace commitments", @@ -115,6 +116,7 @@ where .collect::>() ); + // Get the Merkle caps for all trace commitments and observe them. let trace_caps = trace_commitments .iter() .map(|c| c.merkle_tree.cap.clone()) @@ -127,7 +129,9 @@ where observe_public_values::(&mut challenger, &public_values) .map_err(|_| anyhow::Error::msg("Invalid conversion of public values."))?; + // Get challenges for the cross-table lookups. let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); + // For each STARK, compute its cross-table lookup Z polynomials and get the associated `CtlData`. let ctl_data_per_table = timed!( timing, "compute CTL data", @@ -169,6 +173,13 @@ where }) } +/// Generates a proof for each STARK. +/// At this stage, we have computed the trace polynomials commitments for the various STARKs, +/// and we have the cross-table lookup data for each table, including the associated challenges. +/// - `trace_poly_values` are the trace values for each STARK. +/// - `trace_commitments` are the trace polynomials commitments for each STARK. +/// - `ctl_data_per_table` group all the cross-table lookup data for each STARK. +/// Each STARK uses its associated data to generate a proof. fn prove_with_commitments( all_stark: &AllStark, config: &StarkConfig, @@ -293,7 +304,10 @@ where ]) } -/// Compute proof for a single STARK table. +/// Computes a proof for a single STARK table, including: +/// - the initial state of the challenger, +/// - all the requires Merkle caps, +/// - all the required polynomial and FRI argument openings. pub(crate) fn prove_single_table( stark: &S, config: &StarkConfig, @@ -350,6 +364,8 @@ where ); let num_lookup_columns = lookup_helper_columns.as_ref().map(|v| v.len()).unwrap_or(0); + // We add CTLs to the permutation arguments so that we can batch commit to + // all auxiliary polynomials. let auxiliary_polys = match lookup_helper_columns { None => ctl_data.z_polys(), Some(mut lookup_columns) => { @@ -359,6 +375,7 @@ where }; assert!(!auxiliary_polys.is_empty(), "No CTL?"); + // Get the polynomial commitments for all auxiliary polynomials. let auxiliary_polys_commitment = timed!( timing, "compute auxiliary polynomials commitment", @@ -424,6 +441,7 @@ where }) .collect() ); + // Commit to the quotient polynomials. let quotient_commitment = timed!( timing, "compute quotient commitment", @@ -436,6 +454,7 @@ where None, ) ); + // Observe the quotient polynomials Merkle cap. let quotient_polys_cap = quotient_commitment.merkle_tree.cap.clone(); challenger.observe_cap("ient_polys_cap); @@ -449,6 +468,7 @@ where "Opening point is in the subgroup." ); + // Compute all openings: evaluate all commited polynomials at `zeta` and, when necessary, at `g * zeta`. let openings = StarkOpeningSet::new( zeta, g, @@ -457,6 +477,7 @@ where "ient_commitment, stark.num_lookup_helper_columns(config), ); + // Get the FRI openings and observe them. challenger.observe_openings(&openings.to_fri_openings()); let initial_merkle_trees = vec![ @@ -563,10 +584,12 @@ where lagrange_basis_first, lagrange_basis_last, ); + // Get the local and next row evaluations for the current STARK. let vars = S::EvaluationFrame::from_values( &get_trace_values_packed(i_start), &get_trace_values_packed(i_next_start), ); + // Get the local and next row evaluations for the permutation argument, as well as the associated challenges. let lookup_vars = lookup_challenges.map(|challenges| LookupCheckVars { local_values: auxiliary_polys_commitment.get_lde_values_packed(i_start, step) [..num_lookup_columns] @@ -574,6 +597,13 @@ where next_values: auxiliary_polys_commitment.get_lde_values_packed(i_next_start, step), challenges: challenges.to_vec(), }); + + // Get all the data for this STARK's CTLs: + // - the local and next row evaluations for the CTL Z polynomials + // - the associated challenges. + // - for each CTL: + // - the filter `Column` + // - the `Column`s that form the looking/looked table. let ctl_vars = ctl_data .zs_columns .iter() @@ -588,6 +618,9 @@ where filter_column: &zs_columns.filter_column, }) .collect::>(); + + // Evaluate the polynomial combining all constraints, including those associated + // to the permutation and CTL arguments. eval_vanishing_poly::( stark, &vars, @@ -661,6 +694,7 @@ fn check_constraints<'a, F, C, S, const D: usize>( transpose(&values) }; + // Get batch evaluations of the trace, permutation and CTL polynomials over our subgroup. let trace_subgroup_evals = get_subgroup_evals(trace_commitment); let auxiliary_subgroup_evals = get_subgroup_evals(auxiliary_commitment); @@ -682,16 +716,19 @@ fn check_constraints<'a, F, C, S, const D: usize>( lagrange_basis_first, lagrange_basis_last, ); + // Get the local and next row evaluations for the current STARK's trace. let vars = S::EvaluationFrame::from_values( &trace_subgroup_evals[i], &trace_subgroup_evals[i_next], ); + // Get the local and next row evaluations for the current STARK's permutation argument. let lookup_vars = lookup_challenges.map(|challenges| LookupCheckVars { local_values: auxiliary_subgroup_evals[i][..num_lookup_columns].to_vec(), next_values: auxiliary_subgroup_evals[i_next][..num_lookup_columns].to_vec(), challenges: challenges.to_vec(), }); + // Get the local and next row evaluations for the current STARK's CTL Z polynomials. let ctl_vars = ctl_data .zs_columns .iter() @@ -704,6 +741,8 @@ fn check_constraints<'a, F, C, S, const D: usize>( filter_column: &zs_columns.filter_column, }) .collect::>(); + // Evaluate the polynomial combining all constraints, including those associated + // to the permutation and CTL arguments. eval_vanishing_poly::( stark, &vars, @@ -716,6 +755,7 @@ fn check_constraints<'a, F, C, S, const D: usize>( }) .collect::>(); + // Assert that all constraints evaluate to 0 over our subgroup. for v in constraint_values { assert!( v.iter().all(|x| x.is_zero()), diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index 2ea6010e83..2e1adfc742 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -14,6 +14,8 @@ use crate::lookup::{ }; use crate::stark::Stark; +/// Evaluates all constraint, permutation and cross-table lookup polynomials +/// of the current STARK at the local and next values. pub(crate) fn eval_vanishing_poly( stark: &S, vars: &S::EvaluationFrame, @@ -27,8 +29,10 @@ pub(crate) fn eval_vanishing_poly( P: PackedField, S: Stark, { + // Evaluate all of the STARK's table constraints. stark.eval_packed_generic(vars, consumer); if let Some(lookup_vars) = lookup_vars { + // Evaluate the STARK constraints related to the permutation arguments. eval_packed_lookups_generic::( stark, lookups, @@ -37,9 +41,13 @@ pub(crate) fn eval_vanishing_poly( consumer, ); } + // Evaluate the STARK constraints related to the cross-table lookups. eval_cross_table_lookup_checks::(vars, ctl_vars, consumer); } +/// Circuit version of `eval_vanishing_poly`. +/// Evaluates all constraint, permutation and cross-table lookup polynomials +/// of the current STARK at the local and next values. pub(crate) fn eval_vanishing_poly_circuit( builder: &mut CircuitBuilder, stark: &S, @@ -51,9 +59,12 @@ pub(crate) fn eval_vanishing_poly_circuit( F: RichField + Extendable, S: Stark, { + // Evaluate all of the STARK's table constraints. stark.eval_ext_circuit(builder, vars, consumer); if let Some(lookup_vars) = lookup_vars { + // Evaluate all of the STARK's constraints related to the permutation argument. eval_ext_lookups_circuit::(builder, stark, vars, lookup_vars, consumer); } + // Evaluate all of the STARK's constraints related to the cross-table lookups. eval_cross_table_lookup_checks_circuit::(builder, vars, ctl_vars, consumer); } diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 8186246035..f883f8a4bd 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -30,9 +30,9 @@ pub enum MemoryError { #[derive(Debug)] pub enum ProverInputError { OutOfMptData, + OutOfSmtData, OutOfRlpData, CodeHashNotFound, - InvalidMptInput, InvalidInput, InvalidFunction, } diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index a503ab496c..2777269d89 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -55,6 +55,9 @@ pub(crate) enum Operation { MstoreGeneral, } +/// Adds a CPU row filled with the two inputs and the output of a logic operation. +/// Generates a new logic operation and adds it to the vector of operation in `LogicStark`. +/// Adds three memory read operations to `MemoryStark`: for the two inputs and the output. pub(crate) fn generate_binary_logic_op( op: logic::Op, state: &mut GenerationState, @@ -63,7 +66,7 @@ pub(crate) fn generate_binary_logic_op( let [(in0, _), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = logic::Operation::new(op, in0, in1); - push_no_write(state, &mut row, operation.result, Some(NUM_GP_CHANNELS - 1)); + push_no_write(state, operation.result); state.traces.push_logic(operation); state.traces.push_memory(log_in1); @@ -92,12 +95,7 @@ pub(crate) fn generate_binary_arithmetic_op( } } - push_no_write( - state, - &mut row, - operation.result(), - Some(NUM_GP_CHANNELS - 1), - ); + push_no_write(state, operation.result()); state.traces.push_arithmetic(operation); state.traces.push_memory(log_in1); @@ -114,12 +112,7 @@ pub(crate) fn generate_ternary_arithmetic_op( stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; let operation = arithmetic::Operation::ternary(operator, input0, input1, input2); - push_no_write( - state, - &mut row, - operation.result(), - Some(NUM_GP_CHANNELS - 1), - ); + push_no_write(state, operation.result()); state.traces.push_arithmetic(operation); state.traces.push_memory(log_in1); @@ -151,7 +144,7 @@ pub(crate) fn generate_keccak_general( log::debug!("Hashing {:?}", input); let hash = keccak(&input); - push_no_write(state, &mut row, hash.into_uint(), Some(NUM_GP_CHANNELS - 1)); + push_no_write(state, hash.into_uint()); keccak_sponge_log(state, base_address, input); @@ -180,6 +173,17 @@ pub(crate) fn generate_pop( ) -> Result<(), ProgramError> { let [(_, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let diff = row.stack_len - F::from_canonical_usize(1); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + row.general.stack_mut().stack_inv_aux_2 = F::ONE; + state.registers.is_stack_top_read = true; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.traces.push_cpu(row); Ok(()) @@ -318,7 +322,22 @@ pub(crate) fn generate_get_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - push_with_write(state, &mut row, state.registers.context.into())?; + // Same logic as push_with_write, but we have to use channel 3 for stack constraint reasons. + let write = if state.registers.stack_len == 0 { + None + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let res = mem_write_gp_log_and_fill(3, address, state, &mut row, state.registers.stack_top); + Some(res) + }; + push_no_write(state, state.registers.context.into()); + if let Some(log) = write { + state.traces.push_memory(log); + } state.traces.push_cpu(row); Ok(()) } @@ -374,9 +393,11 @@ pub(crate) fn generate_set_context( if let Some(inv) = new_sp_field.try_inverse() { row.general.stack_mut().stack_inv = inv; row.general.stack_mut().stack_inv_aux = F::ONE; + row.general.stack_mut().stack_inv_aux_2 = F::ONE; } else { row.general.stack_mut().stack_inv = F::ZERO; row.general.stack_mut().stack_inv_aux = F::ZERO; + row.general.stack_mut().stack_inv_aux_2 = F::ZERO; } let new_top_addr = MemoryAddress::new(new_ctx, Segment::Stack, new_sp - 1); @@ -493,7 +514,7 @@ pub(crate) fn generate_dup( } else { mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row) }; - push_no_write(state, &mut row, val, None); + push_no_write(state, val); state.traces.push_memory(log_read); state.traces.push_cpu(row); @@ -515,7 +536,7 @@ pub(crate) fn generate_swap( let [(in0, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); - push_no_write(state, &mut row, in1, None); + push_no_write(state, in1); state.traces.push_memory(log_in1); state.traces.push_memory(log_out0); @@ -529,7 +550,18 @@ pub(crate) fn generate_not( ) -> Result<(), ProgramError> { let [(x, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let result = !x; - push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); + push_no_write(state, result); + + // This is necessary for the stack constraints for POP, + // since the two flags are combined. + let diff = row.stack_len - F::from_canonical_usize(1); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } state.traces.push_cpu(row); Ok(()) @@ -548,7 +580,7 @@ pub(crate) fn generate_iszero( generate_pinv_diff(x, U256::zero(), &mut row); - push_no_write(state, &mut row, result, None); + push_no_write(state, result); state.traces.push_cpu(row); Ok(()) } @@ -587,7 +619,7 @@ fn append_shift( let operation = arithmetic::Operation::binary(operator, input0, input1); state.traces.push_arithmetic(operation); - push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); + push_no_write(state, result); state.traces.push_memory(log_in1); state.traces.push_cpu(row); Ok(()) @@ -701,7 +733,7 @@ pub(crate) fn generate_eq( generate_pinv_diff(in0, in1, &mut row); - push_no_write(state, &mut row, result, None); + push_no_write(state, result); state.traces.push_memory(log_in1); state.traces.push_cpu(row); Ok(()) @@ -749,7 +781,7 @@ pub(crate) fn generate_mload_general( state, &mut row, ); - push_no_write(state, &mut row, val, None); + push_no_write(state, val); let diff = row.stack_len - F::from_canonical_usize(4); if let Some(inv) = diff.try_inverse() { @@ -797,7 +829,7 @@ pub(crate) fn generate_mload_32bytes( .collect_vec(); let packed_int = U256::from_big_endian(&bytes); - push_no_write(state, &mut row, packed_int, Some(4)); + push_no_write(state, packed_int); byte_packing_log(state, base_address, bytes); @@ -843,6 +875,7 @@ pub(crate) fn generate_mstore_general( state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); state.traces.push_memory(log_write); + state.traces.push_cpu(row); Ok(()) diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 312b8591f4..d0afdd0f12 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -160,11 +160,10 @@ fn decode(registers: RegistersState, opcode: u8) -> Result(op: Operation, row: &mut CpuColumnsView) { let flags = &mut row.op; *match op { - Operation::Push(0) => &mut flags.push0, Operation::Push(1..) => &mut flags.push, Operation::Dup(_) | Operation::Swap(_) => &mut flags.dup_swap, Operation::Iszero | Operation::Eq => &mut flags.eq_iszero, - Operation::Not => &mut flags.not, + Operation::Not | Operation::Pop => &mut flags.not_pop, Operation::Syscall(_, _, _) => &mut flags.syscall, Operation::BinaryLogic(_) => &mut flags.logic_op, Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) @@ -174,14 +173,11 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shift, Operation::BinaryArithmetic(_) => &mut flags.binary_op, Operation::TernaryArithmetic(_) => &mut flags.ternary_op, - Operation::KeccakGeneral => &mut flags.keccak_general, + Operation::KeccakGeneral | Operation::Jumpdest => &mut flags.jumpdest_keccak_general, Operation::ProverInput => &mut flags.prover_input, - Operation::Pop => &mut flags.pop, Operation::Jump | Operation::Jumpi => &mut flags.jumps, - Operation::Pc => &mut flags.pc, - Operation::Jumpdest => &mut flags.jumpdest, - Operation::GetContext => &mut flags.get_context, - Operation::SetContext => &mut flags.set_context, + Operation::Pc | Operation::Push(0) => &mut flags.pc_push0, + Operation::GetContext | Operation::SetContext => &mut flags.context_op, Operation::Mload32Bytes => &mut flags.mload_32bytes, Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, @@ -192,11 +188,11 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { // Equal to the number of pops if an operation pops without pushing, and `None` otherwise. fn get_op_special_length(op: Operation) -> Option { let behavior_opt = match op { - Operation::Push(0) => STACK_BEHAVIORS.push0, + Operation::Push(0) | Operation::Pc => STACK_BEHAVIORS.pc_push0, Operation::Push(1..) => STACK_BEHAVIORS.push, Operation::Dup(_) | Operation::Swap(_) => STACK_BEHAVIORS.dup_swap, Operation::Iszero => IS_ZERO_STACK_BEHAVIOR, - Operation::Not => STACK_BEHAVIORS.not, + Operation::Not | Operation::Pop => STACK_BEHAVIORS.not_pop, Operation::Syscall(_, _, _) => STACK_BEHAVIORS.syscall, Operation::Eq => EQ_STACK_BEHAVIOR, Operation::BinaryLogic(_) => STACK_BEHAVIORS.logic_op, @@ -209,15 +205,11 @@ fn get_op_special_length(op: Operation) -> Option { | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => STACK_BEHAVIORS.shift, Operation::BinaryArithmetic(_) => STACK_BEHAVIORS.binary_op, Operation::TernaryArithmetic(_) => STACK_BEHAVIORS.ternary_op, - Operation::KeccakGeneral => STACK_BEHAVIORS.keccak_general, + Operation::KeccakGeneral | Operation::Jumpdest => STACK_BEHAVIORS.jumpdest_keccak_general, Operation::ProverInput => STACK_BEHAVIORS.prover_input, - Operation::Pop => STACK_BEHAVIORS.pop, Operation::Jump => JUMP_OP, Operation::Jumpi => JUMPI_OP, - Operation::Pc => STACK_BEHAVIORS.pc, - Operation::Jumpdest => STACK_BEHAVIORS.jumpdest, - Operation::GetContext => STACK_BEHAVIORS.get_context, - Operation::SetContext => None, + Operation::GetContext | Operation::SetContext => None, Operation::Mload32Bytes => STACK_BEHAVIORS.mload_32bytes, Operation::Mstore32Bytes => STACK_BEHAVIORS.mstore_32bytes, Operation::ExitKernel => STACK_BEHAVIORS.exit_kernel, diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 249703614b..a87ad50b1b 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -68,31 +68,9 @@ pub(crate) fn fill_channel_with_value(row: &mut CpuColumnsView, n: } /// Pushes without writing in memory. This happens in opcodes where a push immediately follows a pop. -/// The pushed value may be loaded in a memory channel, without creating a memory operation. -pub(crate) fn push_no_write( - state: &mut GenerationState, - row: &mut CpuColumnsView, - val: U256, - channel_opt: Option, -) { +pub(crate) fn push_no_write(state: &mut GenerationState, val: U256) { state.registers.stack_top = val; state.registers.stack_len += 1; - - if let Some(channel) = channel_opt { - let val_limbs: [u64; 4] = val.0; - - let channel = &mut row.mem_channels[channel]; - assert_eq!(channel.used, F::ZERO); - channel.used = F::ZERO; - channel.is_read = F::ZERO; - channel.addr_context = F::from_canonical_usize(0); - channel.addr_segment = F::from_canonical_usize(0); - channel.addr_virtual = F::from_canonical_usize(0); - for (i, limb) in val_limbs.into_iter().enumerate() { - channel.value[2 * i] = F::from_canonical_u32(limb as u32); - channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); - } - } } /// Pushes and (maybe) writes the previous stack top in memory. This happens in opcodes which only push. @@ -122,7 +100,7 @@ pub(crate) fn push_with_write( ); Some(res) }; - push_no_write(state, row, val, None); + push_no_write(state, val); if let Some(log) = write { state.traces.push_memory(log); } diff --git a/evm/tests/add11_yml.rs b/evm/tests/add11_yml.rs index cb0212a388..ef38fb5c68 100644 --- a/evm/tests/add11_yml.rs +++ b/evm/tests/add11_yml.rs @@ -13,12 +13,14 @@ use plonky2::plonk::config::KeccakGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -40,37 +42,40 @@ fn add11_yml() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_hashed = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_hashed.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_hashed.into(); let code = [0x60, 0x01, 0x60, 0x01, 0x01, 0x60, 0x00, 0x55, 0x00]; let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; - let sender_account_before = AccountRlp { + let sender_account_before = Account { balance: 0x0de0b6b3a7640000u64.into(), - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { + let to_account_before = Account { balance: 0x0de0b6b3a7640000u64.into(), code_hash, - ..AccountRlp::default() + ..Account::default() }; - let mut state_trie_before = HashedPartialTrie::from(Node::Empty); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - ); - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()); - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec()); + let mut state_smt_before = Smt::empty(); + state_smt_before + .insert(beneficiary_bits, beneficiary_account_before.into()) + .unwrap(); + state_smt_before + .insert(sender_bits, sender_account_before.into()) + .unwrap(); + state_smt_before + .insert(to_bits, to_account_before.into()) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![(to_hashed, Node::Empty.into())], @@ -96,36 +101,34 @@ fn add11_yml() -> anyhow::Result<()> { contract_code.insert(code_hash, code.to_vec()); let expected_state_trie_after = { - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: 0xde0b6b3a75be550u64.into(), - nonce: 1.into(), - ..AccountRlp::default() + nonce: 1, + ..Account::default() }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: 0xde0b6b3a76586a0u64.into(), code_hash, // Storage map: { 0 => 2 } - storage_root: HashedPartialTrie::from(Node::Leaf { - nibbles: Nibbles::from_h256_be(keccak([0u8; 32])), - value: vec![2], - }) - .hash(), - ..AccountRlp::default() + storage_smt: Smt::new([(keccak([0u8; 32]).into(), 2.into())]).unwrap(), + ..Account::default() }; - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after - .insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); - expected_state_trie_after + let mut expected_state_smt_after = Smt::empty(); + expected_state_smt_after + .insert(beneficiary_bits, beneficiary_account_after.into()) + .unwrap(); + expected_state_smt_after + .insert(sender_bits, sender_account_after.into()) + .unwrap(); + expected_state_smt_after + .insert(to_bits, to_account_after.into()) + .unwrap(); + expected_state_smt_after }; let receipt_0 = LegacyReceiptRlp { @@ -146,7 +149,7 @@ fn add11_yml() -> anyhow::Result<()> { .into(); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_trie_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 687328dcb9..a4a58901ff 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -14,12 +14,14 @@ use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; use plonky2_evm::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -42,9 +44,9 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_state_key = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_state_key.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_state_key.into(); let push1 = get_push_opcode(1); let add = get_opcode("ADD"); @@ -53,46 +55,29 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { let code_gas = 3 + 3 + 3; let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; - let sender_account_before = AccountRlp { - nonce: 5.into(), + let sender_account_before = Account { + nonce: 5, balance: eth_to_wei(100_000.into()), - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { + let to_account_before = Account { code_hash, - ..AccountRlp::default() + ..Account::default() }; - let state_trie_before = { - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[beneficiary_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: beneficiary_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&beneficiary_account_before).to_vec(), - } - .into(); - children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: sender_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&sender_account_before).to_vec(), - } - .into(); - children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: to_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&to_account_before).to_vec(), - } - .into(); - Node::Branch { - children, - value: vec![], - } - } - .into(); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![], @@ -122,43 +107,27 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { contract_code.insert(keccak(vec![]), vec![]); contract_code.insert(code_hash, code.to_vec()); - let expected_state_trie_after: HashedPartialTrie = { - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let expected_state_smt_after = { + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_account_before.balance - value - gas_used * 10, nonce: sender_account_before.nonce + 1, ..sender_account_before }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: to_account_before.balance + value, ..to_account_before }; - - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[beneficiary_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: beneficiary_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&beneficiary_account_after).to_vec(), - } - .into(); - children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: sender_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&sender_account_after).to_vec(), - } - .into(); - children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: to_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&to_account_after).to_vec(), - } - .into(); - Node::Branch { - children, - value: vec![], - } - } - .into(); + Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + ]) + .unwrap() + }; let receipt_0 = LegacyReceiptRlp { status: true, @@ -178,7 +147,7 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { .into(); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_smt_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index dd4e624b04..228aaea222 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -17,6 +17,7 @@ use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::Node; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -33,7 +34,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let block_metadata = BlockMetadata::default(); - let state_trie = HashedPartialTrie::from(Node::Empty); + let state_smt = Smt::empty(); let transactions_trie = HashedPartialTrie::from(Node::Empty); let receipts_trie = HashedPartialTrie::from(Node::Empty); let storage_tries = vec![]; @@ -43,21 +44,21 @@ fn test_empty_txn_list() -> anyhow::Result<()> { // No transactions, so no trie roots change. let trie_roots_after = TrieRoots { - state_root: state_trie.hash(), + state_root: state_smt.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; let inputs = GenerationInputs { signed_txns: vec![], tries: TrieInputs { - state_trie, + state_smt: state_smt.serialize(), transactions_trie, receipts_trie, storage_tries, }, trie_roots_after, contract_code, - genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + genesis_state_trie_root: Smt::empty().root, block_metadata, txn_number_before: 0.into(), gas_used_before: 0.into(), diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index dd7ea223e4..8a41046417 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -18,12 +18,14 @@ use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; use plonky2_evm::generation::mpt::transaction_testing::{AddressOption, LegacyTransactionRlp}; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp, LogRlp}; +use plonky2_evm::generation::mpt::{LegacyReceiptRlp, LogRlp}; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -47,9 +49,9 @@ fn test_log_opcodes() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_hashed = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_hashed.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_hashed.into(); // For the first code transaction code, we consider two LOG opcodes. The first deals with 0 topics and empty data. The second deals with two topics, and data of length 5, stored in memory. let code = [ @@ -68,30 +70,29 @@ fn test_log_opcodes() -> anyhow::Result<()> { let code_hash = keccak(code); // Set accounts before the transaction. - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; let sender_balance_before = 5000000000000000u64; - let sender_account_before = AccountRlp { + let sender_account_before = Account { balance: sender_balance_before.into(), - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { + let to_account_before = Account { balance: 9000000000u64.into(), code_hash, - ..AccountRlp::default() + ..Account::default() }; // Initialize the state trie with three accounts. - let mut state_trie_before = HashedPartialTrie::from(Node::Empty); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - ); - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()); - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec()); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); // We now add two receipts with logs and data. This updates the receipt trie as well. let log_0 = LogRlp { @@ -121,7 +122,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { ); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: receipts_trie.clone(), storage_tries: vec![(to_hashed, Node::Empty.into())], @@ -150,21 +151,21 @@ fn test_log_opcodes() -> anyhow::Result<()> { // Update the state and receipt tries after the transaction, so that we have the correct expected tries: // Update accounts - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; let sender_balance_after = sender_balance_before - gas_used * txn_gas_price; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_balance_after.into(), - nonce: 1.into(), - ..AccountRlp::default() + nonce: 1, + ..Account::default() }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: 9000000000u64.into(), code_hash, - ..AccountRlp::default() + ..Account::default() }; // Update the receipt trie. @@ -195,13 +196,12 @@ fn test_log_opcodes() -> anyhow::Result<()> { receipts_trie.insert(receipt_nibbles, rlp::encode(&receipt).to_vec()); // Update the state trie. - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after.insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); + let expected_state_smt_after = Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + ]) + .unwrap(); let transactions_trie: HashedPartialTrie = Node::Leaf { nibbles: Nibbles::from_str("0x80").unwrap(), @@ -210,7 +210,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { .into(); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_smt_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; @@ -255,7 +255,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { // Assert that the proof leads to the correct state and receipt roots. assert_eq!( proof.public_values.trie_roots_after.state_root, - expected_state_trie_after.hash() + expected_state_smt_after.root, ); assert_eq!( @@ -303,46 +303,42 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { let to_hashed = keccak(to_first); let to_hashed_2 = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_hashed.as_bytes()).unwrap(); - let to_second_nibbles = Nibbles::from_bytes_be(to_hashed_2.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_hashed.into(); + let to_second_bits = to_hashed_2.into(); - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; let sender_balance_before = 1000000000000000000u64.into(); - let sender_account_before = AccountRlp { + let sender_account_before = Account { balance: sender_balance_before, - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { - ..AccountRlp::default() + let to_account_before = Account { + ..Account::default() }; - let to_account_second_before = AccountRlp { + let to_account_second_before = Account { code_hash, - ..AccountRlp::default() + ..Account::default() }; // In the first transaction, the sender account sends `txn_value` to `to_account`. let gas_price = 10; let txn_value = 0xau64; - let mut state_trie_before = HashedPartialTrie::from(Node::Empty); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - ); - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()); - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec()); - state_trie_before.insert( - to_second_nibbles, - rlp::encode(&to_account_second_before).to_vec(), - ); - let genesis_state_trie_root = state_trie_before.hash(); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + (to_second_bits, to_account_second_before.clone().into()), + ]) + .unwrap(); + let genesis_state_trie_root = state_smt_before.root; let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![], @@ -378,37 +374,33 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { block_random: Default::default(), }; - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; let sender_balance_after = sender_balance_before - gas_price * 21000 - txn_value; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_balance_after, - nonce: 1.into(), - ..AccountRlp::default() + nonce: 1, + ..Account::default() }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: txn_value.into(), - ..AccountRlp::default() + ..Account::default() }; let mut contract_code = HashMap::new(); contract_code.insert(keccak(vec![]), vec![]); contract_code.insert(code_hash, code.to_vec()); - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after.insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); - expected_state_trie_after.insert( - to_second_nibbles, - rlp::encode(&to_account_second_before).to_vec(), - ); + let expected_state_smt_after = Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.clone().into()), + (to_second_bits, to_account_second_before.clone().into()), + ]) + .unwrap(); // Compute new receipt trie. let mut receipts_trie = HashedPartialTrie::from(Node::Empty); @@ -430,7 +422,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { .into(); let tries_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_smt_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.clone().hash(), }; @@ -474,10 +466,10 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { // Prove second transaction. In this second transaction, the code with logs is executed. - let state_trie_before = expected_state_trie_after; + let state_trie_before = expected_state_smt_after; let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_trie_before.serialize(), transactions_trie: transactions_trie.clone(), receipts_trie: receipts_trie.clone(), storage_tries: vec![], @@ -493,26 +485,26 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { // Update the state and receipt tries after the transaction, so that we have the correct expected tries: // Update accounts. - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; let sender_balance_after = sender_balance_after - gas_used * txn_gas_price; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_balance_after, - nonce: 2.into(), - ..AccountRlp::default() + nonce: 2, + ..Account::default() }; let balance_after = to_account_after.balance; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: balance_after, - ..AccountRlp::default() + ..Account::default() }; - let to_account_second_after = AccountRlp { + let to_account_second_after = Account { balance: to_account_second_before.balance, code_hash, - ..AccountRlp::default() + ..Account::default() }; // Update the receipt trie. @@ -542,23 +534,19 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { receipts_trie.insert(receipt_nibbles, rlp::encode(&receipt).to_vec()); - // Update the state trie. - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after.insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); - expected_state_trie_after.insert( - to_second_nibbles, - rlp::encode(&to_account_second_after).to_vec(), - ); + // Update the state SMT. + let expected_state_trie_after = Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + (to_second_bits, to_account_second_after.into()), + ]) + .unwrap(); transactions_trie.insert(Nibbles::from_str("0x01").unwrap(), txn_2.to_vec()); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_trie_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; @@ -767,36 +755,35 @@ fn test_two_txn() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_hashed = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_hashed.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_hashed.into(); // Set accounts before the transaction. - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; let sender_balance_before = 50000000000000000u64; - let sender_account_before = AccountRlp { + let sender_account_before = Account { balance: sender_balance_before.into(), - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { - ..AccountRlp::default() + let to_account_before = Account { + ..Account::default() }; // Initialize the state trie with three accounts. - let mut state_trie_before = HashedPartialTrie::from(Node::Empty); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - ); - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()); - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec()); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![(to_hashed, Node::Empty.into())], @@ -825,30 +812,29 @@ fn test_two_txn() -> anyhow::Result<()> { contract_code.insert(keccak(vec![]), vec![]); // Update accounts - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; let sender_balance_after = sender_balance_before - gas_price * 21000 * 2 - txn_value * 2; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_balance_after.into(), - nonce: 2.into(), - ..AccountRlp::default() + nonce: 2, + ..Account::default() }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: (2 * txn_value).into(), - ..AccountRlp::default() + ..Account::default() }; // Update the state trie. - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after.insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); + let expected_state_smt_after = Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + ]) + .unwrap(); // Compute new receipt trie. let mut receipts_trie = HashedPartialTrie::from(Node::Empty); @@ -886,7 +872,7 @@ fn test_two_txn() -> anyhow::Result<()> { transactions_trie.insert(Nibbles::from_str("0x01").unwrap(), txn_1.to_vec()); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_smt_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; @@ -916,7 +902,7 @@ fn test_two_txn() -> anyhow::Result<()> { // Assert trie roots. assert_eq!( proof.public_values.trie_roots_after.state_root, - expected_state_trie_after.hash() + expected_state_smt_after.root ); assert_eq!( diff --git a/evm/tests/many_transactions.rs b/evm/tests/many_transactions.rs index 9678d652d3..7d9bfae572 100644 --- a/evm/tests/many_transactions.rs +++ b/evm/tests/many_transactions.rs @@ -16,12 +16,14 @@ use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; use plonky2_evm::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -43,9 +45,9 @@ fn test_four_transactions() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_state_key = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_state_key.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_state_key.into(); let push1 = get_push_opcode(1); let add = get_opcode("ADD"); @@ -54,42 +56,26 @@ fn test_four_transactions() -> anyhow::Result<()> { let code_gas = 3 + 3 + 3; let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp::default(); - let sender_account_before = AccountRlp { - nonce: 5.into(), - + let beneficiary_account_before = Account::default(); + let sender_account_before = Account { + nonce: 5, balance: eth_to_wei(100_000.into()), - - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { + let to_account_before = Account { code_hash, - ..AccountRlp::default() + ..Account::default() }; - let state_trie_before = { - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: sender_nibbles.truncate_n_nibbles_front(1), - - value: rlp::encode(&sender_account_before).to_vec(), - } - .into(); - children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: to_nibbles.truncate_n_nibbles_front(1), - - value: rlp::encode(&to_account_before).to_vec(), - } - .into(); - Node::Branch { - children, - value: vec![], - } - } - .into(); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![], @@ -123,46 +109,27 @@ fn test_four_transactions() -> anyhow::Result<()> { // Update trie roots after the 4 transactions. // State trie. - let expected_state_trie_after: HashedPartialTrie = { - let beneficiary_account_after = AccountRlp { + let expected_state_trie_after = { + let beneficiary_account_after = Account { balance: beneficiary_account_before.balance + gas_used * 10, ..beneficiary_account_before }; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_account_before.balance - value - gas_used * 10, nonce: sender_account_before.nonce + 1, ..sender_account_before }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: to_account_before.balance + value, ..to_account_before }; - - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[beneficiary_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: beneficiary_nibbles.truncate_n_nibbles_front(1), - - value: rlp::encode(&beneficiary_account_after).to_vec(), - } - .into(); - children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: sender_nibbles.truncate_n_nibbles_front(1), - - value: rlp::encode(&sender_account_after).to_vec(), - } - .into(); - children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: to_nibbles.truncate_n_nibbles_front(1), - - value: rlp::encode(&to_account_after).to_vec(), - } - .into(); - Node::Branch { - children, - value: vec![], - } - } - .into(); + Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + ]) + .unwrap() + }; // Transactions trie. let mut transactions_trie: HashedPartialTrie = Node::Leaf { @@ -206,7 +173,7 @@ fn test_four_transactions() -> anyhow::Result<()> { ); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_trie_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index 4492ba9af4..fdcd04b05d 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -13,12 +13,14 @@ use plonky2::plonk::config::KeccakGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -41,9 +43,9 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_hashed = keccak(to); - let beneficiary_nibbles = Nibbles::from_bytes_be(beneficiary_state_key.as_bytes()).unwrap(); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_hashed.as_bytes()).unwrap(); + let beneficiary_bits = beneficiary_state_key.into(); + let sender_bits = sender_state_key.into(); + let to_bits = to_hashed.into(); let code = [ 0x5a, 0x47, 0x5a, 0x90, 0x50, 0x90, 0x03, 0x60, 0x02, 0x90, 0x03, 0x60, 0x01, 0x55, 0x00, @@ -62,29 +64,28 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { + 22100; // SSTORE let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_before = Account { + nonce: 1, + ..Account::default() }; - let sender_account_before = AccountRlp { + let sender_account_before = Account { balance: 0x3635c9adc5dea00000u128.into(), - ..AccountRlp::default() + ..Account::default() }; - let to_account_before = AccountRlp { + let to_account_before = Account { code_hash, - ..AccountRlp::default() + ..Account::default() }; - let mut state_trie_before = HashedPartialTrie::from(Node::Empty); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - ); - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()); - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec()); + let state_smt_before = Smt::new([ + (beneficiary_bits, beneficiary_account_before.clone().into()), + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries: vec![(to_hashed, Node::Empty.into())], @@ -112,39 +113,36 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { contract_code.insert(code_hash, code.to_vec()); let expected_state_trie_after = { - let beneficiary_account_after = AccountRlp { - nonce: 1.into(), - ..AccountRlp::default() + let beneficiary_account_after = Account { + nonce: 1, + ..Account::default() }; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_account_before.balance - U256::from(gas_used) * U256::from(10), - nonce: 1.into(), - ..AccountRlp::default() + nonce: 1, + ..Account::default() }; - let to_account_after = AccountRlp { + let to_account_after = Account { code_hash, // Storage map: { 1 => 5 } - storage_root: HashedPartialTrie::from(Node::Leaf { - // TODO: Could do keccak(pad32(1)) - nibbles: Nibbles::from_str( - "0xb10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", + storage_smt: Smt::new([( + U256::from_str( + "0xb10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", // keccak(pad(1)) ) - .unwrap(), - value: vec![5], - }) - .hash(), - ..AccountRlp::default() + .unwrap() + .into(), + U256::from(5).into(), + )]) + .unwrap(), + ..Account::default() }; - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - ); - expected_state_trie_after - .insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()); - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec()); - expected_state_trie_after + Smt::new([ + (beneficiary_bits, beneficiary_account_after.into()), + (sender_bits, sender_account_after.into()), + (to_bits, to_account_after.into()), + ]) + .unwrap() }; let receipt_0 = LegacyReceiptRlp { @@ -165,7 +163,7 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { .into(); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_trie_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; diff --git a/evm/tests/selfdestruct.rs b/evm/tests/selfdestruct.rs new file mode 100644 index 0000000000..88af15aca7 --- /dev/null +++ b/evm/tests/selfdestruct.rs @@ -0,0 +1,161 @@ +use std::str::FromStr; +use std::time::Duration; + +use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; +use eth_trie_utils::nibbles::Nibbles; +use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; +use ethereum_types::{Address, BigEndianHash, H256, U256}; +use hex_literal::hex; +use keccak_hash::keccak; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::KeccakGoldilocksConfig; +use plonky2::util::timing::TimingTree; +use plonky2_evm::all_stark::AllStark; +use plonky2_evm::config::StarkConfig; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; +use plonky2_evm::generation::{GenerationInputs, TrieInputs}; +use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; +use plonky2_evm::prover::prove; +use plonky2_evm::verifier::verify_proof; +use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::bits::Bits; +use smt_utils::smt::Smt; + +type F = GoldilocksField; +const D: usize = 2; +type C = KeccakGoldilocksConfig; + +/// Test a simple selfdestruct. +#[ignore] +#[test] +fn test_selfdestruct() -> anyhow::Result<()> { + init_logger(); + + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + + let beneficiary = hex!("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef"); + let sender = hex!("5eb96AA102a29fAB267E12A40a5bc6E9aC088759"); + let to = hex!("a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0"); + + let sender_state_key = keccak(sender); + let to_state_key = keccak(to); + + let sender_bits = Bits::from(sender_state_key); + let to_bits = Bits::from(to_state_key); + + let sender_account_before = Account { + nonce: 5, + balance: eth_to_wei(100_000.into()), + code_hash: keccak([]), + storage_smt: Smt::empty(), + }; + let code = vec![ + 0x32, // ORIGIN + 0xFF, // SELFDESTRUCT + ]; + let to_account_before = Account { + nonce: 12, + balance: eth_to_wei(10_000.into()), + code_hash: keccak(&code), + storage_smt: Smt::empty(), + }; + + let state_trie_before = Smt::new([ + (sender_bits, sender_account_before.into()), + (to_bits, to_account_before.into()), + ]) + .unwrap(); + + let tries_before = TrieInputs { + state_smt: state_trie_before.serialize(), + transactions_trie: HashedPartialTrie::from(Node::Empty), + receipts_trie: HashedPartialTrie::from(Node::Empty), + storage_tries: vec![], + }; + + // Generated using a little py-evm script. + let txn = hex!("f868050a831e848094a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0880de0b6b3a76400008025a09bab8db7d72e4b42cba8b117883e16872966bae8e4570582de6ed0065e8c36a1a01256d44d982c75e0ab7a19f61ab78afa9e089d51c8686fdfbee085a5ed5d8ff8"); + + let block_metadata = BlockMetadata { + block_beneficiary: Address::from(beneficiary), + block_timestamp: 0x03e8.into(), + block_number: 1.into(), + block_difficulty: 0x020000.into(), + block_random: H256::from_uint(&0x020000.into()), + block_gaslimit: 0xff112233u32.into(), + block_chain_id: 1.into(), + block_base_fee: 0xa.into(), + block_gas_used: 26002.into(), + block_bloom: [0.into(); 8], + }; + + let contract_code = [(keccak(&code), code), (keccak([]), vec![])].into(); + + let expected_state_trie_after = { + let sender_account_after = Account { + nonce: 6, + balance: eth_to_wei(110_000.into()) - 26_002 * 0xa, + code_hash: keccak([]), + storage_smt: Smt::empty(), + }; + Smt::new([(sender_bits, sender_account_after.into())]).unwrap() + }; + + let receipt_0 = LegacyReceiptRlp { + status: true, + cum_gas_used: 26002.into(), + bloom: vec![0; 256].into(), + logs: vec![], + }; + let mut receipts_trie = HashedPartialTrie::from(Node::Empty); + receipts_trie.insert( + Nibbles::from_str("0x80").unwrap(), + rlp::encode(&receipt_0).to_vec(), + ); + let transactions_trie: HashedPartialTrie = Node::Leaf { + nibbles: Nibbles::from_str("0x80").unwrap(), + value: txn.to_vec(), + } + .into(); + + let trie_roots_after = TrieRoots { + state_root: expected_state_trie_after.root, + transactions_root: transactions_trie.hash(), + receipts_root: receipts_trie.hash(), + }; + let inputs = GenerationInputs { + signed_txns: vec![txn.to_vec()], + tries: tries_before, + trie_roots_after, + contract_code, + genesis_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + block_metadata, + txn_number_before: 0.into(), + gas_used_before: 0.into(), + gas_used_after: 26002.into(), + block_bloom_before: [0.into(); 8], + block_bloom_after: [0.into(); 8], + block_hashes: BlockHashes { + prev_hashes: vec![H256::default(); 256], + cur_hash: H256::default(), + }, + addresses: vec![], + }; + + let mut timing = TimingTree::new("prove", log::Level::Debug); + let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + verify_proof(&all_stark, proof, &config) +} + +fn eth_to_wei(eth: U256) -> U256 { + // 1 ether = 10^18 wei. + eth * U256::from(10).pow(18.into()) +} + +fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info")); +} diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index 80bee8afeb..b4fc49166e 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -13,12 +13,14 @@ use plonky2::plonk::config::KeccakGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; -use plonky2_evm::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use plonky2_evm::generation::mpt::LegacyReceiptRlp; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::{BlockHashes, BlockMetadata, TrieRoots}; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; use plonky2_evm::Node; +use smt_utils::account::Account; +use smt_utils::smt::Smt; type F = GoldilocksField; const D: usize = 2; @@ -39,25 +41,25 @@ fn test_simple_transfer() -> anyhow::Result<()> { let sender_state_key = keccak(sender); let to_state_key = keccak(to); - let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); - let to_nibbles = Nibbles::from_bytes_be(to_state_key.as_bytes()).unwrap(); + let sender_bits = sender_state_key.into(); + let to_bits = to_state_key.into(); - let sender_account_before = AccountRlp { - nonce: 5.into(), + let sender_account_before = Account { + nonce: 5, balance: eth_to_wei(100_000.into()), - storage_root: HashedPartialTrie::from(Node::Empty).hash(), + storage_smt: Smt::empty(), code_hash: keccak([]), }; - let to_account_before = AccountRlp::default(); + let to_account_before = Account::default(); - let state_trie_before = Node::Leaf { - nibbles: sender_nibbles, - value: rlp::encode(&sender_account_before).to_vec(), - } - .into(); + let state_smt_before = Smt::new([ + (sender_bits, sender_account_before.clone().into()), + (to_bits, to_account_before.clone().into()), + ]) + .unwrap(); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_smt: state_smt_before.serialize(), transactions_trie: HashedPartialTrie::from(Node::Empty), receipts_trie: HashedPartialTrie::from(Node::Empty), storage_tries: vec![], @@ -83,36 +85,25 @@ fn test_simple_transfer() -> anyhow::Result<()> { let mut contract_code = HashMap::new(); contract_code.insert(keccak(vec![]), vec![]); - let expected_state_trie_after: HashedPartialTrie = { + let expected_state_smt_after: Smt = { let txdata_gas = 2 * 16; let gas_used = 21_000 + txdata_gas; - let sender_account_after = AccountRlp { + let sender_account_after = Account { balance: sender_account_before.balance - value - gas_used * 10, nonce: sender_account_before.nonce + 1, ..sender_account_before }; - let to_account_after = AccountRlp { + let to_account_after = Account { balance: value, ..to_account_before }; - let mut children = core::array::from_fn(|_| Node::Empty.into()); - children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: sender_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&sender_account_after).to_vec(), - } - .into(); - children[to_nibbles.get_nibble(0) as usize] = Node::Leaf { - nibbles: to_nibbles.truncate_n_nibbles_front(1), - value: rlp::encode(&to_account_after).to_vec(), - } - .into(); - Node::Branch { - children, - value: vec![], - } - .into() + Smt::new([ + (sender_bits, sender_account_after.clone().into()), + (to_bits, to_account_after.clone().into()), + ]) + .unwrap() }; let receipt_0 = LegacyReceiptRlp { @@ -133,7 +124,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { .into(); let trie_roots_after = TrieRoots { - state_root: expected_state_trie_after.hash(), + state_root: expected_state_smt_after.root, transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), };