Skip to content

Commit

Permalink
Make prove() generic in Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 21, 2024
1 parent 9ccea91 commit 0e6f2ac
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 36 deletions.
31 changes: 18 additions & 13 deletions src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
//! Given N polynomials, sort them by size: u_0(P), ... u_{N-1}(P).
//! Given a random alpha, the combined polynomial is defined as
//! f(p) = sum_i alpha^{N-1-i} u_i (P).
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::{Backend, CPUBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CirclePoly, SecureCirclePoly};
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;

/// Accumulates evaluations of u_i(P0) at a single point.
Expand Down Expand Up @@ -119,12 +119,17 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
}
}

impl DomainEvaluationAccumulator<CPUBackend> {
pub trait AccumulationOps: FieldOps<BaseField> + Sized {
/// Accumulates other into column:
/// column = column * alpha + other.
fn accumulate(column: &mut SecureColumn<Self>, alpha: SecureField, other: &SecureColumn<Self>);
}

impl<B: Backend> DomainEvaluationAccumulator<B> {
/// Computes f(P) as coefficients.
pub fn finalize(self) -> SecureCirclePoly {
let mut res_coeffs = SecureColumn::<CPUBackend>::zeros(1 << self.log_size());
pub fn finalize(self) -> SecureCirclePoly<B> {
let mut res_coeffs = SecureColumn::<B>::zeros(1 << self.log_size());
let res_log_size = self.log_size();
let res_size = 1 << res_log_size;

for ((log_size, values), n_cols) in self
.sub_accumulations
Expand All @@ -133,9 +138,9 @@ impl DomainEvaluationAccumulator<CPUBackend> {
.zip(self.n_cols_per_size.iter())
.skip(1)
{
let coeffs = SecureColumn::<CPUBackend> {
let coeffs = SecureColumn::<B> {
columns: values.columns.map(|c| {
CPUCircleEvaluation::<_, BitReversedOrder>::new(
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
CanonicCoset::new(log_size as u32).circle_domain(),
c,
)
Expand All @@ -146,10 +151,7 @@ impl DomainEvaluationAccumulator<CPUBackend> {
};
// Add column coefficients into result coefficients, element-wise, in-place.
let multiplier = self.random_coeff.pow(*n_cols as u128);
for i in 0..res_size {
let res_coeff = res_coeffs.at(i) * multiplier + coeffs.at(i);
res_coeffs.set(i, res_coeff);
}
B::accumulate(&mut res_coeffs, multiplier, &coeffs);
}

SecureCirclePoly(res_coeffs.columns.map(CirclePoly::new))
Expand All @@ -163,6 +165,9 @@ pub struct ColumnAccumulator<'a, B: Backend> {
}
impl<'a> ColumnAccumulator<'a, CPUBackend> {
pub fn accumulate(&mut self, index: usize, evaluation: SecureField) {
// TODO(spapini): Multiplying QM31 by QM31 is not the best way to do this.
// It's probably better to cache all the coefficient powers and multiply QM31 by M31,
// and only add in QM31.
let val = self.col.at(index) * self.random_coeff_pow + evaluation;
self.col.set(index, val);
}
Expand Down
16 changes: 8 additions & 8 deletions src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Air, ComponentTrace};
use crate::core::backend::CPUBackend;
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CirclePoly, SecureCirclePoly};
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::ComponentVec;

pub trait AirExt: Air<CPUBackend> {
pub trait AirExt<B: Backend>: Air<B> {
fn composition_log_degree_bound(&self) -> u32 {
self.components()
.iter()
Expand All @@ -30,8 +30,8 @@ pub trait AirExt: Air<CPUBackend> {
fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, CPUBackend>],
) -> SecureCirclePoly {
component_traces: &[ComponentTrace<'_, B>],
) -> SecureCirclePoly<B> {
let mut accumulator =
DomainEvaluationAccumulator::new(random_coeff, self.composition_log_degree_bound());
zip(self.components(), component_traces).for_each(|(component, trace)| {
Expand All @@ -43,7 +43,7 @@ pub trait AirExt: Air<CPUBackend> {
fn mask_points_and_values(
&self,
point: CirclePoint<SecureField>,
component_traces: &[ComponentTrace<'_, CPUBackend>],
component_traces: &[ComponentTrace<'_, B>],
) -> (
ComponentVec<Vec<CirclePoint<SecureField>>>,
ComponentVec<Vec<SecureField>>,
Expand Down Expand Up @@ -101,8 +101,8 @@ pub trait AirExt: Air<CPUBackend> {

fn component_traces<'a>(
&'a self,
polynomials: &'a [CirclePoly<CPUBackend>],
) -> Vec<ComponentTrace<'_, CPUBackend>> {
polynomials: &'a [CirclePoly<B>],
) -> Vec<ComponentTrace<'_, B>> {
let poly_iter = &mut polynomials.iter();
self.components()
.iter()
Expand All @@ -115,4 +115,4 @@ pub trait AirExt: Air<CPUBackend> {
}
}

impl<A: Air<CPUBackend>> AirExt for A {}
impl<B: Backend, A: Air<B>> AirExt<B> for A {}
13 changes: 13 additions & 0 deletions src/core/backend/cpu/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::CPUBackend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;

impl AccumulationOps for CPUBackend {
fn accumulate(column: &mut SecureColumn<Self>, alpha: SecureField, other: &SecureColumn<Self>) {
for i in 0..column.len() {
let res_coeff = column.at(i) * alpha + other.at(i);
column.set(i, res_coeff);
}
}
}
1 change: 1 addition & 0 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod accumulation;
mod circle;
mod fri;
pub mod quotients;
Expand Down
11 changes: 10 additions & 1 deletion src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::Debug;

pub use cpu::CPUBackend;

use super::air::accumulation::AccumulationOps;
use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand All @@ -13,7 +14,15 @@ pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps + QuotientOps + FriOps
Copy
+ Clone
+ Debug
+ FieldOps<BaseField>
+ FieldOps<SecureField>
+ PolyOps
+ QuotientOps
+ FriOps
+ AccumulationOps
{
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/commitment_scheme/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
mod prover;
pub mod quotients;
pub mod utils;
mod utils;
mod verifier;

pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver};
Expand Down
12 changes: 6 additions & 6 deletions src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use std::ops::Deref;

use super::CircleDomain;
use crate::core::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly};
use super::{CircleDomain, CirclePoly, PolyOps};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldOps;
use crate::core::poly::BitReversedOrder;

pub struct SecureCirclePoly(pub [CPUCirclePoly; SECURE_EXTENSION_DEGREE]);
pub struct SecureCirclePoly<B: FieldOps<BaseField>>(pub [CirclePoly<B>; SECURE_EXTENSION_DEGREE]);

impl SecureCirclePoly {
impl<B: PolyOps> SecureCirclePoly<B> {
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
Self::eval_from_partial_evals(self.eval_columns_at_point(point))
}
Expand Down Expand Up @@ -43,8 +43,8 @@ impl SecureCirclePoly {
}
}

impl Deref for SecureCirclePoly {
type Target = [CPUCirclePoly; SECURE_EXTENSION_DEGREE];
impl<B: FieldOps<BaseField>> Deref for SecureCirclePoly<B> {
type Target = [CirclePoly<B>; SECURE_EXTENSION_DEGREE];

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
14 changes: 7 additions & 7 deletions src/core/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use itertools::Itertools;
use thiserror::Error;

use super::backend::Backend;
use super::commitment_scheme::{CommitmentSchemeProof, TreeVec};
use super::fri::FriVerificationError;
use super::poly::circle::{SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE};
Expand All @@ -9,7 +10,6 @@ use super::ColumnVec;
use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::commitment_scheme::hasher::Hasher;
use crate::core::air::{Air, AirExt};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::CPUBackend;
use crate::core::channel::{Blake2sChannel, Channel as ChannelTrait};
use crate::core::circle::CirclePoint;
Expand Down Expand Up @@ -42,10 +42,10 @@ pub struct AdditionalProofData {
pub oods_quotients: Vec<CircleEvaluation<CPUBackend, SecureField, BitReversedOrder>>,
}

pub fn prove(
air: &impl Air<CPUBackend>,
pub fn prove<B: Backend>(
air: &impl Air<B>,
channel: &mut Channel,
trace: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> Result<StarkProof, ProvingError> {
// Check that traces are not too big.
for (i, trace) in trace.iter().enumerate() {
Expand Down Expand Up @@ -154,11 +154,11 @@ pub fn verify(
commitment_scheme.verify_values(open_points, proof.commitment_scheme_proof, channel)
}

fn opened_values_to_mask(
air: &impl Air<CPUBackend>,
fn opened_values_to_mask<B: Backend>(
air: &impl Air<B>,
mut opened_values: TreeVec<ColumnVec<Vec<SecureField>>>,
) -> Result<(ComponentVec<Vec<SecureField>>, SecureField), ()> {
let composition_oods_values = SecureCirclePoly::eval_from_partial_evals(
let composition_oods_values = SecureCirclePoly::<B>::eval_from_partial_evals(
opened_values
.pop()
.unwrap()
Expand Down

0 comments on commit 0e6f2ac

Please sign in to comment.