Skip to content

Commit

Permalink
fft without copying (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Mar 25, 2024
1 parent bcb5732 commit 3846f82
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 71 deletions.
43 changes: 42 additions & 1 deletion benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

use criterion::Criterion;
use stwo::core::backend::avx512::fft::ifft::get_itwiddle_dbls;
use stwo::core::backend::avx512::fft::rfft::get_twiddle_dbls;
use stwo::core::backend::avx512::PackedBaseField;

pub fn avx512_ifft(c: &mut criterion::Criterion) {
use stwo::core::backend::avx512::fft::ifft;
Expand Down Expand Up @@ -37,8 +39,47 @@ pub fn avx512_ifft(c: &mut criterion::Criterion) {
});
}

pub fn avx512_rfft(c: &mut criterion::Criterion) {
use stwo::core::backend::avx512::fft::rfft;
use stwo::core::backend::avx512::BaseFieldVec;
use stwo::core::fields::m31::BaseField;
use stwo::core::poly::circle::CanonicCoset;
use stwo::platform;
if !platform::avx512_detected() {
return;
}

const LOG_SIZE: u32 = 20;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();

// Compute.
let values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);

c.bench_function("avx rfft 20bit", |b| {
b.iter(|| unsafe {
let mut target = Vec::<PackedBaseField>::with_capacity(values.data.len());
#[allow(clippy::uninit_vec)]
target.set_len(values.data.len());

rfft::fft(
std::mem::transmute(values.data.as_ptr()),
std::mem::transmute(target.as_mut_ptr()),
&twiddle_dbls
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
LOG_SIZE as usize,
);
})
});
}

criterion::criterion_group!(
name=avx_ifft;
config = Criterion::default().sample_size(10);
targets=avx512_ifft);
targets=avx512_ifft, avx512_rfft);
criterion::criterion_main!(avx_ifft);
13 changes: 9 additions & 4 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,14 @@ impl PolyOps for AVX512Backend {

// Evaluate on a big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;

// Alllocate the destination buffer without initializing.
let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(domain.size() >> VECS_LOG_SIZE)
};

for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (0..(fft_log_size - 1))
Expand All @@ -223,12 +230,10 @@ impl PolyOps for AVX512Backend {
})
.collect::<Vec<_>>();

// Copy the coefficients of the polynomial to the values vector.
values.extend_from_slice(&poly.coeffs.data);

// FFT inplace on the values chunk.
// FFT from the coefficients buffer to the values chunk.
unsafe {
rfft::fft(
std::mem::transmute(poly.coeffs.data.as_ptr()),
std::mem::transmute(
values[i << (fft_log_size - VECS_LOG_SIZE)
..(i + 1) << (fft_log_size - VECS_LOG_SIZE)]
Expand Down
Loading

0 comments on commit 3846f82

Please sign in to comment.