Skip to content

Commit

Permalink
Add memory checks for prover_input, as well as range_checks for prove…
Browse files Browse the repository at this point in the history
…r_input, syscalls/exceptions (#1168)

* Add memory checks for prover_input and range_checks for prover_input, syscalls and exceptions

* Replace u32 by U256, and remove extra CTLs

* Add column in ArithmeticStark to use ctl_arithmetic_base_rows for is_range_check

* Fix CTLs and circuit constraint.

* Fix CTLs
  • Loading branch information
LindaGuiga authored and wborgeaud committed Nov 14, 2023
1 parent 136e656 commit 939df63
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 12 deletions.
31 changes: 28 additions & 3 deletions evm/src/arithmetic/arithmetic_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ pub(crate) fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
// If an arithmetic operation is happening on the CPU side,
// the CTL will enforce that the reconstructed opcode value
// from the opcode bits matches.
// These opcodes are missing the syscall and prover_input opcodes,
// since `IS_RANGE_CHECK` can be associated to multiple opcodes.
// For `IS_RANGE_CHECK`, the opcodes are written in OPCODE_COL,
// and we use that column for scaling and the CTL checks.
// Note that we ensure in the STARK's constraints that the
// value in `OPCODE_COL` is 0 if `IS_RANGE_CHECK` = 0.
const COMBINED_OPS: [(usize, u8); 16] = [
(columns::IS_ADD, 0x01),
(columns::IS_MUL, 0x02),
Expand All @@ -89,16 +95,21 @@ pub(crate) fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
columns::OUTPUT_REGISTER,
];

let filter_column = Some(Column::sum(COMBINED_OPS.iter().map(|(c, _v)| *c)));
let mut filter_cols = COMBINED_OPS.to_vec();
filter_cols.push((columns::IS_RANGE_CHECK, 0x01));

let filter_column = Some(Column::sum(filter_cols.iter().map(|(c, _v)| *c)));

let mut all_combined_cols = COMBINED_OPS.to_vec();
all_combined_cols.push((columns::OPCODE_COL, 0x01));
// Create the Arithmetic Table whose columns are those of the
// operations listed in `ops` whose inputs and outputs are given
// by `regs`, where each element of `regs` is a range of columns
// corresponding to a 256-bit input or output register (also `ops`
// is used as the operation filter).
TableWithColumns::new(
Table::Arithmetic,
cpu_arith_data_link(&COMBINED_OPS, &REGISTER_MAP),
cpu_arith_data_link(&all_combined_cols, &REGISTER_MAP),
filter_column,
)
}
Expand All @@ -109,7 +120,7 @@ pub struct ArithmeticStark<F, const D: usize> {
pub f: PhantomData<F>,
}

const RANGE_MAX: usize = 1usize << 16; // Range check strict upper bound
pub(crate) const RANGE_MAX: usize = 1usize << 16; // Range check strict upper bound

impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
/// Expects input in *column*-major layout
Expand Down Expand Up @@ -195,6 +206,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
let lv: &[P; NUM_ARITH_COLUMNS] = vars.get_local_values().try_into().unwrap();
let nv: &[P; NUM_ARITH_COLUMNS] = vars.get_next_values().try_into().unwrap();

// Check that `OPCODE_COL` holds 0 if the operation is not a range_check.
let opcode_constraint = (P::ONES - lv[columns::IS_RANGE_CHECK]) * lv[columns::OPCODE_COL];
yield_constr.constraint(opcode_constraint);

// Check the range column: First value must be 0, last row
// must be 2^16-1, and intermediate rows must increment by 0
// or 1.
Expand Down Expand Up @@ -231,6 +246,16 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
let nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS] =
vars.get_next_values().try_into().unwrap();

// Check that `OPCODE_COL` holds 0 if the operation is not a range_check.
let opcode_constraint = builder.arithmetic_extension(
F::NEG_ONE,
F::ONE,
lv[columns::IS_RANGE_CHECK],
lv[columns::OPCODE_COL],
lv[columns::OPCODE_COL],
);
yield_constr.constraint(builder, opcode_constraint);

// Check the range column: First value must be 0, last row
// must be 2^16-1, and intermediate rows must increment by 0
// or 1.
Expand Down
6 changes: 4 additions & 2 deletions evm/src/arithmetic/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ pub(crate) const IS_GT: usize = IS_LT + 1;
pub(crate) const IS_BYTE: usize = IS_GT + 1;
pub(crate) const IS_SHL: usize = IS_BYTE + 1;
pub(crate) const IS_SHR: usize = IS_SHL + 1;

