diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 21dcf91985..14544654eb 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -63,6 +63,12 @@ pub(crate) fn ctl_arithmetic_rows() -> TableWithColumns { // If an arithmetic operation is happening on the CPU side, // the CTL will enforce that the reconstructed opcode value // from the opcode bits matches. + // These opcodes are missing the syscall and prover_input opcodes, + // since `IS_RANGE_CHECK` can be associated to multiple opcodes. + // For `IS_RANGE_CHECK`, the opcodes are written in OPCODE_COL, + // and we use that column for scaling and the CTL checks. + // Note that we ensure in the STARK's constraints that the + // value in `OPCODE_COL` is 0 if `IS_RANGE_CHECK` = 0. const COMBINED_OPS: [(usize, u8); 16] = [ (columns::IS_ADD, 0x01), (columns::IS_MUL, 0x02), @@ -89,8 +95,13 @@ pub(crate) fn ctl_arithmetic_rows() -> TableWithColumns { columns::OUTPUT_REGISTER, ]; - let filter_column = Some(Column::sum(COMBINED_OPS.iter().map(|(c, _v)| *c))); + let mut filter_cols = COMBINED_OPS.to_vec(); + filter_cols.push((columns::IS_RANGE_CHECK, 0x01)); + let filter_column = Some(Column::sum(filter_cols.iter().map(|(c, _v)| *c))); + + let mut all_combined_cols = COMBINED_OPS.to_vec(); + all_combined_cols.push((columns::OPCODE_COL, 0x01)); // Create the Arithmetic Table whose columns are those of the // operations listed in `ops` whose inputs and outputs are given // by `regs`, where each element of `regs` is a range of columns @@ -98,7 +109,7 @@ pub(crate) fn ctl_arithmetic_rows() -> TableWithColumns { // is used as the operation filter). TableWithColumns::new( Table::Arithmetic, - cpu_arith_data_link(&COMBINED_OPS, ®ISTER_MAP), + cpu_arith_data_link(&all_combined_cols, ®ISTER_MAP), filter_column, ) } @@ -109,7 +120,7 @@ pub struct ArithmeticStark { pub f: PhantomData, } -const RANGE_MAX: usize = 1usize << 16; // Range check strict upper bound +pub(crate) const RANGE_MAX: usize = 1usize << 16; // Range check strict upper bound impl ArithmeticStark { /// Expects input in *column*-major layout @@ -195,6 +206,10 @@ impl, const D: usize> Stark for ArithmeticSta let lv: &[P; NUM_ARITH_COLUMNS] = vars.get_local_values().try_into().unwrap(); let nv: &[P; NUM_ARITH_COLUMNS] = vars.get_next_values().try_into().unwrap(); + // Check that `OPCODE_COL` holds 0 if the operation is not a range_check. + let opcode_constraint = (P::ONES - lv[columns::IS_RANGE_CHECK]) * lv[columns::OPCODE_COL]; + yield_constr.constraint(opcode_constraint); + // Check the range column: First value must be 0, last row // must be 2^16-1, and intermediate rows must increment by 0 // or 1. @@ -231,6 +246,16 @@ impl, const D: usize> Stark for ArithmeticSta let nv: &[ExtensionTarget; NUM_ARITH_COLUMNS] = vars.get_next_values().try_into().unwrap(); + // Check that `OPCODE_COL` holds 0 if the operation is not a range_check. + let opcode_constraint = builder.arithmetic_extension( + F::NEG_ONE, + F::ONE, + lv[columns::IS_RANGE_CHECK], + lv[columns::OPCODE_COL], + lv[columns::OPCODE_COL], + ); + yield_constr.constraint(builder, opcode_constraint); + // Check the range column: First value must be 0, last row // must be 2^16-1, and intermediate rows must increment by 0 // or 1. diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index aa36b3ab71..3736445433 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -38,8 +38,10 @@ pub(crate) const IS_GT: usize = IS_LT + 1; pub(crate) const IS_BYTE: usize = IS_GT + 1; pub(crate) const IS_SHL: usize = IS_BYTE + 1; pub(crate) const IS_SHR: usize = IS_SHL + 1; - -pub(crate) const START_SHARED_COLS: usize = IS_SHR + 1; +pub(crate) const IS_RANGE_CHECK: usize = IS_SHR + 1; +/// Column that stores the opcode if the operation is a range check. +pub(crate) const OPCODE_COL: usize = IS_RANGE_CHECK + 1; +pub(crate) const START_SHARED_COLS: usize = OPCODE_COL + 1; /// Within the Arithmetic Unit, there are shared columns which can be /// used by any arithmetic circuit, depending on which one is active diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 2699ee51c4..4b84a3510e 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -1,6 +1,11 @@ use ethereum_types::U256; use plonky2::field::types::PrimeField64; +use self::columns::{ + INPUT_REGISTER_0, INPUT_REGISTER_1, INPUT_REGISTER_2, OPCODE_COL, OUTPUT_REGISTER, +}; +use self::utils::u256_to_array; +use crate::arithmetic::columns::IS_RANGE_CHECK; use crate::extension_tower::BN_BASE; use crate::util::{addmod, mulmod, submod}; @@ -135,6 +140,7 @@ impl TernaryOperator { } /// An enum representing arithmetic operations that can be either binary or ternary. +#[allow(clippy::enum_variant_names)] #[derive(Debug)] pub(crate) enum Operation { BinaryOperation { @@ -150,6 +156,13 @@ pub(crate) enum Operation { input2: U256, result: U256, }, + RangeCheckOperation { + input0: U256, + input1: U256, + input2: U256, + opcode: U256, + result: U256, + }, } impl Operation { @@ -195,11 +208,28 @@ impl Operation { } } + pub(crate) fn range_check( + input0: U256, + input1: U256, + input2: U256, + opcode: U256, + result: U256, + ) -> Self { + Self::RangeCheckOperation { + input0, + input1, + input2, + opcode, + result, + } + } + /// Gets the result of an arithmetic operation. pub(crate) fn result(&self) -> U256 { match self { Operation::BinaryOperation { result, .. } => *result, Operation::TernaryOperation { result, .. } => *result, + _ => panic!("This function should not be called for range checks."), } } @@ -228,6 +258,13 @@ impl Operation { input2, result, } => ternary_op_to_rows(operator.row_filter(), input0, input1, input2, result), + Operation::RangeCheckOperation { + input0, + input1, + input2, + opcode, + result, + } => range_check_to_rows(input0, input1, input2, opcode, result), } } } @@ -293,3 +330,21 @@ fn binary_op_to_rows( } } } + +fn range_check_to_rows( + input0: U256, + input1: U256, + input2: U256, + opcode: U256, + result: U256, +) -> (Vec, Option>) { + let mut row = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + row[IS_RANGE_CHECK] = F::ONE; + row[OPCODE_COL] = F::from_canonical_u64(opcode.as_u64()); + u256_to_array(&mut row[INPUT_REGISTER_0], input0); + u256_to_array(&mut row[INPUT_REGISTER_1], input1); + u256_to_array(&mut row[INPUT_REGISTER_2], input2); + u256_to_array(&mut row[OUTPUT_REGISTER], result); + + (row, None) +} diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 926cd7485c..755461f345 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -62,8 +62,8 @@ fn ctl_data_binops() -> Vec> { /// Creates the vector of `Columns` corresponding to the three inputs and /// one output of a ternary operation. By default, ternary operations use -/// the first three memory channels, and the last one for the result (binary -/// operations do not use the third inputs). +/// the first three memory channels, and the next top of the stack for the +/// result (binary operations do not use the third inputs). fn ctl_data_ternops() -> Vec> { let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); @@ -103,6 +103,9 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { COL_MAP.op.fp254_op, COL_MAP.op.ternary_op, COL_MAP.op.shift, + COL_MAP.op.prover_input, + COL_MAP.op.syscall, + COL_MAP.op.exception, ])), ) } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index ab844432b6..dffdff42bb 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -91,8 +91,12 @@ pub(crate) const STACK_BEHAVIORS: OpsColumnsView> = OpsCol disable_other_channels: false, }), jumpdest_keccak_general: None, - prover_input: None, // TODO - jumps: None, // Depends on whether it's a JUMP or a JUMPI. + prover_input: Some(StackBehavior { + num_pops: 0, + pushes: true, + disable_other_channels: true, + }), + jumps: None, // Depends on whether it's a JUMP or a JUMPI. pc_push0: Some(StackBehavior { num_pops: 0, pushes: true, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 2777269d89..64ff85660d 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -162,7 +162,22 @@ pub(crate) fn generate_prover_input( let pc = state.registers.program_counter; let input_fn = &KERNEL.prover_inputs[&pc]; let input = state.prover_input(input_fn)?; + let opcode = 0x49.into(); + // `ArithmeticStark` range checks `mem_channels[0]`, which contains + // the top of the stack, `mem_channels[1]`, `mem_channels[2]` and + // next_row's `mem_channels[0]` which contains the next top of the stack. + // Our goal here is to range-check the input, in the next stack top. + let range_check_op = arithmetic::Operation::range_check( + state.registers.stack_top, + U256::from(0), + U256::from(0), + opcode, + input, + ); + push_with_write(state, &mut row, input)?; + + state.traces.push_arithmetic(range_check_op); state.traces.push_cpu(row); Ok(()) } @@ -700,10 +715,24 @@ pub(crate) fn generate_syscall( let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; let new_program_counter = u256_to_usize(handler_addr)?; + let gas = U256::from(state.registers.gas_used); + let syscall_info = U256::from(state.registers.program_counter + 1) + (U256::from(u64::from(state.registers.is_kernel)) << 32) - + (U256::from(state.registers.gas_used) << 192); - + + (gas << 192); + + // `ArithmeticStark` range checks `mem_channels[0]`, which contains + // the top of the stack, `mem_channels[1]`, `mem_channels[2]` and + // next_row's `mem_channels[0]` which contains the next top of the stack. + // Our goal here is to range-check the gas, contained in syscall_info, + // stored in the next stack top. + let range_check_op = arithmetic::Operation::range_check( + state.registers.stack_top, + handler_addr0, + handler_addr1, + U256::from(opcode), + syscall_info, + ); // Set registers before pushing to the stack; in particular, we need to set kernel mode so we // can't incorrectly trigger a stack overflow. However, note that we have to do it _after_ we // make `syscall_info`, which should contain the old values. @@ -715,6 +744,7 @@ pub(crate) fn generate_syscall( log::debug!("Syscall to {}", KERNEL.offset_name(new_program_counter)); + state.traces.push_arithmetic(range_check_op); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); @@ -983,9 +1013,27 @@ pub(crate) fn generate_exception( let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; let new_program_counter = u256_to_usize(handler_addr)?; - let exc_info = - U256::from(state.registers.program_counter) + (U256::from(state.registers.gas_used) << 192); + let gas = U256::from(state.registers.gas_used); + + let exc_info = U256::from(state.registers.program_counter) + (gas << 192); + // Get the opcode so we can provide it to the range_check operation. + let code_context = state.registers.code_context(); + let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); + let opcode = state.memory.get(address); + + // `ArithmeticStark` range checks `mem_channels[0]`, which contains + // the top of the stack, `mem_channels[1]`, `mem_channels[2]` and + // next_row's `mem_channels[0]` which contains the next top of the stack. + // Our goal here is to range-check the gas, contained in syscall_info, + // stored in the next stack top. + let range_check_op = arithmetic::Operation::range_check( + state.registers.stack_top, + handler_addr0, + handler_addr1, + opcode, + exc_info, + ); // Set registers before pushing to the stack; in particular, we need to set kernel mode so we // can't incorrectly trigger a stack overflow. However, note that we have to do it _after_ we // make `exc_info`, which should contain the old values. @@ -996,7 +1044,7 @@ pub(crate) fn generate_exception( push_with_write(state, &mut row, exc_info)?; log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter)); - + state.traces.push_arithmetic(range_check_op); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index 91035fc403..1a2b855c28 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -66,6 +66,7 @@ impl Traces { BinaryOperator::Div | BinaryOperator::Mod => 2, _ => 1, }, + Operation::RangeCheckOperation { .. } => 1, }) .sum(), byte_packing_len: self.byte_packing_ops.iter().map(|op| op.bytes.len()).sum(),