From 2e99afbbd69aba6fa0708b1b1962f9b38348aa14 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:08:16 +0200 Subject: [PATCH] relation tracker --- crates/prover/src/constraint_framework/mod.rs | 1 + .../constraint_framework/relation_tracker.rs | 196 ++++++++++++++++++ crates/prover/src/core/pcs/utils.rs | 7 + .../src/examples/state_machine/components.rs | 43 +++- .../prover/src/examples/state_machine/mod.rs | 21 +- 5 files changed, 259 insertions(+), 9 deletions(-) create mode 100644 crates/prover/src/constraint_framework/relation_tracker.rs diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9bbf05402..8baf0d083 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -7,6 +7,7 @@ mod info; pub mod logup; mod point; pub mod preprocessed_columns; +pub mod relation_tracker; mod simd_domain; use std::array; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs new file mode 100644 index 000000000..3d573e3d0 --- /dev/null +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -0,0 +1,196 @@ +use itertools::Itertools; +use serde::Serialize; + +use super::logup::LogupSums; +use super::{ + EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, + INTERACTION_TRACE_IDX, +}; +use crate::core::backend::simd::column::VeryPackedBaseColumn; +use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedQM31; +use crate::core::backend::simd::very_packed_m31::{ + VeryPackedBaseField, VeryPackedSecureField, LOG_N_VERY_PACKED_ELEMS, N_VERY_PACKED_ELEMS, +}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::utils::Fraction; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::offset_bit_reversed_circle_domain_index; + +#[derive(Serialize, Debug)] +pub struct RelationTrackerEntry { + pub relation: String, + pub mult: u32, + pub values: Vec, +} + +pub trait RelationTracker { + fn entries( + self, + trace: &TreeVec>>, + ) -> Vec; +} + +pub struct RelationTrackerComponent { + eval: E, + trace_locations: TreeVec, +} +impl RelationTrackerComponent { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self { + let info = eval.evaluate(InfoEvaluator::new( + eval.log_size(), + vec![], + LogupSums::default(), + )); + let mut mask_offsets = info.mask_offsets; + mask_offsets.drain(INTERACTION_TRACE_IDX..); + let trace_locations = location_allocator.next_for_structure(&mask_offsets); + Self { + eval, + trace_locations, + } + } +} +impl RelationTracker for RelationTrackerComponent { + fn entries( + self, + trace: &TreeVec>>, + ) -> Vec { + let log_size = self.eval.log_size(); + + // Deref the sub-tree. Only copies the references. + let sub_tree = trace + .sub_tree(&self.trace_locations) + .map(|vec| vec.into_iter().copied().collect_vec()); + let mut entries = vec![]; + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let evaluator = RelationTrackerEvaluator::new(&sub_tree, vec_row, log_size); + entries.extend(self.eval.evaluate(evaluator).summarize()); + } + entries + } +} + +/// Aggregates relation entries. +/// TODO(Ohad): write a summarize function. +pub struct RelationTrackerEvaluator<'a> { + entries: Vec, + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + /// The row index of the simd-vector row to evaluate the constraints at. + pub vec_row: usize, + pub domain_log_size: u32, +} +impl<'a> RelationTrackerEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + domain_log_size: u32, + ) -> Self { + Self { + entries: vec![], + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + domain_log_size, + } + } + + pub fn summarize(self) -> Vec { + self.entries + } +} +impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { + type F = VeryPackedBaseField; + type EF = VeryPackedSecureField; + + // TODO(Ohad): Add debug boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + assert_ne!(interaction, INTERACTION_TRACE_IDX); + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + unsafe { + let col = &self + .trace_eval + .get_unchecked(interaction) + .get_unchecked(col_index) + .values; + let very_packed_col = VeryPackedBaseColumn::transform_under_ref(col); + return *very_packed_col.data.get_unchecked(self.vec_row); + }; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + VeryPackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, + self.domain_log_size, + self.domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + fn add_constraint(&mut self, _constraint: G) {} + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + VeryPackedSecureField::from_very_packed_m31s(values) + } + + fn write_logup_frac(&mut self, _fraction: Fraction) {} + + fn finalize_logup(&mut self) {} + + fn add_to_relation>( + &mut self, + entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + ) { + for entry in entries { + let relation = entry.relation.get_name().to_owned(); + // Unpack VeryPacked. + let values: [Vec; N_VERY_PACKED_ELEMS] = std::array::from_fn(|i| { + entry + .values + .iter() + .map(|vectorized_value| vectorized_value.0[i]) + .collect() + }); + let mults: [PackedQM31; N_VERY_PACKED_ELEMS] = + std::array::from_fn(|i| entry.multiplicity.0[i]); + + for i in 0..N_VERY_PACKED_ELEMS { + let values = values + .iter() + .map(|v| v[i].into_simd().to_array()) + .collect_vec(); + let mult = mults[i].to_array(); + // Unpack SIMD. + for j in 0..N_LANES { + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0].0; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); + } + } + } + } +} diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 36ef3a198..73c624f81 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -41,6 +41,13 @@ impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { } } +/// Converts `&TreeVec<&Vec>` to `TreeVec>`. +impl<'a, T> From<&'a TreeVec<&'a Vec>> for TreeVec> { + fn from(val: &'a TreeVec<&'a Vec>) -> Self { + TreeVec(val.iter().map(|vec| vec.iter().collect()).collect()) + } +} + impl Deref for TreeVec { type Target = Vec; fn deref(&self) -> &Self::Target { diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 2451eef23..af808a67c 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,16 +1,21 @@ use num_traits::{One, Zero}; use crate::constraint_framework::logup::ClaimedPrefixSum; +use crate::constraint_framework::relation_tracker::{ + RelationTracker, RelationTrackerComponent, RelationTrackerEntry, +}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry, - PREPROCESSED_TRACE_IDX, + TraceLocationAllocator, PREPROCESSED_TRACE_IDX, }; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::SimdBackend; use crate::core::channel::Channel; -use crate::core::fields::m31::M31; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; use crate::core::prover::StarkProof; use crate::core::vcs::ops::MerkleHasher; @@ -124,6 +129,40 @@ impl StateMachineComponents { } } +pub fn track_state_machine_relations( + trace: &TreeVec<&Vec>>, + log_n_rows: u32, +) -> Vec { + let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut entries = vec![]; + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<0> { + log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + ) + .entries(&trace.into()), + ); + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<1> { + log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + ) + .entries(&trace.into()), + ); + + entries +} + pub struct StateMachineProof { pub public_input: [State; 2], // Initial and final state. pub stmt0: StateMachineStatement0, diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 787394dd3..44caac446 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -3,9 +3,9 @@ pub mod components; pub mod gen; use components::{ - State, StateMachineComponents, StateMachineElements, StateMachineOp0Component, - StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1, - StateTransitionEval, + track_state_machine_relations, State, StateMachineComponents, StateMachineElements, + StateMachineOp0Component, StateMachineOp1Component, StateMachineProof, StateMachineStatement0, + StateMachineStatement1, StateTransitionEval, }; use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; @@ -19,7 +19,7 @@ use crate::core::backend::simd::SimdBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, PolyOps}; use crate::core::prover::{prove, verify, VerificationError}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; @@ -62,14 +62,21 @@ pub fn prove_state_machine( ]; // Preprocessed trace. - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(gen_preprocessed_columns(preprocessed_columns.iter())); - tree_builder.commit(channel); + let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter()); // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); + let trace = chain![trace_op0.clone(), trace_op1.clone()].collect_vec(); + + let _ = track_state_machine_relations(&TreeVec(vec![&preprocessed_trace, &trace]), log_n_rows); + + // Commitments. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(preprocessed_trace); + tree_builder.commit(channel); + let stmt0 = StateMachineStatement0 { n: x_axis_log_rows, m: y_axis_log_rows,