diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index 4a2704d1..51d53e98 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -20,6 +20,10 @@ use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; pub const RC_Z: &str = "RangeCheckUnit_Z"; pub const RC_COMPONENT_ID: &str = "RC_UNIT"; +pub const RC_LOOKUP_VALUE_0: &str = "RC_UNIT_LOOKUP_0"; +pub const RC_LOOKUP_VALUE_1: &str = "RC_UNIT_LOOKUP_1"; +pub const RC_LOOKUP_VALUE_2: &str = "RC_UNIT_LOOKUP_2"; +pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3"; #[derive(Clone)] pub struct RangeCheckUnitComponent { @@ -39,6 +43,7 @@ impl RangeCheckUnitComponent { evaluation_accumulator: &mut PointEvaluationAccumulator, interaction_elements: &InteractionElements, constraint_zero_domain: Coset, + lookup_values: &LookupValues, ) { let z = interaction_elements[RC_Z]; let value = @@ -46,12 +51,25 @@ impl RangeCheckUnitComponent { let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0]; let denom = point_vanishing(constraint_zero_domain.at(0), point); evaluation_accumulator.accumulate(numerator / denom); + + let lookup_value = SecureField::from_m31( + lookup_values[RC_LOOKUP_VALUE_0], + lookup_values[RC_LOOKUP_VALUE_1], + lookup_values[RC_LOOKUP_VALUE_2], + lookup_values[RC_LOOKUP_VALUE_3], + ); + let numerator = value - lookup_value; + let denom = point_vanishing( + constraint_zero_domain.at(constraint_zero_domain.size() - 1), + point, + ); + evaluation_accumulator.accumulate(numerator / denom); } } impl Component for RangeCheckUnitComponent { fn n_constraints(&self) -> usize { - 1 + 2 } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -85,7 +103,7 @@ impl Component for RangeCheckUnitComponent { mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, + lookup_values: &LookupValues, ) { let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset; self.evaluate_lookup_boundary_constraints_at_point( @@ -94,6 +112,7 @@ impl Component for RangeCheckUnitComponent { evaluation_accumulator, interaction_elements, constraint_zero_domain, + lookup_values, ); } } diff --git a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs index 9c92f633..cf484593 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs @@ -1,3 +1,6 @@ +use std::collections::BTreeMap; + +use itertools::zip_eq; use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; use stwo_prover::core::backend::CpuBackend; @@ -11,7 +14,10 @@ use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE}; use stwo_prover::core::utils::point_vanish_denominator_inverses; use stwo_prover::core::{InteractionElements, LookupValues}; -use super::component::{RangeCheckUnitComponent, RC_Z}; +use super::component::{ + RangeCheckUnitComponent, RC_LOOKUP_VALUE_0, RC_LOOKUP_VALUE_1, RC_LOOKUP_VALUE_2, + RC_LOOKUP_VALUE_3, RC_Z, +}; impl RangeCheckUnitComponent { fn evaluate_lookup_boundary_constraints( @@ -21,17 +27,35 @@ impl RangeCheckUnitComponent { zero_domain: Coset, accum: &mut ColumnAccumulator<'_, CpuBackend>, interaction_elements: &InteractionElements, + lookup_values: &LookupValues, ) { - let denom_inverses = + let first_point_denom_inverses = point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); + let last_point_denom_inverses = point_vanish_denominator_inverses( + trace_eval_domain, + zero_domain.at(zero_domain.size() - 1), + ); let z = interaction_elements[RC_Z]; - for (i, denom_inverse) in denom_inverses.iter().enumerate() { + let lookup_value = SecureField::from_m31( + lookup_values[RC_LOOKUP_VALUE_0], + lookup_values[RC_LOOKUP_VALUE_1], + lookup_values[RC_LOOKUP_VALUE_2], + lookup_values[RC_LOOKUP_VALUE_3], + ); + for (i, (first_point_denom_inverse, last_point_denom_inverse)) in + zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate() + { let value = SecureField::from_m31_array(std::array::from_fn(|j| { trace_evals[INTERACTION_TRACE][j][i] })); - let numerator = accum.random_coeff_powers[0] + let first_point_numerator = accum.random_coeff_powers[1] * (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]); - accum.accumulate(i, numerator * *denom_inverse); + let last_point_numerator = accum.random_coeff_powers[0] * (value - lookup_value); + accum.accumulate( + i, + first_point_numerator * first_point_denom_inverse + + last_point_numerator * last_point_denom_inverse, + ); } } } @@ -42,7 +66,7 @@ impl ComponentProver for RangeCheckUnitComponent { trace: &ComponentTrace<'_, CpuBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, + lookup_values: &LookupValues, ) { let max_constraint_degree = self.max_constraint_log_degree_bound(); let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); @@ -57,10 +81,43 @@ impl ComponentProver for RangeCheckUnitComponent { zero_domain, &mut accum, interaction_elements, + lookup_values, ); } - fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { - LookupValues::default() + fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + let domain = CanonicCoset::new(self.log_n_instances); + let trace_poly = &trace.polys[INTERACTION_TRACE]; + let values = BTreeMap::from_iter([ + ( + RC_LOOKUP_VALUE_0.to_string(), + trace_poly[0] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RC_LOOKUP_VALUE_1.to_string(), + trace_poly[1] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RC_LOOKUP_VALUE_2.to_string(), + trace_poly[2] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RC_LOOKUP_VALUE_3.to_string(), + trace_poly[3] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ]); + LookupValues::new(values) } }