From 0a09f271bc7d78e08595e96c9580dc44d0ca4481 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Wed, 21 Feb 2024 16:10:15 +0200 Subject: [PATCH] AVX fft benchmark --- Cargo.toml | 4 +++ benches/fft.rs | 47 ++++++++++++++++++++++++++++++++++ src/core/backend/avx512/fft.rs | 30 ++++++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 benches/fft.rs diff --git a/Cargo.toml b/Cargo.toml index 010d6c658..38612a10d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,10 @@ lto = true name = "bit_rev" harness = false +[[bench]] +name = "fft" +harness = false + [[bench]] harness = false name = "field" diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 000000000..0fb0bdfc3 --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,47 @@ +#![feature(iter_array_chunks)] + +use criterion::Criterion; + +#[cfg(target_arch = "x86_64")] +pub fn avx512_ifft(c: &mut criterion::Criterion) { + use stwo::core::backend::avx512::fft::ifft; + use stwo::core::backend::avx512::BaseFieldVec; + use stwo::core::fields::m31::BaseField; + use stwo::core::poly::circle::CanonicCoset; + use stwo::platform; + if !platform::avx512_detected() { + return; + } + + const LOG_SIZE: u32 = 28; + let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = (0..(LOG_SIZE as i32 - 1)) + .map(|log_n| (0..(1 << log_n)).collect::>()) + .rev() + .collect::>(); + // TODO(spapini): When batch inverse is implemented, replace with real twiddles. + // let twiddle_dbls = get_itwiddle_dbls(domain); + + c.bench_function("avx ifft", |b| { + b.iter(|| unsafe { + ifft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[..], + LOG_SIZE as usize, + ); + }) + }); +} + +#[cfg(target_arch = "x86_64")] +criterion::criterion_group!( + name=avx_ifft; + config = Criterion::default().sample_size(10); + targets=avx512_ifft); +criterion::criterion_main!(avx_ifft); diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index cb22063cf..f7f18726c 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -6,8 +6,12 @@ use std::arch::x86_64::{ }; use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE}; +use crate::core::fields::Field; +use crate::core::poly::circle::CircleDomain; +use crate::core::utils::bit_reverse; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a +/// L is an input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { core::mem::transmute([ @@ -530,6 +534,32 @@ pub unsafe fn vecwise_ibutterflies( ) } +pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push( + coset + .iter() + .map(|p| (p.y.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + /// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. /// Vectorized over the 16 elements of the vectors. /// Used for radix-8 ifft.