diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index b0f3056cf6..068b0bcbf9 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -6,6 +6,7 @@ use plonky2::hash::hash_types::RichField; use crate::arithmetic::arithmetic_stark; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::{self, BytePackingStark}; use crate::config::StarkConfig; use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; @@ -25,6 +26,7 @@ use crate::stark::Stark; #[derive(Clone)] pub struct AllStark, const D: usize> { pub arithmetic_stark: ArithmeticStark, + pub byte_packing_stark: BytePackingStark, pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, pub keccak_sponge_stark: KeccakSpongeStark, @@ -37,6 +39,7 @@ impl, const D: usize> Default for AllStark { fn default() -> Self { Self { arithmetic_stark: ArithmeticStark::default(), + byte_packing_stark: BytePackingStark::default(), cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), keccak_sponge_stark: KeccakSpongeStark::default(), @@ -51,6 +54,7 @@ impl, const D: usize> AllStark { pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { [ self.arithmetic_stark.num_permutation_batches(config), + self.byte_packing_stark.num_permutation_batches(config), self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), self.keccak_sponge_stark.num_permutation_batches(config), @@ -62,6 +66,7 @@ impl, const D: usize> AllStark { pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { [ self.arithmetic_stark.permutation_batch_size(), + self.byte_packing_stark.permutation_batch_size(), self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), self.keccak_sponge_stark.permutation_batch_size(), @@ -74,11 +79,12 @@ impl, const D: usize> AllStark { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { Arithmetic = 0, - Cpu = 1, - Keccak = 2, - KeccakSponge = 3, - Logic = 4, - Memory = 5, + BytePacking = 1, + Cpu = 2, + Keccak = 3, + KeccakSponge = 4, + Logic = 5, + Memory = 6, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -87,6 +93,7 @@ impl Table { pub(crate) fn all() -> [Self; NUM_TABLES] { [ Self::Arithmetic, + Self::BytePacking, Self::Cpu, Self::Keccak, Self::KeccakSponge, @@ -99,6 +106,7 @@ impl Table { pub(crate) fn all_cross_table_lookups() -> Vec> { vec![ ctl_arithmetic(), + ctl_byte_packing(), ctl_keccak_sponge(), ctl_keccak(), ctl_logic(), @@ -116,6 +124,28 @@ fn ctl_arithmetic() -> CrossTableLookup { ) } +fn ctl_byte_packing() -> CrossTableLookup { + let cpu_packing_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_byte_packing(), + Some(cpu_stark::ctl_filter_byte_packing()), + ); + let cpu_unpacking_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_byte_unpacking(), + Some(cpu_stark::ctl_filter_byte_unpacking()), + ); + let byte_packing_looked = TableWithColumns::new( + Table::BytePacking, + byte_packing_stark::ctl_looked_data(), + Some(byte_packing_stark::ctl_looked_filter()), + ); + CrossTableLookup::new( + vec![cpu_packing_looking, cpu_unpacking_looking], + byte_packing_looked, + ) +} + fn ctl_keccak() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, @@ -184,9 +214,17 @@ fn ctl_memory() -> CrossTableLookup { Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), ) }); + let byte_packing_ops = (0..32).map(|i| { + TableWithColumns::new( + Table::BytePacking, + byte_packing_stark::ctl_looking_memory(i), + Some(byte_packing_stark::ctl_looking_memory_filter(i)), + ) + }); let all_lookers = iter::once(cpu_memory_code_read) .chain(cpu_memory_gp_ops) .chain(keccak_sponge_reads) + .chain(byte_packing_ops) .collect(); let memory_looked = TableWithColumns::new( Table::Memory, diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs new file mode 100644 index 0000000000..f97a2b28ab --- /dev/null +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -0,0 +1,607 @@ +//! This crate enforces the correctness of reading and writing sequences +//! of bytes in Big-Endian ordering from and to the memory. +//! +//! The trace layout consists in N consecutive rows for an `N` byte sequence, +//! with the byte values being cumulatively written to the trace as they are +//! being processed. +//! +//! At row `i` of such a group (starting from 0), the `i`-th byte flag will be activated +//! (to indicate which byte we are going to be processing), but all bytes with index +//! 0 to `i` may have non-zero values, as they have already been processed. +//! +//! The length of a sequence is stored within each group of rows corresponding to that +//! sequence in a dedicated `SEQUENCE_LEN` column. At any row `i`, the remaining length +//! of the sequence being processed is retrieved from that column and the active byte flag +//! as: +//! +//! remaining_length = sequence_length - \sum_{i=0}^31 b[i] * i +//! +//! where b[i] is the `i`-th byte flag. +//! +//! Because of the discrepancy in endianness between the different tables, the byte sequences +//! are actually written in the trace in reverse order from the order they are provided. +//! As such, the memory virtual address for a group of rows corresponding to a sequence starts +//! with the final virtual address, corresponding to the final byte being read/written, and +//! is being decremented at each step. +//! +//! Note that, when writing a sequence of bytes to memory, both the `U256` value and the +//! corresponding sequence length are being read from the stack. Because of the endianness +//! discrepancy mentioned above, we first convert the value to a byte sequence in Little-Endian, +//! then resize the sequence to prune unneeded zeros before reverting the sequence order. +//! This means that the higher-order bytes will be thrown away during the process, if the value +//! is greater than 256^length, and as a result a different value will be stored in memory. + +use std::marker::PhantomData; + +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; + +use super::NUM_BYTES; +use crate::byte_packing::columns::{ + index_bytes, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, BYTE_INDICES_COLS, IS_READ, + NUM_COLUMNS, RANGE_COUNTER, RC_COLS, SEQUENCE_END, SEQUENCE_LEN, TIMESTAMP, +}; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::Column; +use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; +use crate::stark::Stark; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +use crate::witness::memory::MemoryAddress; + +/// Strict upper bound for the individual bytes range-check. +const BYTE_RANGE_MAX: usize = 1usize << 8; + +pub(crate) fn ctl_looked_data() -> Vec> { + // Reconstruct the u32 limbs composing the final `U256` word + // being read/written from the underlying byte values. For each, + // we pack 4 consecutive bytes and shift them accordingly to + // obtain the corresponding limb. + let outputs: Vec> = (0..8) + .map(|i| { + let range = (value_bytes(i * 4)..value_bytes(i * 4) + 4).collect_vec(); + Column::linear_combination( + range + .iter() + .enumerate() + .map(|(j, &c)| (c, F::from_canonical_u64(1 << (8 * j)))), + ) + }) + .collect(); + + Column::singles([ + ADDR_CONTEXT, + ADDR_SEGMENT, + ADDR_VIRTUAL, + SEQUENCE_LEN, + TIMESTAMP, + ]) + .chain(outputs) + .collect() +} + +pub fn ctl_looked_filter() -> Column { + // The CPU table is only interested in our sequence end rows, + // since those contain the final limbs of our packed int. + Column::single(SEQUENCE_END) +} + +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let mut res = + Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); + + // The i'th input byte being read/written. + res.push(Column::single(value_bytes(i))); + + // Since we're reading a single byte, the higher limbs must be zero. + res.extend((1..8).map(|_| Column::zero())); + + res.push(Column::single(TIMESTAMP)); + + res +} + +/// CTL filter for reading/writing the `i`th byte of the byte sequence from/to memory. +pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { + Column::single(index_bytes(i)) +} + +/// Information about a byte packing operation needed for witness generation. +#[derive(Clone, Debug)] +pub(crate) struct BytePackingOp { + /// Whether this is a read (packing) or write (unpacking) operation. + pub(crate) is_read: bool, + + /// The base address at which inputs are read/written. + pub(crate) base_address: MemoryAddress, + + /// The timestamp at which inputs are read/written. + pub(crate) timestamp: usize, + + /// The byte sequence that was read/written. + /// Its length is required to be at most 32. + pub(crate) bytes: Vec, +} + +#[derive(Copy, Clone, Default)] +pub struct BytePackingStark { + pub(crate) f: PhantomData, +} + +impl, const D: usize> BytePackingStark { + pub(crate) fn generate_trace( + &self, + ops: Vec, + min_rows: usize, + timing: &mut TimingTree, + ) -> Vec> { + // Generate most of the trace in row-major form. + let trace_rows = timed!( + timing, + "generate trace rows", + self.generate_trace_rows(ops, min_rows) + ); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); + + let mut trace_cols = transpose(&trace_row_vecs); + self.generate_range_checks(&mut trace_cols); + + trace_cols.into_iter().map(PolynomialValues::new).collect() + } + + fn generate_trace_rows( + &self, + ops: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_COLUMNS]> { + let base_len: usize = ops.iter().map(|op| op.bytes.len()).sum(); + let num_rows = core::cmp::max(base_len.max(BYTE_RANGE_MAX), min_rows).next_power_of_two(); + let mut rows = Vec::with_capacity(num_rows); + + for op in ops { + rows.extend(self.generate_rows_for_op(op)); + } + + for _ in rows.len()..num_rows { + rows.push(self.generate_padding_row()); + } + + rows + } + + fn generate_rows_for_op(&self, op: BytePackingOp) -> Vec<[F; NUM_COLUMNS]> { + let BytePackingOp { + is_read, + base_address, + timestamp, + bytes, + } = op; + + let MemoryAddress { + context, + segment, + virt, + } = base_address; + + let mut rows = Vec::with_capacity(bytes.len()); + let mut row = [F::ZERO; NUM_COLUMNS]; + row[IS_READ] = F::from_bool(is_read); + + row[ADDR_CONTEXT] = F::from_canonical_usize(context); + row[ADDR_SEGMENT] = F::from_canonical_usize(segment); + // Because of the endianness, we start by the final virtual address value + // and decrement it at each step. Similarly, we process the byte sequence + // in reverse order. + row[ADDR_VIRTUAL] = F::from_canonical_usize(virt + bytes.len() - 1); + + row[TIMESTAMP] = F::from_canonical_usize(timestamp); + row[SEQUENCE_LEN] = F::from_canonical_usize(bytes.len()); + + for (i, &byte) in bytes.iter().rev().enumerate() { + if i == bytes.len() - 1 { + row[SEQUENCE_END] = F::ONE; + } + row[value_bytes(i)] = F::from_canonical_u8(byte); + row[index_bytes(i)] = F::ONE; + + rows.push(row.into()); + row[index_bytes(i)] = F::ZERO; + row[ADDR_VIRTUAL] -= F::ONE; + } + + rows + } + + fn generate_padding_row(&self) -> [F; NUM_COLUMNS] { + [F::ZERO; NUM_COLUMNS] + } + + /// Expects input in *column*-major layout + fn generate_range_checks(&self, cols: &mut Vec>) { + debug_assert!(cols.len() == NUM_COLUMNS); + + let n_rows = cols[0].len(); + debug_assert!(cols.iter().all(|col| col.len() == n_rows)); + + for i in 0..BYTE_RANGE_MAX { + cols[RANGE_COUNTER][i] = F::from_canonical_usize(i); + } + for i in BYTE_RANGE_MAX..n_rows { + cols[RANGE_COUNTER][i] = F::from_canonical_usize(BYTE_RANGE_MAX - 1); + } + + // For each column c in cols, generate the range-check + // permutations and put them in the corresponding range-check + // columns rc_c and rc_c+1. + for (i, rc_c) in (0..NUM_BYTES).zip(RC_COLS.step_by(2)) { + let c = value_bytes(i); + let (col_perm, table_perm) = permuted_cols(&cols[c], &cols[RANGE_COUNTER]); + cols[rc_c].copy_from_slice(&col_perm); + cols[rc_c + 1].copy_from_slice(&table_perm); + } + } + + /// There is only one `i` for which `vars.local_values[index_bytes(i)]` is non-zero, + /// and `i+1` is the current position: + fn get_active_position(&self, row: &[P; NUM_COLUMNS]) -> P + where + FE: FieldExtension, + P: PackedField, + { + (0..NUM_BYTES) + .map(|i| row[index_bytes(i)] * P::Scalar::from_canonical_usize(i + 1)) + .sum() + } + + /// Recursive version of `get_active_position`. + fn get_active_position_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + row: &[ExtensionTarget; NUM_COLUMNS], + ) -> ExtensionTarget { + let mut current_position = row[index_bytes(0)]; + + for i in 1..NUM_BYTES { + current_position = builder.mul_const_add_extension( + F::from_canonical_usize(i + 1), + row[index_bytes(i)], + current_position, + ); + } + + current_position + } +} + +impl, const D: usize> Stark for BytePackingStark { + const COLUMNS: usize = NUM_COLUMNS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + // Range check all the columns + for col in RC_COLS.step_by(2) { + eval_lookups(vars, yield_constr, col, col + 1); + } + + let one = P::ONES; + + // We filter active columns by summing all the byte indices. + // Constraining each of them to be boolean is done later on below. + let current_filter = vars.local_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint(current_filter * (current_filter - one)); + + // The filter column must start by one. + yield_constr.constraint_first_row(current_filter - one); + + // The is_read flag must be boolean. + let current_is_read = vars.local_values[IS_READ]; + yield_constr.constraint(current_is_read * (current_is_read - one)); + + // Each byte index must be boolean. + for i in 0..NUM_BYTES { + let idx_i = vars.local_values[index_bytes(i)]; + yield_constr.constraint(idx_i * (idx_i - one)); + } + + // The sequence start flag column must start by one. + let current_sequence_start = vars.local_values[index_bytes(0)]; + yield_constr.constraint_first_row(current_sequence_start - one); + + // The sequence end flag must be boolean + let current_sequence_end = vars.local_values[SEQUENCE_END]; + yield_constr.constraint(current_sequence_end * (current_sequence_end - one)); + + // If filter is off, all flags and byte indices must be off. + let byte_indices = vars.local_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint( + (current_filter - one) * (current_is_read + current_sequence_end + byte_indices), + ); + + // Only padding rows have their filter turned off. + let next_filter = vars.next_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint_transition(next_filter * (next_filter - current_filter)); + + // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. + let next_is_read = vars.next_values[IS_READ]; + yield_constr + .constraint_transition((current_sequence_end - one) * (next_is_read - current_is_read)); + + // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. + let next_sequence_start = vars.next_values[index_bytes(0)]; + yield_constr.constraint_transition( + current_sequence_end * next_filter * (next_sequence_start - one), + ); + + // The remaining length of a byte sequence must decrease by one or be zero. + let current_sequence_length = vars.local_values[SEQUENCE_LEN]; + let current_position = self.get_active_position(vars.local_values); + let next_position = self.get_active_position(vars.next_values); + let current_remaining_length = current_sequence_length - current_position; + let next_sequence_length = vars.next_values[SEQUENCE_LEN]; + let next_remaining_length = next_sequence_length - next_position; + yield_constr.constraint_transition( + current_remaining_length * (current_remaining_length - next_remaining_length - one), + ); + + // At the start of a sequence, the remaining length must be equal to the starting length minus one + yield_constr.constraint( + current_sequence_start * (current_sequence_length - current_remaining_length - one), + ); + + // The remaining length on the last row must be zero. + yield_constr.constraint_last_row(current_remaining_length); + + // If the current remaining length is zero, the end flag must be one. + yield_constr.constraint(current_remaining_length * current_sequence_end); + + // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. + // The virtual address must decrement by one at each step of a sequence. + let current_context = vars.local_values[ADDR_CONTEXT]; + let next_context = vars.next_values[ADDR_CONTEXT]; + let current_segment = vars.local_values[ADDR_SEGMENT]; + let next_segment = vars.next_values[ADDR_SEGMENT]; + let current_virtual = vars.local_values[ADDR_VIRTUAL]; + let next_virtual = vars.next_values[ADDR_VIRTUAL]; + let current_timestamp = vars.local_values[TIMESTAMP]; + let next_timestamp = vars.next_values[TIMESTAMP]; + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_context - current_context), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_segment - current_segment), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_timestamp - current_timestamp), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (current_virtual - next_virtual - one), + ); + + // If not at the end of a sequence, each next byte must equal the current one + // when reading through the sequence, or the next byte index must be one. + for i in 0..NUM_BYTES { + let current_byte = vars.local_values[value_bytes(i)]; + let next_byte = vars.next_values[value_bytes(i)]; + let next_byte_index = vars.next_values[index_bytes(i)]; + yield_constr.constraint_transition( + (current_sequence_end - one) * (next_byte_index - one) * (next_byte - current_byte), + ); + } + } + + fn eval_ext_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + // Range check all the columns + for col in RC_COLS.step_by(2) { + eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); + } + + // We filter active columns by summing all the byte indices. + // Constraining each of them to be boolean is done later on below. + let current_filter = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter); + yield_constr.constraint(builder, constraint); + + // The filter column must start by one. + let constraint = builder.add_const_extension(current_filter, F::NEG_ONE); + yield_constr.constraint_first_row(builder, constraint); + + // The is_read flag must be boolean. + let current_is_read = vars.local_values[IS_READ]; + let constraint = + builder.mul_sub_extension(current_is_read, current_is_read, current_is_read); + yield_constr.constraint(builder, constraint); + + // Each byte index must be boolean. + for i in 0..NUM_BYTES { + let idx_i = vars.local_values[index_bytes(i)]; + let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i); + yield_constr.constraint(builder, constraint); + } + + // The sequence start flag column must start by one. + let current_sequence_start = vars.local_values[index_bytes(0)]; + let constraint = builder.add_const_extension(current_sequence_start, F::NEG_ONE); + yield_constr.constraint_first_row(builder, constraint); + + // The sequence end flag must be boolean + let current_sequence_end = vars.local_values[SEQUENCE_END]; + let constraint = builder.mul_sub_extension( + current_sequence_end, + current_sequence_end, + current_sequence_end, + ); + yield_constr.constraint(builder, constraint); + + // If filter is off, all flags and byte indices must be off. + let byte_indices = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let constraint = builder.add_extension(current_sequence_end, byte_indices); + let constraint = builder.add_extension(constraint, current_is_read); + let constraint = builder.mul_sub_extension(constraint, current_filter, constraint); + yield_constr.constraint(builder, constraint); + + // Only padding rows have their filter turned off. + let next_filter = builder.add_many_extension(&vars.next_values[BYTE_INDICES_COLS]); + let constraint = builder.sub_extension(next_filter, current_filter); + let constraint = builder.mul_extension(next_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + + // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. + let next_is_read = vars.next_values[IS_READ]; + let diff_is_read = builder.sub_extension(next_is_read, current_is_read); + let constraint = + builder.mul_sub_extension(diff_is_read, current_sequence_end, diff_is_read); + yield_constr.constraint_transition(builder, constraint); + + // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. + let next_sequence_start = vars.next_values[index_bytes(0)]; + let constraint = builder.mul_sub_extension( + current_sequence_end, + next_sequence_start, + current_sequence_end, + ); + let constraint = builder.mul_extension(next_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + + // The remaining length of a byte sequence must decrease by one or be zero. + let current_sequence_length = vars.local_values[SEQUENCE_LEN]; + let next_sequence_length = vars.next_values[SEQUENCE_LEN]; + let current_position = self.get_active_position_circuit(builder, vars.local_values); + let next_position = self.get_active_position_circuit(builder, vars.next_values); + + let current_remaining_length = + builder.sub_extension(current_sequence_length, current_position); + let next_remaining_length = builder.sub_extension(next_sequence_length, next_position); + let length_diff = builder.sub_extension(current_remaining_length, next_remaining_length); + let constraint = builder.mul_sub_extension( + current_remaining_length, + length_diff, + current_remaining_length, + ); + yield_constr.constraint_transition(builder, constraint); + + // At the start of a sequence, the remaining length must be equal to the starting length minus one + let current_sequence_length = vars.local_values[SEQUENCE_LEN]; + let length_diff = builder.sub_extension(current_sequence_length, current_remaining_length); + let constraint = + builder.mul_sub_extension(current_sequence_start, length_diff, current_sequence_start); + yield_constr.constraint(builder, constraint); + + // The remaining length on the last row must be zero. + yield_constr.constraint_last_row(builder, current_remaining_length); + + // If the current remaining length is zero, the end flag must be one. + let constraint = builder.mul_extension(current_remaining_length, current_sequence_end); + yield_constr.constraint(builder, constraint); + + // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. + // The virtual address must decrement by one at each step of a sequence. + let current_context = vars.local_values[ADDR_CONTEXT]; + let next_context = vars.next_values[ADDR_CONTEXT]; + let current_segment = vars.local_values[ADDR_SEGMENT]; + let next_segment = vars.next_values[ADDR_SEGMENT]; + let current_virtual = vars.local_values[ADDR_VIRTUAL]; + let next_virtual = vars.next_values[ADDR_VIRTUAL]; + let current_timestamp = vars.local_values[TIMESTAMP]; + let next_timestamp = vars.next_values[TIMESTAMP]; + let addr_filter = builder.mul_sub_extension(next_filter, next_sequence_start, next_filter); + { + let constraint = builder.sub_extension(next_context, current_context); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(next_segment, current_segment); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(next_timestamp, current_timestamp); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(current_virtual, next_virtual); + let constraint = builder.mul_sub_extension(addr_filter, constraint, addr_filter); + yield_constr.constraint_transition(builder, constraint); + } + + // If not at the end of a sequence, each next byte must equal the current one + // when reading through the sequence, or the next byte index must be one. + for i in 0..NUM_BYTES { + let current_byte = vars.local_values[value_bytes(i)]; + let next_byte = vars.next_values[value_bytes(i)]; + let next_byte_index = vars.next_values[index_bytes(i)]; + let byte_diff = builder.sub_extension(next_byte, current_byte); + let constraint = builder.mul_sub_extension(byte_diff, next_byte_index, byte_diff); + let constraint = + builder.mul_sub_extension(constraint, current_sequence_end, constraint); + yield_constr.constraint_transition(builder, constraint); + } + } + + fn constraint_degree(&self) -> usize { + 3 + } +} + +#[cfg(test)] +pub(crate) mod tests { + use anyhow::Result; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::byte_packing::byte_packing_stark::BytePackingStark; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = BytePackingStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = BytePackingStark; + + let stark = S { + f: Default::default(), + }; + test_stark_circuit_constraints::(stark) + } +} diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs new file mode 100644 index 0000000000..f04f450c51 --- /dev/null +++ b/evm/src/byte_packing/columns.rs @@ -0,0 +1,48 @@ +//! Byte packing registers. + +use core::ops::Range; + +use crate::byte_packing::NUM_BYTES; + +/// 1 if this is a READ operation, and 0 if this is a WRITE operation. +pub(crate) const IS_READ: usize = 0; +/// 1 if this is the end of a sequence of bytes. +/// This is also used as filter for the CTL. +pub(crate) const SEQUENCE_END: usize = IS_READ + 1; + +pub(super) const BYTES_INDICES_START: usize = SEQUENCE_END + 1; +pub(crate) const fn index_bytes(i: usize) -> usize { + debug_assert!(i < NUM_BYTES); + BYTES_INDICES_START + i +} + +// Note: Those are used as filter for distinguishing active vs padding rows. +pub(crate) const BYTE_INDICES_COLS: Range = + BYTES_INDICES_START..BYTES_INDICES_START + NUM_BYTES; + +pub(crate) const ADDR_CONTEXT: usize = BYTES_INDICES_START + NUM_BYTES; +pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; +pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; +pub(crate) const TIMESTAMP: usize = ADDR_VIRTUAL + 1; + +/// The total length of a sequence of bytes. +/// Cannot be greater than 32. +pub(crate) const SEQUENCE_LEN: usize = TIMESTAMP + 1; + +// 32 byte limbs hold a total of 256 bits. +const BYTES_VALUES_START: usize = SEQUENCE_LEN + 1; +pub(crate) const fn value_bytes(i: usize) -> usize { + debug_assert!(i < NUM_BYTES); + BYTES_VALUES_START + i +} + +// We need one column for the table, then two columns for every value +// that needs to be range checked in the trace (all written bytes), +// namely the permutation of the column and the permutation of the range. +// The two permutations associated to the byte in column i will be in +// columns RC_COLS[2i] and RC_COLS[2i+1]. +pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; +pub(crate) const NUM_RANGE_CHECK_COLS: usize = 1 + 2 * NUM_BYTES; +pub(crate) const RC_COLS: Range = RANGE_COUNTER + 1..RANGE_COUNTER + NUM_RANGE_CHECK_COLS; + +pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + NUM_RANGE_CHECK_COLS; diff --git a/evm/src/byte_packing/mod.rs b/evm/src/byte_packing/mod.rs new file mode 100644 index 0000000000..7cc93374ca --- /dev/null +++ b/evm/src/byte_packing/mod.rs @@ -0,0 +1,9 @@ +//! Byte packing / unpacking unit for the EVM. +//! +//! This module handles reading / writing to memory byte sequences of +//! length at most 32 in Big-Endian ordering. + +pub mod byte_packing_stark; +pub mod columns; + +pub(crate) const NUM_BYTES: usize = 32; diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 81d8414af6..6c68a18305 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -41,6 +41,8 @@ pub struct OpsColumnsView { pub dup: T, pub swap: T, pub context_op: T, + pub mstore_32bytes: T, + pub mload_32bytes: T, pub exit_kernel: T, // TODO: combine MLOAD_GENERAL and MSTORE_GENERAL into one flag pub mload_general: T, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 7fd0c76fcc..25e7cc6ba0 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -153,6 +153,40 @@ pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { ) } +pub fn ctl_data_byte_packing() -> Vec> { + ctl_data_keccak_sponge() +} + +pub fn ctl_filter_byte_packing() -> Column { + Column::single(COL_MAP.op.mload_32bytes) +} + +pub fn ctl_data_byte_unpacking() -> Vec> { + // When executing MSTORE_32BYTES, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = virt + // GP channel 3: stack[-4] = val + // GP channel 4: stack[-5] = len + let context = Column::single(COL_MAP.mem_channels[0].value[0]); + let segment = Column::single(COL_MAP.mem_channels[1].value[0]); + let virt = Column::single(COL_MAP.mem_channels[2].value[0]); + let val = Column::singles(COL_MAP.mem_channels[3].value); + let len = Column::single(COL_MAP.mem_channels[4].value[0]); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + let mut res = vec![context, segment, virt, len, timestamp]; + res.extend(val); + + res +} + +pub fn ctl_filter_byte_unpacking() -> Column { + Column::single(COL_MAP.op.mstore_32bytes) +} + pub const MEM_CODE_CHANNEL_IDX: usize = 0; pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1; diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index e21000f8e2..9a9c572387 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -22,7 +22,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; /// behavior. /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. -const OPCODES: [(u8, usize, bool, usize); 31] = [ +const OPCODES: [(u8, usize, bool, usize); 33] = [ // (start index of block, number of top bits to check (log2), kernel-only, flag column) (0x01, 0, false, COL_MAP.op.add), (0x02, 0, false, COL_MAP.op.mul), @@ -49,10 +49,12 @@ const OPCODES: [(u8, usize, bool, usize); 31] = [ (0x58, 0, false, COL_MAP.op.pc), (0x5b, 0, false, COL_MAP.op.jumpdest), (0x5f, 0, false, COL_MAP.op.push0), - (0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f - (0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f - (0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f + (0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f + (0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f + (0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f + (0xee, 0, true, COL_MAP.op.mstore_32bytes), (0xf6, 1, true, COL_MAP.op.context_op), // 0xf6-0xf7 + (0xf8, 0, true, COL_MAP.op.mload_32bytes), (0xf9, 0, true, COL_MAP.op.exit_kernel), (0xfb, 0, true, COL_MAP.op.mload_general), (0xfc, 0, true, COL_MAP.op.mstore_general), diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 616900052f..e967c07ece 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -49,6 +49,8 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { dup: G_VERYLOW, swap: G_VERYLOW, context_op: KERNEL_ONLY_INSTR, + mstore_32bytes: KERNEL_ONLY_INSTR, + mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, mload_general: KERNEL_ONLY_INSTR, mstore_general: KERNEL_ONLY_INSTR, diff --git a/evm/src/cpu/kernel/asm/memory/packing.asm b/evm/src/cpu/kernel/asm/memory/packing.asm index 0f8023352c..81ab31236e 100644 --- a/evm/src/cpu/kernel/asm/memory/packing.asm +++ b/evm/src/cpu/kernel/asm/memory/packing.asm @@ -7,40 +7,10 @@ // NOTE: addr: 3 denotes a (context, segment, virtual) tuple global mload_packing: // stack: addr: 3, len, retdest - DUP3 DUP3 DUP3 MLOAD_GENERAL DUP5 %eq_const(1) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(1) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(2) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(2) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(3) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(3) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(4) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(4) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(5) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(5) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(6) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(6) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(7) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(7) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(8) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(8) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(9) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(9) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(10) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(10) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(11) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(11) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(12) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(12) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(13) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(13) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(14) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(14) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(15) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(15) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(16) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(16) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(17) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(17) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(18) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(18) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(19) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(19) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(20) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(20) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(21) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(21) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(22) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(22) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(23) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(23) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(24) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(24) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(25) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(25) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(26) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(26) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(27) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(27) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(28) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(28) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(29) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(29) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(30) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(30) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(31) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(31) DUP4 DUP4 MLOAD_GENERAL ADD -mload_packing_return: - %stack (packed_value, addr: 3, len, retdest) -> (retdest, packed_value) + MLOAD_32BYTES + // stack: packed_value, retdest + SWAP1 + // stack: retdest, packed_value JUMP %macro mload_packing @@ -72,40 +42,12 @@ global mload_packing_u64_LE: // Post stack: offset' global mstore_unpacking: // stack: context, segment, offset, value, len, retdest - // We will enumerate i in (32 - len)..32. - // That way BYTE(i, value) will give us the bytes we want. - DUP5 // len - PUSH 32 - SUB - -mstore_unpacking_loop: - // stack: i, context, segment, offset, value, len, retdest - // If i == 32, finish. - DUP1 - %eq_const(32) - %jumpi(mstore_unpacking_finish) - - // stack: i, context, segment, offset, value, len, retdest - DUP5 // value - DUP2 // i - BYTE - // stack: value[i], i, context, segment, offset, value, len, retdest - DUP5 DUP5 DUP5 // context, segment, offset - // stack: context, segment, offset, value[i], i, context, segment, offset, value, len, retdest - MSTORE_GENERAL - // stack: i, context, segment, offset, value, len, retdest - - // Increment offset. - SWAP3 %increment SWAP3 - // Increment i. - %increment - - %jump(mstore_unpacking_loop) - -mstore_unpacking_finish: - // stack: i, context, segment, offset, value, len, retdest - %pop3 - %stack (offset, value, len, retdest) -> (retdest, offset) + %stack(context, segment, offset, value, len, retdest) -> (context, segment, offset, value, len, len, offset, retdest) + // stack: context, segment, offset, value, len, len, offset, retdest + MSTORE_32BYTES + // stack: len, offset, retdest + ADD SWAP1 + // stack: retdest, offset' JUMP %macro mstore_unpacking diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 98ea3cc217..6039c51505 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -391,6 +391,7 @@ impl<'a> Interpreter<'a> { self.stack(), self.get_kernel_general_memory() ), // "PANIC", + 0xee => self.run_mstore_32bytes(), // "MSTORE_32BYTES", 0xf0 => todo!(), // "CREATE", 0xf1 => todo!(), // "CALL", 0xf2 => todo!(), // "CALLCODE", @@ -399,6 +400,7 @@ impl<'a> Interpreter<'a> { 0xf5 => todo!(), // "CREATE2", 0xf6 => self.run_get_context(), // "GET_CONTEXT", 0xf7 => self.run_set_context(), // "SET_CONTEXT", + 0xf8 => self.run_mload_32bytes(), // "MLOAD_32BYTES", 0xf9 => todo!(), // "EXIT_KERNEL", 0xfa => todo!(), // "STATICCALL", 0xfb => self.run_mload_general(), // "MLOAD_GENERAL", @@ -1024,8 +1026,7 @@ impl<'a> Interpreter<'a> { fn run_mload_general(&mut self) { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; - let offset_u256 = self.pop(); - let offset = offset_u256.as_usize(); + let offset = self.pop().as_usize(); let value = self .generation_state .memory @@ -1034,6 +1035,23 @@ impl<'a> Interpreter<'a> { self.push(value); } + fn run_mload_32bytes(&mut self) { + let context = self.pop().as_usize(); + let segment = Segment::all()[self.pop().as_usize()]; + let offset = self.pop().as_usize(); + let len = self.pop().as_usize(); + let bytes: Vec = (0..len) + .map(|i| { + self.generation_state + .memory + .mload_general(context, segment, offset + i) + .as_u32() as u8 + }) + .collect(); + let value = U256::from_big_endian(&bytes); + self.push(value); + } + fn run_mstore_general(&mut self) { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; @@ -1044,6 +1062,25 @@ impl<'a> Interpreter<'a> { .mstore_general(context, segment, offset, value); } + fn run_mstore_32bytes(&mut self) { + let context = self.pop().as_usize(); + let segment = Segment::all()[self.pop().as_usize()]; + let offset = self.pop().as_usize(); + let value = self.pop(); + let len = self.pop().as_usize(); + + let mut bytes = vec![0; 32]; + value.to_little_endian(&mut bytes); + bytes.resize(len, 0); + bytes.reverse(); + + for (i, &byte) in bytes.iter().enumerate() { + self.generation_state + .memory + .mstore_general(context, segment, offset + i, byte.into()); + } + } + fn stack_len(&self) -> usize { self.generation_state.registers.stack_len } @@ -1270,6 +1307,7 @@ fn get_mnemonic(opcode: u8) -> &'static str { 0xa3 => "LOG3", 0xa4 => "LOG4", 0xa5 => "PANIC", + 0xee => "MSTORE_32BYTES", 0xf0 => "CREATE", 0xf1 => "CALL", 0xf2 => "CALLCODE", @@ -1278,6 +1316,7 @@ fn get_mnemonic(opcode: u8) -> &'static str { 0xf5 => "CREATE2", 0xf6 => "GET_CONTEXT", 0xf7 => "SET_CONTEXT", + 0xf8 => "MLOAD_32BYTES", 0xf9 => "EXIT_KERNEL", 0xfa => "STATICCALL", 0xfb => "MLOAD_GENERAL", diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 09c493e04f..2503a92e74 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -113,6 +113,7 @@ pub fn get_opcode(mnemonic: &str) -> u8 { "LOG3" => 0xa3, "LOG4" => 0xa4, "PANIC" => 0xa5, + "MSTORE_32BYTES" => 0xee, "CREATE" => 0xf0, "CALL" => 0xf1, "CALLCODE" => 0xf2, @@ -121,6 +122,7 @@ pub fn get_opcode(mnemonic: &str) -> u8 { "CREATE2" => 0xf5, "GET_CONTEXT" => 0xf6, "SET_CONTEXT" => 0xf7, + "MLOAD_32BYTES" => 0xf8, "EXIT_KERNEL" => 0xf9, "STATICCALL" => 0xfa, "MLOAD_GENERAL" => 0xfb, diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 8ffc152d4c..cfeaa1b0b5 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -108,6 +108,16 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { dup: None, swap: None, context_op: None, // SET_CONTEXT is special since it involves the old and the new stack. + mstore_32bytes: Some(StackBehavior { + num_pops: 5, + pushes: false, + disable_other_channels: false, + }), + mload_32bytes: Some(StackBehavior { + num_pops: 4, + pushes: true, + disable_other_channels: false, + }), exit_kernel: Some(StackBehavior { num_pops: 1, pushes: false, diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 8f481325c6..a9b90428ca 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -536,10 +536,13 @@ pub(crate) fn verify_cross_table_lookups, const D: config: &StarkConfig, ) -> Result<()> { let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for CrossTableLookup { - looking_tables, - looked_table, - } in cross_table_lookups.iter() + for ( + index, + CrossTableLookup { + looking_tables, + looked_table, + }, + ) in cross_table_lookups.iter().enumerate() { let extra_product_vec = &ctl_extra_looking_products[looked_table.table as usize]; for c in 0..config.num_challenges { @@ -552,7 +555,8 @@ pub(crate) fn verify_cross_table_lookups, const D: let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); ensure!( looking_zs_prod == looked_z, - "Cross-table lookup verification failed." + "Cross-table lookup {:?} verification failed.", + index ); } } diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 72844db69c..5c0f1f8084 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -29,6 +29,7 @@ use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; @@ -298,6 +299,7 @@ where C: GenericConfig + 'static, C::Hasher: AlgebraicHasher, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -378,43 +380,58 @@ where &all_stark.cross_table_lookups, stark_config, ); + let byte_packing = RecursiveCircuitsForTable::new( + Table::BytePacking, + &all_stark.byte_packing_stark, + degree_bits_ranges[1].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); let cpu = RecursiveCircuitsForTable::new( Table::Cpu, &all_stark.cpu_stark, - degree_bits_ranges[1].clone(), + degree_bits_ranges[2].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak = RecursiveCircuitsForTable::new( Table::Keccak, &all_stark.keccak_stark, - degree_bits_ranges[2].clone(), + degree_bits_ranges[3].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak_sponge = RecursiveCircuitsForTable::new( Table::KeccakSponge, &all_stark.keccak_sponge_stark, - degree_bits_ranges[3].clone(), + degree_bits_ranges[4].clone(), &all_stark.cross_table_lookups, stark_config, ); let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, - degree_bits_ranges[4].clone(), + degree_bits_ranges[5].clone(), &all_stark.cross_table_lookups, stark_config, ); let memory = RecursiveCircuitsForTable::new( Table::Memory, &all_stark.memory_stark, - degree_bits_ranges[5].clone(), + degree_bits_ranges[6].clone(), &all_stark.cross_table_lookups, stark_config, ); - let by_table = [arithmetic, cpu, keccak, keccak_sponge, logic, memory]; + let by_table = [ + arithmetic, + byte_packing, + cpu, + keccak, + keccak_sponge, + logic, + memory, + ]; let root = Self::create_root_circuit(&by_table, stark_config); let aggregation = Self::create_aggregation_circuit(&root); let block = Self::create_block_circuit(&aggregation); @@ -489,13 +506,13 @@ where } } - // Extra products to add to the looked last value - // Arithmetic, KeccakSponge, Keccak, Logic + // Extra products to add to the looked last value. + // Only necessary for the Memory values. let mut extra_looking_products = - vec![vec![builder.constant(F::ONE); stark_config.num_challenges]; NUM_TABLES - 1]; + vec![vec![builder.one(); stark_config.num_challenges]; NUM_TABLES]; // Memory - let memory_looking_products = (0..stark_config.num_challenges) + extra_looking_products[Table::Memory as usize] = (0..stark_config.num_challenges) .map(|c| { get_memory_extra_looking_products_circuit( &mut builder, @@ -504,7 +521,6 @@ where ) }) .collect_vec(); - extra_looking_products.push(memory_looking_products); // Verify the CTL checks. verify_cross_table_lookups_circuit::( diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 29ad6738fc..ab48cda04f 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -8,6 +8,7 @@ pub mod all_stark; pub mod arithmetic; +pub mod byte_packing; pub mod config; pub mod constraint_consumer; pub mod cpu; diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 31be89e701..8f5878232b 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -21,6 +21,7 @@ use plonky2_util::{log2_ceil, log2_strict}; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -53,6 +54,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -75,6 +77,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -103,6 +106,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -193,6 +197,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -212,6 +217,19 @@ where timing, )? ); + let byte_packing_proof = timed!( + timing, + "prove byte packing STARK", + prove_single_table( + &all_stark.byte_packing_stark, + config, + &trace_poly_values[Table::BytePacking as usize], + &trace_commitments[Table::BytePacking as usize], + &ctl_data_per_table[Table::BytePacking as usize], + challenger, + timing, + )? + ); let cpu_proof = timed!( timing, "prove CPU STARK", @@ -277,8 +295,10 @@ where timing, )? ); + Ok([ arithmetic_proof, + byte_packing_proof, cpu_proof, keccak_proof, keccak_sponge_proof, diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 49225f1426..1aa9db9782 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -2,6 +2,7 @@ use std::any::type_name; use anyhow::{ensure, Result}; use ethereum_types::U256; +use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::types::Field; use plonky2::fri::verifier::verify_fri_proof; @@ -11,6 +12,7 @@ use plonky2::plonk::plonk_common::reduce_with_powers; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -38,6 +40,7 @@ pub fn verify_proof, C: GenericConfig, co ) -> Result<()> where [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -53,6 +56,7 @@ where let AllStark { arithmetic_stark, + byte_packing_stark, cpu_stark, keccak_stark, keccak_sponge_stark, @@ -75,6 +79,13 @@ where &ctl_vars_per_table[Table::Arithmetic as usize], config, )?; + verify_stark_proof_with_challenges( + byte_packing_stark, + &all_proof.stark_proofs[Table::BytePacking as usize].proof, + &stark_challenges[Table::BytePacking as usize], + &ctl_vars_per_table[Table::BytePacking as usize], + config, + )?; verify_stark_proof_with_challenges( cpu_stark, &all_proof.stark_proofs[Table::Cpu as usize].proof, @@ -96,13 +107,6 @@ where &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; - verify_stark_proof_with_challenges( - memory_stark, - &all_proof.stark_proofs[Table::Memory as usize].proof, - &stark_challenges[Table::Memory as usize], - &ctl_vars_per_table[Table::Memory as usize], - config, - )?; verify_stark_proof_with_challenges( logic_stark, &all_proof.stark_proofs[Table::Logic as usize].proof, @@ -110,21 +114,24 @@ where &ctl_vars_per_table[Table::Logic as usize], config, )?; + verify_stark_proof_with_challenges( + memory_stark, + &all_proof.stark_proofs[Table::Memory as usize].proof, + &stark_challenges[Table::Memory as usize], + &ctl_vars_per_table[Table::Memory as usize], + config, + )?; let public_values = all_proof.public_values; - // Extra products to add to the looked last value - // Arithmetic, KeccakSponge, Keccak, Logic - let mut extra_looking_products = vec![vec![F::ONE; config.num_challenges]; NUM_TABLES - 1]; + // Extra products to add to the looked last value. + // Only necessary for the Memory values. + let mut extra_looking_products = vec![vec![F::ONE; config.num_challenges]; NUM_TABLES]; // Memory - extra_looking_products.push(Vec::new()); - for c in 0..config.num_challenges { - extra_looking_products[Table::Memory as usize].push(get_memory_extra_looking_products( - &public_values, - ctl_challenges.challenges[c], - )); - } + extra_looking_products[Table::Memory as usize] = (0..config.num_challenges) + .map(|i| get_memory_extra_looking_products(&public_values, ctl_challenges.challenges[i])) + .collect_vec(); verify_cross_table_lookups::( cross_table_lookups, diff --git a/evm/src/witness/gas.rs b/evm/src/witness/gas.rs index 4c7947bb6b..3a46c04439 100644 --- a/evm/src/witness/gas.rs +++ b/evm/src/witness/gas.rs @@ -44,6 +44,8 @@ pub(crate) fn gas_to_charge(op: Operation) -> u64 { Swap(_) => G_VERYLOW, GetContext => KERNEL_ONLY_INSTR, SetContext => KERNEL_ONLY_INSTR, + Mload32Bytes => KERNEL_ONLY_INSTR, + Mstore32Bytes => KERNEL_ONLY_INSTR, ExitKernel => KERNEL_ONLY_INSTR, MloadGeneral => KERNEL_ONLY_INSTR, MstoreGeneral => KERNEL_ONLY_INSTR, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 13619b96a7..7d07576d30 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -3,6 +3,7 @@ use itertools::Itertools; use keccak_hash::keccak; use plonky2::field::types::Field; +use super::util::{byte_packing_log, byte_unpacking_log}; use crate::arithmetic::BinaryOperator; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; @@ -47,6 +48,8 @@ pub(crate) enum Operation { Swap(u8), GetContext, SetContext, + Mload32Bytes, + Mstore32Bytes, ExitKernel, MloadGeneral, MstoreGeneral, @@ -686,6 +689,45 @@ pub(crate) fn generate_mload_general( Ok(()) } +pub(crate) fn generate_mload_32bytes( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; + let len = len.as_usize(); + + let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; + if usize::MAX - base_address.virt < len { + return Err(ProgramError::MemoryError(VirtTooLarge { + virt: base_address.virt.into(), + })); + } + let bytes = (0..len) + .map(|i| { + let address = MemoryAddress { + virt: base_address.virt + i, + ..base_address + }; + let val = state.memory.get(address); + val.as_u32() as u8 + }) + .collect_vec(); + + let packed_int = U256::from_big_endian(&bytes); + let log_out = stack_push_log_and_fill(state, &mut row, packed_int)?; + + byte_packing_log(state, base_address, bytes); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + pub(crate) fn generate_mstore_general( state: &mut GenerationState, mut row: CpuColumnsView, @@ -715,6 +757,27 @@ pub(crate) fn generate_mstore_general( Ok(()) } +pub(crate) fn generate_mstore_32bytes( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = + stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; + let len = len.as_usize(); + + let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; + + byte_unpacking_log(state, base_address, val, len); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_in4); + state.traces.push_cpu(row); + Ok(()) +} + pub(crate) fn generate_exception( exc_code: u8, state: &mut GenerationState, diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index 2cc1c5000c..c4cf832dd5 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -9,6 +9,7 @@ use plonky2::util::timing::TimingTree; use crate::all_stark::{AllStark, NUM_TABLES}; use crate::arithmetic::{BinaryOperator, Operation}; +use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; @@ -20,6 +21,7 @@ use crate::{arithmetic, keccak, keccak_sponge, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { pub(self) arithmetic_len: usize, + pub(self) byte_packing_len: usize, pub(self) cpu_len: usize, pub(self) keccak_len: usize, pub(self) keccak_sponge_len: usize, @@ -30,6 +32,7 @@ pub struct TraceCheckpoint { #[derive(Debug)] pub(crate) struct Traces { pub(crate) arithmetic_ops: Vec, + pub(crate) byte_packing_ops: Vec, pub(crate) cpu: Vec>, pub(crate) logic_ops: Vec, pub(crate) memory_ops: Vec, @@ -41,6 +44,7 @@ impl Traces { pub fn new() -> Self { Traces { arithmetic_ops: vec![], + byte_packing_ops: vec![], cpu: vec![], logic_ops: vec![], memory_ops: vec![], @@ -64,6 +68,7 @@ impl Traces { }, }) .sum(), + byte_packing_len: self.byte_packing_ops.iter().map(|op| op.bytes.len()).sum(), cpu_len: self.cpu.len(), keccak_len: self.keccak_inputs.len() * keccak::keccak_stark::NUM_ROUNDS, keccak_sponge_len: self @@ -82,6 +87,7 @@ impl Traces { pub fn checkpoint(&self) -> TraceCheckpoint { TraceCheckpoint { arithmetic_len: self.arithmetic_ops.len(), + byte_packing_len: self.byte_packing_ops.len(), cpu_len: self.cpu.len(), keccak_len: self.keccak_inputs.len(), keccak_sponge_len: self.keccak_sponge_ops.len(), @@ -92,6 +98,7 @@ impl Traces { pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { self.arithmetic_ops.truncate(checkpoint.arithmetic_len); + self.byte_packing_ops.truncate(checkpoint.byte_packing_len); self.cpu.truncate(checkpoint.cpu_len); self.keccak_inputs.truncate(checkpoint.keccak_len); self.keccak_sponge_ops @@ -120,6 +127,10 @@ impl Traces { self.memory_ops.push(op); } + pub fn push_byte_packing(&mut self, op: BytePackingOp) { + self.byte_packing_ops.push(op); + } + pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) { self.keccak_inputs.push(input); } @@ -154,6 +165,7 @@ impl Traces { let cap_elements = config.fri_config.num_cap_elements(); let Traces { arithmetic_ops, + byte_packing_ops, cpu, logic_ops, memory_ops, @@ -166,7 +178,13 @@ impl Traces { "generate arithmetic trace", all_stark.arithmetic_stark.generate_trace(arithmetic_ops) ); - + let byte_packing_trace = timed!( + timing, + "generate byte packing trace", + all_stark + .byte_packing_stark + .generate_trace(byte_packing_ops, cap_elements, timing) + ); let cpu_rows = cpu.into_iter().map(|x| x.into()).collect(); let cpu_trace = trace_rows_to_poly_values(cpu_rows); let keccak_trace = timed!( @@ -198,6 +216,7 @@ impl Traces { [ arithmetic_trace, + byte_packing_trace, cpu_trace, keccak_trace, keccak_sponge_trace, diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 3ee8d4f562..6e279cdf7a 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -128,6 +128,7 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::Mstore32Bytes), (0xf0, _) => Ok(Operation::Syscall(opcode, 3, false)), // CREATE (0xf1, _) => Ok(Operation::Syscall(opcode, 7, false)), // CALL (0xf2, _) => Ok(Operation::Syscall(opcode, 7, false)), // CALLCODE @@ -136,6 +137,7 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::Syscall(opcode, 4, false)), // CREATE2 (0xf6, true) => Ok(Operation::GetContext), (0xf7, true) => Ok(Operation::SetContext), + (0xf8, true) => Ok(Operation::Mload32Bytes), (0xf9, true) => Ok(Operation::ExitKernel), (0xfa, _) => Ok(Operation::Syscall(opcode, 6, false)), // STATICCALL (0xfb, true) => Ok(Operation::MloadGeneral), @@ -183,6 +185,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Pc => &mut flags.pc, Operation::Jumpdest => &mut flags.jumpdest, Operation::GetContext | Operation::SetContext => &mut flags.context_op, + Operation::Mload32Bytes => &mut flags.mload_32bytes, + Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, Operation::MloadGeneral => &mut flags.mload_general, Operation::MstoreGeneral => &mut flags.mstore_general, @@ -220,6 +224,8 @@ fn perform_op( Operation::Jumpdest => generate_jumpdest(state, row)?, Operation::GetContext => generate_get_context(state, row)?, Operation::SetContext => generate_set_context(state, row)?, + Operation::Mload32Bytes => generate_mload_32bytes(state, row)?, + Operation::Mstore32Bytes => generate_mstore_32bytes(state, row)?, Operation::ExitKernel => generate_exit_kernel(state, row)?, Operation::MloadGeneral => generate_mload_general(state, row)?, Operation::MstoreGeneral => generate_mstore_general(state, row)?, diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 0e2b36608b..944886141f 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -1,6 +1,7 @@ use ethereum_types::U256; use plonky2::field::types::Field; +use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::keccak_util::keccakf_u8s; use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; @@ -258,3 +259,63 @@ pub(crate) fn keccak_sponge_log( input, }); } + +pub(crate) fn byte_packing_log( + state: &mut GenerationState, + base_address: MemoryAddress, + bytes: Vec, +) { + let clock = state.traces.clock(); + + let mut address = base_address; + for &byte in &bytes { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); + } + + state.traces.push_byte_packing(BytePackingOp { + is_read: true, + base_address, + timestamp: clock * NUM_CHANNELS, + bytes, + }); +} + +pub(crate) fn byte_unpacking_log( + state: &mut GenerationState, + base_address: MemoryAddress, + val: U256, + len: usize, +) { + let clock = state.traces.clock(); + + let mut bytes = vec![0; 32]; + val.to_little_endian(&mut bytes); + bytes.resize(len, 0); + bytes.reverse(); + + let mut address = base_address; + for &byte in &bytes { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Write, + byte.into(), + )); + address.increment(); + } + + state.traces.push_byte_packing(BytePackingOp { + is_read: false, + base_address, + timestamp: clock * NUM_CHANNELS, + bytes, + }); +} diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 806726fc9c..8bca4accd4 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -72,7 +72,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 15..16, 14..15, 9..10, 12..13, 18..19], // Minimal ranges to prove an empty list + &[16..17, 10..11, 15..16, 14..15, 9..10, 12..13, 18..19], // Minimal ranges to prove an empty list &config, ); diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index 271ab9456f..16d83bd06e 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -439,7 +439,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { // Preprocess all circuits. let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 17..19, 14..15, 9..11, 12..13, 20..21], + &[16..17, 11..13, 17..19, 14..15, 9..11, 12..13, 19..21], &config, );