From b4d9e29cee240240f58fd4bff028ac95849de5dc Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Tue, 16 Jul 2024 14:56:15 +0300 Subject: [PATCH] Implement memory component prover. --- .../src/components/memory/component.rs | 76 +++++++-- .../src/components/memory/component_prover.rs | 137 +++++++++++++++ .../src/components/memory/mod.rs | 161 ++++++++++++++++++ 3 files changed, 358 insertions(+), 16 deletions(-) create mode 100644 stwo_cairo_prover/src/components/memory/component_prover.rs diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index 9230c968..2c04ddc7 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -5,7 +5,8 @@ use stwo_prover::core::air::mask::fixed_mask_points; use stwo_prover::core::air::Component; use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::circle::CirclePoint; -use stwo_prover::core::fields::m31::{BaseField, M31}; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; +use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE}; use stwo_prover::core::fields::FieldExpOps; @@ -17,24 +18,33 @@ use stwo_prover::core::utils::{ }; use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; -use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; +use stwo_prover::trace_generation::{ + ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE, +}; pub const MEMORY_ALPHA: &str = "MEMORY_ALPHA"; pub const MEMORY_Z: &str = "MEMORY_Z"; - -const N_M31_IN_FELT252: usize = 21; -const MULTIPLICITY_COLUMN: usize = 22; -const LOG_MEMORY_ADDRESS_BOUND: u32 = 20; -const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND; +pub const MEMORY_COMPONENT_ID: &str = "MEMORY"; +pub const MEMORY_LOOKUP_VALUE_0: &str = "MEMORY_LOOKUP_0"; +pub const MEMORY_LOOKUP_VALUE_1: &str = "MEMORY_LOOKUP_1"; +pub const MEMORY_LOOKUP_VALUE_2: &str = "MEMORY_LOOKUP_2"; +pub const MEMORY_LOOKUP_VALUE_3: &str = "MEMORY_LOOKUP_3"; + +pub const N_M31_IN_FELT252: usize = 21; +pub const MULTIPLICITY_COLUMN: usize = 22; +// TODO(AlonH): Make memory size configurable. +pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3; +pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND; /// Addresses are continuous and start from 0. /// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits). pub struct MemoryTraceGenerator { // TODO(AlonH): Consider to change values to be Felt252. - pub values: Vec<[M31; N_M31_IN_FELT252]>, + pub values: Vec<[BaseField; N_M31_IN_FELT252]>, pub multiplicities: Vec, } +#[derive(Clone)] pub struct MemoryComponent { pub log_n_rows: u32, } @@ -48,7 +58,7 @@ impl MemoryComponent { impl MemoryTraceGenerator { pub fn new(_path: String) -> Self { // TODO(AlonH): change to read from file. - let values = vec![[M31::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND]; + let values = vec![[BaseField::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND]; let multiplicities = vec![0; MEMORY_ADDRESS_BOUND]; Self { values, @@ -61,7 +71,7 @@ impl ComponentGen for MemoryTraceGenerator {} impl ComponentTraceGenerator for MemoryTraceGenerator { type Component = MemoryComponent; - type Inputs = M31; + type Inputs = BaseField; fn add_inputs(&mut self, inputs: &Self::Inputs) { let input = inputs.0 as usize; @@ -172,12 +182,46 @@ impl Component for MemoryComponent { fn evaluate_constraint_quotients_at_point( &self, - _point: CirclePoint, - _mask: &TreeVec>>, - _evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + interaction_elements: &InteractionElements, + lookup_values: &LookupValues, ) { - todo!() + // First lookup point boundary constraint. + let constraint_zero_domain = CanonicCoset::new(self.log_n_rows).coset; + let (alpha, z) = ( + interaction_elements[MEMORY_ALPHA], + interaction_elements[MEMORY_Z], + ); + let value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); + let address_and_value: [SecureField; N_M31_IN_FELT252 + 1] = + std::array::from_fn(|i| mask[BASE_TRACE][i][0]); + let numerator = value * shifted_secure_combination(&address_and_value, alpha, z) + - mask[BASE_TRACE][MULTIPLICITY_COLUMN][0]; + let denom = point_vanishing(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); + + // Last lookup point boundary constraint. + let lookup_value = SecureField::from_m31( + lookup_values[MEMORY_LOOKUP_VALUE_0], + lookup_values[MEMORY_LOOKUP_VALUE_1], + lookup_values[MEMORY_LOOKUP_VALUE_2], + lookup_values[MEMORY_LOOKUP_VALUE_3], + ); + let numerator = value - lookup_value; + let denom = point_vanishing(constraint_zero_domain.at(1), point); + evaluation_accumulator.accumulate(numerator / denom); + + // Lookup step constraint. + let prev_value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); + let numerator = (value - prev_value) + * shifted_secure_combination(&address_and_value, alpha, z) + - mask[BASE_TRACE][22][0]; + let denom = coset_vanishing(constraint_zero_domain, point) + / point_excluder(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); } } diff --git a/stwo_cairo_prover/src/components/memory/component_prover.rs b/stwo_cairo_prover/src/components/memory/component_prover.rs new file mode 100644 index 00000000..814bf6fd --- /dev/null +++ b/stwo_cairo_prover/src/components/memory/component_prover.rs @@ -0,0 +1,137 @@ +use std::collections::BTreeMap; + +use itertools::izip; +use num_traits::Zero; +use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator; +use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::FieldExpOps; +use stwo_prover::core::poly::circle::CanonicCoset; +use stwo_prover::core::utils::{ + bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, + shifted_secure_combination, +}; +use stwo_prover::core::{InteractionElements, LookupValues}; +use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE}; + +use super::component::{ + MemoryComponent, MEMORY_ALPHA, MEMORY_LOOKUP_VALUE_0, MEMORY_LOOKUP_VALUE_1, + MEMORY_LOOKUP_VALUE_2, MEMORY_LOOKUP_VALUE_3, MEMORY_Z, MULTIPLICITY_COLUMN, N_M31_IN_FELT252, +}; + +impl ComponentProver for MemoryComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, CpuBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + interaction_elements: &InteractionElements, + lookup_values: &LookupValues, + ) { + let max_constraint_degree = self.max_constraint_log_degree_bound(); + let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); + let trace_evals = &trace.evals; + let zero_domain = CanonicCoset::new(self.log_n_rows).coset; + let [mut accum] = + evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); + + // TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions. + let first_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); + let last_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); + let mut step_denoms = vec![]; + for point in trace_eval_domain.iter() { + step_denoms.push( + coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), + ); + } + bit_reverse(&mut step_denoms); + let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; + BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses); + let (alpha, z) = ( + interaction_elements[MEMORY_ALPHA], + interaction_elements[MEMORY_Z], + ); + let lookup_value = SecureField::from_m31( + lookup_values[MEMORY_LOOKUP_VALUE_0], + lookup_values[MEMORY_LOOKUP_VALUE_1], + lookup_values[MEMORY_LOOKUP_VALUE_2], + lookup_values[MEMORY_LOOKUP_VALUE_3], + ); + for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!( + first_point_denom_inverses, + last_point_denom_inverses, + step_denom_inverses, + ) + .enumerate() + { + let value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][i] + })); + let prev_index = previous_bit_reversed_circle_domain_index( + i, + zero_domain.log_size, + trace_eval_domain.log_size(), + ); + let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][prev_index] + })); + let address_and_value: [BaseField; N_M31_IN_FELT252 + 1] = + std::array::from_fn(|j| trace_evals[BASE_TRACE][j][i]); + + let first_point_numerator = accum.random_coeff_powers[2] + * (value * shifted_secure_combination(&address_and_value, alpha, z) + - trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]); + + let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); + let step_numerator = accum.random_coeff_powers[0] + * ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z) + - trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]); + accum.accumulate( + i, + first_point_numerator * first_point_denom_inverse + + last_point_numerator * last_point_denom_inverse + + step_numerator * step_denom_inverse, + ); + } + } + + fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + let domain = CanonicCoset::new(self.log_n_rows); + let trace_poly = &trace.polys[INTERACTION_TRACE]; + let values = BTreeMap::from_iter([ + ( + MEMORY_LOOKUP_VALUE_0.to_string(), + trace_poly[0] + .eval_at_point(domain.at(1).into_ef()) + .try_into() + .unwrap(), + ), + ( + MEMORY_LOOKUP_VALUE_1.to_string(), + trace_poly[1] + .eval_at_point(domain.at(1).into_ef()) + .try_into() + .unwrap(), + ), + ( + MEMORY_LOOKUP_VALUE_2.to_string(), + trace_poly[2] + .eval_at_point(domain.at(1).into_ef()) + .try_into() + .unwrap(), + ), + ( + MEMORY_LOOKUP_VALUE_3.to_string(), + trace_poly[3] + .eval_at_point(domain.at(1).into_ef()) + .try_into() + .unwrap(), + ), + ]); + LookupValues::new(values) + } +} diff --git a/stwo_cairo_prover/src/components/memory/mod.rs b/stwo_cairo_prover/src/components/memory/mod.rs index 9cea807e..7c60567a 100644 --- a/stwo_cairo_prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/src/components/memory/mod.rs @@ -1 +1,162 @@ pub mod component; +pub mod component_prover; + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use component::{ + MemoryComponent, MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, + }; + use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; + use stwo_prover::core::backend::CpuBackend; + use stwo_prover::core::channel::{Blake2sChannel, Channel}; + use stwo_prover::core::fields::m31::BaseField; + use stwo_prover::core::fields::IntoSlice; + use stwo_prover::core::poly::circle::CircleEvaluation; + use stwo_prover::core::poly::BitReversedOrder; + use stwo_prover::core::prover::VerificationError; + use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; + use stwo_prover::core::vcs::hasher::Hasher; + use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; + use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; + use stwo_prover::trace_generation::{ + commit_and_prove, commit_and_verify, AirTraceGenerator, AirTraceVerifier, + ComponentTraceGenerator, + }; + + use super::*; + + pub fn register_test_memory(registry: &mut ComponentGenerationRegistry) { + registry.register( + MEMORY_COMPONENT_ID, + MemoryTraceGenerator::new("".to_string()), + ); + vec![ + vec![BaseField::from_u32_unchecked(0); 3], + vec![BaseField::from_u32_unchecked(1); 1], + vec![BaseField::from_u32_unchecked(2); 2], + vec![BaseField::from_u32_unchecked(3); 5], + vec![BaseField::from_u32_unchecked(4); 10], + vec![BaseField::from_u32_unchecked(5); 1], + vec![BaseField::from_u32_unchecked(6); 0], + vec![BaseField::from_u32_unchecked(7); 1], + ] + .into_iter() + .flatten() + .for_each(|input| { + registry + .get_generator_mut::(MEMORY_COMPONENT_ID) + .add_inputs(&input); + }); + } + + struct TestAirGenerator { + pub registry: ComponentGenerationRegistry, + } + + impl TestAirGenerator { + pub fn new() -> Self { + let mut registry = ComponentGenerationRegistry::default(); + register_test_memory(&mut registry); + Self { registry } + } + } + + impl AirTraceVerifier for TestAirGenerator { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(2); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + ])) + } + } + + impl AirTraceGenerator for TestAirGenerator { + fn write_trace( + &mut self, + ) -> Vec> { + MemoryTraceGenerator::write_trace(MEMORY_COMPONENT_ID, &mut self.registry) + } + + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec> { + let component_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + component_generator.write_interaction_trace(&trace.iter().collect(), elements) + } + + fn to_air_prover(&self) -> impl AirProver { + let component_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + TestAir { + component: component_generator.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + component_generator + .component() + .max_constraint_log_degree_bound() + } + } + + #[derive(Clone)] + pub struct TestAir { + pub component: MemoryComponent, + } + + impl Air for TestAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } + + fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { + Ok(()) + } + } + + impl AirProver for TestAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.component] + } + } + + impl AirTraceVerifier for TestAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(2); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + ])) + } + } + + #[test] + fn test_memory_constraints() { + let mut air = TestAirGenerator::new(); + let trace = air.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let proof = commit_and_prove::(&air, prover_channel, trace).unwrap(); + + let verifier_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = TestAir { + component: air + .registry + .get_generator::(MEMORY_COMPONENT_ID) + .component(), + }; + commit_and_verify(proof, &air, verifier_channel).unwrap(); + } +}