From 941457d01b8eaf6c5c4c248a3350cd29df2c7d9e Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Wed, 17 Jul 2024 13:16:31 +0300 Subject: [PATCH] Test memory trace. --- .../src/components/memory/component.rs | 43 ++++++++++++++++++- .../components/range_check_unit/component.rs | 6 +-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index 7a5a7572..a1dd83ab 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -58,7 +58,12 @@ impl MemoryComponent { impl MemoryTraceGenerator { pub fn new(_path: String) -> Self { // TODO(AlonH): change to read from file. - let values = vec![[BaseField::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND]; + let values = (0..MEMORY_ADDRESS_BOUND) + .map(|i| { + let value = BaseField::from_u32_unchecked(i as u32); + [value; N_M31_IN_FELT252] + }) + .collect(); let multiplicities = vec![0; MEMORY_ADDRESS_BOUND]; Self { values, @@ -228,3 +233,39 @@ impl Component for MemoryComponent { evaluation_accumulator.accumulate(numerator / denom); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::components::memory::tests::register_test_memory; + + #[test] + fn test_memory_trace() { + let mut registry = ComponentGenerationRegistry::default(); + register_test_memory(&mut registry); + let trace = MemoryTraceGenerator::write_trace(MEMORY_COMPONENT_ID, &mut registry); + let alpha = SecureField::from_u32_unchecked(1, 2, 3, 117); + let z = SecureField::from_u32_unchecked(2, 3, 4, 118); + let interaction_elements = InteractionElements::new( + [(MEMORY_ALPHA.to_string(), alpha), (MEMORY_Z.to_string(), z)].into(), + ); + let interaction_trace = registry + .get_generator::(MEMORY_COMPONENT_ID) + .write_interaction_trace(&trace.iter().collect(), &interaction_elements); + + let mut expected_logup_sum = SecureField::zero(); + for i in 0..MEMORY_ADDRESS_BOUND { + assert_eq!(trace[0].values[i], BaseField::from_u32_unchecked(i as u32)); + expected_logup_sum += trace.last().unwrap().values[i] + / shifted_secure_combination( + &[BaseField::from_u32_unchecked(i as u32); N_M31_IN_FELT252 + 1], + alpha, + z, + ); + } + let logup_sum = + SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); + + assert_eq!(logup_sum, expected_logup_sum); + } +} diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index 1e4d8f9a..b98b7a12 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -202,15 +202,15 @@ mod tests { .get_generator::(RC_COMPONENT_ID) .write_interaction_trace(&trace.iter().collect(), &interaction_elements); - let mut trace_sum = SecureField::zero(); + let mut expected_logup_sum = SecureField::zero(); for i in 0..8 { assert_eq!(trace[0].values[i], BaseField::from_u32_unchecked(i as u32)); - trace_sum += trace.last().unwrap().values[i] + expected_logup_sum += trace.last().unwrap().values[i] / (random_value - BaseField::from_u32_unchecked(i as u32)); } let logup_sum = SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); - assert_eq!(logup_sum, trace_sum); + assert_eq!(logup_sum, expected_logup_sum); } }