From 63849ac751c5c2ce5dfd3bd3ec295eb88ca5ff25 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Wed, 21 Feb 2024 15:15:13 +0200 Subject: [PATCH] ifft_full --- src/core/backend/avx512/fft.rs | 113 ++++++++++++++++++++++++++++++++- src/core/backend/avx512/mod.rs | 2 + 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 1aedf9118..cb22063cf 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -5,7 +5,7 @@ use std::arch::x86_64::{ _mm512_xor_epi32, }; -use crate::core::backend::avx512::VECS_LOG_SIZE; +use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE}; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. @@ -61,6 +61,53 @@ const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. +/// # Safety +pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE); + let log_n_vecs = log_n_elements - VECS_LOG_SIZE; + // TODO(spapini): Use CACHED_FFT_LOG_SIZE instead. + if log_n_elements <= 1 { + ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + let log_n_fft_vecs0 = log_n_vecs / 2; + let log_n_fft_vecs1 = (log_n_vecs + 1) / 2; + ifft_lower_with_vecwise( + values, + &twiddle_dbl[..(3 + log_n_fft_vecs1)], + log_n_elements, + log_n_fft_vecs1 + VECS_LOG_SIZE, + ); + transpose_vecs(values, log_n_vecs); + ifft_lower_without_vecwise( + values, + &twiddle_dbl[(3 + log_n_fft_vecs1)..], + log_n_elements, + log_n_fft_vecs0, + ); +} + +// TODO(spapini): This is inefficient. Optimize. +/// # Safety +pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { + let half = log_n_vecs / 2; + for b in 0..(1 << (log_n_vecs & 1)) { + for a in 0..(1 << half) { + for c in 0..(1 << half) { + let i = (a << (log_n_vecs - half)) | (b << half) | c; + let j = (c << (log_n_vecs - half)) | (b << half) | a; + if i >= j { + continue; + } + let val0 = _mm512_load_epi32(values.add(i << 4).cast_const()); + let val1 = _mm512_load_epi32(values.add(j << 4).cast_const()); + _mm512_store_epi32(values.add(i << 4), val1); + _mm512_store_epi32(values.add(j << 4), val0); + } + } + } +} + /// Computes partial ifft on `2^log_size` M31 elements. /// Parameters: /// values - Pointer to the entire value array, aligned to 64 bytes. @@ -102,6 +149,38 @@ pub unsafe fn ifft_lower_with_vecwise( } } +/// # Safety +pub unsafe fn ifft_lower_without_vecwise( + values: *mut i32, + twiddle_dbl: &[Vec], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= VECS_LOG_SIZE); + + for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { + for layer in (0..fft_layers).step_by(3) { + match fft_layers - layer { + 1 => { + todo!() + } + 2 => { + todo!() + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + layer + VECS_LOG_SIZE, + index_h, + ); + } + } + } + } +} + /// Runs the first 5 ifft layers across the entire array. /// Parameters: /// values - Pointer to the entire value array, aligned to 64 bytes. @@ -853,4 +932,36 @@ mod tests { assert_eq!(values.to_vec(), expected_coeffs); } } + + fn run_ifft_full_test(log_size: u32) { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + log_size as usize, + ); + transpose_vecs( + std::mem::transmute(values.data.as_mut_ptr()), + (log_size - 4) as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } + } + + #[test] + fn test_ifft_full_11() { + run_ifft_full_test(3 + 3 + 5); + } } diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index e7034ad0b..285eedbee 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -15,6 +15,8 @@ use crate::core::fields::{Column, FieldOps}; use crate::core::utils; const VECS_LOG_SIZE: usize = 4; +pub const CACHED_FFT_LOG_SIZE: usize = 16; +pub const MIN_FFT_LOG_SIZE: usize = 5; #[derive(Copy, Clone, Debug)] pub struct AVX512Backend;