Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reg_preimage columns in KeccakStark #1279

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions evm/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
ctl_arithmetic(),
ctl_byte_packing(),
ctl_keccak_sponge(),
ctl_keccak(),
ctl_keccak_inputs(),
ctl_keccak_outputs(),
ctl_logic(),
ctl_memory(),
]
Expand Down Expand Up @@ -131,16 +132,33 @@ fn ctl_byte_packing<F: Field>() -> CrossTableLookup<F> {
)
}

fn ctl_keccak<F: Field>() -> CrossTableLookup<F> {
// We now need two different looked tables for `KeccakStark`:
// one for the inputs and one for the outputs.
// They are linked with the timestamp.
fn ctl_keccak_inputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak(),
keccak_sponge_stark::ctl_looking_keccak_inputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data(),
Some(keccak_stark::ctl_filter()),
keccak_stark::ctl_data_inputs(),
Some(keccak_stark::ctl_filter_inputs()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
}

fn ctl_keccak_outputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak_outputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data_outputs(),
Some(keccak_stark::ctl_filter_outputs()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
}
Expand Down
14 changes: 5 additions & 9 deletions evm/src/keccak/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn reg_input_limb<F: Field>(i: usize) -> Column<F> {
let y = i_u64 / 5;
let x = i_u64 % 5;

let reg_low_limb = reg_preimage(x, y);
let reg_low_limb = reg_a(x, y);
let is_high_limb = i % 2;
Column::single(reg_low_limb + is_high_limb)
}
Expand Down Expand Up @@ -48,15 +48,11 @@ const R: [[u8; 5]; 5] = [
[27, 20, 39, 8, 14],
];

const START_PREIMAGE: usize = NUM_ROUNDS;
/// Registers to hold the original input to a permutation, i.e. the input to the first round.
pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);
START_PREIMAGE + (x * 5 + y) * 2
}
/// Column holding the timestamp, used to link inputs and outputs
/// in the `KeccakSpongeStark`.
pub(crate) const TIMESTAMP: usize = NUM_ROUNDS;

const START_A: usize = START_PREIMAGE + 5 * 5 * 2;
const START_A: usize = TIMESTAMP + 1;
pub(crate) const fn reg_a(x: usize, y: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);
Expand Down
121 changes: 49 additions & 72 deletions evm/src/keccak/keccak_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit;
use plonky2::timed;
use plonky2::util::timing::TimingTree;

use super::columns::reg_input_limb;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::keccak::columns::{
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step,
NUM_COLUMNS,
reg_b, reg_c, reg_c_prime, reg_output_limb, reg_step, NUM_COLUMNS, TIMESTAMP,
};
use crate::keccak::constants::{rc_value, rc_value_bit};
use crate::keccak::logic::{
Expand All @@ -33,13 +33,23 @@ pub(crate) const NUM_ROUNDS: usize = 24;
/// Number of 64-bit elements in the Keccak permutation input.
pub(crate) const NUM_INPUTS: usize = 25;

pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
pub fn ctl_data_inputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect();
res.extend(Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)));
res.push(Column::single(TIMESTAMP));
res
}

pub fn ctl_filter<F: Field>() -> Column<F> {
pub fn ctl_data_outputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)).collect();
res.push(Column::single(TIMESTAMP));
res
}

pub fn ctl_filter_inputs<F: Field>() -> Column<F> {
Column::single(reg_step(0))
}

pub fn ctl_filter_outputs<F: Field>() -> Column<F> {
Column::single(reg_step(NUM_ROUNDS - 1))
}

Expand All @@ -53,16 +63,16 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
inputs_and_timestamps: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs.len() * NUM_ROUNDS)
let num_rows = (inputs_and_timestamps.len() * NUM_ROUNDS)
.max(min_rows)
.next_power_of_two();

let mut rows = Vec::with_capacity(num_rows);
for input in inputs.iter() {
let rows_for_perm = self.generate_trace_rows_for_perm(*input);
for input_and_timestamp in inputs_and_timestamps.iter() {
let rows_for_perm = self.generate_trace_rows_for_perm(*input_and_timestamp);
rows.extend(rows_for_perm);
}

Expand All @@ -72,20 +82,19 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
rows
}

fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> {
fn generate_trace_rows_for_perm(
&self,
input_and_timestamp: ([u64; NUM_INPUTS], usize),
) -> Vec<[F; NUM_COLUMNS]> {
let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS];

// Populate the preimage for each row.
let input = input_and_timestamp.0;
let timestamp = input_and_timestamp.1;
// Set the timestamp of the current input.
// It will be checked against the value in `KeccakSponge`.
// The timestamp is used to link the input and output of
// the same permutation together.
for round in 0..24 {
for x in 0..5 {
for y in 0..5 {
let input_xy = input[y * 5 + x];
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
rows[round][reg_preimage_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF);
rows[round][reg_preimage_hi] = F::from_canonical_u64(input_xy >> 32);
}
}
rows[round][TIMESTAMP] = F::from_canonical_usize(timestamp);
}

// Populate the round input for the first round.
Expand Down Expand Up @@ -220,7 +229,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {

pub fn generate_trace(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
inputs: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
Expand Down Expand Up @@ -269,26 +278,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let not_final_step = P::ONES - final_step;
yield_constr.constraint(not_final_step * filter);

// If this is not the final step, the local and next preimages must match.
// Also, if this is the first step, the preimage must match A.
let is_first_step = local_values[reg_step(0)];
for x in 0..5 {
for y in 0..5 {
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
let diff_lo = local_values[reg_preimage_lo] - next_values[reg_preimage_lo];
let diff_hi = local_values[reg_preimage_hi] - next_values[reg_preimage_hi];
yield_constr.constraint_transition(not_final_step * diff_lo);
yield_constr.constraint_transition(not_final_step * diff_hi);

let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo = local_values[reg_preimage_lo] - local_values[reg_a_lo];
let diff_hi = local_values[reg_preimage_hi] - local_values[reg_a_hi];
yield_constr.constraint(is_first_step * diff_lo);
yield_constr.constraint(is_first_step * diff_hi);
}
}
// If this is not the final step or a padding row,
// the local and next timestamps must match.
let sum_round_flags = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
yield_constr.constraint(
sum_round_flags * not_final_step * (next_values[TIMESTAMP] - local_values[TIMESTAMP]),
);

// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 {
Expand Down Expand Up @@ -454,34 +451,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let constraint = builder.mul_extension(not_final_step, filter);
yield_constr.constraint(builder, constraint);

// If this is not the final step, the local and next preimages must match.
// Also, if this is the first step, the preimage must match A.
let is_first_step = local_values[reg_step(0)];
for x in 0..5 {
for y in 0..5 {
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
let diff = builder
.sub_extension(local_values[reg_preimage_lo], next_values[reg_preimage_lo]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);
let diff = builder
.sub_extension(local_values[reg_preimage_hi], next_values[reg_preimage_hi]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);

let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo =
builder.sub_extension(local_values[reg_preimage_lo], local_values[reg_a_lo]);
let constraint = builder.mul_extension(is_first_step, diff_lo);
yield_constr.constraint(builder, constraint);
let diff_hi =
builder.sub_extension(local_values[reg_preimage_hi], local_values[reg_a_hi]);
let constraint = builder.mul_extension(is_first_step, diff_hi);
yield_constr.constraint(builder, constraint);
}
}
// If this is not the final step or a padding row,
// the local and next timestamps must match.
let sum_round_flags =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let diff = builder.sub_extension(next_values[TIMESTAMP], local_values[TIMESTAMP]);
let constr = builder.mul_many_extension([sum_round_flags, not_final_step, diff]);
yield_constr.constraint(builder, constr);

// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 {
Expand Down Expand Up @@ -699,7 +675,7 @@ mod tests {
f: Default::default(),
};

let rows = stark.generate_trace_rows(vec![input], 8);
let rows = stark.generate_trace_rows(vec![(input, 0)], 8);
let last_row = rows[NUM_ROUNDS - 1];
let output = (0..NUM_INPUTS)
.map(|i| {
Expand Down Expand Up @@ -732,7 +708,8 @@ mod tests {

init_logger();

let input: Vec<[u64; NUM_INPUTS]> = (0..NUM_PERMS).map(|_| rand::random()).collect();
let input: Vec<([u64; NUM_INPUTS], usize)> =
(0..NUM_PERMS).map(|_| (rand::random(), 0)).collect();

let mut timing = TimingTree::new("prove", log::Level::Debug);
let trace_poly_values = timed!(
Expand Down
12 changes: 10 additions & 2 deletions evm/src/keccak_sponge/keccak_sponge_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
.collect()
}

pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
pub(crate) fn ctl_looking_keccak_inputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
let mut res: Vec<_> = Column::singles(
[
Expand All @@ -57,6 +57,13 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
.concat(),
)
.collect();
res.push(Column::single(cols.timestamp));

res
}

pub(crate) fn ctl_looking_keccak_outputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;

// We recover the 32-bit digest limbs from their corresponding bytes,
// and then append them to the rest of the updated state limbs.
Expand All @@ -68,9 +75,10 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
)
});

res.extend(digest_u32s);
let mut res: Vec<_> = digest_u32s.collect();

res.extend(Column::singles(&cols.partial_updated_state_u32s));
res.push(Column::single(cols.timestamp));

res
}
Expand Down
10 changes: 5 additions & 5 deletions evm/src/witness/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub(crate) struct Traces<T: Copy> {
pub(crate) cpu: Vec<CpuColumnsView<T>>,
pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) memory_ops: Vec<MemoryOp>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) keccak_inputs: Vec<([u64; keccak::keccak_stark::NUM_INPUTS], usize)>,
pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>,
}

Expand Down Expand Up @@ -131,18 +131,18 @@ impl<T: Copy> Traces<T> {
self.byte_packing_ops.push(op);
}

pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) {
self.keccak_inputs.push(input);
pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) {
self.keccak_inputs.push((input, clock));
}

pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) {
pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) {
let chunks = input
.chunks(size_of::<u64>())
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect_vec()
.try_into()
.unwrap();
self.push_keccak(chunks);
self.push_keccak(chunks, clock);
}

pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {
Expand Down
8 changes: 6 additions & 2 deletions evm/src/witness/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
address.increment();
}
xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap());
state.traces.push_keccak_bytes(sponge_state);
state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);
keccakf_u8s(&mut sponge_state);
}

Expand All @@ -254,7 +256,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
final_block[KECCAK_RATE_BYTES - 1] = 0b10000000;
}
xor_into_sponge(state, &mut sponge_state, &final_block);
state.traces.push_keccak_bytes(sponge_state);
state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);

state.traces.push_keccak_sponge(KeccakSpongeOp {
base_address,
Expand Down
Loading