From 07065a83f9e8bb2ed74f815a1e15f8e95f707a10 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 25 Feb 2024 10:15:30 +0200 Subject: [PATCH] Precompute twiddles --- src/core/backend/avx512/circle.rs | 20 +++- src/core/backend/avx512/fft/mod.rs | 6 +- src/core/backend/cpu/circle.rs | 172 +++++++++++++++++++++-------- src/core/circle.rs | 9 ++ src/core/poly/circle/evaluation.rs | 10 +- src/core/poly/circle/ops.rs | 16 ++- src/core/poly/circle/poly.rs | 12 +- src/core/poly/mod.rs | 1 + src/core/poly/twiddles.rs | 13 +++ 9 files changed, 208 insertions(+), 51 deletions(-) create mode 100644 src/core/poly/twiddles.rs diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs index 78cde6ac3..23072c942 100644 --- a/src/core/backend/avx512/circle.rs +++ b/src/core/backend/avx512/circle.rs @@ -6,18 +6,21 @@ use super::{as_cpu_vec, AVX512Backend, VECS_LOG_SIZE}; use crate::core::backend::avx512::fft::rfft; use crate::core::backend::avx512::BaseFieldVec; use crate::core::backend::CPUBackend; -use crate::core::circle::CirclePoint; +use crate::core::circle::{CirclePoint, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::{Col, ExtensionOf, FieldExpOps}; use crate::core::poly::circle::{ CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, }; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::fold; use crate::core::poly::BitReversedOrder; // TODO(spapini): Everything is returned in redundant representation, where values can also be P. // Decide if and when it's ok and what to do if it's not. impl PolyOps for AVX512Backend { + type Twiddles = (); + fn new_canonical_ordered( coset: CanonicCoset, values: Col, @@ -27,7 +30,10 @@ impl PolyOps for AVX512Backend { CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values)) } - fn interpolate(eval: CircleEvaluation) -> CirclePoly { + fn interpolate( + eval: CircleEvaluation, + _itwiddles: &TwiddleTree, + ) -> CirclePoly { let mut values = eval.values; let log_size = values.length.ilog2(); @@ -70,6 +76,7 @@ impl PolyOps for AVX512Backend { } mappings.reverse(); + // If the polynomial is large, the fft does a transpose in the middle. if poly.log_size() as usize > CACHED_FFT_LOG_SIZE { let n = mappings.len(); let n0 = (n - VECS_LOG_SIZE) / 2; @@ -91,6 +98,7 @@ impl PolyOps for AVX512Backend { fn evaluate( poly: &CirclePoly, domain: CircleDomain, + _twiddles: &TwiddleTree, ) -> CircleEvaluation { // TODO(spapini): Precompute twiddles. // TODO(spapini): Handle small cases. @@ -140,6 +148,14 @@ impl PolyOps for AVX512Backend { }, ) } + + fn precompute_twiddles(coset: Coset) -> TwiddleTree { + TwiddleTree { + root_coset: coset, + twiddles: (), + itwiddles: (), + } + } } #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] diff --git a/src/core/backend/avx512/fft/mod.rs b/src/core/backend/avx512/fft/mod.rs index 32a7863be..2cce5b073 100644 --- a/src/core/backend/avx512/fft/mod.rs +++ b/src/core/backend/avx512/fft/mod.rs @@ -69,11 +69,13 @@ unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); - // The twiddles for layer 0 can be computed from the twiddles for layer 1: + // The twiddles for layer 0 can be computed from the twiddles for layer 1. + // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. + // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. // A circle coset of size 4 in bit reversed order looks like this: // [(x, y), (-x, -y), (y, -x), (-y, x)] // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation - // is (0,-1) and not (0,1). This would cause another relation. + // is (0,-1) and not (0,1). (0,1) would yield another relation. // The twiddles for layer 0 are the y coordinates: // [y, -y, -x, x] // The twiddles for layer 1 in bit reversed order are the x coordinates: diff --git a/src/core/backend/cpu/circle.rs b/src/core/backend/cpu/circle.rs index a3b637162..a6d1b74d1 100644 --- a/src/core/backend/cpu/circle.rs +++ b/src/core/backend/cpu/circle.rs @@ -1,39 +1,21 @@ use num_traits::Zero; use super::CPUBackend; -use crate::core::circle::CirclePoint; +use crate::core::circle::{CirclePoint, Coset}; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; use crate::core::fields::{Col, ExtensionOf, FieldExpOps, FieldOps}; use crate::core::poly::circle::{ CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, }; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::fold; use crate::core::poly::BitReversedOrder; use crate::core::utils::bit_reverse; -fn get_twiddles(domain: CircleDomain) -> Vec> { - let mut coset = domain.half_coset; - - let mut res = vec![]; - res.push(coset.iter().map(|p| (p.y)).collect::>()); - bit_reverse(res.last_mut().unwrap()); - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| (p.x)) - .collect::>(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); - } - - res -} - impl PolyOps for CPUBackend { + type Twiddles = Vec; + fn new_canonical_ordered( coset: CanonicCoset, values: Col, @@ -52,19 +34,35 @@ impl PolyOps for CPUBackend { CircleEvaluation::new(domain, new_values) } - fn interpolate(eval: CircleEvaluation) -> CirclePoly { - let twiddles = get_twiddles(eval.domain); - + fn interpolate( + eval: CircleEvaluation, + twiddles: &TwiddleTree, + ) -> CirclePoly { let mut values = eval.values; - for (i, layer_twiddles) in twiddles.iter().enumerate() { + + assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset)); + let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); + + if eval.domain.log_size() == 1 { + let (mut val0, mut val1) = (values[0], values[1]); + ibutterfly( + &mut val0, + &mut val1, + eval.domain.half_coset.initial.y.inverse(), + ); + let inv = BaseField::from_u32_unchecked(2).inverse(); + (values[0], values[1]) = (val0 * inv, val1 * inv); + return CirclePoly::new(values); + }; + + let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); + + for (h, t) in circle_twiddles.enumerate() { + fft_layer_loop(&mut values, 0, h, t, ibutterfly); + } + for (layer, layer_twiddles) in line_twiddles.into_iter().enumerate() { for (h, &t) in layer_twiddles.iter().enumerate() { - for l in 0..(1 << i) { - let idx0 = (h << (i + 1)) + l; - let idx1 = idx0 + (1 << i); - let (mut val0, mut val1) = (values[idx0], values[idx1]); - ibutterfly(&mut val0, &mut val1, t.inverse()); - (values[idx0], values[idx1]) = (val0, val1); - } + fft_layer_loop(&mut values, layer + 1, h, t, ibutterfly); } } @@ -103,23 +101,111 @@ impl PolyOps for CPUBackend { fn evaluate( poly: &CirclePoly, domain: CircleDomain, + twiddles: &TwiddleTree, ) -> CircleEvaluation { - let twiddles = get_twiddles(domain); - let mut values = poly.extend(domain.log_size()).coeffs; - for (i, layer_twiddles) in twiddles.iter().enumerate().rev() { + + assert!(domain.half_coset.is_doubling_of(twiddles.root_coset)); + let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); + + if domain.log_size() == 1 { + let (mut val0, mut val1) = (values[0], values[1]); + butterfly(&mut val0, &mut val1, domain.half_coset.initial.y.inverse()); + return CircleEvaluation::new(domain, values); + }; + + let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); + + for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() { for (h, &t) in layer_twiddles.iter().enumerate() { - for l in 0..(1 << i) { - let idx0 = (h << (i + 1)) + l; - let idx1 = idx0 + (1 << i); - let (mut val0, mut val1) = (values[idx0], values[idx1]); - butterfly(&mut val0, &mut val1, t); - (values[idx0], values[idx1]) = (val0, val1); - } + fft_layer_loop(&mut values, layer + 1, h, t, butterfly); } } + for (h, t) in circle_twiddles.enumerate() { + fft_layer_loop(&mut values, 0, h, t, butterfly); + } + CircleEvaluation::new(domain, values) } + + fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + let root_coset = coset; + let mut twiddles = Vec::with_capacity(coset.size()); + for _ in 0..coset.log_size() { + let i0 = twiddles.len(); + twiddles.extend( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x) + .collect::>(), + ); + bit_reverse(&mut twiddles[i0..]); + coset = coset.double(); + } + twiddles.push(1.into()); + + // TODO(spapini): Batch inverse. + let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect(); + + TwiddleTree { + root_coset, + twiddles, + itwiddles, + } + } +} + +/// Computes the line twiddles for a [CircleDomain] from the precomputed twiddles tree. +fn domain_line_twiddles_from_tree( + domain: CircleDomain, + twiddle_buffer: &[BaseField], +) -> Vec<&[BaseField]> { + (0..domain.half_coset.log_size()) + .map(|i| { + let len = 1 << i; + &twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len] + }) + .rev() + .collect() +} + +fn fft_layer_loop( + values: &mut [BaseField], + i: usize, + h: usize, + t: BaseField, + butterfly_fn: impl Fn(&mut BaseField, &mut BaseField, BaseField), +) { + for l in 0..(1 << i) { + let idx0 = (h << (i + 1)) + l; + let idx1 = idx0 + (1 << i); + let (mut val0, mut val1) = (values[idx0], values[idx1]); + butterfly_fn(&mut val0, &mut val1, t); + (values[idx0], values[idx1]) = (val0, val1); + } +} + +/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1). +fn circle_twiddles_from_line_twiddles( + first_line_twiddles: &[BaseField], +) -> impl Iterator + '_ { + // The twiddles for layer 0 can be computed from the twiddles for layer 1. + // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. + // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation + // is (0,-1) and not (0,1). (0,1) would yield another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates: + // [x, y] + // Works also for inverse of the twiddles. + first_line_twiddles + .iter() + .array_chunks() + .flat_map(|[&x, &y]| [y, -y, -x, x]) } impl, EvalOrder> IntoIterator diff --git a/src/core/circle.rs b/src/core/circle.rs index f4a199b8e..97f7559a6 100644 --- a/src/core/circle.rs +++ b/src/core/circle.rs @@ -369,6 +369,15 @@ impl Coset { } } + pub fn repeated_double(&self, n_doubles: u32) -> Self { + (0..n_doubles).fold(*self, |coset, _| coset.double()) + } + + pub fn is_doubling_of(&self, other: Self) -> bool { + self.log_size <= other.log_size + && *self == other.repeated_double(other.log_size - self.log_size) + } + pub fn initial(&self) -> CirclePoint { self.initial } diff --git a/src/core/poly/circle/evaluation.rs b/src/core/poly/circle/evaluation.rs index b41faf0eb..f40b07aa5 100644 --- a/src/core/poly/circle/evaluation.rs +++ b/src/core/poly/circle/evaluation.rs @@ -6,6 +6,7 @@ use crate::core::backend::cpu::CPUCircleEvaluation; use crate::core::circle::{CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::{Col, Column, ExtensionOf, FieldOps}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::{BitReversedOrder, NaturalOrder}; use crate::core::utils::bit_reverse_index; @@ -76,7 +77,14 @@ impl CircleEvaluation { /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. pub fn interpolate(self) -> CirclePoly { - B::interpolate(self) + let coset = self.domain.half_coset; + B::interpolate(self, &B::precompute_twiddles(coset)) + } + + /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using + /// precomputed twiddles. + pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> CirclePoly { + B::interpolate(self, twiddles) } } diff --git a/src/core/poly/circle/ops.rs b/src/core/poly/circle/ops.rs index fc3a54a7b..ac49eb5cf 100644 --- a/src/core/poly/circle/ops.rs +++ b/src/core/poly/circle/ops.rs @@ -1,11 +1,16 @@ use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly}; -use crate::core::circle::CirclePoint; +use crate::core::circle::{CirclePoint, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::{Col, ExtensionOf, FieldOps}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; /// Operations on BaseField polynomials. pub trait PolyOps: FieldOps + Sized { + // TODO(spapini): Use a column instead of this type. + /// The type for precomputed twiddles. + type Twiddles; + /// Creates a [CircleEvaluation] from values ordered according to [CanonicCoset]. /// Used by the [`CircleEvaluation::new_canonical_ordered()`] function. fn new_canonical_ordered( @@ -15,7 +20,10 @@ pub trait PolyOps: FieldOps + Sized { /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. /// Used by the [`CircleEvaluation::interpolate()`] function. - fn interpolate(eval: CircleEvaluation) -> CirclePoly; + fn interpolate( + eval: CircleEvaluation, + itwiddles: &TwiddleTree, + ) -> CirclePoly; /// Evaluates the polynomial at a single point. /// Used by the [`CirclePoly::eval_at_point()`] function. @@ -33,5 +41,9 @@ pub trait PolyOps: FieldOps + Sized { fn evaluate( poly: &CirclePoly, domain: CircleDomain, + twiddles: &TwiddleTree, ) -> CircleEvaluation; + + /// Precomputes twiddles for a given coset. + fn precompute_twiddles(coset: Coset) -> TwiddleTree; } diff --git a/src/core/poly/circle/poly.rs b/src/core/poly/circle/poly.rs index d7ea19428..e48f3cbba 100644 --- a/src/core/poly/circle/poly.rs +++ b/src/core/poly/circle/poly.rs @@ -2,6 +2,7 @@ use super::{CircleDomain, CircleEvaluation, PolyOps}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::{Col, Column, ExtensionOf, FieldOps}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; /// A polynomial defined on a [CircleDomain]. @@ -50,7 +51,16 @@ impl CirclePoly { &self, domain: CircleDomain, ) -> CircleEvaluation { - B::evaluate(self, domain) + B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset)) + } + + /// Evaluates the polynomial at all points in the domain, using precomputed twiddles. + pub fn evaluate_with_twiddles( + &self, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation { + B::evaluate(self, domain, twiddles) } } diff --git a/src/core/poly/mod.rs b/src/core/poly/mod.rs index bc63d50fe..301c6988b 100644 --- a/src/core/poly/mod.rs +++ b/src/core/poly/mod.rs @@ -2,6 +2,7 @@ pub mod circle; pub mod line; // TODO(spapini): Remove pub, when LinePoly moved to the backend as well, and we can move the fold // function there. +pub mod twiddles; pub mod utils; /// Bit-reversed evaluation ordering. diff --git a/src/core/poly/twiddles.rs b/src/core/poly/twiddles.rs new file mode 100644 index 000000000..b5adf8e27 --- /dev/null +++ b/src/core/poly/twiddles.rs @@ -0,0 +1,13 @@ +use super::circle::PolyOps; +use crate::core::circle::Coset; + +/// Precomputed twiddles for a specific coset tower. +/// A coset tower is every repeated doubling of a `root_coset`. +/// The largest [CircleDomain] that can be ffted using these twiddles is one with `root_coset` as +/// its `half_coset`. +pub struct TwiddleTree { + pub root_coset: Coset, + // TODO(spapini): Represent a slice, and grabbing, in a generic way + pub twiddles: B::Twiddles, + pub itwiddles: B::Twiddles, +}