diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs index 2222bc2f70..a2a555d365 100644 --- a/evm/src/keccak_sponge/columns.rs +++ b/evm/src/keccak_sponge/columns.rs @@ -5,10 +5,13 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize = + (KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4; pub(crate) const KECCAK_RATE_BYTES: usize = 136; pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; +pub(crate) const KECCAK_DIGEST_U32S: usize = 8; pub(crate) const KECCAK_DIGEST_BYTES: usize = 32; #[repr(C)] @@ -53,7 +56,9 @@ pub(crate) struct KeccakSpongeColumnsView { /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the /// permutation is applied, minus the first limbs where the digest is extracted from. - pub partial_updated_state_u32s: [T; KECCAK_WIDTH_U32S], + /// Those missing limbs can be recomputed from their corresponding bytes stored in + /// `updated_digest_state_bytes`. + pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S], /// The first part of the state of the sponge, seen as bytes, after the permutation is applied. /// This also represents the output digest of the Keccak sponge during the squeezing phase. diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 48e825e7e3..1b1126c04e 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -49,15 +49,34 @@ pub(crate) fn ctl_looked_data() -> Vec> { pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; - Column::singles( + let mut res: Vec<_> = Column::singles( [ cols.xored_rate_u32s.as_slice(), &cols.original_capacity_u32s, - &cols.partial_updated_state_u32s, ] .concat(), ) - .collect() + .collect(); + + // We recover the 32-bit digest limbs from their corresponding bytes, + // and then append them to the rest of the updated state limbs. + let digest_u32s = cols + .updated_digest_state_bytes + .chunks_exact(4) + .map(|c| { + Column::linear_combination( + c.iter() + .enumerate() + .map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))), + ) + }) + .collect::>(); + + res.extend(digest_u32s); + + res.extend(Column::singles(&cols.partial_updated_state_u32s).collect::>()); + + res } pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { @@ -239,9 +258,21 @@ impl, const D: usize> KeccakSpongeStark { block.try_into().unwrap(), ); - sponge_state = row - .partial_updated_state_u32s - .map(|f| f.to_canonical_u64() as u32); + sponge_state[..KECCAK_DIGEST_U32S] + .iter_mut() + .zip(row.updated_digest_state_bytes.chunks_exact(4)) + .for_each(|(s, bs)| { + *s = bs + .iter() + .enumerate() + .map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i)) + .sum(); + }); + + sponge_state[KECCAK_DIGEST_U32S..] + .iter_mut() + .zip(row.partial_updated_state_u32s) + .for_each(|(s, x)| *s = x.to_canonical_u64() as u32); rows.push(row.into()); already_absorbed_bytes += KECCAK_RATE_BYTES; @@ -359,23 +390,28 @@ impl, const D: usize> KeccakSpongeStark { row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); keccakf_u32s(&mut sponge_state); - row.partial_updated_state_u32s = sponge_state.map(F::from_canonical_u32); - let is_final_block = row.is_final_input_len.iter().copied().sum::() == F::ONE; - if is_final_block { - for (l, &elt) in row.partial_updated_state_u32s[..8].iter().enumerate() { - let mut cur_elt = elt; - (0..4).for_each(|i| { - row.updated_digest_state_bytes[l * 4 + i] = - F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32); - cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8); - }); - - let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64(); - for i in 1..4 { - s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); - } - assert_eq!(elt, F::from_canonical_u64(s), "not equal"); + // Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state. + // Those missing limbs will be broken down into bytes and stored separately. + row.partial_updated_state_u32s.copy_from_slice( + &sponge_state[KECCAK_DIGEST_U32S..] + .iter() + .copied() + .map(|i| F::from_canonical_u32(i)) + .collect::>(), + ); + for (l, &elt) in sponge_state[..KECCAK_DIGEST_U32S].iter().enumerate() { + let mut cur_elt = elt; + (0..4).for_each(|i| { + row.updated_digest_state_bytes[l * 4 + i] = F::from_canonical_u32(cur_elt & 0xFF); + cur_elt >>= 8; + }); + + // 32-bit limb reconstruction consistency check. + let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64(); + for i in 1..4 { + s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); } + assert_eq!(elt as u64, s, "not equal"); } } @@ -447,26 +483,39 @@ impl, const D: usize> Stark for KeccakSpongeS ); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after += + current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i))); + } + yield_constr + .constraint_transition(is_full_input_block * (*next_before - current_after)); + } for (¤t_after, &next_before) in local_values .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } for (¤t_after, &next_before) in local_values .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. yield_constr.constraint_transition( is_full_input_block - * (already_absorbed_bytes + P::from(FE::from_canonical_u64(136)) + * (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES)) - next_values.already_absorbed_bytes), ); @@ -483,19 +532,6 @@ impl, const D: usize> Stark for KeccakSpongeS let entry_match = offset - P::from(FE::from_canonical_usize(i)); yield_constr.constraint(is_final_len * entry_match); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.partial_updated_state_u32s[..8] - .iter() - .enumerate() - { - let mut s = local_values.updated_digest_state_bytes[l * 4]; - for i in 1..4 { - s += local_values.updated_digest_state_bytes[l * 4 + i] - * P::from(FE::from_canonical_usize(1 << (8 * i))); - } - yield_constr.constraint(is_final_block * (s - elt)); - } } fn eval_ext_circuit( @@ -571,10 +607,27 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after = builder.mul_const_add_extension( + F::from_canonical_usize(1 << (8 * i)), + current_bytes_after[i], + current_after, + ); + } + let diff = builder.sub_extension(*next_before, current_after); + let constraint = builder.mul_extension(is_full_input_block, diff); + yield_constr.constraint_transition(builder, constraint); + } for (¤t_after, &next_before) in local_values .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { let diff = builder.sub_extension(next_before, current_after); let constraint = builder.mul_extension(is_full_input_block, diff); @@ -583,7 +636,7 @@ impl, const D: usize> Stark for KeccakSpongeS for (¤t_after, &next_before) in local_values .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { let diff = builder.sub_extension(next_before, current_after); @@ -591,9 +644,11 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. - let absorbed_bytes = - builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136)); + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. + let absorbed_bytes = builder.add_const_extension( + already_absorbed_bytes, + F::from_canonical_usize(KECCAK_RATE_BYTES), + ); let absorbed_diff = builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes); let constraint = builder.mul_extension(is_full_input_block, absorbed_diff); @@ -620,24 +675,6 @@ impl, const D: usize> Stark for KeccakSpongeS let constraint = builder.mul_extension(is_final_len, entry_match); yield_constr.constraint(builder, constraint); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.partial_updated_state_u32s[..8] - .iter() - .enumerate() - { - let mut s = local_values.updated_digest_state_bytes[l * 4]; - for i in 1..4 { - s = builder.mul_const_add_extension( - F::from_canonical_usize(1 << (8 * i)), - local_values.updated_digest_state_bytes[l * 4 + i], - s, - ); - } - let constraint = builder.sub_extension(s, elt); - let constraint = builder.mul_extension(is_final_block, constraint); - yield_constr.constraint(builder, constraint); - } } fn constraint_degree(&self) -> usize { @@ -706,9 +743,10 @@ mod tests { let rows = stark.generate_rows_for_op(op); assert_eq!(rows.len(), 1); let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); - let output = last_row.partial_updated_state_u32s[..8] + let output = last_row + .updated_digest_state_bytes .iter() - .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .map(|x| x.to_canonical_u64() as u8) .collect_vec(); assert_eq!(output, expected_output.0);