Skip to content

Commit

Permalink
mem squish
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 17, 2024
1 parent 3cff5e1 commit b917912
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ where
RelationTrackerComponent::new(
tree_span_provider,
memory_address_to_id::Eval {
log_n_rows: claim.memory_address_to_id.log_size,
log_size: claim.memory_address_to_id.log_size,
lookup_elements: relations::MemoryAddressToId::dummy(),
},
1 << claim.memory_address_to_id.log_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,78 @@ use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;

use crate::relations;

pub const N_ADDR_TO_ID_COLUMNS: usize = 3;
// TODO(Ohad): Address should be a preprocessed `seq`.
pub const N_ADDR_COLUMNS: usize = 1;

/// Split the (ID , Multiplicity) columns to shorter chunks. This is done to improve the performance
/// during The merkle commitment and FRI, as this component is usually the tallest in the Cairo AIR.
///
/// 1. The ID and Multiplicity vectors are split to 'N_SPLIT_CHUNKS' chunks of size
/// `ids.len()`/`N_SPLIT_CHUNKS`.
/// 2. The chunks are padded with 0s to the next power of 2.
///
/// # Example
/// ID = [id0..id10], N_SPLIT_CHUNKS = 4:
/// ID0 = [id0, id1, id2, 0]
/// ID1 = [id3, id4, id5, 0]
/// ID2 = [id6, id7, id8, 0]
/// ID3 = [id9, id10, 0, 0]
// TODO(Ohad): Change split to 8 after seq is implemented.
pub(super) const N_SPLIT_CHUNKS: usize = 4;
pub(super) const N_ID_AND_MULT_COLUMNS_PER_CHUNK: usize = 2;
pub(super) const N_TRACE_COLUMNS: usize =
N_ADDR_COLUMNS + N_SPLIT_CHUNKS * N_ID_AND_MULT_COLUMNS_PER_CHUNK;

pub type Component = FrameworkComponent<Eval>;

// TODO(ShaharS): Break to repititions in order to batch the logup.
#[derive(Clone)]
pub struct Eval {
pub log_n_rows: u32,
// The log size of the component after split.
pub log_size: u32,
pub lookup_elements: relations::MemoryAddressToId,
}
impl Eval {
// TODO(ShaharS): use Seq column for address, and also use repititions.
pub const fn n_columns(&self) -> usize {
N_ADDR_TO_ID_COLUMNS
N_TRACE_COLUMNS
}

pub fn new(claim: Claim, lookup_elements: relations::MemoryAddressToId) -> Self {
Self {
log_n_rows: claim.log_size,
log_size: claim.log_size,
lookup_elements,
}
}
}

impl FrameworkEval for Eval {
fn log_size(&self) -> u32 {
self.log_n_rows
self.log_size
}

fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size() + 1
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let address_and_id: [E::F; 2] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&address_and_id,
));
let address = eval.next_trace_mask();
for i in 0..N_SPLIT_CHUNKS {
let id = eval.next_trace_mask();
let multiplicity = eval.next_trace_mask();
let address = address.clone() + E::F::from(M31((i * (1 << self.log_size())) as u32));
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&[address, id],
));
}

