From 64e046db86afbce0497748c94e758d9cdfe6c3a7 Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Thu, 11 Jul 2024 14:55:12 +0300 Subject: [PATCH] Add lookup step constraint. --- Cargo.lock | 10 +-- stwo_cairo_prover/Cargo.toml | 2 +- .../components/range_check_unit/component.rs | 70 +++++++++++++------ .../range_check_unit/component_prover.rs | 69 ++++++++++++++---- .../src/components/range_check_unit/mod.rs | 9 +-- 5 files changed, 116 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ea18e14..96dcf1f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,9 +107,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d08263faac5cde2a4d52b513dadb80846023aade56fcd8fc99ba73ba8050e92" +checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" dependencies = [ "arrayref", "arrayvec", @@ -155,9 +155,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.3" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18e2d530f35b40a84124146478cd16f34225306a8441998836466a2e2961c950" +checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052" [[package]] name = "cfg-if" @@ -553,7 +553,7 @@ dependencies = [ [[package]] name = "stwo-prover" version = "0.1.1" -source = "git+https://github.com/starkware-libs/stwo?rev=2501444#2501444dce400b96530c7c853bdad14d61d96340" +source = "git+https://github.com/starkware-libs/stwo?rev=7a0bddee#7a0bddeec1a847654dbecff5df37bf5a5891f216" dependencies = [ "blake2", "blake3", diff --git a/stwo_cairo_prover/Cargo.toml b/stwo_cairo_prover/Cargo.toml index f1e2aa3a..8ec2e947 100644 --- a/stwo_cairo_prover/Cargo.toml +++ b/stwo_cairo_prover/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" itertools = "0.12.0" num-traits = "0.2.17" # TODO(ShaharS): take stwo version from the source repository. -stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "2501444" } +stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "7a0bddee" } diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index 99b93567..98fa8234 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -1,11 +1,11 @@ -use itertools::{zip_eq, Itertools}; +use itertools::Itertools; use num_traits::Zero; use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; use stwo_prover::core::air::mask::fixed_mask_points; use stwo_prover::core::air::Component; use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::circle::{CirclePoint, Coset}; -use stwo_prover::core::constraints::point_vanishing; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE}; @@ -13,10 +13,12 @@ use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; +use stwo_prover::core::utils::{bit_reverse_index, coset_order_to_circle_domain_order_index}; use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; -use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; +use stwo_prover::trace_generation::{ + ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE, +}; pub const RC_Z: &str = "RangeCheckUnit_Z"; pub const RC_COMPONENT_ID: &str = "RC_UNIT"; @@ -59,17 +61,34 @@ impl RangeCheckUnitComponent { lookup_values[RC_LOOKUP_VALUE_3], ); let numerator = value - lookup_value; - let denom = point_vanishing( - constraint_zero_domain.at(constraint_zero_domain.size() - 1), - point, - ); + let denom = point_vanishing(constraint_zero_domain.at(1), point); + evaluation_accumulator.accumulate(numerator / denom); + } + + fn evaluate_lookup_step_constraints_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + constraint_zero_domain: Coset, + interaction_elements: &InteractionElements, + ) { + let z = interaction_elements[RC_Z]; + let value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); + let prev_value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); + let numerator = + (value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; + let denom = coset_vanishing(constraint_zero_domain, point) + / point_excluder(constraint_zero_domain.at(0), point); evaluation_accumulator.accumulate(numerator / denom); } } impl Component for RangeCheckUnitComponent { fn n_constraints(&self) -> usize { - 2 + 3 } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -87,9 +106,10 @@ impl Component for RangeCheckUnitComponent { &self, point: CirclePoint, ) -> TreeVec>>> { + let domain = CanonicCoset::new(self.log_n_instances); TreeVec::new(vec![ fixed_mask_points(&vec![vec![0_usize]; 2], point), - vec![vec![point]; SECURE_EXTENSION_DEGREE], + vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], ]) } @@ -110,6 +130,13 @@ impl Component for RangeCheckUnitComponent { constraint_zero_domain, lookup_values, ); + self.evaluate_lookup_step_constraints_at_point( + point, + mask, + evaluation_accumulator, + constraint_zero_domain, + interaction_elements, + ); } } @@ -171,15 +198,15 @@ impl ComponentTraceGenerator for RangeCheckUnitTraceGenerator { let denoms = trace[0].values.iter().map(|value| z - *value).collect_vec(); let mut denom_inverses = vec![SecureField::zero(); denoms.len()]; SecureField::batch_inverse(&denoms, &mut denom_inverses); - let logup_values = zip_eq(denom_inverses, &trace[1].values).fold( - Vec::new(), - |mut acc, (denom_inverse, multiplicity)| { - let interaction_value = last + (denom_inverse * *multiplicity); - acc.push(interaction_value); - last = interaction_value; - acc - }, - ); + let mut logup_values = vec![SecureField::zero(); trace[1].values.len()]; + let log_size = interaction_trace_domain.log_size(); + for i in 0..trace[1].values.len() { + let index = coset_order_to_circle_domain_order_index(i, log_size); + let index = bit_reverse_index(index, log_size); + let interaction_value = last + (denom_inverses[index] * trace[1].values[index]); + logup_values[index] = interaction_value; + last = interaction_value; + } let secure_column: SecureColumn = logup_values.into_iter().collect(); secure_column .columns @@ -218,9 +245,8 @@ mod tests { trace_sum += trace.last().unwrap().values[i] / (random_value - BaseField::from_u32_unchecked(i as u32)); } - let logup_sum = SecureField::from_m31_array(std::array::from_fn(|j| { - *interaction_trace[j].last().unwrap() - })); + let logup_sum = + SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); assert_eq!(logup_sum, trace_sum); } diff --git a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs index 33690938..04504c86 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs @@ -1,18 +1,23 @@ use std::collections::BTreeMap; use itertools::zip_eq; +use num_traits::Zero; use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::circle::Coset; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; -use stwo_prover::core::utils::point_vanish_denominator_inverses; +use stwo_prover::core::utils::{ + bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, +}; use stwo_prover::core::{InteractionElements, LookupValues}; +use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE}; use super::component::{ RangeCheckUnitComponent, RC_LOOKUP_VALUE_0, RC_LOOKUP_VALUE_1, RC_LOOKUP_VALUE_2, @@ -42,6 +47,13 @@ impl ComponentProver for RangeCheckUnitComponent { interaction_elements, lookup_values, ); + evaluate_lookup_step_constraints( + trace_evals, + trace_eval_domain, + zero_domain, + &mut accum, + interaction_elements, + ) } fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { @@ -51,28 +63,28 @@ impl ComponentProver for RangeCheckUnitComponent { ( RC_LOOKUP_VALUE_0.to_string(), trace_poly[0] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .eval_at_point(domain.at(1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_1.to_string(), trace_poly[1] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .eval_at_point(domain.at(1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_2.to_string(), trace_poly[2] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .eval_at_point(domain.at(1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_3.to_string(), trace_poly[3] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .eval_at_point(domain.at(1).into_ef()) .try_into() .unwrap(), ), @@ -91,10 +103,8 @@ fn evaluate_lookup_boundary_constraints( ) { let first_point_denom_inverses = point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = point_vanish_denominator_inverses( - trace_eval_domain, - zero_domain.at(zero_domain.size() - 1), - ); + let last_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); let z = interaction_elements[RC_Z]; let lookup_value = SecureField::from_m31( lookup_values[RC_LOOKUP_VALUE_0], @@ -108,9 +118,9 @@ fn evaluate_lookup_boundary_constraints( let value = SecureField::from_m31_array(std::array::from_fn(|j| { trace_evals[INTERACTION_TRACE][j][i] })); - let first_point_numerator = accum.random_coeff_powers[1] + let first_point_numerator = accum.random_coeff_powers[2] * (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]); - let last_point_numerator = accum.random_coeff_powers[0] * (value - lookup_value); + let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); accum.accumulate( i, first_point_numerator * first_point_denom_inverse @@ -118,3 +128,38 @@ fn evaluate_lookup_boundary_constraints( ); } } + +fn evaluate_lookup_step_constraints( + trace_evals: &TreeVec>>, + trace_eval_domain: CircleDomain, + zero_domain: Coset, + accum: &mut ColumnAccumulator<'_, CpuBackend>, + interaction_elements: &InteractionElements, +) { + let mut denoms = vec![]; + for point in trace_eval_domain.iter() { + denoms.push(coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point)); + } + bit_reverse(&mut denoms); + let mut denom_inverses = vec![BaseField::zero(); denoms.len()]; + BaseField::batch_inverse(&denoms, &mut denom_inverses); + let z = interaction_elements[RC_Z]; + + for (i, denom_inverse) in denom_inverses.iter().enumerate() { + let value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][i] + })); + let prev_index = previous_bit_reversed_circle_domain_index( + i, + zero_domain.log_size, + trace_eval_domain.log_size(), + ); + let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][prev_index] + })); + let numerator = accum.random_coeff_powers[0] + * ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i]) + - trace_evals[BASE_TRACE][1][i]); + accum.accumulate(i, numerator * *denom_inverse); + } +} diff --git a/stwo_cairo_prover/src/components/range_check_unit/mod.rs b/stwo_cairo_prover/src/components/range_check_unit/mod.rs index 30a9bbcb..3658ee41 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/mod.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/mod.rs @@ -14,13 +14,14 @@ mod tests { use stwo_prover::core::fields::IntoSlice; use stwo_prover::core::poly::circle::CircleEvaluation; use stwo_prover::core::poly::BitReversedOrder; - use stwo_prover::core::prover::{prove, verify, VerificationError}; + use stwo_prover::core::prover::VerificationError; use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; use stwo_prover::core::vcs::hasher::Hasher; use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; use stwo_prover::trace_generation::{ - AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator, + commit_and_prove, commit_and_verify, AirTraceGenerator, AirTraceVerifier, + ComponentTraceGenerator, }; use super::*; @@ -135,7 +136,7 @@ mod tests { let trace = air.write_trace(); let prover_channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); - let proof = prove::(&air, prover_channel, trace).unwrap(); + let proof = commit_and_prove::(&air, prover_channel, trace).unwrap(); let verifier_channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); @@ -145,6 +146,6 @@ mod tests { .get_generator::(RC_COMPONENT_ID) .component(), }; - verify(proof, &air, verifier_channel).unwrap(); + commit_and_verify(proof, &air, verifier_channel).unwrap(); } }