Skip to content

Commit

Permalink
Create MLE eval prover component (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Oct 1, 2024
1 parent 6250403 commit cd1648e
Show file tree
Hide file tree
Showing 4 changed files with 534 additions and 100 deletions.
35 changes: 20 additions & 15 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ pub struct TraceLocationAllocator {
}

impl TraceLocationAllocator {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> TreeVec<TreeSubspan> {
pub fn next_for_structure<T>(
&mut self,
structure: &TreeVec<ColumnVec<T>>,
) -> TreeVec<TreeSubspan> {
if structure.len() > self.next_tree_offsets.len() {
self.next_tree_offsets.resize(structure.len(), 0);
}
Expand Down Expand Up @@ -78,14 +81,18 @@ pub struct FrameworkComponent<C: FrameworkEval> {
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
let trace_locations = location_allocator.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}

pub fn trace_locations(&self) -> &[TreeSubspan] {
&self.trace_locations
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
Expand All @@ -98,26 +105,20 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map_cols(|col_offsets| {
col_offsets
.iter()
.map(|off| point + trace_step.mul_signed(*off).into_ef())
.map(|offset| point + trace_step.mul_signed(*offset).into_ef())
.collect()
})
}
Expand All @@ -142,6 +143,10 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
if self.n_constraints() == 0 {
return;
}

let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers;
/// Accumulates N evaluations of u_i(P0) at a single point.
/// Computes f(P0), the combined polynomial at that point.
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
#[derive(Debug, Clone)]
pub struct PointEvaluationAccumulator {
random_coeff: SecureField,
accumulation: SecureField,
Expand Down
11 changes: 6 additions & 5 deletions crates/prover/src/examples/xor/gkr_lookups/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1;
/// IOP for multilinear eval at point.
pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR;

/// Accumulates [`Mle`]s grouped by their number of variables.
/// Collection of [`Mle`]s grouped by their number of variables.
pub struct MleCollection<B: Backend> {
mles_by_n_variables: Vec<Option<Vec<DynMle<B>>>>,
}

impl<B: Backend> MleCollection<B> {
/// Appends an [`Mle`] to the collection.
/// Appends an [`Mle`] to the back of the collection.
pub fn push(&mut self, mle: impl Into<DynMle<B>>) {
let mle = mle.into();
let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new());
Expand All @@ -35,6 +35,7 @@ impl<B: Backend> MleCollection<B> {
impl MleCollection<SimdBackend> {
/// Performs a random linear combination of all MLEs, grouped by their number of variables.
///
/// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`.
/// MLEs are returned in ascending order by number of variables.
pub fn random_linear_combine_by_n_variables(
self,
Expand All @@ -53,13 +54,13 @@ impl MleCollection<SimdBackend> {
/// Panics if `mles` is empty or all MLEs don't have the same number of variables.
fn mle_random_linear_combination(
mles: Vec<DynMle<SimdBackend>>,
alpha: SecureField,
random_coeff: SecureField,
) -> Mle<SimdBackend, SecureField> {
assert!(!mles.is_empty());
let n_variables = mles[0].n_variables();
assert!(mles.iter().all(|mle| mle.n_variables() == n_variables));
let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev();
let mut mle_and_coeff = zip(mles, alpha_powers);
let coeff_powers = generate_secure_powers(random_coeff, mles.len());
let mut mle_and_coeff = zip(mles, coeff_powers.into_iter().rev());

// The last value can initialize the accumulator.
let (mle, coeff) = mle_and_coeff.next_back().unwrap();
Expand Down
Loading

0 comments on commit cd1648e

Please sign in to comment.