eval.finalize_logup();
eval
Expand All @@ -63,8 +89,8 @@ pub struct Claim {
impl Claim {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let preprocessed_log_sizes = vec![self.log_size];
let trace_log_sizes = vec![self.log_size; N_ADDR_TO_ID_COLUMNS];
let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE];
let trace_log_sizes = vec![self.log_size; N_TRACE_COLUMNS];
let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE * N_SPLIT_CHUNKS];
TreeVec::new(vec![
preprocessed_log_sizes,
trace_log_sizes,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use itertools::Itertools;
use std::iter::zip;
use std::simd::Simd;

use itertools::{izip, Itertools};
use stwo_prover::constraint_framework::logup::LogupTraceGenerator;
use stwo_prover::constraint_framework::Relation;
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES, N_LANES};
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::{BackendForChannel, Col, Column};
use stwo_prover::core::channel::MerkleChannel;
Expand All @@ -12,8 +13,10 @@ use stwo_prover::core::pcs::TreeBuilder;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;

use super::component::{Claim, InteractionClaim, N_ADDR_TO_ID_COLUMNS};
use crate::components::memory::MEMORY_ADDRESS_BOUND;
use super::component::{Claim, InteractionClaim, N_SPLIT_CHUNKS};
use crate::components::memory_address_to_id::component::{
N_ID_AND_MULT_COLUMNS_PER_CHUNK, N_TRACE_COLUMNS,
};
use crate::input::mem::Memory;
use crate::relations;

Expand All @@ -26,14 +29,11 @@ pub struct ClaimGenerator {
}
impl ClaimGenerator {
pub fn new(mem: &Memory) -> Self {
let mut ids = (0..mem.address_to_id.len())
let ids = (0..mem.address_to_id.len())
.map(|addr| mem.get_raw_id(addr as u32))
.collect_vec();
let size = ids.len().next_power_of_two();
assert!(size <= MEMORY_ADDRESS_BOUND);
ids.resize(size, 0);

let multiplicities = vec![0; size];
let multiplicities = vec![0; ids.len()];
Self {
ids,
multiplicities,
Expand Down Expand Up @@ -69,53 +69,57 @@ impl ClaimGenerator {
}

pub fn write_trace<MC: MerkleChannel>(
self,
mut self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>,
) -> (Claim, InteractionClaimGenerator)
where
SimdBackend: BackendForChannel<MC>,
{
let size = self.ids.len();
let mut trace = (0..N_ADDR_TO_ID_COLUMNS)
.map(|_| Col::<SimdBackend, BaseField>::zeros(size))
.collect_vec();
let size = (self.ids.len() / N_SPLIT_CHUNKS).next_power_of_two();
let n_packed_rows = size.div_ceil(N_LANES);
let mut trace: [_; N_TRACE_COLUMNS] =
std::array::from_fn(|_| Col::<SimdBackend, M31>::zeros(size));

let inc = PackedBaseField::from_array(std::array::from_fn(|i| {
BaseField::from_u32_unchecked((i) as u32)
}));
// Pad to a multiple of `N_LANES`.
let next_multiple_of_16 = self.ids.len().next_multiple_of(16);
self.ids.resize(next_multiple_of_16, 0);
self.multiplicities.resize(next_multiple_of_16, 0);

// TODO(Ohad): avoid copy.
let packed_ids = self
let id_it = self
.ids
.array_chunks::<N_LANES>()
.map(|chunk| {
PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked(chunk[i])))
})
.collect_vec();
.map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) });
let multiplicities_it = self
.multiplicities
.array_chunks::<N_LANES>()
.map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) });

