From 86e955f8890efed5f1d69aa50befae4174c8cbf1 Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Tue, 9 Jul 2024 13:25:45 +0300 Subject: [PATCH] Implement ComponentProver for RangeCheckUnitComponent. --- Cargo.lock | 30 ++-- .../component.rs} | 39 ++--- .../range_check_unit/component_prover.rs | 66 ++++++++ .../src/components/range_check_unit/mod.rs | 146 ++++++++++++++++++ 4 files changed, 240 insertions(+), 41 deletions(-) rename stwo_cairo_prover/src/components/{range_check_unit.rs => range_check_unit/component.rs} (85%) create mode 100644 stwo_cairo_prover/src/components/range_check_unit/component_prover.rs create mode 100644 stwo_cairo_prover/src/components/range_check_unit/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 01161934..cb2eb1cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,14 +150,14 @@ checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] name = "cc" -version = "1.0.105" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5208975e568d83b6b05cc0a063c8e7e9acc2b43bee6da15616a5b73e109d7437" +checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" [[package]] name = "cfg-if" @@ -238,7 +238,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -264,7 +264,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -482,7 +482,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -524,7 +524,7 @@ checksum = "bbc159a1934c7be9761c237333a57febe060ace2bc9e3b337a59a37af206d19f" dependencies = [ "starknet-curve", "starknet-ff", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -553,7 +553,7 @@ dependencies = [ [[package]] name = "stwo-prover" version = "0.1.1" -source = "git+https://github.com/starkware-libs/stwo?branch=dev#8866c2815a5107372bd78356f304f0570daf20aa" +source = "git+https://github.com/starkware-libs/stwo?branch=dev#defcfe244fdf81bb29dcfb7077f1d9caf22e5b0f" dependencies = [ "blake2", "blake3", @@ -599,9 +599,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.69" +version = "2.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201fcda3845c23e8212cd466bfebf0bd20694490fc0356ae8e428e0824a915a6" +checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" dependencies = [ "proc-macro2", "quote", @@ -625,7 +625,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -647,7 +647,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] [[package]] @@ -704,7 +704,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", "wasm-bindgen-shared", ] @@ -726,7 +726,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -754,5 +754,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.69", + "syn 2.0.70", ] diff --git a/stwo_cairo_prover/src/components/range_check_unit.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs similarity index 85% rename from stwo_cairo_prover/src/components/range_check_unit.rs rename to stwo_cairo_prover/src/components/range_check_unit/component.rs index 9101e2f0..4a2704d1 100644 --- a/stwo_cairo_prover/src/components/range_check_unit.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -1,6 +1,5 @@ use itertools::{zip_eq, Itertools}; use num_traits::Zero; - use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; use stwo_prover::core::air::mask::fixed_mask_points; use stwo_prover::core::air::Component; @@ -14,15 +13,15 @@ use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::LookupValues; -use stwo_prover::core::{ColumnVec, InteractionElements}; +use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; +use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; -use stwo_prover::trace_generation::ComponentGen; -use stwo_prover::trace_generation::ComponentTraceGenerator; +use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; pub const RC_Z: &str = "RangeCheckUnit_Z"; pub const RC_COMPONENT_ID: &str = "RC_UNIT"; +#[derive(Clone)] pub struct RangeCheckUnitComponent { pub log_n_instances: u32, } @@ -42,8 +41,9 @@ impl RangeCheckUnitComponent { constraint_zero_domain: Coset, ) { let z = interaction_elements[RC_Z]; - let value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[1][i][0])); - let numerator = value * (z - mask[0][0][0]) - mask[0][1][0]; + let value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); + let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; let denom = point_vanishing(constraint_zero_domain.at(0), point); evaluation_accumulator.accumulate(numerator / denom); } @@ -73,7 +73,10 @@ impl Component for RangeCheckUnitComponent { &self, point: CirclePoint, ) -> TreeVec>>> { - TreeVec::new(vec![fixed_mask_points(&vec![vec![0_usize]], point)]) + TreeVec::new(vec![ + fixed_mask_points(&vec![vec![0_usize]; 2], point), + vec![vec![point]; SECURE_EXTENSION_DEGREE], + ]) } fn evaluate_constraint_quotients_at_point( @@ -180,28 +183,12 @@ impl ComponentTraceGenerator for RangeCheckUnitTraceGenerator { #[cfg(test)] mod tests { use super::*; + use crate::components::range_check_unit::tests::register_test_rc; #[test] fn test_rc_unit_trace() { let mut registry = ComponentGenerationRegistry::default(); - registry.register(RC_COMPONENT_ID, RangeCheckUnitTraceGenerator::new(8)); - let inputs = vec![ - vec![BaseField::from_u32_unchecked(0); 3], - vec![BaseField::from_u32_unchecked(1); 1], - vec![BaseField::from_u32_unchecked(2); 2], - vec![BaseField::from_u32_unchecked(3); 5], - vec![BaseField::from_u32_unchecked(4); 10], - vec![BaseField::from_u32_unchecked(5); 1], - vec![BaseField::from_u32_unchecked(6); 0], - vec![BaseField::from_u32_unchecked(7); 1], - ] - .into_iter() - .flatten() - .collect_vec(); - registry - .get_generator_mut::(RC_COMPONENT_ID) - .add_inputs(&inputs); - + register_test_rc(&mut registry); let trace = RangeCheckUnitTraceGenerator::write_trace(RC_COMPONENT_ID, &mut registry); let random_value = SecureField::from_u32_unchecked(1, 2, 3, 117); let interaction_elements = diff --git a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs new file mode 100644 index 00000000..9c92f633 --- /dev/null +++ b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs @@ -0,0 +1,66 @@ +use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; +use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::circle::Coset; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::pcs::TreeVec; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; +use stwo_prover::core::utils::point_vanish_denominator_inverses; +use stwo_prover::core::{InteractionElements, LookupValues}; + +use super::component::{RangeCheckUnitComponent, RC_Z}; + +impl RangeCheckUnitComponent { + fn evaluate_lookup_boundary_constraints( + &self, + trace_evals: &TreeVec>>, + trace_eval_domain: CircleDomain, + zero_domain: Coset, + accum: &mut ColumnAccumulator<'_, CpuBackend>, + interaction_elements: &InteractionElements, + ) { + let denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); + let z = interaction_elements[RC_Z]; + for (i, denom_inverse) in denom_inverses.iter().enumerate() { + let value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][i] + })); + let numerator = accum.random_coeff_powers[0] + * (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]); + accum.accumulate(i, numerator * *denom_inverse); + } + } +} + +impl ComponentProver for RangeCheckUnitComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, CpuBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + let max_constraint_degree = self.max_constraint_log_degree_bound(); + let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); + let trace_evals = &trace.evals; + let zero_domain = CanonicCoset::new(self.log_n_instances).coset; + let [mut accum] = + evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); + + self.evaluate_lookup_boundary_constraints( + trace_evals, + trace_eval_domain, + zero_domain, + &mut accum, + interaction_elements, + ); + } + + fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + LookupValues::default() + } +} diff --git a/stwo_cairo_prover/src/components/range_check_unit/mod.rs b/stwo_cairo_prover/src/components/range_check_unit/mod.rs new file mode 100644 index 00000000..7c5a1534 --- /dev/null +++ b/stwo_cairo_prover/src/components/range_check_unit/mod.rs @@ -0,0 +1,146 @@ +pub mod component; +pub mod component_prover; + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use component::{RangeCheckUnitComponent, RangeCheckUnitTraceGenerator, RC_COMPONENT_ID, RC_Z}; + use itertools::Itertools; + use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; + use stwo_prover::core::backend::CpuBackend; + use stwo_prover::core::channel::{Blake2sChannel, Channel}; + use stwo_prover::core::fields::m31::BaseField; + use stwo_prover::core::fields::IntoSlice; + use stwo_prover::core::poly::circle::CircleEvaluation; + use stwo_prover::core::poly::BitReversedOrder; + use stwo_prover::core::prover::{prove, verify}; + use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; + use stwo_prover::core::vcs::hasher::Hasher; + use stwo_prover::core::{ColumnVec, InteractionElements}; + use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; + use stwo_prover::trace_generation::{ + AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator, + }; + + use super::*; + + pub fn register_test_rc(registry: &mut ComponentGenerationRegistry) { + registry.register(RC_COMPONENT_ID, RangeCheckUnitTraceGenerator::new(8)); + let inputs = vec![ + vec![BaseField::from_u32_unchecked(0); 3], + vec![BaseField::from_u32_unchecked(1); 1], + vec![BaseField::from_u32_unchecked(2); 2], + vec![BaseField::from_u32_unchecked(3); 5], + vec![BaseField::from_u32_unchecked(4); 10], + vec![BaseField::from_u32_unchecked(5); 1], + vec![BaseField::from_u32_unchecked(6); 0], + vec![BaseField::from_u32_unchecked(7); 1], + ] + .into_iter() + .flatten() + .collect_vec(); + registry + .get_generator_mut::(RC_COMPONENT_ID) + .add_inputs(&inputs); + } + + struct TestAirGenerator { + pub registry: ComponentGenerationRegistry, + } + + impl TestAirGenerator { + pub fn new() -> Self { + let mut registry = ComponentGenerationRegistry::default(); + register_test_rc(&mut registry); + Self { registry } + } + } + + impl AirTraceVerifier for TestAirGenerator { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let element = channel.draw_felts(1)[0]; + InteractionElements::new(BTreeMap::from_iter(vec![(RC_Z.to_string(), element)])) + } + } + + impl AirTraceGenerator for TestAirGenerator { + fn write_trace( + &mut self, + ) -> Vec> { + RangeCheckUnitTraceGenerator::write_trace(RC_COMPONENT_ID, &mut self.registry) + } + + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec> { + let component_generator = self + .registry + .get_generator::(RC_COMPONENT_ID); + component_generator.write_interaction_trace(&trace.iter().collect(), elements) + } + + fn to_air_prover(&self) -> impl AirProver { + let component_generator = self + .registry + .get_generator::(RC_COMPONENT_ID); + TestAir { + component: component_generator.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::(RC_COMPONENT_ID); + component_generator + .component() + .max_constraint_log_degree_bound() + } + } + + #[derive(Clone)] + pub struct TestAir { + pub component: RangeCheckUnitComponent, + } + + impl Air for TestAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } + } + + impl AirProver for TestAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.component] + } + } + + impl AirTraceVerifier for TestAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let element = channel.draw_felts(1)[0]; + InteractionElements::new(BTreeMap::from_iter(vec![(RC_Z.to_string(), element)])) + } + } + + #[test] + fn test_rc_constraints() { + let mut air = TestAirGenerator::new(); + let trace = air.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let proof = prove::(&air, prover_channel, trace).unwrap(); + + let verifier_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = TestAir { + component: air + .registry + .get_generator::(RC_COMPONENT_ID) + .component(), + }; + verify(proof, &air, verifier_channel).unwrap(); + } +}