From db9ada953eab439a33c4b430a5f586f851ac161b Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 7 Mar 2024 11:35:17 +0200 Subject: [PATCH] Use precomputed twiddles in avx --- benches/fft.rs | 2 +- src/core/backend/avx512/circle.rs | 55 +++++++++++++++++++++-------- src/core/backend/avx512/fft/ifft.rs | 12 +++---- src/core/backend/avx512/fft/mod.rs | 2 +- src/core/backend/avx512/fft/rfft.rs | 12 +++---- 5 files changed, 52 insertions(+), 31 deletions(-) diff --git a/benches/fft.rs b/benches/fft.rs index f3254c16c..5fada0f69 100644 --- a/benches/fft.rs +++ b/benches/fft.rs @@ -21,7 +21,7 @@ pub fn avx512_ifft(c: &mut criterion::Criterion) { // Compute. let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); c.bench_function("avx ifft", |b| { b.iter(|| unsafe { diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs index 23072c942..3223787e3 100644 --- a/src/core/backend/avx512/circle.rs +++ b/src/core/backend/avx512/circle.rs @@ -1,4 +1,5 @@ use bytemuck::cast_slice; +use itertools::Itertools; use super::fft::{ifft, CACHED_FFT_LOG_SIZE}; use super::m31::PackedBaseField; @@ -19,7 +20,7 @@ use crate::core::poly::BitReversedOrder; // TODO(spapini): Everything is returned in redundant representation, where values can also be P. // Decide if and when it's ok and what to do if it's not. impl PolyOps for AVX512Backend { - type Twiddles = (); + type Twiddles = Vec; fn new_canonical_ordered( coset: CanonicCoset, @@ -32,23 +33,25 @@ impl PolyOps for AVX512Backend { fn interpolate( eval: CircleEvaluation, - _itwiddles: &TwiddleTree, + itwiddles: &TwiddleTree, ) -> CirclePoly { let mut values = eval.values; let log_size = values.length.ilog2(); - // TODO(spapini): Precompute twiddles. - let twiddle_dbls = ifft::get_itwiddle_dbls(eval.domain); - // TODO(spapini): Handle small cases. + let twiddle_buffer = &itwiddles.itwiddles; + let twiddles = (0..eval.domain.half_coset.log_size()) + .map(|i| { + let len = 1 << i; + &twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len] + }) + .rev() + .collect_vec(); // Safe because [PackedBaseField] is aligned on 64 bytes. unsafe { ifft::ifft( std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls[1..] - .iter() - .map(|x| x.as_slice()) - .collect::>(), + &twiddles, log_size as usize, ); } @@ -98,7 +101,7 @@ impl PolyOps for AVX512Backend { fn evaluate( poly: &CirclePoly, domain: CircleDomain, - _twiddles: &TwiddleTree, + twiddles: &TwiddleTree, ) -> CircleEvaluation { // TODO(spapini): Precompute twiddles. // TODO(spapini): Handle small cases. @@ -109,17 +112,24 @@ impl PolyOps for AVX512Backend { "Can only evaluate on larger domains" ); - let twiddles = rfft::get_twiddle_dbls(domain); + let twiddle_buffer = &twiddles.twiddles; + let twiddles = (0..domain.half_coset.log_size()) + .map(|i| { + let len = 1 << i; + &twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len] + }) + .rev() + .collect_vec(); // Evaluate on a big domains by evaluating on several subdomains. let log_subdomains = log_size - fft_log_size; let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE); for i in 0..(1 << log_subdomains) { // The subdomain twiddles are a slice of the large domain twiddles. - let subdomain_twiddles = (1..fft_log_size) + let subdomain_twiddles = (0..(fft_log_size - 1)) .map(|layer_i| { &twiddles[layer_i] - [i << (fft_log_size - 1 - layer_i)..(i + 1) << (fft_log_size - 1 - layer_i)] + [i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)] }) .collect::>(); @@ -150,10 +160,25 @@ impl PolyOps for AVX512Backend { } fn precompute_twiddles(coset: Coset) -> TwiddleTree { + let mut twiddles = Vec::with_capacity(coset.size()); + let mut itwiddles = Vec::with_capacity(coset.size()); + + // Optimize. + for layer in &rfft::get_twiddle_dbls(coset)[1..] { + twiddles.extend(layer); + } + twiddles.push(2); + assert_eq!(twiddles.len(), coset.size()); + for layer in &ifft::get_itwiddle_dbls(coset)[1..] { + itwiddles.extend(layer); + } + itwiddles.push(2); + assert_eq!(itwiddles.len(), coset.size()); + TwiddleTree { root_coset: coset, - twiddles: (), - itwiddles: (), + twiddles, + itwiddles, } } } diff --git a/src/core/backend/avx512/fft/ifft.rs b/src/core/backend/avx512/fft/ifft.rs index f42b449bc..853efd170 100644 --- a/src/core/backend/avx512/fft/ifft.rs +++ b/src/core/backend/avx512/fft/ifft.rs @@ -8,8 +8,8 @@ use std::arch::x86_64::{ use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; +use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; -use crate::core::poly::circle::CircleDomain; use crate::core::utils::bit_reverse; /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. @@ -374,9 +374,7 @@ pub unsafe fn vecwise_ibutterflies( val0.deinterleave_with(val1) } -pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { - let mut coset = domain.half_coset; - +pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { let mut res = vec![]; res.push( coset @@ -643,7 +641,7 @@ mod tests { #[test] fn test_vecwise_ibutterflies() { let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_itwiddle_dbls(domain); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); assert_eq!(twiddle_dbls.len(), 5); let values0: [i32; 16] = std::array::from_fn(|i| i as i32); let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); @@ -681,7 +679,7 @@ mod tests { // Compute. let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); unsafe { ifft_lower_with_vecwise( @@ -709,7 +707,7 @@ mod tests { // Compute. let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); unsafe { ifft( diff --git a/src/core/backend/avx512/fft/mod.rs b/src/core/backend/avx512/fft/mod.rs index 1ce28cd5b..f740d09a2 100644 --- a/src/core/backend/avx512/fft/mod.rs +++ b/src/core/backend/avx512/fft/mod.rs @@ -107,7 +107,7 @@ mod tests { #[test] fn test_twiddle_relation() { - let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain()); + let ts = get_itwiddle_dbls(CanonicCoset::new(5).half_coset()); let t0 = ts[0] .iter() .copied() diff --git a/src/core/backend/avx512/fft/rfft.rs b/src/core/backend/avx512/fft/rfft.rs index c3a5e3e56..65645e27f 100644 --- a/src/core/backend/avx512/fft/rfft.rs +++ b/src/core/backend/avx512/fft/rfft.rs @@ -8,7 +8,7 @@ use std::arch::x86_64::{ use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; -use crate::core::poly::circle::CircleDomain; +use crate::core::circle::Coset; use crate::core::utils::bit_reverse; /// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. @@ -347,9 +347,7 @@ pub unsafe fn vecwise_butterflies( val0.interleave_with(val1) } -pub fn get_twiddle_dbls(domain: CircleDomain) -> Vec> { - let mut coset = domain.half_coset; - +pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { let mut res = vec![]; res.push(coset.iter().map(|p| (p.y.0 * 2) as i32).collect::>()); bit_reverse(res.last_mut().unwrap()); @@ -606,7 +604,7 @@ mod tests { #[test] fn test_vecwise_butterflies() { let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_twiddle_dbls(domain); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); assert_eq!(twiddle_dbls.len(), 5); let values0: [i32; 16] = std::array::from_fn(|i| i as i32); let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); @@ -648,7 +646,7 @@ mod tests { // Compute. let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_twiddle_dbls(domain); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); unsafe { fft_lower_with_vecwise( @@ -676,7 +674,7 @@ mod tests { // Compute. let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_twiddle_dbls(domain); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); unsafe { transpose_vecs(