From b9179129d2389cd110a406537ef18a7cd9c37234 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Tue, 17 Dec 2024 10:46:38 +0200 Subject: [PATCH] mem squish --- .../prover/src/cairo_air/debug_tools.rs | 2 +- .../memory/memory_address_to_id/component.rs | 56 ++++++--- .../memory/memory_address_to_id/prover.rs | 116 +++++++++--------- .../prover/src/components/memory/mod.rs | 7 +- 4 files changed, 103 insertions(+), 78 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs index 088c65a0..b9c7e0de 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs @@ -481,7 +481,7 @@ where RelationTrackerComponent::new( tree_span_provider, memory_address_to_id::Eval { - log_n_rows: claim.memory_address_to_id.log_size, + log_size: claim.memory_address_to_id.log_size, lookup_elements: relations::MemoryAddressToId::dummy(), }, 1 << claim.memory_address_to_id.log_size, diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/component.rs b/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/component.rs index 7b01cf8a..a94e088e 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/component.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/component.rs @@ -4,30 +4,52 @@ use stwo_prover::constraint_framework::{ EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, }; use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use stwo_prover::core::pcs::TreeVec; use crate::relations; -pub const N_ADDR_TO_ID_COLUMNS: usize = 3; +// TODO(Ohad): Address should be a preprocessed `seq`. +pub const N_ADDR_COLUMNS: usize = 1; + +/// Split the (ID , Multiplicity) columns to shorter chunks. This is done to improve the performance +/// during The merkle commitment and FRI, as this component is usually the tallest in the Cairo AIR. +/// +/// 1. The ID and Multiplicity vectors are split to 'N_SPLIT_CHUNKS' chunks of size +/// `ids.len()`/`N_SPLIT_CHUNKS`. +/// 2. The chunks are padded with 0s to the next power of 2. +/// +/// # Example +/// ID = [id0..id10], N_SPLIT_CHUNKS = 4: +/// ID0 = [id0, id1, id2, 0] +/// ID1 = [id3, id4, id5, 0] +/// ID2 = [id6, id7, id8, 0] +/// ID3 = [id9, id10, 0, 0] +// TODO(Ohad): Change split to 8 after seq is implemented. +pub(super) const N_SPLIT_CHUNKS: usize = 4; +pub(super) const N_ID_AND_MULT_COLUMNS_PER_CHUNK: usize = 2; +pub(super) const N_TRACE_COLUMNS: usize = + N_ADDR_COLUMNS + N_SPLIT_CHUNKS * N_ID_AND_MULT_COLUMNS_PER_CHUNK; + pub type Component = FrameworkComponent; -// TODO(ShaharS): Break to repititions in order to batch the logup. #[derive(Clone)] pub struct Eval { - pub log_n_rows: u32, + // The log size of the component after split. + pub log_size: u32, pub lookup_elements: relations::MemoryAddressToId, } impl Eval { // TODO(ShaharS): use Seq column for address, and also use repititions. pub const fn n_columns(&self) -> usize { - N_ADDR_TO_ID_COLUMNS + N_TRACE_COLUMNS } pub fn new(claim: Claim, lookup_elements: relations::MemoryAddressToId) -> Self { Self { - log_n_rows: claim.log_size, + log_size: claim.log_size, lookup_elements, } } @@ -35,7 +57,7 @@ impl Eval { impl FrameworkEval for Eval { fn log_size(&self) -> u32 { - self.log_n_rows + self.log_size } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -43,13 +65,17 @@ impl FrameworkEval for Eval { } fn evaluate(&self, mut eval: E) -> E { - let address_and_id: [E::F; 2] = std::array::from_fn(|_| eval.next_trace_mask()); - let multiplicity = eval.next_trace_mask(); - eval.add_to_relation(RelationEntry::new( - &self.lookup_elements, - E::EF::from(-multiplicity), - &address_and_id, - )); + let address = eval.next_trace_mask(); + for i in 0..N_SPLIT_CHUNKS { + let id = eval.next_trace_mask(); + let multiplicity = eval.next_trace_mask(); + let address = address.clone() + E::F::from(M31((i * (1 << self.log_size())) as u32)); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::from(-multiplicity), + &[address, id], + )); + } eval.finalize_logup(); eval @@ -63,8 +89,8 @@ pub struct Claim { impl Claim { pub fn log_sizes(&self) -> TreeVec> { let preprocessed_log_sizes = vec![self.log_size]; - let trace_log_sizes = vec![self.log_size; N_ADDR_TO_ID_COLUMNS]; - let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE]; + let trace_log_sizes = vec![self.log_size; N_TRACE_COLUMNS]; + let interaction_log_sizes = vec![self.log_size; SECURE_EXTENSION_DEGREE * N_SPLIT_CHUNKS]; TreeVec::new(vec![ preprocessed_log_sizes, trace_log_sizes, diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/prover.rs b/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/prover.rs index 5ec8cce7..e13f8d9c 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/prover.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/memory_address_to_id/prover.rs @@ -1,9 +1,10 @@ -use itertools::Itertools; +use std::iter::zip; +use std::simd::Simd; + +use itertools::{izip, Itertools}; use stwo_prover::constraint_framework::logup::LogupTraceGenerator; use stwo_prover::constraint_framework::Relation; -use stwo_prover::core::backend::simd::column::BaseColumn; use stwo_prover::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES, N_LANES}; -use stwo_prover::core::backend::simd::qm31::PackedSecureField; use stwo_prover::core::backend::simd::SimdBackend; use stwo_prover::core::backend::{BackendForChannel, Col, Column}; use stwo_prover::core::channel::MerkleChannel; @@ -12,8 +13,10 @@ use stwo_prover::core::pcs::TreeBuilder; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; -use super::component::{Claim, InteractionClaim, N_ADDR_TO_ID_COLUMNS}; -use crate::components::memory::MEMORY_ADDRESS_BOUND; +use super::component::{Claim, InteractionClaim, N_SPLIT_CHUNKS}; +use crate::components::memory_address_to_id::component::{ + N_ID_AND_MULT_COLUMNS_PER_CHUNK, N_TRACE_COLUMNS, +}; use crate::input::mem::Memory; use crate::relations; @@ -26,14 +29,11 @@ pub struct ClaimGenerator { } impl ClaimGenerator { pub fn new(mem: &Memory) -> Self { - let mut ids = (0..mem.address_to_id.len()) + let ids = (0..mem.address_to_id.len()) .map(|addr| mem.get_raw_id(addr as u32)) .collect_vec(); - let size = ids.len().next_power_of_two(); - assert!(size <= MEMORY_ADDRESS_BOUND); - ids.resize(size, 0); - let multiplicities = vec![0; size]; + let multiplicities = vec![0; ids.len()]; Self { ids, multiplicities, @@ -69,53 +69,57 @@ impl ClaimGenerator { } pub fn write_trace( - self, + mut self, tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>, ) -> (Claim, InteractionClaimGenerator) where SimdBackend: BackendForChannel, { - let size = self.ids.len(); - let mut trace = (0..N_ADDR_TO_ID_COLUMNS) - .map(|_| Col::::zeros(size)) - .collect_vec(); + let size = (self.ids.len() / N_SPLIT_CHUNKS).next_power_of_two(); + let n_packed_rows = size.div_ceil(N_LANES); + let mut trace: [_; N_TRACE_COLUMNS] = + std::array::from_fn(|_| Col::::zeros(size)); - let inc = PackedBaseField::from_array(std::array::from_fn(|i| { - BaseField::from_u32_unchecked((i) as u32) - })); + // Pad to a multiple of `N_LANES`. + let next_multiple_of_16 = self.ids.len().next_multiple_of(16); + self.ids.resize(next_multiple_of_16, 0); + self.multiplicities.resize(next_multiple_of_16, 0); // TODO(Ohad): avoid copy. - let packed_ids = self + let id_it = self .ids .array_chunks::() - .map(|chunk| { - PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked(chunk[i]))) - }) - .collect_vec(); + .map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) }); + let multiplicities_it = self + .multiplicities + .array_chunks::() + .map(|&chunk| unsafe { PackedM31::from_simd_unchecked(Simd::from_array(chunk)) }); - // Replace with seq. - for (i, id) in packed_ids.iter().enumerate() { + let inc = + PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32))); + for i in 0..n_packed_rows { trace[0].data[i] = - PackedM31::broadcast(BaseField::from_u32_unchecked((i * N_LANES) as u32)) + inc; - trace[1].data[i] = *id; + inc + PackedM31::broadcast(M31::from_u32_unchecked((i * N_LANES) as u32)); } - // TODO(Ohad): avoid copy - trace[2] = BaseColumn::from_iter( - self.multiplicities - .clone() - .into_iter() - .map(BaseField::from_u32_unchecked), - ); + // TODO(Ohad): Replace with seq. + for (i, (id, multiplicity)) in zip(id_it, multiplicities_it).enumerate() { + let chunk_idx = i / n_packed_rows; + let i = i % n_packed_rows; + trace[1 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = id; + trace[2 + chunk_idx * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data[i] = multiplicity; + } // Lookup data. let addresses = trace[0].data.clone(); - let ids = trace[1].data.clone(); - let multiplicities = trace[2].data.clone(); + let ids: [_; N_SPLIT_CHUNKS] = + std::array::from_fn(|i| trace[1 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone()); + let multiplicities: [_; N_SPLIT_CHUNKS] = + std::array::from_fn(|i| trace[2 + i * N_ID_AND_MULT_COLUMNS_PER_CHUNK].data.clone()); // Commit on trace. - let log_address_bound = size.checked_ilog2().unwrap(); - let domain = CanonicCoset::new(log_address_bound).circle_domain(); + let log_size = size.checked_ilog2().unwrap(); + let domain = CanonicCoset::new(log_size).circle_domain(); let trace = trace .into_iter() .map(|eval| { @@ -125,9 +129,7 @@ impl ClaimGenerator { tree_builder.extend_evals(trace); ( - Claim { - log_size: log_address_bound, - }, + Claim { log_size }, InteractionClaimGenerator { addresses, ids, @@ -139,18 +141,10 @@ impl ClaimGenerator { pub struct InteractionClaimGenerator { pub addresses: Vec, - pub ids: Vec, - pub multiplicities: Vec, + pub ids: [Vec; N_SPLIT_CHUNKS], + pub multiplicities: [Vec; N_SPLIT_CHUNKS], } impl InteractionClaimGenerator { - pub fn with_capacity(capacity: usize) -> Self { - Self { - addresses: Vec::with_capacity(capacity), - ids: Vec::with_capacity(capacity), - multiplicities: Vec::with_capacity(capacity), - } - } - pub fn write_interaction_trace( self, tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>, @@ -159,17 +153,21 @@ impl InteractionClaimGenerator { where SimdBackend: BackendForChannel, { - let log_size = self.addresses.len().ilog2() + LOG_N_LANES; + let packed_size = self.addresses.len(); + let log_size = packed_size.ilog2() + LOG_N_LANES; let mut logup_gen = LogupTraceGenerator::new(log_size); - let mut col_gen = logup_gen.new_col(); - for vec_row in 0..1 << (log_size - LOG_N_LANES) { - let addr = self.addresses[vec_row]; - let id = self.ids[vec_row]; - let denom: PackedSecureField = lookup_elements.combine(&[addr, id]); - col_gen.write_frac(vec_row, (-self.multiplicities[vec_row]).into(), denom); + for (i, (ids, multiplicities)) in izip!(&self.ids, &self.multiplicities).enumerate() { + let mut col_gen = logup_gen.new_col(); + for (vec_row, (&addr, &id, &mult)) in + izip!(&self.addresses, ids, multiplicities).enumerate() + { + let addr = addr + PackedM31::broadcast(M31((i * (1 << log_size)) as u32)); + let denom = lookup_elements.combine(&[addr, id]); + col_gen.write_frac(vec_row, (-mult).into(), denom); + } + col_gen.finalize_col(); } - col_gen.finalize_col(); let (trace, claimed_sum) = logup_gen.finalize_last(); tree_builder.extend_evals(trace); diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs b/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs index bf33535f..d863e9dc 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs @@ -16,9 +16,10 @@ mod tests { #[test] fn test_memory_trace_prover() { + const N_ENTRIES: u64 = 10; let memory = MemoryBuilder::from_iter( MemConfig::default(), - (0..10).map(|i| MemEntry { + (0..N_ENTRIES).map(|i| MemEntry { addr: i, val: [i as u32; 8], }), @@ -30,8 +31,8 @@ mod tests { .into_iter() .map(BaseField::from) .collect_vec(); - let expected_addr_mult: [u32; 16] = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let expected_f252_mult: [u32; 16] = [2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let expected_addr_mult: [u32; N_ENTRIES as usize] = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; + let expected_f252_mult: [u32; N_ENTRIES as usize] = [2, 3, 0, 0, 0, 0, 0, 0, 0, 0]; address_usages.iter().for_each(|addr| { let decoded_id = memory.address_to_id[addr.0 as usize].decode();