Skip to content

Commit

Permalink
Pass entire mask to components
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 24, 2024
1 parent 573ee79 commit 62d5e0f
Show file tree
Hide file tree
Showing 19 changed files with 484 additions and 386 deletions.
114 changes: 89 additions & 25 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::ops::Deref;

use itertools::Itertools;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, ComponentTrace};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
Expand All @@ -15,36 +17,89 @@ use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::pcs::{TreeColumnSpan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec, InteractionElements, LookupValues};

// TODO(andrew): Docs.
// TODO(andrew): Consider better location for this.
#[derive(Debug, Default)]
pub struct TreeColumnSpanProvider {
/// Mapping of tree index to next available column offset.
next_tree_offsets: BTreeMap<usize, usize>,
}

impl TreeColumnSpanProvider {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> Vec<TreeColumnSpan> {
structure
.iter()
.enumerate()
.filter_map(|(tree_index, tree)| {
if tree.is_empty() {
return None;
}

let n_columns = tree.len();
let next_tree_offset = self.next_tree_offsets.entry(tree_index).or_default();
let col_start = *next_tree_offset;
let col_end = col_start + n_columns;
*next_tree_offset = col_end;

Some(TreeColumnSpan {
tree_index,
col_start,
col_end,
})
})
.collect()
}
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [Component] and [ComponentProver] for the
/// SIMD backend.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkComponent {
pub trait FrameworkEval {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;

fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
}

impl<C: FrameworkComponent> Component for C {
pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: Vec<TreeColumnSpan>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TreeColumnSpanProvider, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
fn n_constraints(&self) -> usize {
self.evaluate(InfoEvaluator::default()).n_constraints
self.eval.evaluate(InfoEvaluator::default()).n_constraints
}

fn max_constraint_log_degree_bound(&self) -> u32 {
FrameworkComponent::max_constraint_log_degree_bound(self)
self.eval.max_constraint_log_degree_bound()
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.evaluate(InfoEvaluator::default())
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.log_size(); tree_masks.len()])
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
}
Expand All @@ -53,8 +108,8 @@ impl<C: FrameworkComponent> Component for C {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.log_size()).step();
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
.iter()
Expand All @@ -71,30 +126,32 @@ impl<C: FrameworkComponent> Component for C {
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
self.evaluate(PointEvaluator::new(
mask.as_ref(),
self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.log_size()).coset, point).inverse(),
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
}
}

impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, SimdBackend>,
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.log_size());
let trace_domain = CanonicCoset::new(self.eval.log_size());

let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);

// Extend trace if necessary.
// TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good
// subdomain.
let need_to_extend = trace
.evals
let need_to_extend = component_evals
.iter()
.flatten()
.any(|c| c.domain != eval_domain);
Expand All @@ -103,12 +160,11 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
> = if need_to_extend {
let _span = span!(Level::INFO, "Extension").entered();
let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset);
trace
.polys
component_polys
.as_cols_ref()
.map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles)))
} else {
trace.evals.as_cols_ref().map_cols(|c| Cow::Borrowed(*c))
component_evals.clone().map_cols(|c| Cow::Borrowed(*c))
};

// Denom inverses.
Expand Down Expand Up @@ -137,7 +193,7 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.evaluate(eval).row_res;
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
unsafe {
Expand All @@ -150,7 +206,15 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
}
}

fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues {
fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues {
LookupValues::default()
}
}

impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
type Target = E;

fn deref(&self) -> &E {
&self.eval
}
}
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::FrameworkComponent;
pub use component::{FrameworkComponent, FrameworkEval, TreeColumnSpanProvider};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use crate::core::ColumnVec;

