Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create blake example that uses GKR for lookups #807

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use super::{EvalAtRow, EvalAtRowWithMle, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
Expand All @@ -31,7 +31,10 @@ pub struct TraceLocationAllocator {
}

impl TraceLocationAllocator {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> TreeVec<TreeSubspan> {
pub fn next_for_structure<T>(
&mut self,
structure: &TreeVec<ColumnVec<T>>,
) -> TreeVec<TreeSubspan> {
if structure.len() > self.next_tree_offsets.len() {
self.next_tree_offsets.resize(structure.len(), 0);
}
Expand All @@ -54,6 +57,18 @@ impl TraceLocationAllocator {
}
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkEvalWithMle {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;

fn evaluate<E: EvalAtRowWithMle>(&self, eval: E) -> E;
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
Expand All @@ -72,14 +87,18 @@ pub struct FrameworkComponent<C: FrameworkEval> {
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
let trace_locations = location_allocator.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}

pub fn trace_locations(&self) -> &[TreeSubspan] {
&self.trace_locations
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
Expand All @@ -92,26 +111,20 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map_cols(|col_offsets| {
col_offsets
.iter()
.map(|off| point + trace_step.mul_signed(*off).into_ef())
.map(|offset| point + trace_step.mul_signed(*offset).into_ef())
.collect()
})
}
Expand All @@ -136,6 +149,10 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
if self.n_constraints() == 0 {
return;
}

let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<const N: usize> LookupElements<N> {
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<Output = EF>,
{
zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (&value, power)| {
acc + EF::from(power) * value
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ pub trait EvalAtRow {
/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;
}

trait EvalAtRowWithMle: EvalAtRow {
fn add_mle_coeff_col_eval(&mut self, eval: Self::EF);
}
53 changes: 53 additions & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,56 @@ impl<'a> EvalAtRow for PointEvaluator<'a> {
SecureField::from_partial_evals(values)
}
}

// /// Evaluates expressions at a point out of domain.
// pub struct MleCoeffColEvalAccumulator<'a> {
// pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
// pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
// pub col_index: Vec<usize>,
// pub denom_inverse: SecureField,
// }
// impl<'a> MleCoeffColEvalAccumulator<'a> {
// pub fn new(
// mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
// evaluation_accumulator: &'a mut PointEvaluationAccumulator,
// denom_inverse: SecureField,
// ) -> Self {
// let col_index = vec![0; mask.len()];
// Self {
// mask,
// evaluation_accumulator,
// col_index,
// denom_inverse,
// }
// }
// }
// impl<'a> EvalAtRow for MleCoeffColEvalAccumulator<'a> {
// type F = SecureField;
// type EF = SecureField;

// fn next_interaction_mask<const N: usize>(
// &mut self,
// interaction: usize,
// _offsets: [isize; N],
// ) -> [Self::F; N] {
// let col_index = self.col_index[interaction];
// self.col_index[interaction] += 1;
// let mask = self.mask[interaction][col_index].clone();
// assert_eq!(mask.len(), N);
// mask.try_into().unwrap()
// }
// fn add_constraint<G>(&mut self, constraint: G)
// where
// Self::EF: Mul<G, Output = Self::EF>,
// {
// }
// fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
// SecureField::from_partial_evals(values)
// }
// }

// impl<'a> EvalAtRowWithMle for MleCoeffColEvalAccumulator<'a> {
// fn add_mle_coeff_col_eval(&mut self, eval: Self::EF) {
// todo!()
// }
// }
1 change: 1 addition & 0 deletions crates/prover/src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers;
/// Accumulates N evaluations of u_i(P0) at a single point.
/// Computes f(P0), the combined polynomial at that point.
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
#[derive(Debug, Clone)]
pub struct PointEvaluationAccumulator {
random_coeff: SecureField,
accumulation: SecureField,
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ mod tests {
let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -354,6 +355,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -391,6 +393,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -427,6 +430,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -590,6 +591,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -629,6 +631,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -666,6 +669,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
13 changes: 12 additions & 1 deletion crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;

use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
Expand Down Expand Up @@ -409,6 +409,16 @@ pub fn prove_batch<B: GkrOps>(
.collect_vec();
let n_layers = *n_layers_by_instance.iter().max().unwrap();

let gate_by_instance = input_layer_by_instance
.iter()
.map(|l| match l {
Layer::GrandProduct(_) => Gate::GrandProduct,
Layer::LogUpGeneric { .. }
| Layer::LogUpMultiplicities { .. }
| Layer::LogUpSingles { .. } => Gate::LogUp,
})
.collect();

// Evaluate all instance circuits and collect the layer values.
let mut layers_by_instance = input_layer_by_instance
.into_iter()
Expand Down Expand Up @@ -502,6 +512,7 @@ pub fn prove_batch<B: GkrOps>(

let artifact = GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance: n_layers_by_instance,
};
Expand Down
Loading
Loading