Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate constraints together. #18

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 27 additions & 64 deletions stwo_cairo_prover/src/components/range_check_unit/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use stwo_prover::core::air::accumulation::PointEvaluationAccumulator;
use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::{CirclePoint, Coset};
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -37,55 +37,6 @@ pub struct RangeCheckUnitTraceGenerator {
pub multiplicities: Vec<u32>,
}

impl RangeCheckUnitComponent {
fn evaluate_lookup_boundary_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
constraint_zero_domain: Coset,
lookup_values: &LookupValues,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_lookup_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator =
(value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}

impl Component for RangeCheckUnitComponent {
fn n_constraints(&self) -> usize {
3
Expand Down Expand Up @@ -121,22 +72,34 @@ impl Component for RangeCheckUnitComponent {
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset;
self.evaluate_lookup_boundary_constraints_at_point(
point,
mask,
evaluation_accumulator,
interaction_elements,
constraint_zero_domain,
lookup_values,
);
self.evaluate_lookup_step_constraints_at_point(
point,
mask,
evaluation_accumulator,
constraint_zero_domain,
interaction_elements,
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let numerator = value * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

// Last lookup point boundary constraint.
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);

// Lookup step constraint.
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator =
(value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}

Expand Down
145 changes: 55 additions & 90 deletions stwo_cairo_prover/src/components/range_check_unit/component_prover.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::collections::BTreeMap;

use itertools::zip_eq;
use itertools::izip;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::Coset;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::poly::circle::CanonicCoset;
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
};
Expand All @@ -39,21 +36,60 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

evaluate_lookup_boundary_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
lookup_values,
// TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions.
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let mut step_denoms = vec![];
for point in trace_eval_domain.iter() {
step_denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut step_denoms);
let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses);
let z = interaction_elements[RC_Z];
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
evaluate_lookup_step_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!(
first_point_denom_inverses,
last_point_denom_inverses,
step_denom_inverses,
)
.enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));

let first_point_numerator = accum.random_coeff_powers[2]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);

let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
let step_numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i])
- trace_evals[BASE_TRACE][1][i]);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse
+ step_numerator * step_denom_inverse,
);
}
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
Expand Down Expand Up @@ -92,74 +128,3 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
LookupValues::new(values)
}
}

fn evaluate_lookup_boundary_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let z = interaction_elements[RC_Z];
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
lookup_values[RC_LOOKUP_VALUE_1],
lookup_values[RC_LOOKUP_VALUE_2],
lookup_values[RC_LOOKUP_VALUE_3],
);
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let first_point_numerator = accum.random_coeff_powers[2]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);
let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

fn evaluate_lookup_step_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
) {
let mut denoms = vec![];
for point in trace_eval_domain.iter() {
denoms.push(coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); denoms.len()];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let z = interaction_elements[RC_Z];

for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
let numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i])
- trace_evals[BASE_TRACE][1][i]);
accum.accumulate(i, numerator * *denom_inverse);
}
}
Loading