Skip to content

Commit

Permalink
Precompute twiddles
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 9, 2024
1 parent 710a538 commit 86da951
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 48 deletions.
15 changes: 14 additions & 1 deletion src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use super::{as_cpu_vec, AVX512Backend, VECS_LOG_SIZE};
use crate::core::backend::avx512::fft::rfft;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, ExtensionOf, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::fold;
use crate::core::poly::BitReversedOrder;

Expand All @@ -29,6 +30,7 @@ impl PolyOps<BaseField> for AVX512Backend {

fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
_itwiddles: &TwiddleTree<Self, BaseField>,
) -> CirclePoly<Self, BaseField> {
let mut values = eval.values;
let log_size = values.length.ilog2();
Expand Down Expand Up @@ -86,6 +88,7 @@ impl PolyOps<BaseField> for AVX512Backend {
fn evaluate(
poly: &CirclePoly<Self, BaseField>,
domain: CircleDomain,
_twiddles: &TwiddleTree<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
Expand Down Expand Up @@ -141,6 +144,16 @@ impl PolyOps<BaseField> for AVX512Backend {
poly.evaluate(CanonicCoset::new(log_size).circle_domain())
.interpolate()
}

type Twiddles = ();

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self, BaseField> {
TwiddleTree {
root_coset: coset,
twiddles: (),
itwiddles: (),
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand Down
160 changes: 117 additions & 43 deletions src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,16 @@
use super::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, ExtensionOf, FieldExpOps, FieldOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::fold;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse;

fn get_twiddles(domain: CircleDomain) -> Vec<Vec<BaseField>> {
let mut coset = domain.half_coset;

let mut res = vec![];
res.push(coset.iter().map(|p| (p.y)).collect::<Vec<_>>());
bit_reverse(res.last_mut().unwrap());
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| (p.x))
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}

res
}

impl<F: ExtensionOf<BaseField>> PolyOps<F> for CPUBackend {
fn new_canonical_ordered(
coset: CanonicCoset,
Expand All @@ -50,19 +30,56 @@ impl<F: ExtensionOf<BaseField>> PolyOps<F> for CPUBackend {
CircleEvaluation::new(domain, new_values)
}

fn interpolate(eval: CircleEvaluation<Self, F, BitReversedOrder>) -> CirclePoly<Self, F> {
let twiddles = get_twiddles(eval.domain);

fn interpolate(
eval: CircleEvaluation<Self, F, BitReversedOrder>,
twiddles: &TwiddleTree<Self, F>,
) -> CirclePoly<Self, F> {
let mut values = eval.values;
for (i, layer_twiddles) in twiddles.iter().enumerate() {

let twiddle_buffer = &twiddles.itwiddles;
let mut x_twiddles = (0..eval.domain.half_coset.log_size())
.map(|i| {
let len = 1 << i;
&twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len]
})
.rev()
.peekable();

if eval.domain.log_size() == 1 {
let (mut val0, mut val1) = (values[0], values[1]);
ibutterfly(
&mut val0,
&mut val1,
eval.domain.half_coset.initial.y.inverse(),
);
let inv = BaseField::from_u32_unchecked(2).inverse();
(values[0], values[1]) = (val0 * inv, val1 * inv);
return CirclePoly::new(values);
};

// [x,y] => [y,-y,-x,x]
let y_twiddles = x_twiddles
.peek()
.unwrap()
.array_chunks()
.flat_map(|&[x, y]| [y, -y, -x, x]);

let ifft_loop = |values: &mut [F], i: usize, h: usize, t: BaseField| {
for l in 0..(1 << i) {
let idx0 = (h << (i + 1)) + l;
let idx1 = idx0 + (1 << i);
let (mut val0, mut val1) = (values[idx0], values[idx1]);
ibutterfly(&mut val0, &mut val1, t);
(values[idx0], values[idx1]) = (val0, val1);
}
};

for (h, t) in y_twiddles.enumerate() {
ifft_loop(&mut values, 0, h, t);
}
for (i, layer_twiddles) in x_twiddles.enumerate() {
for (h, &t) in layer_twiddles.iter().enumerate() {
for l in 0..(1 << i) {
let idx0 = (h << (i + 1)) + l;
let idx1 = idx0 + (1 << i);
let (mut val0, mut val1) = (values[idx0], values[idx1]);
ibutterfly(&mut val0, &mut val1, t.inverse());
(values[idx0], values[idx1]) = (val0, val1);
}
ifft_loop(&mut values, i + 1, h, t);
}
}

Expand Down Expand Up @@ -98,23 +115,80 @@ impl<F: ExtensionOf<BaseField>> PolyOps<F> for CPUBackend {
fn evaluate(
poly: &CirclePoly<Self, F>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self, F>,
) -> CircleEvaluation<Self, F, BitReversedOrder> {
let twiddles = get_twiddles(domain);

let mut values = poly.extend(domain.log_size()).coeffs;
for (i, layer_twiddles) in twiddles.iter().enumerate().rev() {

let twiddle_buffer = &twiddles.twiddles;
let mut x_twiddles = (0..domain.half_coset.log_size())
.map(|i| {
let len = 1 << i;
&twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len]
})
.rev()
.peekable();

if domain.log_size() == 1 {
let (mut val0, mut val1) = (values[0], values[1]);
butterfly(&mut val0, &mut val1, domain.half_coset.initial.y.inverse());
return CircleEvaluation::new(domain, values);
};

// [x,y] => [y,-y,-x,x]
let y_twiddles = x_twiddles
.peek()
.unwrap()
.array_chunks()
.flat_map(|&[x, y]| [y, -y, -x, x]);

let fft_loop = |values: &mut [F], i: usize, h: usize, t: BaseField| {
for l in 0..(1 << i) {
let idx0 = (h << (i + 1)) + l;
let idx1 = idx0 + (1 << i);
let (mut val0, mut val1) = (values[idx0], values[idx1]);
butterfly(&mut val0, &mut val1, t);
(values[idx0], values[idx1]) = (val0, val1);
}
};

for (i, layer_twiddles) in x_twiddles.enumerate().rev() {
for (h, &t) in layer_twiddles.iter().enumerate() {
for l in 0..(1 << i) {
let idx0 = (h << (i + 1)) + l;
let idx1 = idx0 + (1 << i);
let (mut val0, mut val1) = (values[idx0], values[idx1]);
butterfly(&mut val0, &mut val1, t);
(values[idx0], values[idx1]) = (val0, val1);
}
fft_loop(&mut values, i + 1, h, t);
}
}
for (h, t) in y_twiddles.enumerate() {
fft_loop(&mut values, 0, h, t);
}

CircleEvaluation::new(domain, values)
}

type Twiddles = Vec<BaseField>;
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self, F> {
let mut twiddles = Vec::with_capacity(coset.size());
for _ in 0..coset.log_size() {
let i0 = twiddles.len();
twiddles.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect::<Vec<_>>(),
);
bit_reverse(&mut twiddles[i0..]);
coset = coset.double();
}
twiddles.push(1.into());

// TODO(spapini): Batch inverse.
let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect();

TwiddleTree {
root_coset: coset,
twiddles,
itwiddles,
}
}
}

impl<F: ExtensionOf<BaseField>, EvalOrder> IntoIterator
Expand Down
4 changes: 4 additions & 0 deletions src/core/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ impl Coset {
}
}

pub fn repeated_double(&self, n_doubles: u32) -> Self {
(0..n_doubles).fold(*self, |c, _| c.double())
}

pub fn initial(&self) -> CirclePoint<M31> {
self.initial
}
Expand Down
11 changes: 10 additions & 1 deletion src/core/poly/circle/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::circle::{CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, Column, ExtensionOf, FieldOps};
use crate::core::poly::twiddles::TwiddleBank;
use crate::core::poly::{BitReversedOrder, NaturalOrder};
use crate::core::utils::bit_reverse_index;

Expand Down Expand Up @@ -76,7 +77,15 @@ impl<B: PolyOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReverse

/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
pub fn interpolate(self) -> CirclePoly<B, F> {
B::interpolate(self)
let coset = self.domain.half_coset;
B::interpolate(self, &B::precompute_twiddles(coset))
}

/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using
/// preconditioned twiddles.
pub fn interpolate_with_twiddles(self, twiddles: &TwiddleBank<B, F>) -> CirclePoly<B, F> {
let coset = self.domain.half_coset;
B::interpolate(self, twiddles.get_tree(coset))
}

pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
Expand Down
13 changes: 11 additions & 2 deletions src/core/poly/circle/ops.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly};
use crate::core::circle::CirclePoint;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, ExtensionOf, FieldOps};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;

pub trait PolyOps<F: ExtensionOf<BaseField>>: FieldOps<F> + Sized {
Expand All @@ -14,7 +15,10 @@ pub trait PolyOps<F: ExtensionOf<BaseField>>: FieldOps<F> + Sized {

/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
/// Used by the [`CircleEvaluation::interpolate()`] function.
fn interpolate(eval: CircleEvaluation<Self, F, BitReversedOrder>) -> CirclePoly<Self, F>;
fn interpolate(
eval: CircleEvaluation<Self, F, BitReversedOrder>,
itwiddles: &TwiddleTree<Self, F>,
) -> CirclePoly<Self, F>;

/// Evaluates the polynomial at a single point.
/// Used by the [`CirclePoly::eval_at_point()`] function.
Expand All @@ -29,5 +33,10 @@ pub trait PolyOps<F: ExtensionOf<BaseField>>: FieldOps<F> + Sized {
fn evaluate(
poly: &CirclePoly<Self, F>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self, F>,
) -> CircleEvaluation<Self, F, BitReversedOrder>;

type Twiddles;
/// Precomputes twiddles for a given coset.
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self, F>;
}
12 changes: 11 additions & 1 deletion src/core/poly/circle/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{CircleDomain, CircleEvaluation, PolyOps};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, Column, ExtensionOf, FieldOps};
use crate::core::poly::twiddles::TwiddleBank;
use crate::core::poly::BitReversedOrder;

/// A polynomial defined on a [CircleDomain].
Expand Down Expand Up @@ -47,7 +48,16 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CirclePoly<B, F> {

/// Evaluates the polynomial at all points in the domain.
pub fn evaluate(&self, domain: CircleDomain) -> CircleEvaluation<B, F, BitReversedOrder> {
B::evaluate(self, domain)
B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset))
}

/// Evaluates the polynomial at all points in the domain, using precomputed twiddles.
pub fn evaluate_with_twiddles(
&self,
domain: CircleDomain,
twiddles: &TwiddleBank<B, F>,
) -> CircleEvaluation<B, F, BitReversedOrder> {
B::evaluate(self, domain, twiddles.get_tree(domain.half_coset))
}
}

Expand Down
1 change: 1 addition & 0 deletions src/core/poly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod circle;
pub mod line;
// TODO(spapini): Remove pub, when LinePoly moved to the backend as well, and we can move the fold
// function there.
pub mod twiddles;
pub mod utils;

/// Bit-reversed evaluation ordering.
Expand Down
37 changes: 37 additions & 0 deletions src/core/poly/twiddles.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use super::circle::PolyOps;
use crate::core::circle::Coset;
use crate::core::fields::m31::BaseField;
use crate::core::fields::ExtensionOf;

// TODO(spapini): If we decide to only have rectangular components, there will only be a single tree
// and we don't need a bank.
/// A bank holding precomputed [TwiddleTree]s for different coset towers.
/// A coset tower is every repeated doubling of a root coset
pub struct TwiddleBank<B: PolyOps<F>, F: ExtensionOf<BaseField>> {
trees: Vec<TwiddleTree<B, F>>,
}
impl<B: PolyOps<F>, F: ExtensionOf<BaseField>> TwiddleBank<B, F> {
pub fn get_tree(&self, coset: Coset) -> &TwiddleTree<B, F> {
self.trees
.iter()
.find(|t| {
t.root_coset.log_size() >= coset.log_size()
&& t.root_coset
.repeated_double(t.root_coset.log_size() - coset.log_size())
== coset
})
.expect("No precomputed twiddles found.")
}
pub fn add_tree(&mut self, coset: Coset) {
self.trees.push(B::precompute_twiddles(coset));
}
}

/// Precomputed twiddles for a specific coset tower.
/// A coset tower is every repeated doubling of a root coset
pub struct TwiddleTree<B: PolyOps<F>, F: ExtensionOf<BaseField>> {
pub root_coset: Coset,
// TODO(spapini): Represent a slice, and grabbing, in a generic way
pub twiddles: B::Twiddles,
pub itwiddles: B::Twiddles,
}

0 comments on commit 86da951

Please sign in to comment.