Skip to content

Commit

Permalink
Update stwo and fix changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 8, 2024
1 parent 7af2aaf commit 21185a0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 58 deletions.
46 changes: 23 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 30 additions & 35 deletions stwo_cairo_prover/src/components/range_check_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ use itertools::{zip_eq, Itertools};
use num_traits::Zero;

use stwo_prover::core::air::accumulation::PointEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentTraceWriter};
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SecureColumn;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::LookupValues;
use stwo_prover::core::{ColumnVec, InteractionElements};

use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, TraceGenerator};
use stwo_prover::trace_generation::ComponentGen;
use stwo_prover::trace_generation::ComponentTraceGenerator;

pub const RC_Z: &str = "RangeCheckUnit_Z";
pub const RC_COMPONENT_ID: &str = "RC_UNIT";
Expand Down Expand Up @@ -55,16 +57,13 @@ impl Component for RangeCheckUnitComponent {
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &ColumnVec<Vec<SecureField>>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
unimplemented!()
}

fn interaction_element_ids(&self) -> Vec<String> {
unimplemented!()
}
}

impl RangeCheckUnitTraceGenerator {
Expand All @@ -79,7 +78,7 @@ impl RangeCheckUnitTraceGenerator {

impl ComponentGen for RangeCheckUnitTraceGenerator {}

impl TraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
type Component = RangeCheckUnitComponent;
type Inputs = Vec<BaseField>;

Expand Down Expand Up @@ -113,14 +112,6 @@ impl TraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
.collect_vec()
}

fn component(&self) -> RangeCheckUnitComponent {
RangeCheckUnitComponent {
log_n_instances: self.max_value.checked_ilog2().unwrap() as usize,
}
}
}

impl ComponentTraceWriter<CpuBackend> for RangeCheckUnitTraceGenerator {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand All @@ -129,21 +120,28 @@ impl ComponentTraceWriter<CpuBackend> for RangeCheckUnitTraceGenerator {
let interaction_trace_domain = trace[0].domain;
let z = elements[RC_Z];

let mut last = BaseField::zero();
let mut last = SecureField::zero();
let interaction_values = zip_eq(&trace[0].values, &trace[1].values).fold(
Vec::new(),
|mut acc, (trace_value, multiplicity)| {
let interaction_value = last + (z - *trace_value).inverse() * *multiplicity;
let interaction_value = last + ((z - *trace_value).inverse() * *multiplicity);
acc.push(interaction_value);
last = interaction_value;
acc
},
);
let secure_column: SecureColumn<CpuBackend> = interaction_values.into_iter().collect();
secure_column
.columns
.into_iter()
.map(|eval| CircleEvaluation::new(interaction_trace_domain, eval))
.collect_vec()
}

vec![CircleEvaluation::new(
interaction_trace_domain,
interaction_values,
)]
fn component(&self) -> RangeCheckUnitComponent {
RangeCheckUnitComponent {
log_n_instances: self.max_value.checked_ilog2().unwrap() as usize,
}
}
}

Expand All @@ -154,7 +152,6 @@ mod tests {
#[test]
fn test_rc_unit_trace() {
let mut registry = ComponentGenerationRegistry::default();
let random_seed: usize = 117;
registry.register(RC_COMPONENT_ID, RangeCheckUnitTraceGenerator::new(8));
let inputs = vec![
vec![BaseField::from_u32_unchecked(0); 3],
Expand All @@ -174,25 +171,23 @@ mod tests {
.add_inputs(&inputs);

let trace = RangeCheckUnitTraceGenerator::write_trace(RC_COMPONENT_ID, &mut registry);
let interaction_elements = InteractionElements::new(
[(
RC_Z.to_string(),
BaseField::from_u32_unchecked(random_seed as u32),
)]
.into(),
);
let random_value = SecureField::from_u32_unchecked(1, 2, 3, 117);
let interaction_elements =
InteractionElements::new([(RC_Z.to_string(), random_value)].into());
let interaction_trace = registry
.get_generator::<RangeCheckUnitTraceGenerator>(RC_COMPONENT_ID)
.write_interaction_trace(&trace.iter().collect(), &interaction_elements);

let mut trace_sum = BaseField::zero();
let mut trace_sum = SecureField::zero();
for i in 0..8 {
assert_eq!(trace[0].values[i], BaseField::from_u32_unchecked(i as u32));
trace_sum += trace.last().unwrap().values[i]
/ BaseField::from_u32_unchecked((random_seed - i) as u32);
/ (random_value - BaseField::from_u32_unchecked(i as u32));
}
let logup_sum = interaction_trace.first().unwrap().last().unwrap();
let logup_sum = SecureField::from_m31_array(std::array::from_fn(|j| {
*interaction_trace[j].last().unwrap()
}));

assert_eq!(*logup_sum, trace_sum);
assert_eq!(logup_sum, trace_sum);
}
}

0 comments on commit 21185a0

Please sign in to comment.