Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass entire mask to components #801

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::borrow::Cow;
use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, ComponentTrace};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
Expand All @@ -15,36 +17,87 @@ use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec, InteractionElements, LookupValues};

// TODO(andrew): Docs.
// TODO(andrew): Consider better location for this.
#[derive(Debug, Default)]
pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
}

impl TraceLocationAllocator {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> TreeVec<TreeSubspan> {
if structure.len() > self.next_tree_offsets.len() {
self.next_tree_offsets.resize(structure.len(), 0);
}

TreeVec::new(
zip(&mut *self.next_tree_offsets, &**structure)
.enumerate()
.map(|(tree_index, (offset, cols))| {
let col_start = *offset;
let col_end = col_start + cols.len();
*offset = col_end;
TreeSubspan {
tree_index,
col_start,
col_end,
}
})
.collect(),
)
}
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [Component] and [ComponentProver] for the
/// SIMD backend.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkComponent {
pub trait FrameworkEval {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;

fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
}

impl<C: FrameworkComponent> Component for C {
pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
fn n_constraints(&self) -> usize {
self.evaluate(InfoEvaluator::default()).n_constraints
self.eval.evaluate(InfoEvaluator::default()).n_constraints
}

fn max_constraint_log_degree_bound(&self) -> u32 {
FrameworkComponent::max_constraint_log_degree_bound(self)
self.eval.max_constraint_log_degree_bound()
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.evaluate(InfoEvaluator::default())
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.log_size(); tree_masks.len()])
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
}
Expand All @@ -53,8 +106,8 @@ impl<C: FrameworkComponent> Component for C {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.log_size()).step();
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
.iter()
Expand All @@ -71,30 +124,32 @@ impl<C: FrameworkComponent> Component for C {
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
self.evaluate(PointEvaluator::new(
mask.as_ref(),
self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.log_size()).coset, point).inverse(),
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
}
}

impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, SimdBackend>,
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.log_size());
let trace_domain = CanonicCoset::new(self.eval.log_size());

let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);

// Extend trace if necessary.
// TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good
// subdomain.
let need_to_extend = trace
.evals
let need_to_extend = component_evals
.iter()
.flatten()
.any(|c| c.domain != eval_domain);
Expand All @@ -103,12 +158,11 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
> = if need_to_extend {
let _span = span!(Level::INFO, "Extension").entered();
let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset);
trace
.polys
component_polys
.as_cols_ref()
.map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles)))
} else {
trace.evals.as_cols_ref().map_cols(|c| Cow::Borrowed(*c))
component_evals.clone().map_cols(|c| Cow::Borrowed(*c))
};

// Denom inverses.
Expand Down Expand Up @@ -137,7 +191,7 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.evaluate(eval).row_res;
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
unsafe {
Expand All @@ -150,7 +204,15 @@ impl<C: FrameworkComponent> ComponentProver<SimdBackend> for C {
}
}

fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues {
fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues {
LookupValues::default()
}
}

impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
type Target = E;

fn deref(&self) -> &E {
&self.eval
}
}
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::FrameworkComponent;
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use crate::core::ColumnVec;

/// Evaluates expressions at a point out of domain.
pub struct PointEvaluator<'a> {
pub mask: TreeVec<&'a ColumnVec<Vec<SecureField>>>,
pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
pub col_index: Vec<usize>,
pub denom_inverse: SecureField,
}
impl<'a> PointEvaluator<'a> {
pub fn new(
mask: TreeVec<&'a ColumnVec<Vec<SecureField>>>,
mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
evaluation_accumulator: &'a mut PointEvaluationAccumulator,
denom_inverse: SecureField,
) -> Self {
Expand Down
72 changes: 15 additions & 57 deletions crates/prover/src/core/air/components.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use itertools::{zip_eq, Itertools};
use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, ComponentTrace};
use crate::core::backend::{Backend, BackendForChannel};
use crate::core::channel::MerkleChannel;
use super::{Component, ComponentProver, Trace};
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::{ColumnVec, InteractionElements, LookupValues};

Expand All @@ -31,21 +30,21 @@ impl<'a> Components<'a> {
pub fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &Vec<TreeVec<Vec<Vec<SecureField>>>>,
mask_values: &TreeVec<Vec<Vec<SecureField>>>,
random_coeff: SecureField,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(&self.0, mask_values).for_each(|(component, mask)| {
for component in &self.0 {
component.evaluate_constraint_quotients_at_point(
point,
mask,
mask_values,
&mut evaluation_accumulator,
interaction_elements,
lookup_values,
)
});
}
evaluation_accumulator.finalize()
}

Expand All @@ -67,7 +66,7 @@ impl<'a, B: Backend> ComponentProvers<'a, B> {
pub fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, B>],
trace: &Trace<'_, B>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureCirclePoly<B> {
Expand All @@ -77,63 +76,22 @@ impl<'a, B: Backend> ComponentProvers<'a, B> {
self.components().composition_log_degree_bound(),
total_constraints,
);
zip_eq(&self.0, component_traces).for_each(|(component, trace)| {
for component in &self.0 {
component.evaluate_constraint_quotients_on_domain(
trace,
&mut accumulator,
interaction_elements,
lookup_values,
)
});
}
accumulator.finalize()
}

