From d98224c219aee5c8bae7f155923404bd263c535c Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 21 Aug 2024 22:53:03 -0400 Subject: [PATCH] Pass entire mask to components --- .../src/constraint_framework/component.rs | 114 ++++++++++--- crates/prover/src/constraint_framework/mod.rs | 2 +- .../prover/src/constraint_framework/point.rs | 4 +- crates/prover/src/core/air/components.rs | 72 ++------ crates/prover/src/core/air/mod.rs | 14 +- crates/prover/src/core/pcs/prover.rs | 4 +- crates/prover/src/core/pcs/utils.rs | 26 +++ crates/prover/src/core/prover/mod.rs | 104 +++++------- crates/prover/src/examples/blake/air.rs | 105 +++++++----- crates/prover/src/examples/blake/round/mod.rs | 40 +++-- .../examples/blake/scheduler/constraints.rs | 95 ++++++----- .../src/examples/blake/scheduler/mod.rs | 54 +++--- .../src/examples/blake/xor_table/mod.rs | 29 ++-- crates/prover/src/examples/plonk/mod.rs | 42 +++-- crates/prover/src/examples/poseidon/mod.rs | 155 +++++++++--------- .../wide_fibonacci/constraint_eval.rs | 6 +- .../prover/src/examples/wide_fibonacci/mod.rs | 4 +- .../src/examples/wide_fibonacci/simd.rs | 6 +- crates/prover/src/trace_generation/prove.rs | 15 +- 19 files changed, 497 insertions(+), 394 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 5e20b3db7..7b9459641 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -1,11 +1,13 @@ use std::borrow::Cow; +use std::collections::BTreeMap; +use std::ops::Deref; use itertools::Itertools; use tracing::{span, Level}; use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentProver, ComponentTrace}; +use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; @@ -15,36 +17,89 @@ use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::pcs::TreeVec; +use crate::core::pcs::{TreeColumnSpan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::{utils, ColumnVec, InteractionElements, LookupValues}; +// TODO(andrew): Docs. +// TODO(andrew): Consider better location for this. +#[derive(Debug, Default)] +pub struct TreeColumnSpanProvider { + /// Mapping of tree index to next available column offset. + next_tree_offsets: BTreeMap, +} + +impl TreeColumnSpanProvider { + fn next_for_structure(&mut self, structure: &TreeVec>) -> Vec { + structure + .iter() + .enumerate() + .filter_map(|(tree_index, tree)| { + if tree.is_empty() { + return None; + } + + let n_columns = tree.len(); + let next_tree_offset = self.next_tree_offsets.entry(tree_index).or_default(); + let col_start = *next_tree_offset; + let col_end = col_start + n_columns; + *next_tree_offset = col_end; + + Some(TreeColumnSpan { + tree_index, + col_start, + col_end, + }) + }) + .collect() + } +} + /// A component defined solely in means of the constraints framework. -/// Implementing this trait introduces implementations for [Component] and [ComponentProver] for the -/// SIMD backend. +/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for +/// the SIMD backend. /// Note that the constraint framework only support components with columns of the same size. -pub trait FrameworkComponent { +pub trait FrameworkEval { fn log_size(&self) -> u32; + fn max_constraint_log_degree_bound(&self) -> u32; + fn evaluate(&self, eval: E) -> E; } -impl Component for C { +pub struct FrameworkComponentImpl { + eval: C, + trace_locations: Vec, +} + +impl FrameworkComponentImpl { + pub fn new(provider: &mut TreeColumnSpanProvider, eval: E) -> Self { + let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; + let trace_locations = provider.next_for_structure(&eval_tree_structure); + Self { + eval, + trace_locations, + } + } +} + +impl Component for FrameworkComponentImpl { fn n_constraints(&self) -> usize { - self.evaluate(InfoEvaluator::default()).n_constraints + self.eval.evaluate(InfoEvaluator::default()).n_constraints } fn max_constraint_log_degree_bound(&self) -> u32 { - FrameworkComponent::max_constraint_log_degree_bound(self) + self.eval.max_constraint_log_degree_bound() } fn trace_log_degree_bounds(&self) -> TreeVec> { TreeVec::new( - self.evaluate(InfoEvaluator::default()) + self.eval + .evaluate(InfoEvaluator::default()) .mask_offsets .iter() - .map(|tree_masks| vec![self.log_size(); tree_masks.len()]) + .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) .collect(), ) } @@ -53,8 +108,8 @@ impl Component for C { &self, point: CirclePoint, ) -> TreeVec>>> { - let info = self.evaluate(InfoEvaluator::default()); - let trace_step = CanonicCoset::new(self.log_size()).step(); + let info = self.eval.evaluate(InfoEvaluator::default()); + let trace_step = CanonicCoset::new(self.eval.log_size()).step(); info.mask_offsets.map_cols(|col_mask| { col_mask .iter() @@ -71,30 +126,32 @@ impl Component for C { _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, ) { - self.evaluate(PointEvaluator::new( - mask.as_ref(), + self.eval.evaluate(PointEvaluator::new( + mask.sub_tree(&self.trace_locations), evaluation_accumulator, - coset_vanishing(CanonicCoset::new(self.log_size()).coset, point).inverse(), + coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), )); } } -impl ComponentProver for C { +impl ComponentProver for FrameworkComponentImpl { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, SimdBackend>, + trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, ) { let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); - let trace_domain = CanonicCoset::new(self.log_size()); + let trace_domain = CanonicCoset::new(self.eval.log_size()); + + let component_polys = trace.polys.sub_tree(&self.trace_locations); + let component_evals = trace.evals.sub_tree(&self.trace_locations); // Extend trace if necessary. // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good // subdomain. - let need_to_extend = trace - .evals + let need_to_extend = component_evals .iter() .flatten() .any(|c| c.domain != eval_domain); @@ -103,12 +160,11 @@ impl ComponentProver for C { > = if need_to_extend { let _span = span!(Level::INFO, "Extension").entered(); let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); - trace - .polys + component_polys .as_cols_ref() .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) } else { - trace.evals.as_cols_ref().map_cols(|c| Cow::Borrowed(*c)) + component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) }; // Denom inverses. @@ -137,7 +193,7 @@ impl ComponentProver for C { trace_domain.log_size(), eval_domain.log_size(), ); - let row_res = self.evaluate(eval).row_res; + let row_res = self.eval.evaluate(eval).row_res; // Finalize row. unsafe { @@ -150,7 +206,15 @@ impl ComponentProver for C { } } - fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { LookupValues::default() } } + +impl Deref for FrameworkComponentImpl { + type Target = E; + + fn deref(&self) -> &E { + &self.eval + } +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index f0d6ca9be..3087df05d 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -12,7 +12,7 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; -pub use component::FrameworkComponent; +pub use component::{FrameworkComponentImpl, FrameworkEval, TreeColumnSpanProvider}; pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 5bbdb778d..6c6f72f81 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -9,14 +9,14 @@ use crate::core::ColumnVec; /// Evaluates expressions at a point out of domain. pub struct PointEvaluator<'a> { - pub mask: TreeVec<&'a ColumnVec>>, + pub mask: TreeVec>>, pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, pub col_index: Vec, pub denom_inverse: SecureField, } impl<'a> PointEvaluator<'a> { pub fn new( - mask: TreeVec<&'a ColumnVec>>, + mask: TreeVec>>, evaluation_accumulator: &'a mut PointEvaluationAccumulator, denom_inverse: SecureField, ) -> Self { diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 9b3d1200e..397468616 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -1,12 +1,11 @@ -use itertools::{zip_eq, Itertools}; +use itertools::Itertools; use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use super::{Component, ComponentProver, ComponentTrace}; -use crate::core::backend::{Backend, BackendForChannel}; -use crate::core::channel::MerkleChannel; +use super::{Component, ComponentProver, Trace}; +use crate::core::backend::Backend; use crate::core::circle::CirclePoint; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentTreeProver, TreeVec}; +use crate::core::pcs::TreeVec; use crate::core::poly::circle::SecureCirclePoly; use crate::core::{ColumnVec, InteractionElements, LookupValues}; @@ -31,21 +30,21 @@ impl<'a> Components<'a> { pub fn eval_composition_polynomial_at_point( &self, point: CirclePoint, - mask_values: &Vec>>>, + mask_values: &TreeVec>>, random_coeff: SecureField, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) -> SecureField { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); - zip_eq(&self.0, mask_values).for_each(|(component, mask)| { + for component in &self.0 { component.evaluate_constraint_quotients_at_point( point, - mask, + mask_values, &mut evaluation_accumulator, interaction_elements, lookup_values, ) - }); + } evaluation_accumulator.finalize() } @@ -67,7 +66,7 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { pub fn compute_composition_polynomial( &self, random_coeff: SecureField, - component_traces: &[ComponentTrace<'_, B>], + trace: &Trace<'_, B>, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) -> SecureCirclePoly { @@ -77,63 +76,22 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { self.components().composition_log_degree_bound(), total_constraints, ); - zip_eq(&self.0, component_traces).for_each(|(component, trace)| { + for component in &self.0 { component.evaluate_constraint_quotients_on_domain( trace, &mut accumulator, interaction_elements, lookup_values, ) - }); + } accumulator.finalize() } - pub fn component_traces<'b, MC: MerkleChannel>( - &'b self, - trees: &'b [CommitmentTreeProver], - ) -> Vec> - where - B: BackendForChannel, - { - let mut poly_iters = trees - .iter() - .map(|tree| tree.polynomials.iter()) - .collect_vec(); - let mut eval_iters = trees - .iter() - .map(|tree| tree.evaluations.iter()) - .collect_vec(); - - self.0 - .iter() - .map(|component| { - let col_sizes_per_tree = component - .trace_log_degree_bounds() - .iter() - .map(|col_sizes| col_sizes.len()) - .collect_vec(); - let polys = col_sizes_per_tree - .iter() - .zip(poly_iters.iter_mut()) - .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) - .collect_vec(); - let evals = col_sizes_per_tree - .iter() - .zip(eval_iters.iter_mut()) - .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) - .collect_vec(); - ComponentTrace { - polys: TreeVec::new(polys), - evals: TreeVec::new(evals), - } - }) - .collect_vec() - } - - pub fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues { + pub fn lookup_values(&self, trace: &Trace<'_, B>) -> LookupValues { let mut values = LookupValues::default(); - zip_eq(&self.0, component_traces) - .for_each(|(component, trace)| values.extend(component.lookup_values(trace))); + for component in &self.0 { + values.extend(component.lookup_values(trace)) + } values } } diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index efd2d23c5..0e2a71750 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -62,14 +62,24 @@ pub trait ComponentProver: Component { /// Accumulates quotients in `evaluation_accumulator`. fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, B>, + trace: &Trace<'_, B>, evaluation_accumulator: &mut DomainEvaluationAccumulator, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ); /// Returns the values needed to evaluate the components lookup boundary constraints. - fn lookup_values(&self, _trace: &ComponentTrace<'_, B>) -> LookupValues; + fn lookup_values(&self, _trace: &Trace<'_, B>) -> LookupValues; +} + +/// The set of polynomials that make up the trace. +/// +/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency) +pub struct Trace<'a, B: Backend> { + /// Polynomials for each column. + pub polys: TreeVec>>, + /// Evaluations for each column (evaluated on their commitment domains). + pub evals: TreeVec>>, } /// A component trace is a set of polynomials for each column on that component. diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 425b4581a..90612da33 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -66,7 +66,9 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, .map(|tree| tree.polynomials.iter().collect()) } - fn evaluations(&self) -> TreeVec>> { + pub fn evaluations( + &self, + ) -> TreeVec>> { self.trees .as_ref() .map(|tree| tree.evaluations.iter().collect()) diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index bd1c6f9ca..9961f36e9 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -1,8 +1,10 @@ +use std::collections::BTreeSet; use std::ops::{Deref, DerefMut}; use itertools::zip_eq; use serde::{Deserialize, Serialize}; +use super::TreeColumnSpan; use crate::core::ColumnVec; /// A container that holds an element for each commitment tree. @@ -110,6 +112,30 @@ impl TreeVec> { } result } + + pub fn get_chunk(&self, location: TreeColumnSpan) -> Option> { + let tree = self.0.get(location.tree_index)?; + let chunk = tree.get(location.col_start..location.col_end)?; + Some(chunk.iter().collect()) + } + + /// # Panics + /// + /// If two or more locations have the same tree index. + pub fn sub_tree(&self, locations: &[TreeColumnSpan]) -> TreeVec> { + let tree_indicies: BTreeSet = locations.iter().map(|l| l.tree_index).collect(); + assert_eq!(tree_indicies.len(), locations.len()); + let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); + let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); + + for &location in locations { + // TODO(andrew): Throwing error here might be better instead. + let chunk = self.get_chunk(location).unwrap(); + res[location.tree_index] = chunk; + } + + res + } } impl<'a, T> From<&'a TreeVec>> for TreeVec> { diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index b4cf0c537..a924f1e4a 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -1,16 +1,15 @@ -use itertools::Itertools; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{span, Level}; -use super::air::{Component, ComponentProver, ComponentProvers, Components}; +use super::air::{Component, ComponentProver, ComponentProvers, Components, Trace}; use super::backend::BackendForChannel; use super::channel::MerkleChannel; use super::fields::secure_column::SECURE_EXTENSION_DEGREE; use super::fri::FriVerificationError; use super::pcs::{CommitmentSchemeProof, TreeVec}; use super::vcs::ops::MerkleHasher; -use super::{ColumnVec, InteractionElements, LookupValues}; +use super::{InteractionElements, LookupValues}; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::circle::CirclePoint; @@ -41,9 +40,13 @@ pub fn prove, MC: MerkleChannel>( interaction_elements: &InteractionElements, commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, ) -> Result, ProvingError> { + let trace = Trace { + polys: commitment_scheme.polynomials(), + evals: commitment_scheme.evaluations(), + }; + let component_provers = ComponentProvers(components.to_vec()); - let component_traces = component_provers.component_traces(&commitment_scheme.trees); - let lookup_values = component_provers.lookup_values(&component_traces); + let lookup_values = component_provers.lookup_values(&trace); // Evaluate and commit on composition polynomial. let random_coeff = channel.draw_felt(); @@ -52,7 +55,7 @@ pub fn prove, MC: MerkleChannel>( let span1 = span!(Level::INFO, "Generation").entered(); let composition_polynomial_poly = component_provers.compute_composition_polynomial( random_coeff, - &component_traces, + &trace, interaction_elements, &lookup_values, ); @@ -74,21 +77,18 @@ pub fn prove, MC: MerkleChannel>( // Prove the trace and composition OODS values, and retrieve them. let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); - // Evaluate composition polynomial at OODS point and check that it matches the trace OODS - // values. This is a sanity check. // TODO(spapini): Save clone. - let (trace_oods_values, composition_oods_value) = sampled_values_to_mask( - &component_provers.components(), - &commitment_scheme_proof.sampled_values, - ) - .unwrap(); + let sampled_oods_values = &commitment_scheme_proof.sampled_values; + let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); - if composition_oods_value + // Evaluate composition polynomial at OODS point and check that it matches the trace OODS + // values. This is a sanity check. + if composition_oods_eval != component_provers .components() .eval_composition_polynomial_at_point( oods_point, - &trace_oods_values, + sampled_oods_values, random_coeff, interaction_elements, &lookup_values, @@ -130,18 +130,15 @@ pub fn verify( sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); // TODO(spapini): Save clone. - let (trace_oods_values, composition_oods_value) = - sampled_values_to_mask(&components, &proof.commitment_scheme_proof.sampled_values) - .map_err(|_| { - VerificationError::InvalidStructure( - "Unexpected sampled_values structure".to_string(), - ) - })?; - - if composition_oods_value + let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values; + let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { + VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) + })?; + + if composition_oods_eval != components.eval_composition_polynomial_at_point( oods_point, - &trace_oods_values, + &proof.commitment_scheme_proof.sampled_values, random_coeff, interaction_elements, &proof.lookup_values, @@ -153,41 +150,30 @@ pub fn verify( commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) } -#[allow(clippy::type_complexity)] -/// Structures the tree-wise sampled values into component-wise OODS values and a composition -/// polynomial OODS value. -fn sampled_values_to_mask( - components: &Components<'_>, - sampled_values: &TreeVec>>, -) -> Result<(Vec>>>, SecureField), InvalidOodsSampleStructure> { - let mut sampled_values = sampled_values.as_ref(); - let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?; - - let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter()); - let trace_oods_values = components - .0 - .iter() - .map(|component| { - component - .mask_points(CirclePoint::zero()) - .zip(sample_iters.as_mut()) - .map(|(mask_per_tree, tree_iter)| { - tree_iter.take(mask_per_tree.len()).cloned().collect_vec() - }) - }) - .collect_vec(); - - let composition_oods_value = SecureField::from_partial_evals( - composition_values - .iter() - .flatten() - .cloned() - .collect_vec() - .try_into() - .map_err(|_| InvalidOodsSampleStructure)?, - ); +/// Extracts the composition trace evaluation from the mask. +fn extract_composition_eval( + mask: &TreeVec>>, +) -> Result { + let mut composition_cols = mask.last().into_iter().flatten(); + + let col0 = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let col1 = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let col2 = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let col3 = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + + // Too many columns. + if composition_cols.next().is_some() { + return Err(InvalidOodsSampleStructure); + } + + let [eval0] = col0.try_into().map_err(|_| InvalidOodsSampleStructure)?; + let [eval1] = col1.try_into().map_err(|_| InvalidOodsSampleStructure)?; + let [eval2] = col2.try_into().map_err(|_| InvalidOodsSampleStructure)?; + let [eval3] = col3.try_into().map_err(|_| InvalidOodsSampleStructure)?; - Ok((trace_oods_values, composition_oods_value)) + Ok(SecureField::from_partial_evals([ + eval0, eval1, eval2, eval3, + ])) } /// Error when the sampled values have an invalid structure. diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 809c2ccde..297e192a6 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -5,9 +5,10 @@ use num_traits::Zero; use serde::Serialize; use tracing::{span, Level}; -use super::round::{blake_round_info, BlakeRoundComponent}; -use super::scheduler::BlakeSchedulerComponent; -use super::xor_table::XorTableComponent; +use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; +use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; +use super::xor_table::{XorTableComponent, XorTableEval}; +use crate::constraint_framework::{FrameworkComponentImpl, TreeColumnSpanProvider}; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; @@ -111,52 +112,76 @@ pub struct BlakeProof { pub struct BlakeComponents { scheduler_component: BlakeSchedulerComponent, - round_components: Vec, - xor12: XorTableComponent<12, 4>, - xor9: XorTableComponent<9, 2>, - xor8: XorTableComponent<8, 2>, - xor7: XorTableComponent<7, 2>, - xor4: XorTableComponent<4, 0>, + round_components: Vec>, + xor12: FrameworkComponentImpl>, + xor9: FrameworkComponentImpl>, + xor8: FrameworkComponentImpl>, + xor7: FrameworkComponentImpl>, + xor4: FrameworkComponentImpl>, } impl BlakeComponents { fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { + let tree_span_provider = &mut TreeColumnSpanProvider::default(); Self { - scheduler_component: BlakeSchedulerComponent { - log_size: stmt0.log_size, - blake_lookup_elements: all_elements.blake_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum: stmt1.scheduler_claimed_sum, - }, + scheduler_component: BlakeSchedulerComponent::new( + tree_span_provider, + BlakeSchedulerEval { + log_size: stmt0.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum: stmt1.scheduler_claimed_sum, + }, + ), round_components: ROUND_LOG_SPLIT .iter() .zip(stmt1.round_claimed_sums.clone()) - .map(|(l, claimed_sum)| BlakeRoundComponent { - log_size: stmt0.log_size + l, - xor_lookup_elements: all_elements.xor_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum, + .map(|(l, claimed_sum)| { + BlakeRoundComponent::new( + tree_span_provider, + BlakeRoundEval { + log_size: stmt0.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum, + }, + ) }) .collect(), - xor12: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor12.clone(), - claimed_sum: stmt1.xor12_claimed_sum, - }, - xor9: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor9.clone(), - claimed_sum: stmt1.xor9_claimed_sum, - }, - xor8: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor8.clone(), - claimed_sum: stmt1.xor8_claimed_sum, - }, - xor7: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor7.clone(), - claimed_sum: stmt1.xor7_claimed_sum, - }, - xor4: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor4.clone(), - claimed_sum: stmt1.xor4_claimed_sum, - }, + xor12: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + claimed_sum: stmt1.xor12_claimed_sum, + }, + ), + xor9: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor9.clone(), + claimed_sum: stmt1.xor9_claimed_sum, + }, + ), + xor8: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor8.clone(), + claimed_sum: stmt1.xor8_claimed_sum, + }, + ), + xor7: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor7.clone(), + claimed_sum: stmt1.xor7_claimed_sum, + }, + ), + xor4: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor4.clone(), + claimed_sum: stmt1.xor4_claimed_sum, + }, + ), } } fn components(&self) -> Vec<&dyn Component> { diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index c5123b5ce..d03397c23 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,34 +1,28 @@ mod constraints; mod gen; -use constraints::BlakeRoundEval; +pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; use num_traits::Zero; -pub use r#gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponentImpl, FrameworkEval, InfoEvaluator, +}; use crate::core::fields::qm31::SecureField; -pub fn blake_round_info() -> InfoEvaluator { - let component = BlakeRoundComponent { - log_size: 1, - xor_lookup_elements: BlakeXorElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} +pub type BlakeRoundComponent = FrameworkComponentImpl; pub type RoundElements = LookupElements; -pub struct BlakeRoundComponent { + +pub struct BlakeRoundEval { pub log_size: u32, pub xor_lookup_elements: BlakeXorElements, pub round_lookup_elements: RoundElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for BlakeRoundComponent { +impl FrameworkEval for BlakeRoundEval { fn log_size(&self) -> u32 { self.log_size } @@ -36,7 +30,7 @@ impl FrameworkComponent for BlakeRoundComponent { self.log_size + 1 } fn evaluate(&self, eval: E) -> E { - let blake_eval = BlakeRoundEval { + let blake_eval = constraints::BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, @@ -46,6 +40,16 @@ impl FrameworkComponent for BlakeRoundComponent { } } +pub fn blake_round_info() -> InfoEvaluator { + let component = BlakeRoundEval { + log_size: 1, + xor_lookup_elements: BlakeXorElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + #[cfg(test)] mod tests { use std::simd::Simd; @@ -53,12 +57,12 @@ mod tests { use itertools::Itertools; use crate::constraint_framework::constant_columns::gen_is_first; - use crate::constraint_framework::FrameworkComponent; + use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::r#gen::{ generate_interaction_trace, generate_trace, BlakeRoundInput, }; - use crate::examples::blake::round::{BlakeRoundComponent, RoundElements}; + use crate::examples::blake::round::{BlakeRoundEval, RoundElements}; use crate::examples::blake::{BlakeXorElements, XorAccums}; #[test] @@ -91,7 +95,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = BlakeRoundComponent { + let component = BlakeRoundEval { log_size: LOG_SIZE, xor_lookup_elements, round_lookup_elements, diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 9d5057fed..63b3cf696 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -8,60 +8,57 @@ use crate::core::vcs::blake2s_ref::SIGMA; use crate::examples::blake::round::RoundElements; use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; -pub struct BlakeSchedulerEval<'a, E: EvalAtRow> { - pub eval: E, - pub blake_lookup_elements: &'a BlakeElements, - pub round_lookup_elements: &'a RoundElements, - pub logup: LogupAtRow<2, E>, -} -impl<'a, E: EvalAtRow> BlakeSchedulerEval<'a, E> { - pub fn eval(mut self) -> E { - let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = - std::array::from_fn(|_| std::array::from_fn(|_| self.next_u32())); - - // Schedule. - for i in 0..N_ROUNDS { - let input_state = &states[i]; - let output_state = &states[i + 1]; - let round_messages = SIGMA[i].map(|j| messages[j as usize]); - // Use triplet in round lookup. - self.logup.push_lookup( - &mut self.eval, - E::EF::one(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - round_messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - self.round_lookup_elements, - ) - } - - let input_state = &states[0]; - let output_state = &states[N_ROUNDS]; +pub fn eval_blake_scheduler_constraints( + eval: &mut E, + blake_lookup_elements: &BlakeElements, + round_lookup_elements: &RoundElements, + mut logup: LogupAtRow<2, E>, +) { + let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); + let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = + std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); - // TODO(spapini): Support multiplicities. - // TODO(spapini): Change to -1. - self.logup.push_lookup( - &mut self.eval, - E::EF::zero(), + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + logup.push_lookup( + eval, + E::EF::one(), &chain![ input_state.iter().copied().flat_map(Fu32::to_felts), output_state.iter().copied().flat_map(Fu32::to_felts), - messages.iter().copied().flat_map(Fu32::to_felts) + round_messages.iter().copied().flat_map(Fu32::to_felts) ] .collect_vec(), - self.blake_lookup_elements, - ); - - self.logup.finalize(&mut self.eval); - self.eval - } - fn next_u32(&mut self) -> Fu32 { - let l = self.eval.next_trace_mask(); - let h = self.eval.next_trace_mask(); - Fu32 { l, h } + round_lookup_elements, + ) } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + + // TODO(spapini): Support multiplicities. + // TODO(spapini): Change to -1. + logup.push_lookup( + eval, + E::EF::zero(), + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + blake_lookup_elements, + ); + + logup.finalize(eval); +} + +fn eval_next_u32(eval: &mut E) -> Fu32 { + let l = eval.next_trace_mask(); + let h = eval.next_trace_mask(); + Fu32 { l, h } } diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index 116053246..4141f58d8 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -1,63 +1,67 @@ mod constraints; mod gen; -use constraints::BlakeSchedulerEval; +use constraints::eval_blake_scheduler_constraints; pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; use num_traits::Zero; use super::round::RoundElements; use super::N_ROUND_INPUT_FELTS; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponentImpl, FrameworkEval, InfoEvaluator, +}; use crate::core::fields::qm31::SecureField; -pub type BlakeElements = LookupElements; +pub type BlakeSchedulerComponent = FrameworkComponentImpl; -pub fn blake_scheduler_info() -> InfoEvaluator { - let component = BlakeSchedulerComponent { - log_size: 1, - blake_lookup_elements: BlakeElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} +pub type BlakeElements = LookupElements; -pub struct BlakeSchedulerComponent { +pub struct BlakeSchedulerEval { pub log_size: u32, pub blake_lookup_elements: BlakeElements, pub round_lookup_elements: RoundElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for BlakeSchedulerComponent { +impl FrameworkEval for BlakeSchedulerEval { fn log_size(&self) -> u32 { self.log_size } fn max_constraint_log_degree_bound(&self) -> u32 { self.log_size + 1 } - fn evaluate(&self, eval: E) -> E { - let blake_eval = BlakeSchedulerEval { - eval, - blake_lookup_elements: &self.blake_lookup_elements, - round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), - }; - blake_eval.eval() + fn evaluate(&self, mut eval: E) -> E { + eval_blake_scheduler_constraints( + &mut eval, + &self.blake_lookup_elements, + &self.round_lookup_elements, + LogupAtRow::new(1, self.claimed_sum, self.log_size), + ); + eval } } +pub fn blake_scheduler_info() -> InfoEvaluator { + let component = BlakeSchedulerEval { + log_size: 1, + blake_lookup_elements: BlakeElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + #[cfg(test)] mod tests { use std::simd::Simd; use itertools::Itertools; - use crate::constraint_framework::FrameworkComponent; + use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::RoundElements; use crate::examples::blake::scheduler::r#gen::{gen_interaction_trace, gen_trace, BlakeInput}; - use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerComponent}; + use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerEval}; #[test] fn test_blake_scheduler() { @@ -87,7 +91,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = BlakeSchedulerComponent { + let component = BlakeSchedulerEval { log_size: LOG_SIZE, blake_lookup_elements, round_lookup_elements, diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 21c417cfd..5cd8053c0 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -15,20 +15,21 @@ mod gen; use std::simd::u32x16; -use constraints::XorTableEval; use itertools::Itertools; use num_traits::Zero; pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponentImpl, FrameworkEval, InfoEvaluator, +}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::Column; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::TreeVec; +use crate::core::pcs::{TreeColumnSpan, TreeVec}; pub fn trace_sizes() -> TreeVec> { - let component = XorTableComponent:: { + let component = XorTableEval:: { lookup_elements: LookupElements::<3>::dummy(), claimed_sum: SecureField::zero(), }; @@ -83,13 +84,19 @@ impl XorAccumulator = + FrameworkComponentImpl>; + pub type XorElements = LookupElements<3>; -pub struct XorTableComponent { + +/// Evaluates the xor table. +pub struct XorTableEval { pub lookup_elements: XorElements, pub claimed_sum: SecureField, } -impl FrameworkComponent - for XorTableComponent + +impl FrameworkEval + for XorTableEval { fn log_size(&self) -> u32 { column_bits::() @@ -98,7 +105,7 @@ impl FrameworkComponent column_bits::() + 1 } fn evaluate(&self, mut eval: E) -> E { - let xor_eval = XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { + let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { eval, lookup_elements: &self.lookup_elements, logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), @@ -112,12 +119,12 @@ mod tests { use std::simd::u32x16; use crate::constraint_framework::logup::LookupElements; - use crate::constraint_framework::{assert_constraints, FrameworkComponent}; + use crate::constraint_framework::{assert_constraints, FrameworkEval}; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::xor_table::r#gen::{ generate_constant_trace, generate_interaction_trace, generate_trace, }; - use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableComponent}; + use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableEval}; #[test] fn test_xor_table() { @@ -138,7 +145,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace, constant_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = XorTableComponent:: { + let component = XorTableEval:: { lookup_elements, claimed_sum, }; diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 79a48b14f..8a6bb9113 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -3,7 +3,9 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{assert_constraints, EvalAtRow, FrameworkComponent}; +use crate::constraint_framework::{ + assert_constraints, EvalAtRow, FrameworkComponentImpl, FrameworkEval, TreeColumnSpanProvider, +}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::qm31::PackedSecureField; @@ -12,21 +14,26 @@ use crate::core::backend::Column; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeColumnSpan}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::{ColumnVec, InteractionElements}; +pub type PlonkComponent = FrameworkComponentImpl; + #[derive(Clone)] -pub struct PlonkComponent { +pub struct PlonkEval { pub log_n_rows: u32, pub lookup_elements: LookupElements<2>, pub claimed_sum: SecureField, + pub base_trace_location: TreeColumnSpan, + pub interaction_trace_location: TreeColumnSpan, + pub constants_trace_location: TreeColumnSpan, } -impl FrameworkComponent for PlonkComponent { +impl FrameworkEval for PlonkEval { fn log_size(&self) -> u32 { self.log_n_rows } @@ -142,7 +149,10 @@ pub fn gen_interaction_trace( pub fn prove_fibonacci_plonk( log_n_rows: u32, config: PcsConfig, -) -> (PlonkComponent, StarkProof) { +) -> ( + FrameworkComponentImpl, + StarkProof, +) { assert!(log_n_rows >= LOG_N_LANES); // Prepare a fibonacci circuit. @@ -181,7 +191,7 @@ pub fn prove_fibonacci_plonk( let span = span!(Level::INFO, "Trace").entered(); let trace = gen_trace(log_n_rows, &circuit); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); + let base_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); span.exit(); @@ -192,14 +202,14 @@ pub fn prove_fibonacci_plonk( let span = span!(Level::INFO, "Interaction").entered(); let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); + let interaction_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); span.exit(); // Constant trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( + let constants_trace_location = tree_builder.extend_evals( chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] .into_iter() .map(|col| { @@ -214,11 +224,17 @@ pub fn prove_fibonacci_plonk( span.exit(); // Prove constraints. - let component = PlonkComponent { - log_n_rows, - lookup_elements, - claimed_sum, - }; + let component = PlonkComponent::new( + &mut TreeColumnSpanProvider::default(), + PlonkEval { + log_n_rows, + lookup_elements, + claimed_sum, + base_trace_location, + interaction_trace_location, + constants_trace_location, + }, + ); // Sanity check. Remove for production. let trace_polys = commitment_scheme diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 63af1d03f..f48b4c782 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -7,7 +7,9 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponentImpl, FrameworkEval, TreeColumnSpanProvider, +}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -39,28 +41,27 @@ const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; +pub type PoseidonComponent = FrameworkComponentImpl; + pub type PoseidonElements = LookupElements<{ N_STATE * 2 }>; #[derive(Clone)] -pub struct PoseidonComponent { +pub struct PoseidonEval { pub log_n_rows: u32, pub lookup_elements: PoseidonElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for PoseidonComponent { +impl FrameworkEval for PoseidonEval { fn log_size(&self) -> u32 { self.log_n_rows } fn max_constraint_log_degree_bound(&self) -> u32 { self.log_n_rows + LOG_EXPAND } - fn evaluate(&self, eval: E) -> E { - let poseidon_eval = PoseidonEval { - eval, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_n_rows), - lookup_elements: &self.lookup_elements, - }; - poseidon_eval.eval() + fn evaluate(&self, mut eval: E) -> E { + let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); + eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); + eval } } @@ -133,67 +134,60 @@ fn pow5(x: F) -> F { x4 * x } -struct PoseidonEval<'a, E: EvalAtRow> { - eval: E, - logup: LogupAtRow<2, E>, - lookup_elements: &'a PoseidonElements, -} - -impl<'a, E: EvalAtRow> PoseidonEval<'a, E> { - fn eval(mut self) -> E { - for _ in 0..N_INSTANCES_PER_ROW { - let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); +pub fn eval_poseidon_constraints( + eval: &mut E, + mut logup: LogupAtRow<2, E>, + lookup_elements: &PoseidonElements, +) { + for _ in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); - // Require state lookup. - self.logup - .push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements); + // Require state lookup. + logup.push_lookup(eval, E::EF::one(), &state, lookup_elements); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; - }); + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += INTERNAL_ROUND_CONSTS[round]; - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(state[0] - m); - state[0] = m; + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; }); + }); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; - }); - }); + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + let m = eval.next_trace_mask(); + eval.add_constraint(state[0] - m); + state[0] = m; + }); - // Provide state lookup. - self.logup - .push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements); - } + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; + }); + }); - self.logup.finalize(&mut self.eval); - self.eval + // Provide state lookup. + logup.push_lookup(eval, -E::EF::one(), &state, lookup_elements); } + + logup.finalize(eval); } pub struct LookupData { @@ -327,7 +321,10 @@ pub fn gen_interaction_trace( pub fn prove_poseidon( log_n_instances: u32, config: PcsConfig, -) -> (PoseidonComponent, StarkProof) { +) -> ( + FrameworkComponentImpl, + StarkProof, +) { assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; @@ -364,11 +361,14 @@ pub fn prove_poseidon( span.exit(); // Prove constraints. - let component = PoseidonComponent { - log_n_rows, - lookup_elements, - claimed_sum, - }; + let component = PoseidonComponent::new( + &mut TreeColumnSpanProvider::default(), + PoseidonEval { + log_n_rows, + lookup_elements, + claimed_sum, + }, + ); let proof = prove::( &[&component], channel, @@ -399,8 +399,8 @@ mod tests { use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::core::InteractionElements; use crate::examples::poseidon::{ - apply_internal_round_matrix, apply_m4, gen_interaction_trace, gen_trace, prove_poseidon, - PoseidonElements, PoseidonEval, + apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, + gen_trace, prove_poseidon, PoseidonElements, }; use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; @@ -467,13 +467,12 @@ mod tests { let traces = TreeVec::new(vec![trace0, trace1]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| { - PoseidonEval { - eval, - logup: LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), - lookup_elements: &lookup_elements, - } - .eval(); + assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { + eval_poseidon_constraints( + &mut eval, + LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), + &lookup_elements, + ); }); } diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index 3aa983d09..bd3f9025f 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -9,7 +9,7 @@ use super::component::{ }; use super::trace_gen::write_trace_row; use crate::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; -use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace}; +use crate::core::air::{AirProver, Component, ComponentProver, Trace}; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::circle::Coset; @@ -256,7 +256,7 @@ impl WideFibComponent { impl ComponentProver for WideFibComponent { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, CpuBackend>, + trace: &Trace<'_, CpuBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, interaction_elements: &InteractionElements, lookup_values: &LookupValues, @@ -300,7 +300,7 @@ impl ComponentProver for WideFibComponent { ); } - fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + fn lookup_values(&self, trace: &Trace<'_, CpuBackend>) -> LookupValues { let domain = CanonicCoset::new(self.log_column_size()); let trace_poly = &trace.polys[BASE_TRACE]; let values = BTreeMap::from_iter([ diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 7976324e6..15afa3294 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -13,7 +13,7 @@ mod tests { use super::component::{Input, WideFibAir, WideFibComponent, LOG_N_COLUMNS}; use super::constraint_eval::gen_trace; use crate::core::air::accumulation::DomainEvaluationAccumulator; - use crate::core::air::{Component, ComponentProver, ComponentTrace}; + use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; @@ -183,7 +183,7 @@ mod tests { .iter() .map(|poly| poly.evaluate(eval_domain)) .collect_vec(); - let trace = ComponentTrace { + let trace = Trace { polys: TreeVec::new(vec![ trace_polys.iter().collect_vec(), interaction_poly.iter().collect_vec(), diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index 8d12f0cd5..a7ed4dcec 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -5,7 +5,7 @@ use tracing::{span, Level}; use super::component::LOG_N_COLUMNS; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; +use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -193,7 +193,7 @@ impl ComponentTraceGenerator for SimdWideFibComponent { impl ComponentProver for SimdWideFibComponent { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, SimdBackend>, + trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, @@ -250,7 +250,7 @@ impl ComponentProver for SimdWideFibComponent { } } - fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { LookupValues::default() } } diff --git a/crates/prover/src/trace_generation/prove.rs b/crates/prover/src/trace_generation/prove.rs index 93d6e4bf5..789d506eb 100644 --- a/crates/prover/src/trace_generation/prove.rs +++ b/crates/prover/src/trace_generation/prove.rs @@ -3,7 +3,7 @@ use thiserror::Error; use tracing::{span, Level}; use super::{AirTraceGenerator, AirTraceVerifier, BASE_TRACE, INTERACTION_TRACE}; -use crate::core::air::{Air, AirProver, ComponentProvers, Components}; +use crate::core::air::{Air, AirProver, ComponentProvers, Components, Trace}; use crate::core::backend::BackendForChannel; use crate::core::channel::{Channel, MerkleChannel}; use crate::core::fields::m31::BaseField; @@ -58,11 +58,16 @@ pub fn commit_and_prove, MC: MerkleChannel>( let (mut commitment_scheme, interaction_elements) = evaluate_and_commit_on_trace(air, channel, &twiddles, trace, config)?; + let trace = Trace { + polys: commitment_scheme.polynomials(), + evals: commitment_scheme.evaluations(), + }; + let air_prover = &air.to_air_prover(); let components = ComponentProvers(air_prover.component_provers()); channel.mix_felts( &components - .lookup_values(&components.component_traces(&commitment_scheme.trees)) + .lookup_values(&trace) .0 .values() .map(|v| SecureField::from(*v)) @@ -175,7 +180,7 @@ mod tests { use num_traits::Zero; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; - use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; + use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; @@ -310,7 +315,7 @@ mod tests { impl ComponentProver for TestComponent { fn evaluate_constraint_quotients_on_domain( &self, - _trace: &ComponentTrace<'_, CpuBackend>, + _trace: &Trace<'_, CpuBackend>, _evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, @@ -318,7 +323,7 @@ mod tests { // Does nothing. } - fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, CpuBackend>) -> LookupValues { LookupValues::default() } }