diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index a424929798..0f92cbd20d 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -19,8 +19,8 @@ pub(crate) fn eval_packed( // Not needed here; val is the input and we're verifying that output is // val * 2^d (mod 2^256) - //let val = lv.mem_channels[0]; - //let output = lv.mem_channels[NUM_GP_CHANNELS - 1]; + // let val = lv.mem_channels[0]; + // let output = lv.mem_channels[NUM_GP_CHANNELS - 1]; let shift_table_segment = P::Scalar::from_canonical_u64(Segment::ShiftTable as u64); @@ -28,7 +28,7 @@ pub(crate) fn eval_packed( // two_exp.used is true (1) if the high limbs of the displacement are // zero and false (0) otherwise. let high_limbs_are_zero = two_exp.used; - yield_constr.constraint(is_shift * (two_exp.is_read - P::ONES)); + yield_constr.constraint(is_shift * high_limbs_are_zero * (two_exp.is_read - P::ONES)); let high_limbs_sum: P = displacement.value[1..].iter().copied().sum(); let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; @@ -70,14 +70,20 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64); + // Only lookup the shifting factor when displacement is < 2^32. + // two_exp.used is true (1) if the high limbs of the displacement are + // zero and false (0) otherwise. let high_limbs_are_zero = two_exp.used; let one = builder.one_extension(); let t = builder.sub_extension(two_exp.is_read, one); + let t = builder.mul_extension(high_limbs_are_zero, t); let t = builder.mul_extension(is_shift, t); yield_constr.constraint(builder, t); let high_limbs_sum = builder.add_many_extension(&displacement.value[1..]); let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; + // Verify that high_limbs_are_zero = 0 implies high_limbs_sum != 0 and + // high_limbs_are_zero = 1 implies high_limbs_sum = 0. let t = builder.one_extension(); let t = builder.sub_extension(t, high_limbs_are_zero); let t = builder.mul_sub_extension(high_limbs_sum, high_limbs_sum_inv, t); @@ -87,6 +93,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let t = builder.mul_many_extension([is_shift, high_limbs_sum, high_limbs_are_zero]); yield_constr.constraint(builder, t); + // When the shift displacement is < 2^32, constrain the two_exp + // mem_channel to be the entry corresponding to `displacement` in + // the shift table lookup (will be zero if displacement >= 256). let t = builder.mul_extension(is_shift, two_exp.addr_context); yield_constr.constraint(builder, t); let t = builder.arithmetic_extension( @@ -101,6 +110,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let t = builder.mul_extension(is_shift, t); yield_constr.constraint(builder, t); + // Other channels must be unused for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] { let t = builder.mul_extension(is_shift, chan.used); yield_constr.constraint(builder, t); diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 568fe4b181..f4dc03e806 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -497,6 +497,10 @@ fn append_shift( channel.addr_context = F::from_canonical_usize(lookup_addr.context); channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); + + // Extra field required by the constraints for large shifts. + let high_limb_sum = row.mem_channels[0].value[1..].iter().copied().sum::(); + row.general.shift_mut().high_limb_sum_inv = high_limb_sum.inverse(); } let operator = if is_shl {