pub(crate) const START_SHARED_COLS: usize = IS_SHR + 1;
pub(crate) const IS_RANGE_CHECK: usize = IS_SHR + 1;
/// Column that stores the opcode if the operation is a range check.
pub(crate) const OPCODE_COL: usize = IS_RANGE_CHECK + 1;
pub(crate) const START_SHARED_COLS: usize = OPCODE_COL + 1;

/// Within the Arithmetic Unit, there are shared columns which can be
/// used by any arithmetic circuit, depending on which one is active
Expand Down
55 changes: 55 additions & 0 deletions evm/src/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use ethereum_types::U256;
use plonky2::field::types::PrimeField64;

use self::columns::{
INPUT_REGISTER_0, INPUT_REGISTER_1, INPUT_REGISTER_2, OPCODE_COL, OUTPUT_REGISTER,
};
use self::utils::u256_to_array;
use crate::arithmetic::columns::IS_RANGE_CHECK;
use crate::extension_tower::BN_BASE;
use crate::util::{addmod, mulmod, submod};

Expand Down Expand Up @@ -135,6 +140,7 @@ impl TernaryOperator {
}

/// An enum representing arithmetic operations that can be either binary or ternary.
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
pub(crate) enum Operation {
BinaryOperation {
Expand All @@ -150,6 +156,13 @@ pub(crate) enum Operation {
input2: U256,
result: U256,
},
RangeCheckOperation {
input0: U256,
input1: U256,
input2: U256,
opcode: U256,
result: U256,
},
}

impl Operation {
Expand Down Expand Up @@ -195,11 +208,28 @@ impl Operation {
}
}

pub(crate) fn range_check(
input0: U256,
input1: U256,
input2: U256,
opcode: U256,
result: U256,
) -> Self {
Self::RangeCheckOperation {
input0,
input1,
input2,
opcode,
result,
}
}

/// Gets the result of an arithmetic operation.
pub(crate) fn result(&self) -> U256 {
match self {
Operation::BinaryOperation { result, .. } => *result,
Operation::TernaryOperation { result, .. } => *result,
_ => panic!("This function should not be called for range checks."),
}
}

Expand Down Expand Up @@ -228,6 +258,13 @@ impl Operation {
input2,
result,
} => ternary_op_to_rows(operator.row_filter(), input0, input1, input2, result),
Operation::RangeCheckOperation {
input0,
input1,
input2,
opcode,
result,
} => range_check_to_rows(input0, input1, input2, opcode, result),
}
}
}
Expand Down Expand Up @@ -293,3 +330,21 @@ fn binary_op_to_rows<F: PrimeField64>(
}
}
}