pub fn component_traces<'b, MC: MerkleChannel>(
&'b self,
trees: &'b [CommitmentTreeProver<B, MC>],
) -> Vec<ComponentTrace<'b, B>>
where
B: BackendForChannel<MC>,
{
let mut poly_iters = trees
.iter()
.map(|tree| tree.polynomials.iter())
.collect_vec();
let mut eval_iters = trees
.iter()
.map(|tree| tree.evaluations.iter())
.collect_vec();

self.0
.iter()
.map(|component| {
let col_sizes_per_tree = component
.trace_log_degree_bounds()
.iter()
.map(|col_sizes| col_sizes.len())
.collect_vec();
let polys = col_sizes_per_tree
.iter()
.zip(poly_iters.iter_mut())
.map(|(n_columns, iter)| iter.take(*n_columns).collect_vec())
.collect_vec();
let evals = col_sizes_per_tree
.iter()
.zip(eval_iters.iter_mut())
.map(|(n_columns, iter)| iter.take(*n_columns).collect_vec())
.collect_vec();
ComponentTrace {
polys: TreeVec::new(polys),
evals: TreeVec::new(evals),
}
})
.collect_vec()
}

pub fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues {
pub fn lookup_values(&self, trace: &Trace<'_, B>) -> LookupValues {
let mut values = LookupValues::default();
zip_eq(&self.0, component_traces)
.for_each(|(component, trace)| values.extend(component.lookup_values(trace)));
for component in &self.0 {
values.extend(component.lookup_values(trace))
}
values
}
}
20 changes: 6 additions & 14 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,22 @@ pub trait ComponentProver<B: Backend>: Component {
/// Accumulates quotients in `evaluation_accumulator`.
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, B>,
trace: &Trace<'_, B>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<B>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
);

/// Returns the values needed to evaluate the components lookup boundary constraints.
fn lookup_values(&self, _trace: &ComponentTrace<'_, B>) -> LookupValues;
fn lookup_values(&self, _trace: &Trace<'_, B>) -> LookupValues;
}

/// A component trace is a set of polynomials for each column on that component.
/// The set of polynomials that make up the trace.
///
/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency)
pub struct ComponentTrace<'a, B: Backend> {
pub struct Trace<'a, B: Backend> {
/// Polynomials for each column.
pub polys: TreeVec<ColumnVec<&'a CirclePoly<B>>>,
/// Evaluations for each column (evaluated on the commitment domains).
/// Evaluations for each column (evaluated on their commitment domains).
pub evals: TreeVec<ColumnVec<&'a CircleEvaluation<B, BaseField, BitReversedOrder>>>,
}

impl<'a, B: Backend> ComponentTrace<'a, B> {
pub fn new(
polys: TreeVec<ColumnVec<&'a CirclePoly<B>>>,
evals: TreeVec<ColumnVec<&'a CircleEvaluation<B, BaseField, BitReversedOrder>>>,
) -> Self {
Self { polys, evals }
}
}
2 changes: 1 addition & 1 deletion crates/prover/src/core/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use self::verifier::CommitmentSchemeVerifier;
use super::fri::FriConfig;

#[derive(Copy, Debug, Clone, PartialEq, Eq)]
pub struct TreeColumnSpan {
pub struct TreeSubspan {
pub tree_index: usize,
pub col_start: usize,
pub col_end: usize,
Expand Down
Loading
Loading