diff --git a/stwo_cairo_prover/src/air/air.rs b/stwo_cairo_prover/src/air/air.rs new file mode 100644 index 00000000..feefe6aa --- /dev/null +++ b/stwo_cairo_prover/src/air/air.rs @@ -0,0 +1,141 @@ +use std::collections::BTreeMap; + +use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::channel::{Blake2sChannel, Channel}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::prover::VerificationError; +use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; +use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; +use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; + +use crate::components::memory::component::{ + MemoryComponent, MemoryTraceGenerator, MEMORY_ADDRESS_BOUND, MEMORY_ALPHA, MEMORY_COMPONENT_ID, + MEMORY_Z, N_MEMORY_COLUMNS, +}; +use crate::components::range_check_unit::component::{ + RangeCheckUnitComponent, RangeCheckUnitTraceGenerator, N_RC_COLUMNS, RC_COMPONENT_ID, RC_Z, +}; + +struct CairoAirGenerator { + pub registry: ComponentGenerationRegistry, +} + +impl CairoAirGenerator { + pub fn new(path: String) -> Self { + let mut registry = ComponentGenerationRegistry::default(); + registry.register(MEMORY_COMPONENT_ID, MemoryTraceGenerator::new(path)); + registry.register( + RC_COMPONENT_ID, + RangeCheckUnitTraceGenerator::new(MEMORY_ADDRESS_BOUND), + ); + Self { registry } + } +} + +impl AirTraceVerifier for CairoAirGenerator { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(3); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), + ])) + } +} + +impl AirTraceGenerator for CairoAirGenerator { + fn write_trace(&mut self) -> Vec> { + let mut trace = Vec::with_capacity(N_MEMORY_COLUMNS + N_RC_COLUMNS); + trace.extend(MemoryTraceGenerator::write_trace( + MEMORY_COMPONENT_ID, + &mut self.registry, + )); + trace.extend(RangeCheckUnitTraceGenerator::write_trace( + RC_COMPONENT_ID, + &mut self.registry, + )); + trace + } + + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec> { + let mut interaction_trace = Vec::new(); + let trace_iter = &mut trace.iter(); + let memory_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + interaction_trace.extend( + memory_generator + .write_interaction_trace(&trace_iter.take(N_MEMORY_COLUMNS).collect(), elements), + ); + let rc_generator = self + .registry + .get_generator::(RC_COMPONENT_ID); + interaction_trace.extend( + rc_generator + .write_interaction_trace(&trace_iter.take(N_RC_COLUMNS).collect(), elements), + ); + interaction_trace + } + + fn to_air_prover(&self) -> impl AirProver { + let memory = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + let range_check_unit = self + .registry + .get_generator::(RC_COMPONENT_ID); + CairoAir { + memory: memory.component(), + range_check_unit: range_check_unit.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + component_generator + .component() + .max_constraint_log_degree_bound() + } +} + +#[derive(Clone)] +pub struct CairoAir { + pub memory: MemoryComponent, + pub range_check_unit: RangeCheckUnitComponent, +} + +impl Air for CairoAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.memory, &self.range_check_unit] + } + + fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { + Ok(()) + } +} + +impl AirProver for CairoAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.memory, &self.range_check_unit] + } +} + +impl AirTraceVerifier for CairoAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(3); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), + ])) + } +} diff --git a/stwo_cairo_prover/src/air/mod.rs b/stwo_cairo_prover/src/air/mod.rs new file mode 100644 index 00000000..83aab11f --- /dev/null +++ b/stwo_cairo_prover/src/air/mod.rs @@ -0,0 +1 @@ +pub mod air; diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index 7f594ba8..5dd67303 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -35,6 +35,7 @@ pub const MULTIPLICITY_COLUMN: usize = N_M31_IN_FELT252 + 1; // TODO(AlonH): Make memory size configurable. pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3; pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND; +pub const N_MEMORY_COLUMNS: usize = N_M31_IN_FELT252 + 2; /// Addresses are continuous and start from 0. /// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits). @@ -51,7 +52,7 @@ pub struct MemoryComponent { impl MemoryComponent { pub const fn n_columns(&self) -> usize { - N_M31_IN_FELT252 + 2 + N_MEMORY_COLUMNS } } diff --git a/stwo_cairo_prover/src/components/memory/mod.rs b/stwo_cairo_prover/src/components/memory/mod.rs index 7c60567a..15426f41 100644 --- a/stwo_cairo_prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/src/components/memory/mod.rs @@ -6,7 +6,8 @@ mod tests { use std::collections::BTreeMap; use component::{ - MemoryComponent, MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, + MemoryComponent, MemoryTraceGenerator, MEMORY_ADDRESS_BOUND, MEMORY_ALPHA, + MEMORY_COMPONENT_ID, MEMORY_Z, }; use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; use stwo_prover::core::backend::CpuBackend; @@ -26,12 +27,19 @@ mod tests { }; use super::*; + use crate::components::range_check_unit::component::{ + RangeCheckUnitTraceGenerator, RC_COMPONENT_ID, + }; pub fn register_test_memory(registry: &mut ComponentGenerationRegistry) { registry.register( MEMORY_COMPONENT_ID, MemoryTraceGenerator::new("".to_string()), ); + registry.register( + RC_COMPONENT_ID, + RangeCheckUnitTraceGenerator::new(MEMORY_ADDRESS_BOUND), + ); vec![ vec![BaseField::from_u32_unchecked(0); 3], vec![BaseField::from_u32_unchecked(1); 1], diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index b98b7a12..e3d05e84 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -27,6 +27,8 @@ pub const RC_LOOKUP_VALUE_1: &str = "RC_UNIT_LOOKUP_1"; pub const RC_LOOKUP_VALUE_2: &str = "RC_UNIT_LOOKUP_2"; pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3"; +pub const N_RC_COLUMNS: usize = 2; + #[derive(Clone)] pub struct RangeCheckUnitComponent { pub log_n_rows: u32, @@ -48,7 +50,7 @@ impl Component for RangeCheckUnitComponent { fn trace_log_degree_bounds(&self) -> TreeVec> { TreeVec::new(vec![ - vec![self.log_n_rows; 2], + vec![self.log_n_rows; N_RC_COLUMNS], vec![self.log_n_rows; SECURE_EXTENSION_DEGREE], ]) } @@ -59,7 +61,7 @@ impl Component for RangeCheckUnitComponent { ) -> TreeVec>>> { let domain = CanonicCoset::new(self.log_n_rows); TreeVec::new(vec![ - fixed_mask_points(&vec![vec![0_usize]; 2], point), + fixed_mask_points(&vec![vec![0_usize]; N_RC_COLUMNS], point), vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], ]) } @@ -134,7 +136,7 @@ impl ComponentTraceGenerator for RangeCheckUnitTraceGenerator { registry.get_generator::(component_id); let rc_max_value = rc_unit_trace_generator.max_value; - let mut trace = vec![vec![BaseField::zero(); rc_max_value]; 2]; + let mut trace = vec![vec![BaseField::zero(); rc_max_value]; N_RC_COLUMNS]; for (i, multiplicity) in rc_unit_trace_generator.multiplicities.iter().enumerate() { // TODO(AlonH): Either create a constant column for the addresses and remove it from // here or add constraints to the column here. diff --git a/stwo_cairo_prover/src/main.rs b/stwo_cairo_prover/src/main.rs index d4f2eaff..70638826 100644 --- a/stwo_cairo_prover/src/main.rs +++ b/stwo_cairo_prover/src/main.rs @@ -1,3 +1,4 @@ +pub mod air; pub mod components; fn main() {