Skip to content

Commit

Permalink
CommitmentSchemeProver generic in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 21, 2024
1 parent 2ffa947 commit 9ccea91
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,6 @@ mod tests {
let mut data: BaseFieldVec = data.into_iter().collect();

bit_reverse_m31(&mut data.data[..]);
assert_eq!(data.to_vec(), expected);
assert_eq!(data.to_cpu(), expected);
}
}
8 changes: 4 additions & 4 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ mod tests {
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec());
assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu());
}
}

Expand All @@ -369,8 +369,8 @@ mod tests {
let evaluation2 = poly.evaluate(domain_ext);
let poly2 = evaluation2.interpolate();
assert_eq!(
poly.extend(log_size + 3).coeffs.to_vec(),
poly2.coeffs.to_vec()
poly.extend(log_size + 3).coeffs.to_cpu(),
poly2.coeffs.to_cpu()
);
}
}
Expand Down Expand Up @@ -411,7 +411,7 @@ mod tests {
.extend(log_size + 2)
.evaluate(CanonicCoset::new(log_size + 2).circle_domain());

assert_eq!(eval0.values.to_vec(), eval1.values.to_vec());
assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}
}
Expand Down Expand Up @@ -718,7 +718,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/backend/avx512/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}
}
Expand Down Expand Up @@ -690,7 +690,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}

