Skip to content

Commit

Permalink
Remove redundant Keccak sponge cols (#1233)
Browse files Browse the repository at this point in the history
* Rename columns in KeccakSponge for clarity

* Remove redundant columns

* Apply comments
  • Loading branch information
Nashtare authored Sep 14, 2023
1 parent 06bc73f commit 19220b2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 61 deletions.
15 changes: 11 additions & 4 deletions evm/src/keccak_sponge/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks};

pub(crate) const KECCAK_WIDTH_BYTES: usize = 200;
pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4;
pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize =
(KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4;
pub(crate) const KECCAK_RATE_BYTES: usize = 136;
pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4;
pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64;
pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4;
pub(crate) const KECCAK_DIGEST_BYTES: usize = 32;
pub(crate) const KECCAK_DIGEST_U32S: usize = KECCAK_DIGEST_BYTES / 4;

#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -52,10 +55,14 @@ pub(crate) struct KeccakSpongeColumnsView<T: Copy> {
pub xored_rate_u32s: [T; KECCAK_RATE_U32S],

/// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the
/// permutation is applied.
pub updated_state_u32s: [T; KECCAK_WIDTH_U32S],

pub updated_state_bytes: [T; KECCAK_DIGEST_BYTES],
/// permutation is applied, minus the first limbs where the digest is extracted from.
/// Those missing limbs can be recomputed from their corresponding bytes stored in
/// `updated_digest_state_bytes`.
pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S],

/// The first part of the state of the sponge, seen as bytes, after the permutation is applied.
/// This also represents the output digest of the Keccak sponge during the squeezing phase.
pub updated_digest_state_bytes: [T; KECCAK_DIGEST_BYTES],
}

// `u8` is guaranteed to have a `size_of` of 1.
Expand Down
160 changes: 103 additions & 57 deletions evm/src/keccak_sponge/keccak_sponge_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
let mut outputs = Vec::with_capacity(8);
for i in (0..8).rev() {
let cur_col = Column::linear_combination(
cols.updated_state_bytes[i * 4..(i + 1) * 4]
cols.updated_digest_state_bytes[i * 4..(i + 1) * 4]
.iter()
.enumerate()
.map(|(j, &c)| (c, F::from_canonical_u64(1 << (24 - 8 * j)))),
Expand All @@ -49,15 +49,30 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {

pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
Column::singles(
let mut res: Vec<_> = Column::singles(
[
cols.xored_rate_u32s.as_slice(),
&cols.original_capacity_u32s,
&cols.updated_state_u32s,
]
.concat(),
)
.collect()
.collect();

// We recover the 32-bit digest limbs from their corresponding bytes,
// and then append them to the rest of the updated state limbs.
let digest_u32s = cols.updated_digest_state_bytes.chunks_exact(4).map(|c| {
Column::linear_combination(
c.iter()
.enumerate()
.map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))),
)
});

res.extend(digest_u32s);

res.extend(Column::singles(&cols.partial_updated_state_u32s));

res
}

pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
Expand Down Expand Up @@ -239,7 +254,21 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
block.try_into().unwrap(),
);

sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32);
sponge_state[..KECCAK_DIGEST_U32S]
.iter_mut()
.zip(row.updated_digest_state_bytes.chunks_exact(4))
.for_each(|(s, bs)| {
*s = bs
.iter()
.enumerate()
.map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i))
.sum();
});

sponge_state[KECCAK_DIGEST_U32S..]
.iter_mut()
.zip(row.partial_updated_state_u32s)
.for_each(|(s, x)| *s = x.to_canonical_u64() as u32);

rows.push(row.into());
already_absorbed_bytes += KECCAK_RATE_BYTES;
Expand Down Expand Up @@ -357,24 +386,33 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32);

keccakf_u32s(&mut sponge_state);
row.updated_state_u32s = sponge_state.map(F::from_canonical_u32);
let is_final_block = row.is_final_input_len.iter().copied().sum::<F>() == F::ONE;
if is_final_block {
for (l, &elt) in row.updated_state_u32s[..8].iter().enumerate() {
// Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state.
// Those missing limbs will be broken down into bytes and stored separately.
row.partial_updated_state_u32s.copy_from_slice(
&sponge_state[KECCAK_DIGEST_U32S..]
.iter()
.copied()
.map(|i| F::from_canonical_u32(i))
.collect::<Vec<_>>(),
);
sponge_state[..KECCAK_DIGEST_U32S]
.iter()
.enumerate()
.for_each(|(l, &elt)| {
let mut cur_elt = elt;
(0..4).for_each(|i| {
row.updated_state_bytes[l * 4 + i] =
F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32);
cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8);
row.updated_digest_state_bytes[l * 4 + i] =
F::from_canonical_u32(cur_elt & 0xFF);
cur_elt >>= 8;
});

let mut s = row.updated_state_bytes[l * 4].to_canonical_u64();
// 32-bit limb reconstruction consistency check.
let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64();
for i in 1..4 {
s += row.updated_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i);
s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i);
}
assert_eq!(elt, F::from_canonical_u64(s), "not equal");
}
}
assert_eq!(elt as u64, s, "not equal");
})
}

fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] {
Expand Down Expand Up @@ -445,26 +483,39 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
);

// If this is a full-input block, the next row's "before" should match our "after" state.
for (current_bytes_after, next_before) in local_values
.updated_digest_state_bytes
.chunks_exact(4)
.zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S])
{
let mut current_after = current_bytes_after[0];
for i in 1..4 {
current_after +=
current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i)));
}
yield_constr
.constraint_transition(is_full_input_block * (*next_before - current_after));
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.zip(next_values.original_rate_u32s.iter())
.zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter())
{
yield_constr.constraint_transition(is_full_input_block * (next_before - current_after));
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.skip(KECCAK_RATE_U32S)
.skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S)
.zip(next_values.original_capacity_u32s.iter())
{
yield_constr.constraint_transition(is_full_input_block * (next_before - current_after));
}

// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136.
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`.
yield_constr.constraint_transition(
is_full_input_block
* (already_absorbed_bytes + P::from(FE::from_canonical_u64(136))
* (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES))
- next_values.already_absorbed_bytes),
);

Expand All @@ -481,16 +532,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
let entry_match = offset - P::from(FE::from_canonical_usize(i));
yield_constr.constraint(is_final_len * entry_match);
}

// Adding constraints for byte columns.
for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() {
let mut s = local_values.updated_state_bytes[l * 4];
for i in 1..4 {
s += local_values.updated_state_bytes[l * 4 + i]
* P::from(FE::from_canonical_usize(1 << (8 * i)));
}
yield_constr.constraint(is_final_block * (s - elt));
}
}

fn eval_ext_circuit(
Expand Down Expand Up @@ -566,29 +607,48 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
yield_constr.constraint_transition(builder, constraint);

// If this is a full-input block, the next row's "before" should match our "after" state.
for (current_bytes_after, next_before) in local_values
.updated_digest_state_bytes
.chunks_exact(4)
.zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S])
{
let mut current_after = current_bytes_after[0];
for i in 1..4 {
current_after = builder.mul_const_add_extension(
F::from_canonical_usize(1 << (8 * i)),
current_bytes_after[i],
current_after,
);
}
let diff = builder.sub_extension(*next_before, current_after);
let constraint = builder.mul_extension(is_full_input_block, diff);
yield_constr.constraint_transition(builder, constraint);
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.zip(next_values.original_rate_u32s.iter())
.zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter())
{
let diff = builder.sub_extension(next_before, current_after);
let constraint = builder.mul_extension(is_full_input_block, diff);
yield_constr.constraint_transition(builder, constraint);
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.skip(KECCAK_RATE_U32S)
.skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S)
.zip(next_values.original_capacity_u32s.iter())
{
let diff = builder.sub_extension(next_before, current_after);
let constraint = builder.mul_extension(is_full_input_block, diff);
yield_constr.constraint_transition(builder, constraint);
}

// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136.
let absorbed_bytes =
builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136));
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`.
let absorbed_bytes = builder.add_const_extension(
already_absorbed_bytes,
F::from_canonical_usize(KECCAK_RATE_BYTES),
);
let absorbed_diff =
builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes);
let constraint = builder.mul_extension(is_full_input_block, absorbed_diff);
Expand All @@ -615,21 +675,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
let constraint = builder.mul_extension(is_final_len, entry_match);
yield_constr.constraint(builder, constraint);
}

// Adding constraints for byte columns.
for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() {
let mut s = local_values.updated_state_bytes[l * 4];
for i in 1..4 {
s = builder.mul_const_add_extension(
F::from_canonical_usize(1 << (8 * i)),
local_values.updated_state_bytes[l * 4 + i],
s,
);
}
let constraint = builder.sub_extension(s, elt);
let constraint = builder.mul_extension(is_final_block, constraint);
yield_constr.constraint(builder, constraint);
}
}

fn constraint_degree(&self) -> usize {
Expand Down Expand Up @@ -698,9 +743,10 @@ mod tests {
let rows = stark.generate_rows_for_op(op);
assert_eq!(rows.len(), 1);
let last_row: &KeccakSpongeColumnsView<F> = rows.last().unwrap().borrow();
let output = last_row.updated_state_u32s[..8]
let output = last_row
.updated_digest_state_bytes
.iter()
.flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes())
.map(|x| x.to_canonical_u64() as u8)
.collect_vec();

assert_eq!(output, expected_output.0);
Expand Down

0 comments on commit 19220b2

Please sign in to comment.