diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9bbf05402..8baf0d083 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -7,6 +7,7 @@ mod info; pub mod logup; mod point; pub mod preprocessed_columns; +pub mod relation_tracker; mod simd_domain; use std::array; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs new file mode 100644 index 000000000..df3996d63 --- /dev/null +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -0,0 +1,189 @@ +use std::fmt::Debug; + +use itertools::Itertools; +use num_traits::Zero; + +use super::logup::LogupSums; +use super::{ + EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, + INTERACTION_TRACE_IDX, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::LOG_N_VERY_PACKED_ELEMS; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::utils::Fraction; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{ + bit_reverse_index, coset_index_to_circle_domain_index, offset_bit_reversed_circle_domain_index, +}; + +#[derive(Debug)] +pub struct RelationTrackerEntry { + pub relation: String, + pub mult: M31, + pub values: Vec, +} + +pub struct RelationTrackerComponent { + eval: E, + trace_locations: TreeVec, + n_rows: usize, +} +impl RelationTrackerComponent { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E, n_rows: usize) -> Self { + let info = eval.evaluate(InfoEvaluator::new( + eval.log_size(), + vec![], + LogupSums::default(), + )); + let mut mask_offsets = info.mask_offsets; + mask_offsets.drain(INTERACTION_TRACE_IDX..); + let trace_locations = location_allocator.next_for_structure(&mask_offsets); + Self { + eval, + trace_locations, + n_rows, + } + } + + pub fn entries( + self, + trace: &TreeVec>>, + ) -> Vec { + let log_size = self.eval.log_size(); + + // Deref the sub-tree. Only copies the references. + let sub_tree = trace + .sub_tree(&self.trace_locations) + .map(|vec| vec.into_iter().copied().collect_vec()); + let mut entries = vec![]; + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let evaluator = + RelationTrackerEvaluator::new(&sub_tree, vec_row, log_size, self.n_rows); + entries.extend(self.eval.evaluate(evaluator).entries()); + } + entries + } +} + +/// Aggregates relation entries. +// TODO(Ohad): write a summarize function, test. +pub struct RelationTrackerEvaluator<'a> { + entries: Vec, + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + pub vec_row: usize, + pub domain_log_size: u32, + pub n_rows: usize, +} +impl<'a> RelationTrackerEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + domain_log_size: u32, + n_rows: usize, + ) -> Self { + Self { + entries: vec![], + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + domain_log_size, + n_rows, + } + } + + pub fn entries(self) -> Vec { + self.entries + } +} +impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { + type F = PackedBaseField; + type EF = PackedSecureField; + + // TODO(Ohad): Add debug boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + assert_ne!(interaction, INTERACTION_TRACE_IDX); + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + unsafe { + let col = &self + .trace_eval + .get_unchecked(interaction) + .get_unchecked(col_index) + .values; + return *col.data.get_unchecked(self.vec_row); + }; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + PackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, + self.domain_log_size, + self.domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + fn add_constraint(&mut self, _constraint: G) {} + + fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + PackedSecureField::zero() + } + + fn write_logup_frac(&mut self, _fraction: Fraction) {} + + fn finalize_logup(&mut self) {} + + fn add_to_relation>( + &mut self, + entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + ) { + for entry in entries { + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); + + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, + self.domain_log_size, + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; + } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); + } + } + } +} diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 36ef3a198..73c624f81 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -41,6 +41,13 @@ impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { } } +/// Converts `&TreeVec<&Vec>` to `TreeVec>`. +impl<'a, T> From<&'a TreeVec<&'a Vec>> for TreeVec> { + fn from(val: &'a TreeVec<&'a Vec>) -> Self { + TreeVec(val.iter().map(|vec| vec.iter().collect()).collect()) + } +} + impl Deref for TreeVec { type Target = Vec; fn deref(&self) -> &Self::Target {