fn range_check_to_rows<F: PrimeField64>(
input0: U256,
input1: U256,
input2: U256,
opcode: U256,
result: U256,
) -> (Vec<F>, Option<Vec<F>>) {
let mut row = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
row[IS_RANGE_CHECK] = F::ONE;
row[OPCODE_COL] = F::from_canonical_u64(opcode.as_u64());
u256_to_array(&mut row[INPUT_REGISTER_0], input0);
u256_to_array(&mut row[INPUT_REGISTER_1], input1);
u256_to_array(&mut row[INPUT_REGISTER_2], input2);
u256_to_array(&mut row[OUTPUT_REGISTER], result);

(row, None)
}
7 changes: 5 additions & 2 deletions evm/src/cpu/cpu_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ fn ctl_data_binops<F: Field>() -> Vec<Column<F>> {

/// Creates the vector of `Columns` corresponding to the three inputs and
/// one output of a ternary operation. By default, ternary operations use
/// the first three memory channels, and the last one for the result (binary
/// operations do not use the third inputs).
/// the first three memory channels, and the next top of the stack for the
/// result (binary operations do not use the third inputs).
fn ctl_data_ternops<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[1].value));
Expand Down Expand Up @@ -115,6 +115,9 @@ pub fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
COL_MAP.op.fp254_op,
COL_MAP.op.ternary_op,
COL_MAP.op.shift,
COL_MAP.op.prover_input,
COL_MAP.op.syscall,
COL_MAP.op.exception,
])),
)
}
Expand Down
6 changes: 6 additions & 0 deletions evm/src/cpu/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ pub(crate) const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsCol
pushes: true,
disable_other_channels: true,
}),
prover_input: Some(StackBehavior {
num_pops: 0,
pushes: true,
disable_other_channels: true,
}),
jumps: None, // Depends on whether it's a JUMP or a JUMPI.
prover_input: None, // TODO
jumps: None, // Depends on whether it's a JUMP or a JUMPI.
pc_push0: Some(StackBehavior {
Expand Down
58 changes: 53 additions & 5 deletions evm/src/witness/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,22 @@ pub(crate) fn generate_prover_input<F: Field>(
let pc = state.registers.program_counter;
let input_fn = &KERNEL.prover_inputs[&pc];
let input = state.prover_input(input_fn)?;
let opcode = 0x49.into();
// `ArithmeticStark` range checks `mem_channels[0]`, which contains
// the top of the stack, `mem_channels[1]`, `mem_channels[2]` and
// next_row's `mem_channels[0]` which contains the next top of the stack.
// Our goal here is to range-check the input, in the next stack top.
let range_check_op = arithmetic::Operation::range_check(
state.registers.stack_top,
U256::from(0),
U256::from(0),
opcode,
input,
);

push_with_write(state, &mut row, input)?;

state.traces.push_arithmetic(range_check_op);
state.traces.push_cpu(row);
Ok(())
}
Expand Down Expand Up @@ -772,10 +787,24 @@ pub(crate) fn generate_syscall<F: Field>(
let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2;
let new_program_counter = u256_to_usize(handler_addr)?;

let gas = U256::from(state.registers.gas_used);

let syscall_info = U256::from(state.registers.program_counter + 1)
+ (U256::from(u64::from(state.registers.is_kernel)) << 32)
+ (U256::from(state.registers.gas_used) << 192);

+ (gas << 192);

// `ArithmeticStark` range checks `mem_channels[0]`, which contains
// the top of the stack, `mem_channels[1]`, `mem_channels[2]` and
// next_row's `mem_channels[0]` which contains the next top of the stack.
// Our goal here is to range-check the gas, contained in syscall_info,
// stored in the next stack top.
let range_check_op = arithmetic::Operation::range_check(
state.registers.stack_top,
handler_addr0,
handler_addr1,
U256::from(opcode),
syscall_info,
);
// Set registers before pushing to the stack; in particular, we need to set kernel mode so we
// can't incorrectly trigger a stack overflow. However, note that we have to do it _after_ we
// make `syscall_info`, which should contain the old values.
Expand All @@ -787,6 +816,7 @@ pub(crate) fn generate_syscall<F: Field>(

log::debug!("Syscall to {}", KERNEL.offset_name(new_program_counter));

state.traces.push_arithmetic(range_check_op);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
Expand Down Expand Up @@ -1055,9 +1085,27 @@ pub(crate) fn generate_exception<F: Field>(
let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2;
let new_program_counter = u256_to_usize(handler_addr)?;

let exc_info =
U256::from(state.registers.program_counter) + (U256::from(state.registers.gas_used) << 192);
let gas = U256::from(state.registers.gas_used);

let exc_info = U256::from(state.registers.program_counter) + (gas << 192);

// Get the opcode so we can provide it to the range_check operation.
let code_context = state.registers.code_context();
let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter);
let opcode = state.memory.get(address);

// `ArithmeticStark` range checks `mem_channels[0]`, which contains
// the top of the stack, `mem_channels[1]`, `mem_channels[2]` and
// next_row's `mem_channels[0]` which contains the next top of the stack.
// Our goal here is to range-check the gas, contained in syscall_info,
// stored in the next stack top.
let range_check_op = arithmetic::Operation::range_check(
state.registers.stack_top,
handler_addr0,
handler_addr1,
opcode,
exc_info,
);
// Set registers before pushing to the stack; in particular, we need to set kernel mode so we
// can't incorrectly trigger a stack overflow. However, note that we have to do it _after_ we
// make `exc_info`, which should contain the old values.
Expand All @@ -1068,7 +1116,7 @@ pub(crate) fn generate_exception<F: Field>(
push_with_write(state, &mut row, exc_info)?;

log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter));

state.traces.push_arithmetic(range_check_op);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
Expand Down
1 change: 1 addition & 0 deletions evm/src/witness/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl<T: Copy> Traces<T> {
BinaryOperator::Div | BinaryOperator::Mod => 2,
_ => 1,
},
Operation::RangeCheckOperation { .. } => 1,
})
.sum(),
byte_packing_len: self.byte_packing_ops.iter().map(|op| op.bytes.len()).sum(),
Expand Down

0 comments on commit 939df63

Please sign in to comment.