Skip to content

Commit

Permalink
Evaluate constraints together.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 14, 2024
1 parent 8dc79b9 commit 9f3cbe9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 154 deletions.
91 changes: 27 additions & 64 deletions stwo_cairo_prover/src/components/range_check_unit/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use stwo_prover::core::air::accumulation::PointEvaluationAccumulator;
use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::{CirclePoint, Coset};
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -36,55 +36,6 @@ pub struct RangeCheckUnitTraceGenerator {
pub multiplicities: Vec<u32>,
}

impl RangeCheckUnitComponent {
fn evaluate_lookup_boundary_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
constraint_zero_domain: Coset,
lookup_values: &LookupValues,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_lookup_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator =
(value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}

impl Component for RangeCheckUnitComponent {
fn n_constraints(&self) -> usize {
3
Expand Down Expand Up @@ -120,22 +71,34 @@ impl Component for RangeCheckUnitComponent {
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset;
self.evaluate_lookup_boundary_constraints_at_point(
point,
mask,
evaluation_accumulator,
interaction_elements,
constraint_zero_domain,
lookup_values,
);
self.evaluate_lookup_step_constraints_at_point(
point,
mask,
evaluation_accumulator,
constraint_zero_domain,
interaction_elements,
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

// Last lookup point boundary constraint.
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);

// Lookup step constraint.
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator =
(value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}

Expand Down
145 changes: 55 additions & 90 deletions stwo_cairo_prover/src/components/range_check_unit/component_prover.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::collections::BTreeMap;

use itertools::zip_eq;
use itertools::izip;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::Coset;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::poly::circle::CanonicCoset;
use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
Expand All @@ -39,21 +36,60 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

evaluate_lookup_boundary_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
lookup_values,
// TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions.
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let mut step_denoms = vec![];
for point in trace_eval_domain.iter() {
step_denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut step_denoms);
let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses);
let z = interaction_elements[RC_Z];
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
evaluate_lookup_step_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!(
first_point_denom_inverses,
last_point_denom_inverses,
step_denom_inverses,
)
.enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));

let first_point_numerator = accum.random_coeff_powers[2]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);

let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
let step_numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i])
- trace_evals[BASE_TRACE][1][i]);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse
+ step_numerator * step_denom_inverse,
);
}
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
Expand Down Expand Up @@ -92,74 +128,3 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
LookupValues::new(values)
}
}

fn evaluate_lookup_boundary_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let z = interaction_elements[RC_Z];
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let first_point_numerator = accum.random_coeff_powers[2]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);
let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

fn evaluate_lookup_step_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
) {
let mut denoms = vec![];
for point in trace_eval_domain.iter() {
denoms.push(coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); denoms.len()];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let z = interaction_elements[RC_Z];

for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
let numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i])
- trace_evals[BASE_TRACE][1][i]);
accum.accumulate(i, numerator * *denom_inverse);
}
}

0 comments on commit 9f3cbe9

Please sign in to comment.