Skip to content

Commit

Permalink
AVX fft benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 4, 2024
1 parent 01b161c commit 8590a89
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 27 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ lto = true
name = "bit_rev"
harness = false

[[bench]]
name = "fft"
harness = false

[[bench]]
harness = false
name = "field"
Expand Down
45 changes: 45 additions & 0 deletions benches/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#![feature(iter_array_chunks)]

use criterion::Criterion;

pub fn avx512_ifft(c: &mut criterion::Criterion) {
use stwo::core::backend::avx512::fft::ifft;
use stwo::core::backend::avx512::BaseFieldVec;
use stwo::core::fields::m31::BaseField;
use stwo::core::poly::circle::CanonicCoset;
use stwo::platform;
if !platform::avx512_detected() {
return;
}

const LOG_SIZE: u32 = 28;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = (0..(LOG_SIZE as i32 - 1))
.map(|log_n| (0..(1 << log_n)).collect::<Vec<_>>())
.rev()
.collect::<Vec<_>>();
// TODO(spapini): When batch inverse is implemented, replace with real twiddles.
// let twiddle_dbls = get_itwiddle_dbls(domain);

c.bench_function("avx ifft", |b| {
b.iter(|| unsafe {
ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[..],
LOG_SIZE as usize,
);
})
});
}

criterion::criterion_group!(
name=avx_ifft;
config = Criterion::default().sample_size(10);
targets=avx512_ifft);
criterion::criterion_main!(avx_ifft);
57 changes: 30 additions & 27 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use std::arch::x86_64::{
};

use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE};
use crate::core::fields::Field;
use crate::core::poly::circle::CircleDomain;
use crate::core::utils::bit_reverse;

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
Expand Down Expand Up @@ -553,6 +556,32 @@ unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) {
(t0, t1)
}

pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

let mut res = vec![];
res.push(
coset
.iter()
.map(|p| (p.y.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| (p.x.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}

res
}

/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
Expand Down Expand Up @@ -652,7 +681,7 @@ mod tests {
use crate::core::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, Field};
use crate::core::fields::Column;
use crate::core::poly::circle::{CanonicCoset, CircleDomain};
use crate::core::utils::bit_reverse;

Expand Down Expand Up @@ -834,32 +863,6 @@ mod tests {
res
}

fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

let mut res = vec![];
res.push(
coset
.iter()
.map(|p| (p.y.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| (p.x.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}

res
}

#[test]
fn test_twiddle_relation() {
let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain());
Expand Down

0 comments on commit 8590a89

Please sign in to comment.