Skip to content

Commit

Permalink
trace to evals
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 24, 2024
1 parent 69f144c commit 0795055
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions crates/air_utils/src/trace/component_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rayon::prelude::*;
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::poly::circle::CircleEvaluation;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;

/// A 2D Matrix of [`PackedM31`] values.
Expand Down Expand Up @@ -96,7 +96,14 @@ impl<const N: usize> ComponentTrace<N> {
}

pub fn to_evals(self) -> [CircleEvaluation<SimdBackend, M31, BitReversedOrder>; N] {
todo!()
let domain = CanonicCoset::new(self.log_size).circle_domain();
self.data.map(|column| {
let eval = BaseColumn {
data: column,
length: 1 << self.log_size,
};
CircleEvaluation::<SimdBackend, M31, BitReversedOrder>::new(domain, eval)
})
}

pub fn row_at(&self, row: usize) -> [M31; N] {
Expand Down Expand Up @@ -276,4 +283,47 @@ mod tests {

assert_eq!(expected, actual);
}

#[test]
fn test_parallel_trace() {
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;

const N_COLUMNS: usize = 3;
const LOG_SIZE: u32 = 8;
let mut trace = super::IterableTrace::<N_COLUMNS>::zeroed(LOG_SIZE);
let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec();
let expected = arr
.iter()
.map(|&a| {
let b = a + M31::from(1);
let c = a.square() + b.square();
(a, b, c)
})
.multiunzip();

trace
.par_iter_mut()
.zip(arr.par_chunks(N_LANES))
.chunks(4)
.for_each(|chunk| {
chunk.into_iter().for_each(|(row, input)| {
*row[0] = PackedM31::from_array(input.try_into().unwrap());
*row[1] = *row[0] + PackedM31::broadcast(M31(1));
*row[2] = row[0].square() + row[1].square();
})
});
let actual = trace
.data
.map(|c| {
c.into_iter()
.flat_map(|packed| packed.to_array())
.collect_vec()
})
.into_iter()
.next_tuple()
.unwrap();

assert_eq!(expected, actual);
}
}

0 comments on commit 0795055

Please sign in to comment.