diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 02a07c623..f9aa0d767 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -48,26 +48,82 @@ const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. +/// # Safety +pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec], log_n_elements: usize) { + assert!(log_n_elements >= 4); + if log_n_elements <= 1 { + // 16 { + ifft_lower( + values, + Some(&twiddle_dbl[..3]), + &twiddle_dbl[3..], + log_n_elements - 4, + log_n_elements - 4, + ); + return; + } + let log_n_vecs = log_n_elements - 4; + let log_n_fft_vecs0 = log_n_vecs / 2; + let log_n_fft_vecs1 = (log_n_vecs + 1) / 2; + ifft_lower( + values, + Some(&twiddle_dbl[..3]), + &twiddle_dbl[3..(3 + log_n_fft_vecs1)], + log_n_elements - 4, + log_n_fft_vecs1, + ); + // TODO(spapini): better transpose. + transpose_vecs(values, log_n_elements - 4); + ifft_lower( + values, + None, + &twiddle_dbl[(3 + log_n_fft_vecs1)..], + log_n_elements - 4, + log_n_fft_vecs0, + ); +} + +// TODO(spapini): This is inefficient. Optimize. +/// # Safety +pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { + let half = log_n_vecs / 2; + for b in 0..(1 << (log_n_vecs & 1)) { + for a in 0..(1 << half) { + for c in 0..(1 << half) { + let i = (a << (log_n_vecs - half)) | (b << half) | c; + let j = (c << (log_n_vecs - half)) | (b << half) | a; + if i >= j { + continue; + } + let val0 = _mm512_load_epi32(values.add(i << 4).cast_const()); + let val1 = _mm512_load_epi32(values.add(j << 4).cast_const()); + _mm512_store_epi32(values.add(i << 4), val1); + _mm512_store_epi32(values.add(j << 4), val0); + } + } + } +} + /// # Safety pub unsafe fn ifft_lower( values: *mut i32, vecwise_twiddle_dbl: Option<&[Vec]>, twiddle_dbl: &[Vec], - n_total_bits: usize, - n_fft_bits: usize, + log_n_vecs: usize, + fft_bits: usize, ) { - assert!(n_fft_bits >= 1); + assert!(fft_bits >= 1); if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { - assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 2)); - assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 1)); - assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << n_fft_bits); + assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (log_n_vecs + 2)); + assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (log_n_vecs + 1)); + assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << log_n_vecs); } - for h in 0..(1 << (n_total_bits - n_fft_bits)) { + for h in 0..(1 << (log_n_vecs - fft_bits)) { // TODO(spapini): if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { - for l in 0..(1 << (n_fft_bits - 1)) { + for l in 0..(1 << (fft_bits - 1)) { // TODO(spapini): modulo for twiddles on the iters. - let index = (h << (n_fft_bits - 1)) + l; + let index = (h << (fft_bits - 1)) + l; let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); (val0, val1) = vecwise_ibutterflies( @@ -82,25 +138,31 @@ pub unsafe fn ifft_lower( // TODO(spapini): do a fifth layer here. } } - for bit_i in (0..n_fft_bits).step_by(3) { - if bit_i + 3 > n_fft_bits { + for bit_i in (0..fft_bits).step_by(3) { + if bit_i + 3 > fft_bits { todo!(); } - for m in 0..(1 << (n_fft_bits - 3 - bit_i)) { - let twid_index = (h << (n_fft_bits - 3 - bit_i)) + m; + for m in 0..(1 << (fft_bits - 3 - bit_i)) { + let twid_index = (h << (fft_bits - 3 - bit_i)) + m; for l in 0..(1 << bit_i) { ifft3( values, - (h << n_fft_bits) + (m << (bit_i + 3)) + l, + (h << fft_bits) + (m << (bit_i + 3)) + l, bit_i, std::array::from_fn(|i| { - *twiddle_dbl[bit_i].get_unchecked(twid_index * 4 + i) + *twiddle_dbl[bit_i].get_unchecked( + (twid_index * 4 + i) & (twiddle_dbl[bit_i].len() - 1), + ) }), std::array::from_fn(|i| { - *twiddle_dbl[bit_i + 1].get_unchecked(twid_index * 2 + i) + *twiddle_dbl[bit_i + 1].get_unchecked( + (twid_index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1), + ) }), std::array::from_fn(|i| { - *twiddle_dbl[bit_i + 2].get_unchecked(twid_index + i) + *twiddle_dbl[bit_i + 2].get_unchecked( + (twid_index + i) & (twiddle_dbl[bit_i + 2].len() - 1), + ) }), ); } @@ -649,4 +711,38 @@ mod tests { } } } + + fn run_ifft_full_test(log_size: u32) { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_vec(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + log_size as usize, + ); + transpose_vecs( + std::mem::transmute(values.data.as_mut_ptr()), + (log_size - 4) as usize, + ); + + // Compare. + for i in 0..expected_coeffs.len() { + assert_eq!(values[i], expected_coeffs[i]); + } + } + } + + #[test] + fn test_ifft_full_10() { + run_ifft_full_test(3 + 3 + 4); + } }