Skip to content

Commit

Permalink
Added range check to memory. (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti authored Sep 8, 2024
2 parents 5ebb0ca + 70a2955 commit c390067
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 31 deletions.
51 changes: 47 additions & 4 deletions stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ use crate::components::range_check_builtin::component::{
RangeCheckBuiltinInteractionClaim,
};
use crate::components::range_check_builtin::prover::RangeCheckBuiltinClaimProver;
use crate::components::range_check_unit::component::{
RangeCheckClaim, RangeCheckInteractionClaim, RangeCheckUnitComponent, RangeCheckUnitEval,
};
use crate::components::range_check_unit::component_prover::RangeCheckClaimProver;
use crate::components::range_check_unit::RangeCheckElements;
use crate::components::ret_opcode::component::{
RetOpcodeClaim, RetOpcodeComponent, RetOpcodeEval, RetOpcodeInteractionClaim,
};
Expand All @@ -34,6 +39,11 @@ use crate::felt::split_f252;
use crate::input::instructions::VmState;
use crate::input::CairoInput;

const RC9_LOG_MAX: u32 = 9;
const RC9_LOG_REPS: u32 = 1;
const RC9_LOG_HEIGHT: u32 = RC9_LOG_MAX - RC9_LOG_REPS;
const RC9_REPS: usize = 1 << RC9_LOG_REPS;

pub struct CairoProof<H: MerkleHasher> {
pub claim: CairoClaim,
pub interaction_claim: CairoInteractionClaim,
Expand All @@ -49,6 +59,7 @@ pub struct CairoClaim {
pub ret: Vec<RetOpcodeClaim>,
pub range_check_builtin: RangeCheckBuiltinClaim,
pub memory: MemoryClaim,
pub range_check9: RangeCheckClaim<RC9_REPS>,
// ...
}

Expand All @@ -65,18 +76,21 @@ impl CairoClaim {
self.ret.iter().map(|c| c.log_sizes()),
[self.range_check_builtin.log_sizes()],
[self.memory.log_sizes()],
[self.range_check9.log_sizes()]
))
}
}

pub struct CairoInteractionElements {
memory_lookup: MemoryLookupElements,
range9_lookup: RangeCheckElements,
// ...
}
impl CairoInteractionElements {
pub fn draw(channel: &mut impl Channel) -> CairoInteractionElements {
CairoInteractionElements {
memory_lookup: MemoryLookupElements::draw(channel),
range9_lookup: RangeCheckElements::draw(channel),
}
}
}
Expand All @@ -85,6 +99,7 @@ pub struct CairoInteractionClaim {
pub ret: Vec<RetOpcodeInteractionClaim>,
pub range_check_builtin: RangeCheckBuiltinInteractionClaim,
pub memory: MemoryInteractionClaim,
pub range_check9: RangeCheckInteractionClaim<RC9_REPS>,
// ...
}

Expand Down Expand Up @@ -121,6 +136,7 @@ pub fn lookup_sum_valid(
})
.sum::<SecureField>();
// TODO: include initial and final state.
sum += interaction_claim.range_check9.claimed_sum;
sum += interaction_claim.ret[0].claimed_sum;
sum += interaction_claim.range_check_builtin.claimed_sum;
sum += interaction_claim.memory.claimed_sum;
Expand All @@ -131,6 +147,7 @@ pub struct CairoComponents {
ret: Vec<RetOpcodeComponent>,
range_check_builtin: RangeCheckBuiltinComponent,
memory: MemoryComponent,
range_check9: RangeCheckUnitComponent<RC9_REPS>,
// ...
}

Expand Down Expand Up @@ -170,14 +187,23 @@ impl CairoComponents {
MemoryEval::new(
cairo_claim.memory.clone(),
interaction_elements.memory_lookup.clone(),
interaction_elements.range9_lookup.clone(),
interaction_claim.memory.clone(),
),
);