// Replace with seq.
for (i, id) in packed_ids.iter().enumerate() {
let inc =
PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
for i in 0..n_packed_rows {
trace[0].data[i] =
PackedM31::broadcast(BaseField::from_u32_unchecked((i * N_LANES) as u32)) + inc;
trace[1].data[i] = *id;
inc + PackedM31::broadcast(M31::from_u32_unchecked((i * N_LANES) as u32));
}

// TODO(Ohad): avoid copy
trace[2] = BaseColumn::from_iter(
self.multiplicities
.clone()
.into_iter()
.map(BaseField::from_u32_unchecked),
);
// TODO(Ohad): Replace with seq.
for (i, (id, multiplicity)) in zip(id_it, multiplicities_it).enumerate() {
let chunk_idx = i / n_packed_rows;
let i = i % n_packed_rows;
trace[1 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = id;
trace[2 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = multiplicity;
}

// Lookup data.
let addresses = trace[0].data.clone();
let ids = trace[1].data.clone();
let multiplicities = trace[2].data.clone();
let ids: [_; N_SPLIT_CHUNKS] =
std::array::from_fn(|i| trace[1 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone());
let multiplicities: [_; N_SPLIT_CHUNKS] =
std::array::from_fn(|i| trace[2 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone());

// Commit on trace.
let log_address_bound = size.checked_ilog2().unwrap();
let domain = CanonicCoset::new(log_address_bound).circle_domain();
let log_size = size.checked_ilog2().unwrap();
let domain = CanonicCoset::new(log_size).circle_domain();
let trace = trace
.into_iter()
.map(|eval| {
Expand All @@ -125,9 +129,7 @@ impl ClaimGenerator {
tree_builder.extend_evals(trace);

(
Claim {
log_size: log_address_bound,
},
Claim { log_size },
InteractionClaimGenerator {
addresses,
ids,
Expand All @@ -139,18 +141,10 @@ impl ClaimGenerator {

pub struct InteractionClaimGenerator {
pub addresses: Vec<PackedM31>,
pub ids: Vec<PackedM31>,
pub multiplicities: Vec<PackedM31>,
pub ids: [Vec<PackedM31>; N_SPLIT_CHUNKS],
pub multiplicities: [Vec<PackedM31>; N_SPLIT_CHUNKS],
}
impl InteractionClaimGenerator {
pub fn with_capacity(capacity: usize) -> Self {
Self {
addresses: Vec::with_capacity(capacity),
ids: Vec::with_capacity(capacity),
multiplicities: Vec::with_capacity(capacity),
}
}

pub fn write_interaction_trace<MC: MerkleChannel>(
self,
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>,
Expand All @@ -159,17 +153,21 @@ impl InteractionClaimGenerator {
where
SimdBackend: BackendForChannel<MC>,
{
let log_size = self.addresses.len().ilog2() + LOG_N_LANES;
let packed_size = self.addresses.len();
let log_size = packed_size.ilog2() + LOG_N_LANES;
let mut logup_gen = LogupTraceGenerator::new(log_size);

let mut col_gen = logup_gen.new_col();
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let addr = self.addresses[vec_row];
let id = self.ids[vec_row];
let denom: PackedSecureField = lookup_elements.combine(&[addr, id]);
col_gen.write_frac(vec_row, (-self.multiplicities[vec_row]).into(), denom);
for (i, (ids, multiplicities)) in izip!(&self.ids, &self.multiplicities).enumerate() {
let mut col_gen = logup_gen.new_col();
for (vec_row, (&addr, &id, &mult)) in
izip!(&self.addresses, ids, multiplicities).enumerate()
{
let addr = addr + PackedM31::broadcast(M31((i * (1 << log_size)) as u32));
let denom = lookup_elements.combine(&[addr, id]);
col_gen.write_frac(vec_row, (-mult).into(), denom);
}
col_gen.finalize_col();
}
col_gen.finalize_col();

let (trace, claimed_sum) = logup_gen.finalize_last();
tree_builder.extend_evals(trace);
Expand Down
7 changes: 4 additions & 3 deletions stwo_cairo_prover/crates/prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ mod tests {

#[test]
fn test_memory_trace_prover() {
const N_ENTRIES: u64 = 10;
let memory = MemoryBuilder::from_iter(
MemConfig::default(),
(0..10).map(|i| MemEntry {
(0..N_ENTRIES).map(|i| MemEntry {
addr: i,
val: [i as u32; 8],
}),
Expand All @@ -30,8 +31,8 @@ mod tests {
.into_iter()
.map(BaseField::from)
.collect_vec();
let expected_addr_mult: [u32; 16] = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let expected_f252_mult: [u32; 16] = [2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let expected_addr_mult: [u32; N_ENTRIES as usize] = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0];
let expected_f252_mult: [u32; N_ENTRIES as usize] = [2, 3, 0, 0, 0, 0, 0, 0, 0, 0];

address_usages.iter().for_each(|addr| {
let decoded_id = memory.address_to_id[addr.0 as usize].decode();
Expand Down

0 comments on commit b917912

Please sign in to comment.