Skip to content

Commit

Permalink
Remove reg_preimage columns in KeccakStark (#1279)
Browse files Browse the repository at this point in the history
* Remove reg_preimage columns in KeccakStark

* Apply comments

* Minor cleanup
  • Loading branch information
LindaGuiga authored Oct 6, 2023
1 parent 0de6f94 commit e58d779
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 95 deletions.
28 changes: 23 additions & 5 deletions evm/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
ctl_arithmetic(),
ctl_byte_packing(),
ctl_keccak_sponge(),
ctl_keccak(),
ctl_keccak_inputs(),
ctl_keccak_outputs(),
ctl_logic(),
ctl_memory(),
]
Expand Down Expand Up @@ -131,16 +132,33 @@ fn ctl_byte_packing<F: Field>() -> CrossTableLookup<F> {
)
}

fn ctl_keccak<F: Field>() -> CrossTableLookup<F> {
// We now need two different looked tables for `KeccakStark`:
// one for the inputs and one for the outputs.
// They are linked with the timestamp.
fn ctl_keccak_inputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak(),
keccak_sponge_stark::ctl_looking_keccak_inputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data(),
Some(keccak_stark::ctl_filter()),
keccak_stark::ctl_data_inputs(),
Some(keccak_stark::ctl_filter_inputs()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
}

fn ctl_keccak_outputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak_outputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data_outputs(),
Some(keccak_stark::ctl_filter_outputs()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
}
Expand Down
14 changes: 5 additions & 9 deletions evm/src/keccak/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn reg_input_limb<F: Field>(i: usize) -> Column<F> {
let y = i_u64 / 5;
let x = i_u64 % 5;

let reg_low_limb = reg_preimage(x, y);
let reg_low_limb = reg_a(x, y);
let is_high_limb = i % 2;
Column::single(reg_low_limb + is_high_limb)
}
Expand Down Expand Up @@ -48,15 +48,11 @@ const R: [[u8; 5]; 5] = [
[27, 20, 39, 8, 14],
];

const START_PREIMAGE: usize = NUM_ROUNDS;
/// Registers to hold the original input to a permutation, i.e. the input to the first round.
pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);
START_PREIMAGE + (x * 5 + y) * 2
}
/// Column holding the timestamp, used to link inputs and outputs
/// in the `KeccakSpongeStark`.
pub(crate) const TIMESTAMP: usize = NUM_ROUNDS;

const START_A: usize = START_PREIMAGE + 5 * 5 * 2;
const START_A: usize = TIMESTAMP + 1;
pub(crate) const fn reg_a(x: usize, y: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);
Expand Down
121 changes: 49 additions & 72 deletions evm/src/keccak/keccak_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit;
use plonky2::timed;
use plonky2::util::timing::TimingTree;

use super::columns::reg_input_limb;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::keccak::columns::{
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step,
NUM_COLUMNS,
reg_b, reg_c, reg_c_prime, reg_output_limb, reg_step, NUM_COLUMNS, TIMESTAMP,
};
use crate::keccak::constants::{rc_value, rc_value_bit};
use crate::keccak::logic::{
Expand All @@ -33,13 +33,23 @@ pub(crate) const NUM_ROUNDS: usize = 24;
/// Number of 64-bit elements in the Keccak permutation input.
pub(crate) const NUM_INPUTS: usize = 25;

pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
pub fn ctl_data_inputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect();
res.extend(Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)));
res.push(Column::single(TIMESTAMP));
res
}

pub fn ctl_filter<F: Field>() -> Column<F> {
pub fn ctl_data_outputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)).collect();
res.push(Column::single(TIMESTAMP));
res
}

pub fn ctl_filter_inputs<F: Field>() -> Column<F> {
Column::single(reg_step(0))
}

pub fn ctl_filter_outputs<F: Field>() -> Column<F> {
Column::single(reg_step(NUM_ROUNDS - 1))
}

Expand All @@ -53,16 +63,16 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
inputs_and_timestamps: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs.len() * NUM_ROUNDS)
let num_rows = (inputs_and_timestamps.len() * NUM_ROUNDS)
.max(min_rows)
.next_power_of_two();

let mut rows = Vec::with_capacity(num_rows);
for input in inputs.iter() {
let rows_for_perm = self.generate_trace_rows_for_perm(*input);
for input_and_timestamp in inputs_and_timestamps.iter() {
let rows_for_perm = self.generate_trace_rows_for_perm(*input_and_timestamp);
rows.extend(rows_for_perm);
}

Expand All @@ -72,20 +82,19 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
rows
}

fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> {
fn generate_trace_rows_for_perm(
&self,
input_and_timestamp: ([u64; NUM_INPUTS], usize),
) -> Vec<[F; NUM_COLUMNS]> {
let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS];

// Populate the preimage for each row.
let input = input_and_timestamp.0;
let timestamp = input_and_timestamp.1;
// Set the timestamp of the current input.
// It will be checked against the value in `KeccakSponge`.
// The timestamp is used to link the input and output of
// the same permutation together.
for round in 0..24 {
for x in 0..5 {
for y in 0..5 {
let input_xy = input[y * 5 + x];
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
rows[round][reg_preimage_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF);
rows[round][reg_preimage_hi] = F::from_canonical_u64(input_xy >> 32);
}
}
rows[round][TIMESTAMP] = F::from_canonical_usize(timestamp);
}

// Populate the round input for the first round.
Expand Down Expand Up @@ -220,7 +229,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {

pub fn generate_trace(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
inputs: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
Expand Down Expand Up @@ -269,26 +278,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let not_final_step = P::ONES - final_step;
yield_constr.constraint(not_final_step * filter);

// If this is not the final step, the local and next preimages must match.
// Also, if this is the first step, the preimage must match A.
let is_first_step = local_values[reg_step(0)];
for x in 0..5 {
for y in 0..5 {
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
let diff_lo = local_values[reg_preimage_lo] - next_values[reg_preimage_lo];
let diff_hi = local_values[reg_preimage_hi] - next_values[reg_preimage_hi];
yield_constr.constraint_transition(not_final_step * diff_lo);
yield_constr.constraint_transition(not_final_step * diff_hi);

let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo = local_values[reg_preimage_lo] - local_values[reg_a_lo];
let diff_hi = local_values[reg_preimage_hi] - local_values[reg_a_hi];
yield_constr.constraint(is_first_step * diff_lo);
yield_constr.constraint(is_first_step * diff_hi);
}
}
// If this is not the final step or a padding row,
// the local and next timestamps must match.
let sum_round_flags = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
yield_constr.constraint(
sum_round_flags * not_final_step * (next_values[TIMESTAMP] - local_values[TIMESTAMP]),
);

// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 {
Expand Down Expand Up @@ -454,34 +451,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let constraint = builder.mul_extension(not_final_step, filter);
yield_constr.constraint(builder, constraint);

// If this is not the final step, the local and next preimages must match.
// Also, if this is the first step, the preimage must match A.
let is_first_step = local_values[reg_step(0)];
for x in 0..5 {
for y in 0..5 {
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
let diff = builder
.sub_extension(local_values[reg_preimage_lo], next_values[reg_preimage_lo]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);
let diff = builder
.sub_extension(local_values[reg_preimage_hi], next_values[reg_preimage_hi]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);

let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo =
builder.sub_extension(local_values[reg_preimage_lo], local_values[reg_a_lo]);
let constraint = builder.mul_extension(is_first_step, diff_lo);
yield_constr.constraint(builder, constraint);
let diff_hi =
builder.sub_extension(local_values[reg_preimage_hi], local_values[reg_a_hi]);
let constraint = builder.mul_extension(is_first_step, diff_hi);
yield_constr.constraint(builder, constraint);
}
}
// If this is not the final step or a padding row,
// the local and next timestamps must match.
let sum_round_flags =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let diff = builder.sub_extension(next_values[TIMESTAMP], local_values[TIMESTAMP]);
let constr = builder.mul_many_extension([sum_round_flags, not_final_step, diff]);
yield_constr.constraint(builder, constr);

// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 {
Expand Down Expand Up @@ -699,7 +675,7 @@ mod tests {
f: Default::default(),
};

let rows = stark.generate_trace_rows(vec![input], 8);
let rows = stark.generate_trace_rows(vec![(input, 0)], 8);
let last_row = rows[NUM_ROUNDS - 1];
let output = (0..NUM_INPUTS)
.map(|i| {
Expand Down Expand Up @@ -732,7 +708,8 @@ mod tests {

init_logger();

let input: Vec<[u64; NUM_INPUTS]> = (0..NUM_PERMS).map(|_| rand::random()).collect();
let input: Vec<([u64; NUM_INPUTS], usize)> =
(0..NUM_PERMS).map(|_| (rand::random(), 0)).collect();

let mut timing = TimingTree::new("prove", log::Level::Debug);
let trace_poly_values = timed!(
Expand Down
12 changes: 10 additions & 2 deletions evm/src/keccak_sponge/keccak_sponge_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
.collect()
}

pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
pub(crate) fn ctl_looking_keccak_inputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
let mut res: Vec<_> = Column::singles(
[
Expand All @@ -57,6 +57,13 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
.concat(),
)
.collect();
res.push(Column::single(cols.timestamp));

res
}

pub(crate) fn ctl_looking_keccak_outputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;

// We recover the 32-bit digest limbs from their corresponding bytes,
// and then append them to the rest of the updated state limbs.
Expand All @@ -68,9 +75,10 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
)
});

res.extend(digest_u32s);
let mut res: Vec<_> = digest_u32s.collect();

res.extend(Column::singles(&cols.partial_updated_state_u32s));
res.push(Column::single(cols.timestamp));

res
}
Expand Down
10 changes: 5 additions & 5 deletions evm/src/witness/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub(crate) struct Traces<T: Copy> {
pub(crate) cpu: Vec<CpuColumnsView<T>>,
pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) memory_ops: Vec<MemoryOp>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) keccak_inputs: Vec<([u64; keccak::keccak_stark::NUM_INPUTS], usize)>,
pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>,
}

Expand Down Expand Up @@ -131,18 +131,18 @@ impl<T: Copy> Traces<T> {
self.byte_packing_ops.push(op);
}

pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) {
self.keccak_inputs.push(input);
pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) {
self.keccak_inputs.push((input, clock));
}

pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) {
pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) {
let chunks = input
.chunks(size_of::<u64>())
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect_vec()
.try_into()
.unwrap();
self.push_keccak(chunks);
self.push_keccak(chunks, clock);
}

pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {
Expand Down
8 changes: 6 additions & 2 deletions evm/src/witness/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
address.increment();
}
xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap());
state.traces.push_keccak_bytes(sponge_state);
state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);
keccakf_u8s(&mut sponge_state);
}

Expand All @@ -254,7 +256,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
final_block[KECCAK_RATE_BYTES - 1] = 0b10000000;
}
xor_into_sponge(state, &mut sponge_state, &final_block);
state.traces.push_keccak_bytes(sponge_state);
state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);

state.traces.push_keccak_sponge(KeccakSpongeOp {
base_address,
Expand Down

0 comments on commit e58d779

Please sign in to comment.