Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove extra SHL/SHR CTL. #1270

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions evm/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {

fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_shift_rows(),
],
vec![cpu_stark::ctl_arithmetic_base_rows()],
arithmetic_stark::ctl_arithmetic_rows(),
)
}
Expand Down
3 changes: 3 additions & 0 deletions evm/src/arithmetic/arithmetic_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use plonky2::util::transpose;
use static_assertions::const_assert;

use super::columns::NUM_ARITH_COLUMNS;
use super::shift;
use crate::all_stark::Table;
use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS};
use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation};
Expand Down Expand Up @@ -208,6 +209,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_packed(lv, nv, yield_constr);
modular::eval_packed(lv, nv, yield_constr);
byte::eval_packed(lv, yield_constr);
shift::eval_packed_generic(lv, nv, yield_constr);
}

fn eval_ext_circuit(
Expand Down Expand Up @@ -237,6 +239,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_ext_circuit(builder, lv, nv, yield_constr);
modular::eval_ext_circuit(builder, lv, nv, yield_constr);
byte::eval_ext_circuit(builder, lv, yield_constr);
shift::eval_ext_circuit(builder, lv, nv, yield_constr);
}

fn constraint_degree(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion evm/src/arithmetic/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_REGISTER_0;
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start;
pub(crate) const MODULAR_AUX_INPUT_LO: Range<usize> = AUX_REGISTER_1.start + 1..AUX_REGISTER_1.end;
pub(crate) const MODULAR_AUX_INPUT_HI: Range<usize> = AUX_REGISTER_2;
// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV]
// Must be set to MOD_IS_ZERO for DIV and SHR operations i.e. MOD_IS_ZERO * (lv[IS_DIV] + lv[IS_SHR]).
pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end;

/// The counter column (used for the range check) starts from 0 and increments.
Expand Down
77 changes: 49 additions & 28 deletions evm/src/arithmetic/divmod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@ use crate::arithmetic::modular::{
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
/// Generates the output and auxiliary values for modular operations,
/// assuming the input, modular and output limbs are already set.
pub(crate) fn generate_divmod<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
input_limbs_range: Range<usize>,
modulus_range: Range<usize>,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);

u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);

let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, input_limbs_range);
let pol_input = pol_extend(input_limbs);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, modulus_range);

debug_assert!(
&quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO),
"expected top half of quo_input to be zero"
Expand Down Expand Up @@ -62,16 +57,35 @@ pub(crate) fn generate<F: PrimeField64>(
);
lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]);
}
_ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"),
_ => panic!("expected filter to be IS_DIV, IS_SHR or IS_MOD but it was {filter}"),
};
}
/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);

u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);

generate_divmod(lv, nv, filter, INPUT_REGISTER_0, INPUT_REGISTER_1);
}