/// Evaluates expressions at a point out of domain.
pub struct PointEvaluator<'a> {
pub mask: TreeVec<&'a ColumnVec<Vec<SecureField>>>,
pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
pub col_index: Vec<usize>,
pub denom_inverse: SecureField,
}
impl<'a> PointEvaluator<'a> {
pub fn new(
mask: TreeVec<&'a ColumnVec<Vec<SecureField>>>,
mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
evaluation_accumulator: &'a mut PointEvaluationAccumulator,
denom_inverse: SecureField,
) -> Self {
Expand Down
72 changes: 15 additions & 57 deletions crates/prover/src/core/air/components.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use itertools::{zip_eq, Itertools};
use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, ComponentTrace};
use crate::core::backend::{Backend, BackendForChannel};
use crate::core::channel::MerkleChannel;
use super::{Component, ComponentProver, Trace};
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::{ColumnVec, InteractionElements, LookupValues};

Expand All @@ -31,21 +30,21 @@ impl<'a> Components<'a> {
pub fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &Vec<TreeVec<Vec<Vec<SecureField>>>>,
mask_values: &TreeVec<Vec<Vec<SecureField>>>,
random_coeff: SecureField,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(&self.0, mask_values).for_each(|(component, mask)| {
for component in &self.0 {
component.evaluate_constraint_quotients_at_point(
point,
mask,
mask_values,
&mut evaluation_accumulator,
interaction_elements,
lookup_values,
)
});
}
evaluation_accumulator.finalize()
}

Expand All @@ -67,7 +66,7 @@ impl<'a, B: Backend> ComponentProvers<'a, B> {
pub fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, B>],
trace: &Trace<'_, B>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureCirclePoly<B> {
Expand All @@ -77,63 +76,22 @@ impl<'a, B: Backend> ComponentProvers<'a, B> {
self.components().composition_log_degree_bound(),
total_constraints,
);
zip_eq(&self.0, component_traces).for_each(|(component, trace)| {
for component in &self.0 {
component.evaluate_constraint_quotients_on_domain(
trace,
&mut accumulator,
interaction_elements,
lookup_values,
)
});
}
accumulator.finalize()
}

pub fn component_traces<'b, MC: MerkleChannel>(
&'b self,
trees: &'b [CommitmentTreeProver<B, MC>],
) -> Vec<ComponentTrace<'b, B>>
where
B: BackendForChannel<MC>,
{
let mut poly_iters = trees
.iter()
.map(|tree| tree.polynomials.iter())
.collect_vec();
let mut eval_iters = trees
.iter()
.map(|tree| tree.evaluations.iter())
.collect_vec();

self.0
.iter()
.map(|component| {
let col_sizes_per_tree = component
.trace_log_degree_bounds()
.iter()
.map(|col_sizes| col_sizes.len())
.collect_vec();
let polys = col_sizes_per_tree
.iter()
.zip(poly_iters.iter_mut())
.map(|(n_columns, iter)| iter.take(*n_columns).collect_vec())
.collect_vec();
let evals = col_sizes_per_tree
.iter()
.zip(eval_iters.iter_mut())
.map(|(n_columns, iter)| iter.take(*n_columns).collect_vec())
.collect_vec();
ComponentTrace {
polys: TreeVec::new(polys),
evals: TreeVec::new(evals),
}
})
.collect_vec()
}

pub fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues {
pub fn lookup_values(&self, trace: &Trace<'_, B>) -> LookupValues {
let mut values = LookupValues::default();
zip_eq(&self.0, component_traces)
.for_each(|(component, trace)| values.extend(component.lookup_values(trace)));
for component in &self.0 {
values.extend(component.lookup_values(trace))
}
values
}
}
14 changes: 12 additions & 2 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ pub trait ComponentProver<B: Backend>: Component {
/// Accumulates quotients in `evaluation_accumulator`.
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, B>,
trace: &Trace<'_, B>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<B>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
);

/// Returns the values needed to evaluate the components lookup boundary constraints.
fn lookup_values(&self, _trace: &ComponentTrace<'_, B>) -> LookupValues;
fn lookup_values(&self, _trace: &Trace<'_, B>) -> LookupValues;
}

/// The set of polynomials that make up the trace.
///
/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency)
pub struct Trace<'a, B: Backend> {
/// Polynomials for each column.
pub polys: TreeVec<ColumnVec<&'a CirclePoly<B>>>,
/// Evaluations for each column (evaluated on their commitment domains).
pub evals: TreeVec<ColumnVec<&'a CircleEvaluation<B, BaseField, BitReversedOrder>>>,
}

/// A component trace is a set of polynomials for each column on that component.
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
.map(|tree| tree.polynomials.iter().collect())
}

fn evaluations(&self) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
pub fn evaluations(
&self,
) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
self.trees
.as_ref()
.map(|tree| tree.evaluations.iter().collect())
Expand Down
Loading

0 comments on commit 62d5e0f

Please sign in to comment.