Skip to content

Commit

Permalink
ifft_full
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 4, 2024
1 parent 5615ce4 commit 63849ac
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
113 changes: 112 additions & 1 deletion src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::arch::x86_64::{
_mm512_xor_epi32,
};

use crate::core::backend::avx512::VECS_LOG_SIZE;
use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE};

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
Expand Down Expand Up @@ -61,6 +61,53 @@ const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };
// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

/// # Safety
pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec<i32>], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE);
let log_n_vecs = log_n_elements - VECS_LOG_SIZE;
// TODO(spapini): Use CACHED_FFT_LOG_SIZE instead.
if log_n_elements <= 1 {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
return;
}
let log_n_fft_vecs0 = log_n_vecs / 2;
let log_n_fft_vecs1 = (log_n_vecs + 1) / 2;
ifft_lower_with_vecwise(
values,
&twiddle_dbl[..(3 + log_n_fft_vecs1)],
log_n_elements,
log_n_fft_vecs1 + VECS_LOG_SIZE,
);
transpose_vecs(values, log_n_vecs);
ifft_lower_without_vecwise(
values,
&twiddle_dbl[(3 + log_n_fft_vecs1)..],
log_n_elements,
log_n_fft_vecs0,
);
}

// TODO(spapini): This is inefficient. Optimize.
/// # Safety
pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..(1 << (log_n_vecs & 1)) {
for a in 0..(1 << half) {
for c in 0..(1 << half) {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
if i >= j {
continue;
}
let val0 = _mm512_load_epi32(values.add(i << 4).cast_const());
let val1 = _mm512_load_epi32(values.add(j << 4).cast_const());
_mm512_store_epi32(values.add(i << 4), val1);
_mm512_store_epi32(values.add(j << 4), val0);
}
}
}
}

/// Computes partial ifft on `2^log_size` M31 elements.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
Expand Down Expand Up @@ -102,6 +149,38 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}

/// # Safety
pub unsafe fn ifft_lower_without_vecwise(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
log_size: usize,
fft_layers: usize,
) {
assert!(log_size >= VECS_LOG_SIZE);

for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) {
for layer in (0..fft_layers).step_by(3) {
match fft_layers - layer {
1 => {
todo!()
}
2 => {
todo!()
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
layer + VECS_LOG_SIZE,
index_h,
);
}
}
}
}
}

/// Runs the first 5 ifft layers across the entire array.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
Expand Down Expand Up @@ -853,4 +932,36 @@ mod tests {
assert_eq!(values.to_vec(), expected_coeffs);
}
}

fn run_ifft_full_test(log_size: u32) {
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..],
log_size as usize,
);
transpose_vecs(
std::mem::transmute(values.data.as_mut_ptr()),
(log_size - 4) as usize,
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
}
}

#[test]
fn test_ifft_full_11() {
run_ifft_full_test(3 + 3 + 5);
}
}
2 changes: 2 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use crate::core::fields::{Column, FieldOps};
use crate::core::utils;

const VECS_LOG_SIZE: usize = 4;
pub const CACHED_FFT_LOG_SIZE: usize = 16;
pub const MIN_FFT_LOG_SIZE: usize = 5;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;
Expand Down

0 comments on commit 63849ac

Please sign in to comment.