let range_check9_component = RangeCheckUnitComponent::new(
tree_span_provider,
RangeCheckUnitEval {
log_n_rows: RC9_LOG_HEIGHT,
lookup_elements: interaction_elements.range9_lookup.clone(),
claimed_sum: interaction_claim.range_check9.claimed_sum,
},
);
Self {
ret: ret_components,
range_check_builtin: range_check_builtin_component,
memory: memory_component,
range_check9: range_check9_component,
}
}

Expand All @@ -188,6 +214,7 @@ impl CairoComponents {
}
vec.push(&self.range_check_builtin);
vec.push(&self.memory);
vec.push(&self.range_check9);
vec
}

Expand All @@ -198,6 +225,7 @@ impl CairoComponents {
}
vec.push(&self.range_check_builtin);
vec.push(&self.memory);
vec.push(&self.range_check9);
vec
}
}
Expand Down Expand Up @@ -231,20 +259,28 @@ pub fn prove_cairo(input: CairoInput) -> CairoProof<Blake2sMerkleHasher> {
let range_check_builtin_trace_generator =
RangeCheckBuiltinClaimProver::new(input.range_check_builtin);
let mut memory_trace_generator = MemoryClaimProver::new(input.mem);
let mut range_check9_trace_generator = RangeCheckClaimProver::<RC9_LOG_HEIGHT, RC9_REPS> {
multiplicities: input
.range_check9
.to_2d_simd_vec::<RC9_LOG_HEIGHT, RC9_REPS>(),
};

// Add public memory.
for addr in &input.public_mem_addresses {
memory_trace_generator.add_inputs(M31::from_u32_unchecked(*addr));
}

let mut tree_builder = commitment_scheme.tree_builder();

let (ret_claim, ret_interaction_prover) =
ret_trace_generator.write_trace(&mut tree_builder, &mut memory_trace_generator);
let (range_check_builtin_claim, range_check_builtin_interaction_prover) =
range_check_builtin_trace_generator
.write_trace(&mut tree_builder, &mut memory_trace_generator);
let (memory_claim, memory_interaction_prover) =
memory_trace_generator.write_trace(&mut tree_builder);
let (range_check9_claim, range_check9_interaction_prover) =
range_check9_trace_generator.write_trace(&mut tree_builder);

// Commit to the claim and the trace.
let claim = CairoClaim {
Expand All @@ -254,6 +290,7 @@ pub fn prove_cairo(input: CairoInput) -> CairoProof<Blake2sMerkleHasher> {
ret: vec![ret_claim],
range_check_builtin: range_check_builtin_claim.clone(),
memory: memory_claim.clone(),
range_check9: range_check9_claim.clone(),
};
claim.mix_into(channel);
tree_builder.commit(channel);
Expand All @@ -267,14 +304,21 @@ pub fn prove_cairo(input: CairoInput) -> CairoProof<Blake2sMerkleHasher> {
.write_interaction_trace(&mut tree_builder, &interaction_elements.memory_lookup);
let range_check_builtin_interaction_claim = range_check_builtin_interaction_prover
.write_interaction_trace(&mut tree_builder, &interaction_elements.memory_lookup);
let memory_interaction_claim = memory_interaction_prover
.write_interaction_trace(&mut tree_builder, &interaction_elements.memory_lookup);
let memory_interaction_claim = memory_interaction_prover.write_interaction_trace(
&mut tree_builder,
&interaction_elements.memory_lookup,
&interaction_elements.range9_lookup,
);

let range_check9_interaction_claim = range_check9_interaction_prover
.write_interaction_trace(&mut tree_builder, &interaction_elements.range9_lookup);

// Commit to the interaction claim and the interaction trace.
let interaction_claim = CairoInteractionClaim {
ret: vec![ret_interaction_claim.clone()],
range_check_builtin: range_check_builtin_interaction_claim.clone(),
memory: memory_interaction_claim.clone(),
range_check9: range_check9_interaction_claim.clone(),
};
debug_assert!(lookup_sum_valid(
&claim,
Expand Down Expand Up @@ -375,7 +419,6 @@ mod tests {
#[test]
fn test_cairo_air() {
let cairo_proof = prove_cairo(test_input());

verify_cairo(cairo_proof).unwrap();
}
}
22 changes: 20 additions & 2 deletions stwo_cairo_prover/crates/prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use num_traits::One;
use stwo_prover::constraint_framework::logup::LogupAtRow;
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use stwo_prover::core::channel::Channel;
Expand All @@ -6,13 +7,16 @@ use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;

use super::MemoryLookupElements;
use crate::components::range_check_unit::RangeCheckElements;

pub const N_M31_IN_FELT252: usize = 28;
pub const MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_FELT252 + 1;
// TODO(AlonH): Make memory size configurable.
pub const N_MEMORY_COLUMNS: usize = N_M31_IN_FELT252 + 2;
pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 20;
pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;
pub const MEMORY_ADDRESS_SIZE: usize = 1;

pub type MemoryComponent = FrameworkComponent<MemoryEval>;

/// Addresses are continuous and start from 0.
Expand All @@ -21,6 +25,7 @@ pub type MemoryComponent = FrameworkComponent<MemoryEval>;
pub struct MemoryEval {
pub log_n_rows: u32,
pub lookup_elements: MemoryLookupElements,
pub range9_lookup_elements: RangeCheckElements,
pub claimed_sum: QM31,
}
impl MemoryEval {
Expand All @@ -30,11 +35,13 @@ impl MemoryEval {
pub fn new(
claim: MemoryClaim,
lookup_elements: MemoryLookupElements,
range9_lookup_elements: RangeCheckElements,
interaction_claim: MemoryInteractionClaim,
) -> Self {
Self {
log_n_rows: claim.log_address_bound,
lookup_elements,
range9_lookup_elements,
claimed_sum: interaction_claim.claimed_sum,
}
}
Expand All @@ -61,7 +68,17 @@ impl FrameworkEval for MemoryEval {
&address_and_value,
&self.lookup_elements,
);
// TODO(Ohad): add range check lookup constraint.

// Range check elements.
for value_limb in address_and_value.iter().skip(MEMORY_ADDRESS_SIZE) {
logup.push_lookup(
&mut eval,
E::EF::one(),
&[*value_limb],
&self.range9_lookup_elements,
);
}

logup.finalize(&mut eval);

eval
Expand All @@ -75,7 +92,8 @@ pub struct MemoryClaim {
impl MemoryClaim {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let interaction_0_log_size = vec![self.log_address_bound; N_M31_IN_FELT252 + 2];
let interaction_1_log_size = vec![self.log_address_bound; SECURE_EXTENSION_DEGREE];
let interaction_1_log_size =
vec![self.log_address_bound; SECURE_EXTENSION_DEGREE * (N_M31_IN_FELT252 + 1)];
TreeVec::new(vec![interaction_0_log_size, interaction_1_log_size])
}

Expand Down
23 changes: 20 additions & 3 deletions stwo_cairo_prover/crates/prover/src/components/memory/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::component::{
MemoryClaim, MemoryInteractionClaim, MEMORY_ADDRESS_BOUND, MULTIPLICITY_COLUMN_OFFSET,
N_M31_IN_FELT252,
MemoryClaim, MemoryInteractionClaim, MEMORY_ADDRESS_BOUND, MEMORY_ADDRESS_SIZE,
MULTIPLICITY_COLUMN_OFFSET, N_M31_IN_FELT252,
};
use super::MemoryLookupElements;
use crate::components::range_check_unit::RangeCheckElements;
use crate::components::MIN_SIMD_TRACE_LENGTH;
use crate::felt::split_f252_simd;
use crate::input::mem::{Memory, MemoryValue};
Expand Down Expand Up @@ -141,6 +142,7 @@ impl InteractionClaimProver {
&self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
lookup_elements: &MemoryLookupElements,
range9_lookup_elements: &RangeCheckElements,
) -> MemoryInteractionClaim {
let log_size = self.addresses_and_values[0].len().ilog2() + LOG_N_LANES;
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -154,6 +156,19 @@ impl InteractionClaimProver {
col_gen.write_frac(vec_row, (-self.multiplicities[vec_row]).into(), denom);
}
col_gen.finalize_col();

for value_col in self.addresses_and_values.iter().skip(MEMORY_ADDRESS_SIZE) {
let mut col_gen = logup_gen.new_col();
for (vec_row, value) in value_col.iter().enumerate() {
// TOOD(alont) Add 2-batching.
col_gen.write_frac(
vec_row,
PackedQM31::broadcast(M31(1).into()),
range9_lookup_elements.combine(&[PackedQM31::from(*value)]),
);
}
col_gen.finalize_col();
}
let (trace, claimed_sum) = logup_gen.finalize();
tree_builder.extend_evals(trace);

Expand All @@ -170,6 +185,7 @@ mod tests {

use crate::components::memory::component::N_M31_IN_FELT252;
use crate::input::mem::{MemConfig, MemoryBuilder};
use crate::input::range_check_unit::RangeCheckUnitInput;

#[test]
fn test_deduce_output_simd() {
Expand All @@ -181,7 +197,8 @@ mod tests {
.map(|v| std::array::from_fn(|i| if i == 0 { v } else { M31::zero() }));

// Create memory.
let mut mem = MemoryBuilder::new(MemConfig::default());
let mut range_check9 = RangeCheckUnitInput::new();
let mut mem = MemoryBuilder::new(MemConfig::default(), &mut range_check9);
for a in &addr {
let arr = std::array::from_fn(|i| if i == 0 { *a } else { 0 });
mem.set(*a as u64, mem.value_from_felt252(arr));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use stwo_prover::constraint_framework::logup::LogupAtRow;
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkEval};
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
Expand All @@ -8,14 +8,17 @@ use stwo_prover::core::pcs::TreeVec;

use super::RangeCheckElements;

pub type RangeCheckUnitComponent<const N_REPETITIONS: usize> =
FrameworkComponent<RangeCheckUnitEval<N_REPETITIONS>>;

#[derive(Clone)]
pub struct RangeCheckUnitComponent<const N_REPETITIONS: usize> {
pub struct RangeCheckUnitEval<const N_REPETITIONS: usize> {
pub log_n_rows: u32,
pub lookup_elements: RangeCheckElements,
pub claimed_sum: SecureField,
}

impl<const N_REPETITIONS: usize> FrameworkEval for RangeCheckUnitComponent<N_REPETITIONS> {
impl<const N_REPETITIONS: usize> FrameworkEval for RangeCheckUnitEval<N_REPETITIONS> {
fn log_size(&self) -> u32 {
self.log_n_rows
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<const RC_LOG_HEIGHT: u32, const N_REPETITIONS: usize>
mod tests {
use itertools::Itertools;
use rand::Rng;
use stwo_prover::constraint_framework::FrameworkEval;
use stwo_prover::constraint_framework::{FrameworkEval, TraceLocationAllocator};
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::Column;
Expand All @@ -186,7 +186,9 @@ mod tests {
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};

use super::RangeCheckClaimProver;
use crate::components::range_check_unit::component::RangeCheckUnitComponent;
use crate::components::range_check_unit::component::{
RangeCheckUnitComponent, RangeCheckUnitEval,
};
use crate::components::range_check_unit::RangeCheckElements;

#[test]
Expand Down Expand Up @@ -253,11 +255,15 @@ mod tests {
interaction_claim_prover.write_interaction_trace(&mut tree_builder, &lookup_elements);
tree_builder.commit(channel);

let component = RangeCheckUnitComponent::<{ N_REPS as usize }> {
log_n_rows: LOG_HEIGHT,
lookup_elements,
claimed_sum: interaction_claim.claimed_sum,
};
let tree_span_provider = &mut TraceLocationAllocator::default();
let component = RangeCheckUnitComponent::new(
tree_span_provider,
RangeCheckUnitEval::<{ N_REPS as usize }> {
log_n_rows: LOG_HEIGHT,
lookup_elements,
claimed_sum: interaction_claim.claimed_sum,
},
);

let trace_polys = commitment_scheme
.trees
Expand Down
Loading

0 comments on commit c390067

Please sign in to comment.