/// Verify that num = quo * den + rem and 0 <= rem < den.
fn eval_packed_divmod_helper<P: PackedField>(
pub(crate) fn eval_packed_divmod_helper<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
Expand All @@ -80,8 +94,8 @@ fn eval_packed_divmod_helper<P: PackedField>(

yield_constr.constraint_last_row(filter);

let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let mut quo = [P::ZEROS; 2 * N_LIMBS];
quo[..N_LIMBS].copy_from_slice(&lv[quo_range]);
Expand All @@ -104,14 +118,13 @@ pub(crate) fn eval_packed<P: PackedField>(
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
// Constrain IS_SHR independently, so that it doesn't impact the
// constraints when combining the flag with IS_DIV.
yield_constr.constraint_last_row(lv[IS_SHR]);
eval_packed_divmod_helper(
lv,
nv,
yield_constr,
lv[IS_DIV] + lv[IS_SHR],
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
Expand All @@ -120,24 +133,28 @@ pub(crate) fn eval_packed<P: PackedField>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
}

fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
pub(crate) fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
yield_constr.constraint_last_row(builder, filter);

let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let zero = builder.zero_extension();
let mut quo = [zero; 2 * N_LIMBS];
Expand All @@ -164,14 +181,14 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
yield_constr.constraint_last_row(builder, lv[IS_SHR]);
let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
div_shr_flag,
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
Expand All @@ -181,6 +198,8 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
Expand Down Expand Up @@ -214,7 +233,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Deactivate the SHR flag so that a DIV operation is not triggered.
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;

let mut constraint_consumer = ConstraintConsumer::new(
Expand Down Expand Up @@ -247,6 +266,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;

Expand Down Expand Up @@ -308,6 +328,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;

Expand Down
33 changes: 29 additions & 4 deletions evm/src/arithmetic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod byte;
mod divmod;
mod modular;
mod mul;
mod shift;
mod utils;

pub mod arithmetic_stark;
Expand All @@ -35,15 +36,29 @@ impl BinaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 {
match self {
BinaryOperator::Add => input0.overflowing_add(input1).0,
BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0,
BinaryOperator::Mul => input0.overflowing_mul(input1).0,
BinaryOperator::Shl => {
if input0 < U256::from(256usize) {
input1 << input0
} else {
U256::zero()
}
}
BinaryOperator::Sub => input0.overflowing_sub(input1).0,
BinaryOperator::Div | BinaryOperator::Shr => {
BinaryOperator::Div => {
if input1.is_zero() {
U256::zero()
} else {
input0 / input1
}
}
BinaryOperator::Shr => {
if input0 < U256::from(256usize) {
input1 >> input0
} else {
U256::zero()
}
}
BinaryOperator::Mod => {
if input1.is_zero() {
U256::zero()
Expand Down Expand Up @@ -238,15 +253,25 @@ fn binary_op_to_rows<F: PrimeField64>(
addcy::generate(&mut row, op.row_filter(), input0, input1);
(row, None)
}
BinaryOperator::Mul | BinaryOperator::Shl => {
BinaryOperator::Mul => {
mul::generate(&mut row, input0, input1);
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => {
BinaryOperator::Shl => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, true, input0, input1, result);
unzvfu marked this conversation as resolved.
Show resolved Hide resolved
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result);
(row, Some(nv))
}
BinaryOperator::Shr => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, false, input0, input1, result);
(row, Some(nv))
}
BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => {
ternary_op_to_rows::<F>(op.row_filter(), input0, input1, BN_BASE, result)
}
Expand Down
24 changes: 15 additions & 9 deletions evm/src/arithmetic/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(

let mut mod_is_zero = F::ZERO;
if modulus.is_zero() {
if filter == columns::IS_DIV {
if filter == columns::IS_DIV || filter == columns::IS_SHR {
// set modulus = 2^256; the condition above means we know
// it's zero at this point, so we can just set bit 256.
modulus.set_bit(256, true);
Expand Down Expand Up @@ -330,7 +330,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(

nv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64));
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV];
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);

(
output_limbs.map(F::from_canonical_i64),
Expand Down Expand Up @@ -392,14 +392,14 @@ pub(crate) fn check_reduced<P: PackedField>(
// Verify that the output is reduced, i.e. output < modulus.
let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
// This sets is_less_than to 1 unless we get mod_is_zero when
// doing a DIV; in that case, we need is_less_than=0, since
// doing a DIV or SHR; in that case, we need is_less_than=0, since
// eval_packed_generic_addcy checks
//
// modulus + out_aux_red == output + is_less_than*2^256
//
// and we are given output = out_aux_red when modulus is zero.
let mut is_less_than = [P::ZEROS; N_LIMBS];
is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV];
is_less_than[0] = P::ONES - mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);
// NB: output and modulus in lv while out_aux_red and
// is_less_than (via mod_is_zero) depend on nv, hence the
// 'is_two_row_op' argument is set to 'true'.
Expand Down Expand Up @@ -448,13 +448,15 @@ pub(crate) fn modular_constr_poly<P: PackedField>(
// modulus = 0.
modulus[0] += mod_is_zero;

// Is 1 iff the operation is DIV and the denominator is zero.
// Is 1 iff the operation is DIV or SHR and the denominator is zero.
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero));
yield_constr.constraint_transition(
filter * (mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]) - div_denom_is_zero),
);

// Needed to compensate for adding mod_is_zero to modulus above,
// since the call eval_packed_generic_addcy() below subtracts modulus
// to verify in the case of a DIV.
// to verify in the case of a DIV or SHR.
output[0] += div_denom_is_zero;

check_reduced(lv, nv, yield_constr, filter, output, modulus, mod_is_zero);
Expand Down Expand Up @@ -635,7 +637,8 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
modulus[0] = builder.add_extension(modulus[0], mod_is_zero);

let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero);
let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero);
let t = builder.mul_extension(filter, t);
yield_constr.constraint_transition(builder, t);
output[0] = builder.add_extension(output[0], div_denom_is_zero);
Expand All @@ -645,7 +648,7 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
let zero = builder.zero_extension();
let mut is_less_than = [zero; N_LIMBS];
is_less_than[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, div_shr_filter, one);

eval_ext_circuit_addcy(
builder,
Expand Down Expand Up @@ -834,6 +837,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;

Expand Down Expand Up @@ -867,6 +871,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;
Expand Down Expand Up @@ -926,6 +931,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;
Expand Down
Loading
Loading