Expand Down
35 changes: 3 additions & 32 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ pub mod qm31;
pub mod quotients;

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use itertools::izip;
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use self::cm31::PackedCM31;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use self::qm31::PackedQM31;
use super::{Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::utils;
Expand Down Expand Up @@ -79,7 +76,7 @@ impl Column<BaseField> for BaseFieldVec {
length: len,
}
}
fn to_vec(&self) -> Vec<BaseField> {
fn to_cpu(&self) -> Vec<BaseField> {
self.data
.iter()
.flat_map(|x| x.to_array())
Expand Down Expand Up @@ -133,21 +130,6 @@ impl FromIterator<BaseField> for BaseFieldVec {
}

impl SecureColumn<AVX512Backend> {
pub fn at(&self, vec_index: usize) -> PackedQM31 {
unsafe {
PackedQM31([
PackedCM31([
*self.columns[0].data.get_unchecked(vec_index),
*self.columns[1].data.get_unchecked(vec_index),
]),
PackedCM31([
*self.columns[2].data.get_unchecked(vec_index),
*self.columns[3].data.get_unchecked(vec_index),
]),
])
}
}

pub fn set(&mut self, vec_index: usize, value: PackedQM31) {
unsafe {
*self.columns[0].data.get_unchecked_mut(vec_index) = value.a().a();
Expand All @@ -156,17 +138,6 @@ impl SecureColumn<AVX512Backend> {
*self.columns[3].data.get_unchecked_mut(vec_index) = value.b().b();
}
}

pub fn to_cpu(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_vec(),
self.columns[1].to_vec(),
self.columns[2].to_vec(),
self.columns[3].to_vec(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand All @@ -186,7 +157,7 @@ mod tests {
for i in 0..100 {
let col = Col::<B, BaseField>::from_iter((0..i).map(BaseField::from));
assert_eq!(
col.to_vec(),
col.to_cpu(),
(0..i).map(BaseField::from).collect::<Vec<_>>()
);
for j in 0..i {
Expand All @@ -202,7 +173,7 @@ mod tests {
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
col.to_cpu(),
(0..len)
.map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32)))
.collect::<Vec<_>>()
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ mod tests {
.map(|c| {
CircleEvaluation::<CPUBackend, _, BitReversedOrder>::new(
c.domain,
c.values.to_vec(),
c.values.to_cpu(),
)
})
.collect::<Vec<_>>();
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
fn zeros(len: usize) -> Self {
vec![T::default(); len]
}
fn to_vec(&self) -> Vec<T> {
fn to_cpu(&self) -> Vec<T> {
self.clone()
}
fn len(&self) -> usize {
Expand Down
6 changes: 4 additions & 2 deletions src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::fmt::Debug;

pub use cpu::CPUBackend;

use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldOps;
use super::fri::FriOps;
use super::poly::circle::PolyOps;

pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps + QuotientOps + FriOps
{
}

Expand All @@ -27,7 +29,7 @@ pub trait Column<T>: Clone + Debug + FromIterator<T> {
/// Creates a new column of zeros with the given length.
fn zeros(len: usize) -> Self;
/// Returns a cpu vector of the column.
fn to_vec(&self) -> Vec<T>;
fn to_cpu(&self) -> Vec<T>;
/// Returns the length of the column.
fn len(&self) -> usize;
/// Returns true if the column is empty.
Expand Down
37 changes: 19 additions & 18 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::ops::Deref;

use itertools::Itertools;

use super::super::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly};
use super::super::backend::CPUBackend;
use super::super::channel::Blake2sChannel;
use super::super::circle::CirclePoint;
Expand All @@ -28,27 +27,28 @@ use super::utils::TreeVec;
use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::merkle_decommitment::MerkleDecommitment;
use crate::commitment_scheme::merkle_tree::MerkleTree;
use crate::core::backend::{Backend, Column};
use crate::core::channel::Channel;
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly, SecureEvaluation};

type MerkleHasher = Blake2sHasher;
type ProofChannel = Blake2sChannel;

/// The prover side of a FRI polynomial commitment scheme. See [self].
pub struct CommitmentSchemeProver {
pub trees: TreeVec<CommitmentTreeProver>,
pub struct CommitmentSchemeProver<B: Backend> {
pub trees: TreeVec<CommitmentTreeProver<B>>,
pub log_blowup_factor: u32,
}

impl CommitmentSchemeProver {
impl<B: Backend> CommitmentSchemeProver<B> {
pub fn new(log_blowup_factor: u32) -> Self {
CommitmentSchemeProver {
trees: TreeVec::<CommitmentTreeProver>::default(),
trees: TreeVec::default(),
log_blowup_factor,
}
}

pub fn commit(&mut self, polynomials: ColumnVec<CPUCirclePoly>, channel: &mut ProofChannel) {
pub fn commit(&mut self, polynomials: ColumnVec<CirclePoly<B>>, channel: &mut ProofChannel) {
let tree = CommitmentTreeProver::new(polynomials, self.log_blowup_factor, channel);
self.trees.push(tree);
}
Expand All @@ -57,13 +57,13 @@ impl CommitmentSchemeProver {
self.trees.as_ref().map(|tree| tree.root())
}

pub fn polynomials(&self) -> TreeVec<ColumnVec<&CPUCirclePoly>> {
pub fn polynomials(&self) -> TreeVec<ColumnVec<&CirclePoly<B>>> {
self.trees
.as_ref()
.map(|tree| tree.polynomials.iter().collect())
}

fn evaluations(&self) -> TreeVec<ColumnVec<&CPUCircleEvaluation<BaseField, BitReversedOrder>>> {
fn evaluations(&self) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
self.trees
.as_ref()
.map(|tree| tree.evaluations.iter().collect())
Expand Down Expand Up @@ -101,7 +101,7 @@ impl CommitmentSchemeProver {
// SecureColumn.
let quotients = quotients
.into_iter()
.map(SecureEvaluation::to_cpu)
.map(SecureEvaluation::to_cpu_circle_eval)
.collect_vec();

// Run FRI commitment phase on the oods quotients.
Expand Down Expand Up @@ -146,16 +146,16 @@ pub struct CommitmentSchemeProof {

/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to
/// commit on a set of polynomials at a time. This corresponds to such a set.
pub struct CommitmentTreeProver {
pub polynomials: ColumnVec<CPUCirclePoly>,
pub evaluations: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
pub struct CommitmentTreeProver<B: Backend> {
pub polynomials: ColumnVec<CirclePoly<B>>,
pub evaluations: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
// TODO(AlonH): Change to mixed degree merkle and remove values clone.
commitment: MerkleTree<BaseField, MerkleHasher>,
}

impl CommitmentTreeProver {
impl<B: Backend> CommitmentTreeProver<B> {
fn new(
polynomials: ColumnVec<CPUCirclePoly>,
polynomials: ColumnVec<CirclePoly<B>>,
log_blowup_factor: u32,
channel: &mut ProofChannel,
) -> Self {
Expand All @@ -167,10 +167,11 @@ impl CommitmentTreeProver {
)
})
.collect_vec();
// TODO(spapini): Remove to_cpu() when Merkle support different backends.
let commitment = MerkleTree::<BaseField, MerkleHasher>::commit(
evaluations
.iter()
.map(|eval| eval.values.clone())
.map(|eval| eval.values.to_cpu())
.collect_vec(),
);
channel.mix_digest(commitment.root());
Expand All @@ -194,14 +195,14 @@ impl CommitmentTreeProver {
let values = self
.evaluations
.iter()
.map(|c| queries.iter().map(|p| c[*p]).collect())
.map(|c| queries.iter().map(|p| c.at(*p)).collect())
.collect();
let decommitment = self.commitment.generate_decommitment(queries);
(values, decommitment)
}
}

impl Deref for CommitmentTreeProver {
impl<B: Backend> Deref for CommitmentTreeProver<B> {
type Target = MerkleTree<BaseField, MerkleHasher>;

fn deref(&self) -> &Self::Target {
Expand Down
2 changes: 1 addition & 1 deletion src/core/commitment_scheme/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub fn fri_answers_for_log_size(
random_coeff,
&batched_openings,
)
.to_cpu()
.to_cpu_circle_eval()
})
.collect(),
);
Expand Down
18 changes: 9 additions & 9 deletions src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,12 @@ pub struct SecureColumn<B: FieldOps<BaseField>> {
pub columns: [Col<B, BaseField>; SECURE_EXTENSION_DEGREE],
}
impl SecureColumn<CPUBackend> {
pub fn at(&self, index: usize) -> SecureField {
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i][index]))
}

pub fn set(&mut self, index: usize, value: SecureField) {
self.columns
.iter_mut()
.map(|c| &mut c[index])
.assign(value.to_m31_array());
}

// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
pub fn to_cpu(&self) -> Vec<SecureField> {
(0..self.len()).map(|i| self.at(i)).collect()
}
}
impl<B: FieldOps<BaseField>> SecureColumn<B> {
pub fn zeros(len: usize) -> Self {
Expand All @@ -43,4 +34,13 @@ impl<B: FieldOps<BaseField>> SecureColumn<B> {
pub fn is_empty(&self) -> bool {
self.columns[0].is_empty()
}

pub fn at(&self, index: usize) -> SecureField {
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index)))
}

// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
pub fn to_cpu(&self) -> Vec<SecureField> {
(0..self.len()).map(|i| self.at(i)).collect()
}
}
2 changes: 1 addition & 1 deletion src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ impl<B: FriOps, H: Hasher<NativeType = u8>> FriLayerProver<B, H> {
fn new(evaluation: LineEvaluation<B, SecureField, BitReversedOrder>) -> Self {
// TODO(spapini): Commit on slice.
// TODO(spapini): Merkle tree in backend.
let merkle_tree = MerkleTree::commit(vec![evaluation.values.to_vec()]);
let merkle_tree = MerkleTree::commit(vec![evaluation.values.to_cpu()]);
#[allow(unreachable_code)]
FriLayerProver {
evaluation,
Expand Down
1 change: 1 addition & 0 deletions src/core/poly/circle/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> CircleEvaluation<B, F
// Note: The concrete implementation of the poly operations is in the specific backend used.
// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`.
impl<F: ExtensionOf<BaseField>, B: FieldOps<F>> CircleEvaluation<B, F, NaturalOrder> {
// TODO(spapini): Remove. Is this even used.
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
self.values
.at(self.domain.find(point_index).expect("Not in domain"))
Expand Down
Loading

0 comments on commit 9ccea91

Please sign in to comment.