Skip to content

Commit

Permalink
mem squish
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 10, 2024
1 parent 9809441 commit a1bbaad
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 75 deletions.
2 changes: 1 addition & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ pub fn track_cairo_relations(
RelationTrackerComponent::new(
tree_span_provider,
memory_address_to_id::Eval {
log_n_rows: claim.memory_address_to_id.log_size,
log_size: claim.memory_address_to_id.log_size,
lookup_elements: relations::MemoryAddressToId::dummy(),
},
1 << claim.memory_address_to_id.log_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,69 @@ use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;

use crate::relations;

pub const N_ADDR_TO_ID_COLUMNS: usize = 3;
// TODO(Ohad): Address should be a preprocessed `seq`.
pub const N_ADDR_COLUMNS: usize = 1;

// Split the (ID , Multiplicity) columns to shorter chunks. This is done to improve the performance
// during The merkle commitment and FRI, as this component is usually the tallest in the Cairo AIR.
// TODO(Ohad): Change split to 8 after seq is implemented. NOTE: it is possible to split further
// with an expansion trick similar to the one used in XOR. Investigate if it is worth it.
pub(super) const LOG_SPLIT_SIZE: u32 = 2;
pub(super) const SPLIT_SIZE: usize = 1 << LOG_SPLIT_SIZE;
pub(super) const N_ID_AND_MULT_COLUMNS_PER_CHUNK: usize = 2;
pub(super) const N_TRACE_COLUMNS: usize =
N_ADDR_COLUMNS + SPLIT_SIZE * N_ID_AND_MULT_COLUMNS_PER_CHUNK;

pub type Component = FrameworkComponent<Eval>;

// TODO(ShaharS): Break to repititions in order to batch the logup.
#[derive(Clone)]
pub struct Eval {
pub log_n_rows: u32,
// The log size of the component after split.
pub log_size: u32,
pub lookup_elements: relations::MemoryAddressToId,
}
impl Eval {
// TODO(ShaharS): use Seq column for address, and also use repititions.
pub const fn n_columns(&self) -> usize {
N_ADDR_TO_ID_COLUMNS
N_TRACE_COLUMNS
}

pub fn new(claim: Claim, lookup_elements: relations::MemoryAddressToId) -> Self {
Self {
log_n_rows: claim.log_size,
log_size: claim.log_size,
lookup_elements,
}
}
}

impl FrameworkEval for Eval {
fn log_size(&self) -> u32 {
self.log_n_rows
self.log_size
}

fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size() + 1
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let address_and_id: [E::F; 2] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();
eval.add_to_relation(&[RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&address_and_id,
)]);
let address = eval.next_trace_mask();
for i in 0..SPLIT_SIZE {
let id = eval.next_trace_mask();
let multiplicity = eval.next_trace_mask();
let address = address.clone() + E::F::from(M31((i * (1 << self.log_size())) as u32));
eval.add_to_relation(&[RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&[address, id],
)]);
}

eval.finalize_logup();
eval
Expand All @@ -62,8 +79,8 @@ pub struct Claim {
impl Claim {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let preprocessed_log_sizes = vec![self.log_size];
let trace_log_sizes = vec![self.log_size; N_ADDR_TO_ID_COLUMNS];
let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE];
let trace_log_sizes = vec![self.log_size; N_TRACE_COLUMNS];
let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE * SPLIT_SIZE];
TreeVec::new(vec![
preprocessed_log_sizes,
trace_log_sizes,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use itertools::Itertools;
use std::iter::zip;
use std::simd::Simd;

use itertools::{izip, Itertools};
use stwo_prover::constraint_framework::logup::LogupTraceGenerator;
use stwo_prover::constraint_framework::Relation;
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES, N_LANES};
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::{Col, Column};
use stwo_prover::core::fields::m31::{BaseField, M31};
Expand All @@ -12,8 +13,10 @@ use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::component::{Claim, InteractionClaim, N_ADDR_TO_ID_COLUMNS};
use crate::components::memory::MEMORY_ADDRESS_BOUND;
use super::component::{Claim, InteractionClaim, SPLIT_SIZE};
use crate::components::memory_address_to_id::component::{
N_ID_AND_MULT_COLUMNS_PER_CHUNK, N_TRACE_COLUMNS,
};
use crate::input::mem::Memory;
use crate::relations;

Expand All @@ -26,14 +29,11 @@ pub struct ClaimGenerator {
}
impl ClaimGenerator {
pub fn new(mem: &Memory) -> Self {
let mut ids = (0..mem.address_to_id.len())
let ids = (0..mem.address_to_id.len())
.map(|addr| mem.get_raw_id(addr as u32))
.collect_vec();
let size = ids.len().next_power_of_two();
assert!(size <= MEMORY_ADDRESS_BOUND);
ids.resize(size, 0);

let multiplicities = vec![0; size];
let multiplicities = vec![0; ids.len()];
Self {
ids,
multiplicities,
Expand Down Expand Up @@ -69,60 +69,62 @@ impl ClaimGenerator {
}

pub fn write_trace(
self,
mut self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
) -> (Claim, InteractionClaimGenerator) {
let size = self.ids.len();
let mut trace = (0..N_ADDR_TO_ID_COLUMNS)
.map(|_| Col::<SimdBackend, BaseField>::zeros(size))
.collect_vec();
let size = (self.ids.len() / SPLIT_SIZE).next_power_of_two();
let packed_size = size.div_ceil(N_LANES);
let mut trace: [_; N_TRACE_COLUMNS] =
std::array::from_fn(|_| Col::<SimdBackend, M31>::zeros(size));

let inc = PackedBaseField::from_array(std::array::from_fn(|i| {
BaseField::from_u32_unchecked((i) as u32)
}));
// Pad to a multiple of `N_LANES`.
let next_multiple_of_16 = ((self.ids.len() + 15) >> 4) << 4;
self.ids.resize(next_multiple_of_16, 0);
self.multiplicities.resize(next_multiple_of_16, 0);

// TODO(Ohad): avoid copy.
let packed_ids = self
let id_it = self
.ids
.array_chunks::<N_LANES>()
.map(|chunk| {
PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked(chunk[i])))
})
.collect_vec();
.map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) });
let multiplicities_it = self
.multiplicities
.array_chunks::<N_LANES>()
.map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) });

