diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index 5614e933..3b2dd8ed 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -4,7 +4,7 @@ use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; use stwo_prover::core::air::mask::fixed_mask_points; use stwo_prover::core::air::Component; use stwo_prover::core::backend::CpuBackend; -use stwo_prover::core::circle::{CirclePoint, Coset}; +use stwo_prover::core::circle::CirclePoint; use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; @@ -36,55 +36,6 @@ pub struct RangeCheckUnitTraceGenerator { pub multiplicities: Vec, } -impl RangeCheckUnitComponent { - fn evaluate_lookup_boundary_constraints_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - interaction_elements: &InteractionElements, - constraint_zero_domain: Coset, - lookup_values: &LookupValues, - ) { - let z = interaction_elements[RC_Z]; - let value = - SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); - let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; - let denom = point_vanishing(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - - let lookup_value = SecureField::from_m31( - lookup_values[RC_LOOKUP_VALUE_0], - lookup_values[RC_LOOKUP_VALUE_1], - lookup_values[RC_LOOKUP_VALUE_2], - lookup_values[RC_LOOKUP_VALUE_3], - ); - let numerator = value - lookup_value; - let denom = point_vanishing(constraint_zero_domain.at(1), point); - evaluation_accumulator.accumulate(numerator / denom); - } - - fn evaluate_lookup_step_constraints_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - constraint_zero_domain: Coset, - interaction_elements: &InteractionElements, - ) { - let z = interaction_elements[RC_Z]; - let value = - SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); - let prev_value = - SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); - let numerator = - (value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; - let denom = coset_vanishing(constraint_zero_domain, point) - / point_excluder(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - } -} - impl Component for RangeCheckUnitComponent { fn n_constraints(&self) -> usize { 3 @@ -120,22 +71,34 @@ impl Component for RangeCheckUnitComponent { interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) { + // First lookup point boundary constraint. let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset; - self.evaluate_lookup_boundary_constraints_at_point( - point, - mask, - evaluation_accumulator, - interaction_elements, - constraint_zero_domain, - lookup_values, - ); - self.evaluate_lookup_step_constraints_at_point( - point, - mask, - evaluation_accumulator, - constraint_zero_domain, - interaction_elements, + let z = interaction_elements[RC_Z]; + let value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); + let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; + let denom = point_vanishing(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); + + // Last lookup point boundary constraint. + let lookup_value = SecureField::from_m31( + lookup_values[RC_LOOKUP_VALUE_0], + lookup_values[RC_LOOKUP_VALUE_1], + lookup_values[RC_LOOKUP_VALUE_2], + lookup_values[RC_LOOKUP_VALUE_3], ); + let numerator = value - lookup_value; + let denom = point_vanishing(constraint_zero_domain.at(1), point); + evaluation_accumulator.accumulate(numerator / denom); + + // Lookup step constraint. + let prev_value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); + let numerator = + (value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; + let denom = coset_vanishing(constraint_zero_domain, point) + / point_excluder(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); } } diff --git a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs index 53d5b563..cf62d389 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs @@ -1,18 +1,15 @@ use std::collections::BTreeMap; -use itertools::zip_eq; +use itertools::izip; use num_traits::Zero; -use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; +use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator; use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; use stwo_prover::core::backend::CpuBackend; -use stwo_prover::core::circle::Coset; use stwo_prover::core::constraints::{coset_vanishing, point_excluder}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::FieldExpOps; -use stwo_prover::core::pcs::TreeVec; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; -use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::poly::circle::CanonicCoset; use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; use stwo_prover::core::utils::{ bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, @@ -39,21 +36,60 @@ impl ComponentProver for RangeCheckUnitComponent { let [mut accum] = evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); - evaluate_lookup_boundary_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - interaction_elements, - lookup_values, + // TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions. + let first_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); + let last_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); + let mut step_denoms = vec![]; + for point in trace_eval_domain.iter() { + step_denoms.push( + coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), + ); + } + bit_reverse(&mut step_denoms); + let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; + BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses); + let z = interaction_elements[RC_Z]; + let lookup_value = SecureField::from_m31( + lookup_values[RC_LOOKUP_VALUE_0], + lookup_values[RC_LOOKUP_VALUE_1], + lookup_values[RC_LOOKUP_VALUE_2], + lookup_values[RC_LOOKUP_VALUE_3], ); - evaluate_lookup_step_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - interaction_elements, + for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!( + first_point_denom_inverses, + last_point_denom_inverses, + step_denom_inverses, ) + .enumerate() + { + let value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][i] + })); + let prev_index = previous_bit_reversed_circle_domain_index( + i, + zero_domain.log_size, + trace_eval_domain.log_size(), + ); + let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][prev_index] + })); + + let first_point_numerator = accum.random_coeff_powers[2] + * (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]); + + let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); + let step_numerator = accum.random_coeff_powers[0] + * ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i]) + - trace_evals[BASE_TRACE][1][i]); + accum.accumulate( + i, + first_point_numerator * first_point_denom_inverse + + last_point_numerator * last_point_denom_inverse + + step_numerator * step_denom_inverse, + ); + } } fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { @@ -92,74 +128,3 @@ impl ComponentProver for RangeCheckUnitComponent { LookupValues::new(values) } } - -fn evaluate_lookup_boundary_constraints( - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, -) { - let first_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); - let z = interaction_elements[RC_Z]; - let lookup_value = SecureField::from_m31( - lookup_values[RC_LOOKUP_VALUE_0], - lookup_values[RC_LOOKUP_VALUE_1], - lookup_values[RC_LOOKUP_VALUE_2], - lookup_values[RC_LOOKUP_VALUE_3], - ); - for (i, (first_point_denom_inverse, last_point_denom_inverse)) in - zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate() - { - let value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][i] - })); - let first_point_numerator = accum.random_coeff_powers[2] - * (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]); - let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); - accum.accumulate( - i, - first_point_numerator * first_point_denom_inverse - + last_point_numerator * last_point_denom_inverse, - ); - } -} - -fn evaluate_lookup_step_constraints( - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - interaction_elements: &InteractionElements, -) { - let mut denoms = vec![]; - for point in trace_eval_domain.iter() { - denoms.push(coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point)); - } - bit_reverse(&mut denoms); - let mut denom_inverses = vec![BaseField::zero(); denoms.len()]; - BaseField::batch_inverse(&denoms, &mut denom_inverses); - let z = interaction_elements[RC_Z]; - - for (i, denom_inverse) in denom_inverses.iter().enumerate() { - let value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][i] - })); - let prev_index = previous_bit_reversed_circle_domain_index( - i, - zero_domain.log_size, - trace_eval_domain.log_size(), - ); - let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][prev_index] - })); - let numerator = accum.random_coeff_powers[0] - * ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i]) - - trace_evals[BASE_TRACE][1][i]); - accum.accumulate(i, numerator * *denom_inverse); - } -}