Skip to content

Commit

Permalink
Dumb down FRI
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 25, 2024
1 parent 494fd27 commit 9ab5285
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 332 deletions.
10 changes: 8 additions & 2 deletions benches/fri.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use stwo::core::backend::CPUBackend;
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::SecureField;
use stwo::core::fields::secure_column::SecureColumn;
use stwo::core::fri::FriOps;
use stwo::core::poly::circle::CanonicCoset;
use stwo::core::poly::line::{LineDomain, LineEvaluation};
Expand All @@ -10,9 +12,13 @@ fn folding_benchmark(c: &mut Criterion) {
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let evals = LineEvaluation::new(
domain,
vec![BaseField::from_u32_unchecked(712837213).into(); 1 << LOG_SIZE],
SecureColumn {
columns: std::array::from_fn(|i| {
vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE]
}),
},
);
let alpha = BaseField::from_u32_unchecked(12389).into();
let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983);
c.bench_function("fold_line", |b| {
b.iter(|| {
black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha)));
Expand Down
2 changes: 1 addition & 1 deletion src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl DomainEvaluationAccumulator<CPUBackend> {
.zip(self.n_cols_per_size.iter())
.skip(1)
{
let coeffs = SecureColumn {
let coeffs = SecureColumn::<CPUBackend> {
columns: values.columns.map(|c| {
CPUCircleEvaluation::<_, BitReversedOrder>::new(
CanonicCoset::new(log_size as u32).circle_domain(),
Expand Down
34 changes: 13 additions & 21 deletions src/core/backend/cpu/fri.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
use std::iter::zip;

use super::CPUBackend;
use crate::core::fft::ibutterfly;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field, FieldExpOps};
use crate::core::fields::FieldExpOps;
use crate::core::fri::{FriOps, CIRCLE_TO_LINE_FOLD_STEP, FOLD_STEP};
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;

impl FriOps for CPUBackend {
fn fold_line(
eval: &LineEvaluation<Self, SecureField, BitReversedOrder>,
alpha: SecureField,
) -> LineEvaluation<Self, SecureField, BitReversedOrder> {
fn fold_line(eval: &LineEvaluation<Self>, alpha: SecureField) -> LineEvaluation<Self> {
let n = eval.len();
assert!(n >= 2, "Evaluation too small");

let domain = eval.domain();

let folded_values = eval
.values
.into_iter()
.array_chunks()
.enumerate()
.map(|(i, &[f_x, f_neg_x])| {
.map(|(i, [f_x, f_neg_x])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let x = domain.at(bit_reverse_index(i << FOLD_STEP, domain.log_size()));

Expand All @@ -37,22 +31,20 @@ impl FriOps for CPUBackend {

LineEvaluation::new(domain.double(), folded_values)
}
fn fold_circle_into_line<F: Field>(
dst: &mut LineEvaluation<Self, SecureField, BitReversedOrder>,
src: &CircleEvaluation<Self, F, BitReversedOrder>,
fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self>,
alpha: SecureField,
) where
F: ExtensionOf<BaseField>,
SecureField: ExtensionOf<F> + Field,
{
) {
assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len());

let domain = src.domain;
let alpha_sq = alpha * alpha;

zip(&mut dst.values, src.array_chunks())
src.into_iter()
.array_chunks()
.enumerate()
.for_each(|(i, (dst, &[f_p, f_neg_p]))| {
.for_each(|(i, [f_p, f_neg_p])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let p = domain.at(bit_reverse_index(
i << CIRCLE_TO_LINE_FOLD_STEP,
Expand All @@ -64,7 +56,7 @@ impl FriOps for CPUBackend {
ibutterfly(&mut f0_px, &mut f1_px, p.y.inverse());
let f_prime = alpha * f1_px + f0_px;

*dst = *dst * alpha_sq + f_prime;
dst.values.set(i, dst.values.at(i) * alpha_sq + f_prime);
});
}
}
2 changes: 0 additions & 2 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::fmt::Debug;
use super::{Backend, Column, ColumnOps, FieldOps};
use crate::core::fields::Field;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::poly::line::LineEvaluation;
use crate::core::utils::bit_reverse;

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -50,7 +49,6 @@ impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
pub type CPUCirclePoly = CirclePoly<CPUBackend>;
pub type CPUCircleEvaluation<F, EvalOrder> = CircleEvaluation<CPUBackend, F, EvalOrder>;
// TODO(spapini): Remove the EvalOrder on LineEvaluation.
pub type CPULineEvaluation<F, EvalOrder> = LineEvaluation<CPUBackend, F, EvalOrder>;

#[cfg(test)]
mod tests {
Expand Down
6 changes: 5 additions & 1 deletion src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ impl CommitmentSchemeProver {
let fri_prover = FriProver::<CPUBackend, MerkleHasher>::commit(
channel,
fri_config,
&quotients.flatten_cols_rev(),
&quotients
.flatten_cols_rev()
.into_iter()
.map(|e| e.into())
.collect_vec(),
);

// Proof of work.
Expand Down
2 changes: 1 addition & 1 deletion src/core/commitment_scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ fn eval_quotients_on_sparse_domain(
commitment_domain: CircleDomain,
point: CirclePoint<SecureField>,
value: SecureField,
) -> Result<SparseCircleEvaluation<SecureField>, VerificationError> {
) -> Result<SparseCircleEvaluation, VerificationError> {
let queried_values = &mut queried_values.into_iter();
let res = SparseCircleEvaluation::new(
query_domains
Expand Down
61 changes: 57 additions & 4 deletions src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@ pub const SECURE_EXTENSION_DEGREE: usize =

/// An array of `SECURE_EXTENSION_DEGREE` base field columns, that represents a column of secure
/// field elements.
#[derive(Clone, Debug)]
pub struct SecureColumn<B: Backend> {
pub columns: [Col<B, BaseField>; SECURE_EXTENSION_DEGREE],
}
impl SecureColumn<CPUBackend> {
pub fn at(&self, index: usize) -> SecureField {
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i][index]))
}

pub fn set(&mut self, index: usize, value: SecureField) {
self.columns
.iter_mut()
Expand All @@ -25,6 +22,10 @@ impl SecureColumn<CPUBackend> {
}
}
impl<B: Backend> SecureColumn<B> {
pub fn at(&self, index: usize) -> SecureField {
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index)))
}

pub fn zeros(len: usize) -> Self {
Self {
columns: std::array::from_fn(|_| Col::<B, BaseField>::zeros(len)),
Expand All @@ -38,4 +39,56 @@ impl<B: Backend> SecureColumn<B> {
pub fn is_empty(&self) -> bool {
self.columns[0].is_empty()
}

pub fn to_cpu(&self) -> SecureColumn<CPUBackend> {
SecureColumn {
columns: self.columns.clone().map(|c| c.to_vec()),
}
}
}

pub struct SecureColumnIter<'a> {
column: &'a SecureColumn<CPUBackend>,
index: usize,
}
impl Iterator for SecureColumnIter<'_> {
type Item = SecureField;

fn next(&mut self) -> Option<Self::Item> {
if self.index < self.column.len() {
let value = self.column.at(self.index);
self.index += 1;
Some(value)
} else {
None
}
}
}
impl<'a> IntoIterator for &'a SecureColumn<CPUBackend> {
type Item = SecureField;
type IntoIter = SecureColumnIter<'a>;

fn into_iter(self) -> Self::IntoIter {
SecureColumnIter {
column: self,
index: 0,
}
}
}
impl FromIterator<SecureField> for SecureColumn<CPUBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut columns = std::array::from_fn(|_| vec![]);
for value in iter.into_iter() {
let vals = value.to_m31_array();
for j in 0..SECURE_EXTENSION_DEGREE {
columns[j].push(vals[j]);
}
}
SecureColumn { columns }
}
}
impl From<SecureColumn<CPUBackend>> for Vec<SecureField> {
fn from(column: SecureColumn<CPUBackend>) -> Self {
column.into_iter().collect()
}
}
Loading

0 comments on commit 9ab5285

Please sign in to comment.