Skip to content

Commit

Permalink
WideFib test with AVX Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 18, 2024
1 parent 12ade2f commit 74caa3e
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
.zip(self.n_cols_per_size.iter())
.skip(1)
{
if *n_cols == 0 {
continue;
}
let coeffs = SecureColumn::<B> {
cols: values.cols.map(|c| {
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
Expand All @@ -160,8 +163,8 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {

/// An domain accumulator for polynomials of a single size.
pub struct ColumnAccumulator<'a, B: Backend> {
random_coeff_pow: SecureField,
col: &'a mut SecureColumn<B>,
pub random_coeff_pow: SecureField,
pub col: &'a mut SecureColumn<B>,
}
impl<'a> ColumnAccumulator<'a, CPUBackend> {
pub fn accumulate(&mut self, index: usize, evaluation: SecureField) {
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::core::fields::secure::SecureColumn;
impl AccumulationOps for AVX512Backend {
fn accumulate(column: &mut SecureColumn<Self>, alpha: SecureField, other: &SecureColumn<Self>) {
let alpha = PackedQM31::broadcast(alpha);
for i in 0..column.len() {
for i in 0..column.vec_len() {
let res_coeff = column.get_vec(i) * alpha + other.get_vec(i);
column.set_vec(i, res_coeff);
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ impl FromIterator<BaseField> for BaseFieldVec {
}

impl SecureColumn<AVX512Backend> {
pub fn vec_len(&self) -> usize {
self.cols[0].data.len()
}
pub fn set_vec(&mut self, vec_index: usize, value: PackedQM31) {
unsafe {
*self.cols[0].data.get_unchecked_mut(vec_index) = value.a().a();
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions src/fibonacci/mod.rs → src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ mod tests {
use crate::core::prover::{prove, verify};
use crate::core::queries::Queries;
use crate::core::utils::bit_reverse;
use crate::fibonacci::air::MultiFibonacciAir;
use crate::fibonacci::verify_proof;
use crate::examples::fibonacci::air::MultiFibonacciAir;
use crate::examples::fibonacci::verify_proof;
use crate::{m31, qm31};

#[test]
Expand Down
2 changes: 2 additions & 0 deletions src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod fibonacci;
pub mod wide_fib;
167 changes: 167 additions & 0 deletions src/examples/wide_fib/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use itertools::Itertools;
use num_traits::One;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Air, Component, ComponentTrace, Mask};
use crate::core::backend::avx512::qm31::PackedQM31;
use crate::core::backend::avx512::{AVX512Backend, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::{Col, Column};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::ColumnVec;

const N_COLS: usize = 1 << 8;

pub struct WideFibAir {
component: WideFibComponent,
}
impl Air<AVX512Backend> for WideFibAir {
fn components(&self) -> Vec<&dyn Component<AVX512Backend>> {
vec![&self.component]
}
}
pub struct WideFibComponent {
pub log_size: u32,
}

pub fn gen_trace(
log_size: usize,
) -> ColumnVec<CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>> {
assert!(log_size >= VECS_LOG_SIZE);
let mut trace = (0..N_COLS)
.map(|_| Col::<AVX512Backend, BaseField>::zeros(1 << log_size))
.collect_vec();
for vec_index in 0..(1 << (log_size - VECS_LOG_SIZE)) {
let mut a = PackedBaseField::one();
let mut b = PackedBaseField::one();
trace[0].data[vec_index] = a;
trace[1].data[vec_index] = b;
trace.iter_mut().take(log_size).skip(2).for_each(|col| {
(a, b) = (b, a.square() + b.square());
col.data[vec_index] = b;
});
}
let domain = CanonicCoset::new(log_size as u32).circle_domain();
trace
.into_iter()
.map(|eval| CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

impl Component<AVX512Backend> for WideFibComponent {
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size + 1
}

fn trace_log_degree_bounds(&self) -> Vec<u32> {
vec![self.log_size; N_COLS]
}

fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, AVX512Backend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<AVX512Backend>,
) {
assert_eq!(trace.columns.len(), N_COLS);
// TODO(spapini): Steal evaluation from commitment.
let eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain();
let trace_eval = trace
.columns
.iter()
.map(|poly| poly.evaluate(eval_domain))
.collect_vec();
let random_coeff = PackedQM31::broadcast(evaluation_accumulator.random_coeff);
let column_coeffs = (0..N_COLS)
.scan(PackedQM31::one(), |state, _| {
let res = *state;
*state *= random_coeff;
Some(res)
})
.collect_vec();

let constraint_log_degree_bound = self.log_size + 1;
let [accum] = evaluation_accumulator.columns([(constraint_log_degree_bound, N_COLS - 2)]);

for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) {
// Numerator.
let mut row_res = PackedQM31::zero();
let mut a = trace_eval[0].data[vec_row];
let mut b = trace_eval[1].data[vec_row];
#[allow(clippy::needless_range_loop)]
for i in 0..(N_COLS - 2) {
unsafe {
let c = *trace_eval.get_unchecked(i + 2).data.get_unchecked(vec_row);
row_res = row_res + column_coeffs[i] * (a.square() + b.square() - c);
(a, b) = (b, c);
}
}

// Denominator.
// TODO(spapini): Optimized this, for the small number of columns case.
let points = std::array::from_fn(|i| {
eval_domain.at(bit_reverse_index(
(vec_row << VECS_LOG_SIZE) + i,
eval_domain.log_size(),
) + 1)
});
let mut shifted_xs = PackedBaseField::from_array(points.map(|p| p.x));
for _ in 1..self.log_size {
shifted_xs = shifted_xs.square() - PackedBaseField::one();
}

accum.col.set_vec(
vec_row,
accum.col.get_vec(vec_row) * PackedQM31::broadcast(accum.random_coeff_pow)
+ row_res,
)
}
}

fn mask(&self) -> Mask {
Mask(vec![vec![0]; N_COLS])
}

fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
let constraint_zero_domain = Coset::subgroup(self.log_size);
let constraint_log_degree_bound = self.log_size + 1;
for i in 0..(N_COLS - 2) {
let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0];
let denominator = coset_vanishing(constraint_zero_domain, point);
evaluation_accumulator.accumulate(constraint_log_degree_bound, numerator / denominator);
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::commitment_scheme::hasher::Hasher;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
use crate::core::prover::prove;
use crate::examples::wide_fib::{gen_trace, WideFibAir, WideFibComponent};

#[test]
fn test_avx_wide_fib_prove() {
// TODO(spapini): Increase to 20, to get 1GB of trace.
const LOG_SIZE: u32 = 16;
let component = WideFibComponent { log_size: LOG_SIZE };
let air = WideFibAir { component };
let trace = gen_trace(LOG_SIZE as usize);
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
prove(&air, channel, trace);
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)]
pub mod commitment_scheme;
pub mod core;
pub mod fibonacci;
pub mod examples;
pub mod hash_functions;
pub mod math;
pub mod platform;

0 comments on commit 74caa3e

Please sign in to comment.