From 11b884137edeeaa653affa400680de4091f742fe Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 11:48:07 +0200 Subject: [PATCH] iterable trace --- Cargo.lock | 9 ++ Cargo.toml | 2 +- crates/air_utils/Cargo.toml | 12 ++ crates/air_utils/src/lib.rs | 2 + crates/air_utils/src/trace/component_trace.rs | 127 ++++++++++++++++++ crates/air_utils/src/trace/mod.rs | 1 + 6 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 crates/air_utils/Cargo.toml create mode 100644 crates/air_utils/src/lib.rs create mode 100644 crates/air_utils/src/trace/component_trace.rs create mode 100644 crates/air_utils/src/trace/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a7e009352..54d4d4ff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,15 @@ dependencies = [ "serde", ] +[[package]] +name = "stwo-air-utils" +version = "0.1.1" +dependencies = [ + "bytemuck", + "itertools 0.12.1", + "stwo-prover", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 0f314a496..fadd620de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover"] +members = ["crates/prover", "crates/air_utils"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml new file mode 100644 index 000000000..c021b76cb --- /dev/null +++ b/crates/air_utils/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "stwo-air-utils" +version.workspace = true +edition.workspace = true + +[dependencies] +bytemuck.workspace = true +itertools.workspace = true +stwo-prover = { path = "../prover" } + +[lib] +bench = false diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs new file mode 100644 index 000000000..8603c2cee --- /dev/null +++ b/crates/air_utils/src/lib.rs @@ -0,0 +1,2 @@ +#![feature(exact_size_is_empty, raw_slice_split, portable_simd)] +pub mod trace; diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs new file mode 100644 index 000000000..1fe6928f0 --- /dev/null +++ b/crates/air_utils/src/trace/component_trace.rs @@ -0,0 +1,127 @@ +use std::marker::PhantomData; + +use bytemuck::Zeroable; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; + +/// A 2D Matrix of [`PackedM31`] values. +/// Used for generating the witness of 'Stwo' proofs. +/// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. +/// Exposes an iterator over mutable references to the rows of the matrix. +/// +/// # Example: +/// +/// ```text +/// Computation trace of a^2 + (a + 1)^2 for a in 0..256 +/// ``` +/// ``` +/// use stwo_air_utils::trace::component_trace::ComponentTrace; +/// use itertools::Itertools; +/// use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +/// use stwo_prover::core::fields::m31::M31; +/// use stwo_prover::core::fields::FieldExpOps; +/// +/// const N_COLUMNS: usize = 3; +/// const LOG_SIZE: u32 = 8; +/// let mut trace = ComponentTrace::::zeroed(LOG_SIZE); +/// let example_input = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); // 0..256 +/// trace +/// .iter_mut() +/// .zip(example_input.chunks(N_LANES)) +/// .chunks(4) +/// .into_iter() +/// .for_each(|chunk| { +/// chunk.into_iter().for_each(|(row, input)| { +/// *row[0] = PackedM31::from_array(input.try_into().unwrap()); +/// *row[1] = *row[0] + PackedM31::broadcast(M31(1)); +/// *row[2] = row[0].square() + row[1].square(); +/// }) +/// }); +/// +/// let first_3_rows = (0..N_COLUMNS).map(|i| trace.row_at(i)).collect::>(); +/// assert_eq!(first_3_rows, [[0,1,1], [1,2,5], [2,3,13]].map(|row| row.map(M31::from))); +/// ``` +#[derive(Debug)] +pub struct ComponentTrace { + data: [Vec; N], + + /// Log number of non-packed rows in each column. + log_size: u32, +} + +impl ComponentTrace { + pub fn zeroed(log_size: u32) -> Self { + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| vec![PackedM31::zeroed(); n_simd_elems]); + Self { data, log_size } + } + + /// # Safety + /// The caller must ensure that the column is populated before being used. + #[allow(clippy::uninit_vec)] + pub unsafe fn uninitialized(_log_size: u32) -> Self { + todo!() + } + + pub fn log_size(&self) -> u32 { + self.log_size + } + + pub fn iter_mut(&mut self) -> RowIterMut<'_, N> { + RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + + pub fn to_evals(self) -> [CircleEvaluation; N] { + todo!() + } + + pub fn row_at(&self, row: usize) -> [M31; N] { + assert!(row < 1 << self.log_size); + let packed_row = row / N_LANES; + let idx_in_simd_vector = row % N_LANES; + self.data + .each_ref() + .map(|column| column[packed_row].to_array()[idx_in_simd_vector]) + } +} + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`ComponentTrace`]. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs new file mode 100644 index 000000000..03a022de5 --- /dev/null +++ b/crates/air_utils/src/trace/mod.rs @@ -0,0 +1 @@ +pub mod component_trace;