From f2c1802f4479440a17e20a9cedb05a0d17195c9f Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Wed, 17 Jul 2024 13:16:31 +0300 Subject: [PATCH] Test memory trace. --- .../src/components/memory/component.rs | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index 7a5a7572..bcbe0462 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -58,7 +58,12 @@ impl MemoryComponent { impl MemoryTraceGenerator { pub fn new(_path: String) -> Self { // TODO(AlonH): change to read from file. - let values = vec![[BaseField::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND]; + let values = (0..MEMORY_ADDRESS_BOUND) + .map(|i| { + let value = BaseField::from_u32_unchecked(i as u32); + [value; N_M31_IN_FELT252] + }) + .collect(); let multiplicities = vec![0; MEMORY_ADDRESS_BOUND]; Self { values, @@ -228,3 +233,39 @@ impl Component for MemoryComponent { evaluation_accumulator.accumulate(numerator / denom); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::components::memory::tests::register_test_memory; + + #[test] + fn test_memory_trace() { + let mut registry = ComponentGenerationRegistry::default(); + register_test_memory(&mut registry); + let trace = MemoryTraceGenerator::write_trace(MEMORY_COMPONENT_ID, &mut registry); + let alpha = SecureField::from_u32_unchecked(1, 2, 3, 117); + let z = SecureField::from_u32_unchecked(2, 3, 4, 118); + let interaction_elements = InteractionElements::new( + [(MEMORY_ALPHA.to_string(), alpha), (MEMORY_Z.to_string(), z)].into(), + ); + let interaction_trace = registry + .get_generator::(MEMORY_COMPONENT_ID) + .write_interaction_trace(&trace.iter().collect(), &interaction_elements); + + let mut trace_sum = SecureField::zero(); + for i in 0..MEMORY_ADDRESS_BOUND { + assert_eq!(trace[0].values[i], BaseField::from_u32_unchecked(i as u32)); + trace_sum += trace.last().unwrap().values[i] + / shifted_secure_combination( + &[BaseField::from_u32_unchecked(i as u32); N_M31_IN_FELT252 + 1], + alpha, + z, + ); + } + let logup_sum = + SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); + + assert_eq!(logup_sum, trace_sum); + } +}