Skip to content

Commit

Permalink
Add lookup step constraint. (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Jul 15, 2024
2 parents fa46cdc + 64e046d commit a6a5aa0
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 44 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ edition = "2021"
itertools = "0.12.0"
num-traits = "0.2.17"
# TODO(ShaharS): take stwo version from the source repository.
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "2501444" }
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "7a0bddee" }
70 changes: 48 additions & 22 deletions stwo_cairo_prover/src/components/range_check_unit/component.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use itertools::{zip_eq, Itertools};
use itertools::Itertools;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::PointEvaluationAccumulator;
use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::{CirclePoint, Coset};
use stwo_prover::core::constraints::point_vanishing;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::utils::{bit_reverse_index, coset_order_to_circle_domain_order_index};
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};
use stwo_prover::trace_generation::{
ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE,
};

pub const RC_Z: &str = "RangeCheckUnit_Z";
pub const RC_COMPONENT_ID: &str = "RC_UNIT";
Expand Down Expand Up @@ -59,17 +61,34 @@ impl RangeCheckUnitComponent {
lookup_values[RC_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(
constraint_zero_domain.at(constraint_zero_domain.size() - 1),
point,
);
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_lookup_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator =
(value - prev_value) * (z - mask[BASE_TRACE][0][0]) - mask[BASE_TRACE][1][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}

impl Component for RangeCheckUnitComponent {
fn n_constraints(&self) -> usize {
2
3
}

fn max_constraint_log_degree_bound(&self) -> u32 {
Expand All @@ -87,9 +106,10 @@ impl Component for RangeCheckUnitComponent {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let domain = CanonicCoset::new(self.log_n_instances);
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; 2], point),
vec![vec![point]; SECURE_EXTENSION_DEGREE],
vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -110,6 +130,13 @@ impl Component for RangeCheckUnitComponent {
constraint_zero_domain,
lookup_values,
);
self.evaluate_lookup_step_constraints_at_point(
point,
mask,
evaluation_accumulator,
constraint_zero_domain,
interaction_elements,
);
}
}

Expand Down Expand Up @@ -171,15 +198,15 @@ impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
let denoms = trace[0].values.iter().map(|value| z - *value).collect_vec();
let mut denom_inverses = vec![SecureField::zero(); denoms.len()];
SecureField::batch_inverse(&denoms, &mut denom_inverses);
let logup_values = zip_eq(denom_inverses, &trace[1].values).fold(
Vec::new(),
|mut acc, (denom_inverse, multiplicity)| {
let interaction_value = last + (denom_inverse * *multiplicity);
acc.push(interaction_value);
last = interaction_value;
acc
},
);
let mut logup_values = vec![SecureField::zero(); trace[1].values.len()];
let log_size = interaction_trace_domain.log_size();
for i in 0..trace[1].values.len() {
let index = coset_order_to_circle_domain_order_index(i, log_size);
let index = bit_reverse_index(index, log_size);
let interaction_value = last + (denom_inverses[index] * trace[1].values[index]);
logup_values[index] = interaction_value;
last = interaction_value;
}
let secure_column: SecureColumn<CpuBackend> = logup_values.into_iter().collect();
secure_column
.columns
Expand Down Expand Up @@ -218,9 +245,8 @@ mod tests {
trace_sum += trace.last().unwrap().values[i]
/ (random_value - BaseField::from_u32_unchecked(i as u32));
}
let logup_sum = SecureField::from_m31_array(std::array::from_fn(|j| {
*interaction_trace[j].last().unwrap()
}));
let logup_sum =
SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1]));

assert_eq!(logup_sum, trace_sum);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
use std::collections::BTreeMap;

use itertools::zip_eq;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::Coset;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use stwo_prover::core::utils::point_vanish_denominator_inverses;
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
};
use stwo_prover::core::{InteractionElements, LookupValues};
use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE};

use super::component::{
RangeCheckUnitComponent, RC_LOOKUP_VALUE_0, RC_LOOKUP_VALUE_1, RC_LOOKUP_VALUE_2,
Expand Down Expand Up @@ -42,6 +47,13 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
interaction_elements,
lookup_values,
);
evaluate_lookup_step_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
)
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
Expand All @@ -51,28 +63,28 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
(
RC_LOOKUP_VALUE_0.to_string(),
trace_poly[0]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_1.to_string(),
trace_poly[1]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_2.to_string(),
trace_poly[2]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
RC_LOOKUP_VALUE_3.to_string(),
trace_poly[3]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
Expand All @@ -91,10 +103,8 @@ fn evaluate_lookup_boundary_constraints(
) {
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let z = interaction_elements[RC_Z];
let lookup_value = SecureField::from_m31(
lookup_values[RC_LOOKUP_VALUE_0],
Expand All @@ -108,13 +118,48 @@ fn evaluate_lookup_boundary_constraints(
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let first_point_numerator = accum.random_coeff_powers[1]
let first_point_numerator = accum.random_coeff_powers[2]
* (value * (z - trace_evals[BASE_TRACE][0][i]) - trace_evals[BASE_TRACE][1][i]);
let last_point_numerator = accum.random_coeff_powers[0] * (value - lookup_value);
let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

fn evaluate_lookup_step_constraints(
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
) {
let mut denoms = vec![];
for point in trace_eval_domain.iter() {
denoms.push(coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); denoms.len()];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let z = interaction_elements[RC_Z];

for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
let numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * (z - trace_evals[BASE_TRACE][0][i])
- trace_evals[BASE_TRACE][1][i]);
accum.accumulate(i, numerator * *denom_inverse);
}
}
9 changes: 5 additions & 4 deletions stwo_cairo_prover/src/components/range_check_unit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ mod tests {
use stwo_prover::core::fields::IntoSlice;
use stwo_prover::core::poly::circle::CircleEvaluation;
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::{prove, verify, VerificationError};
use stwo_prover::core::prover::VerificationError;
use stwo_prover::core::vcs::blake2_hash::Blake2sHasher;
use stwo_prover::core::vcs::hasher::Hasher;
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{
AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator,
commit_and_prove, commit_and_verify, AirTraceGenerator, AirTraceVerifier,
ComponentTraceGenerator,
};

use super::*;
Expand Down Expand Up @@ -135,7 +136,7 @@ mod tests {
let trace = air.write_trace();
let prover_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let proof = prove::<CpuBackend>(&air, prover_channel, trace).unwrap();
let proof = commit_and_prove::<CpuBackend>(&air, prover_channel, trace).unwrap();

let verifier_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
Expand All @@ -145,6 +146,6 @@ mod tests {
.get_generator::<RangeCheckUnitTraceGenerator>(RC_COMPONENT_ID)
.component(),
};
verify(proof, &air, verifier_channel).unwrap();
commit_and_verify(proof, &air, verifier_channel).unwrap();
}
}

0 comments on commit a6a5aa0

Please sign in to comment.