// Replace with seq.
for (i, id) in packed_ids.iter().enumerate() {
let inc =
PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
for i in 0..packed_size {
trace[0].data[i] =
PackedM31::broadcast(BaseField::from_u32_unchecked((i * N_LANES) as u32)) + inc;
trace[1].data[i] = *id;
inc + PackedM31::broadcast(M31::from_u32_unchecked((i * N_LANES) as u32));
}

// TODO(Ohad): avoid copy
trace[2] = BaseColumn::from_iter(
self.multiplicities
.clone()
.into_iter()
.map(BaseField::from_u32_unchecked),
);
// TODO(Ohad): Replace with seq.
for (i, (id, multiplicity)) in zip(id_it, multiplicities_it).enumerate() {
let chunk_idx = i / packed_size;
let i = i % packed_size;
trace[1 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = id;
trace[2 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = multiplicity;
}

// Lookup data.
let addresses = trace[0].data.clone();
let ids = trace[1].data.clone();
let multiplicities = trace[2].data.clone();
let ids: [_; SPLIT_SIZE] =
std::array::from_fn(|i| trace[1 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone());
let multiplicities: [_; SPLIT_SIZE] =
std::array::from_fn(|i| trace[2 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone());

// Commit on trace.
let log_address_bound = size.checked_ilog2().unwrap();
let domain = CanonicCoset::new(log_address_bound).circle_domain();
let log_size = size.checked_ilog2().unwrap();
let domain = CanonicCoset::new(log_size).circle_domain();
let trace = trace
.into_iter()
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec();
tree_builder.extend_evals(trace);

(
Claim {
log_size: log_address_bound,
},
Claim { log_size },
InteractionClaimGenerator {
addresses,
ids,
Expand All @@ -134,34 +136,30 @@ impl ClaimGenerator {

pub struct InteractionClaimGenerator {
pub addresses: Vec<PackedM31>,
pub ids: Vec<PackedM31>,
pub multiplicities: Vec<PackedM31>,
pub ids: [Vec<PackedM31>; SPLIT_SIZE],
pub multiplicities: [Vec<PackedM31>; SPLIT_SIZE],
}
impl InteractionClaimGenerator {
pub fn with_capacity(capacity: usize) -> Self {
Self {
addresses: Vec::with_capacity(capacity),
ids: Vec::with_capacity(capacity),
multiplicities: Vec::with_capacity(capacity),
}
}

pub fn write_interaction_trace(
self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
lookup_elements: &relations::MemoryAddressToId,
) -> InteractionClaim {
let log_size = self.addresses.len().ilog2() + LOG_N_LANES;
let packed_size = self.addresses.len();
let log_size = packed_size.ilog2() + LOG_N_LANES;
let mut logup_gen = LogupTraceGenerator::new(log_size);

let mut col_gen = logup_gen.new_col();
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let addr = self.addresses[vec_row];
let id = self.ids[vec_row];
let denom: PackedSecureField = lookup_elements.combine(&[addr, id]);
col_gen.write_frac(vec_row, (-self.multiplicities[vec_row]).into(), denom);
for (i, (ids, multiplicities)) in izip!(&self.ids, &self.multiplicities).enumerate() {
let mut col_gen = logup_gen.new_col();
for (vec_row, (&addr, &id, &mult)) in
izip!(&self.addresses, ids, multiplicities).enumerate()
{
let addr = addr + PackedM31::broadcast(M31((i * (1 << log_size)) as u32));
let denom = lookup_elements.combine(&[addr, id]);
col_gen.write_frac(vec_row, (-mult).into(), denom);
}
col_gen.finalize_col();
}
col_gen.finalize_col();

let (trace, claimed_sum) = logup_gen.finalize_last();
tree_builder.extend_evals(trace);
Expand Down
Loading

0 comments on commit a1bbaad

Please sign in to comment.