From 06dd0af68bcdce847cdbe540cd730255da9adbbc Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Wed, 21 Feb 2024 09:14:50 +0200 Subject: [PATCH] ifft_lower --- src/core/backend/avx512/fft.rs | 225 ++++++++++++++++++++++++++++----- src/core/fields/m31.rs | 2 +- src/core/utils.rs | 3 + 3 files changed, 199 insertions(+), 31 deletions(-) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index f4890990f..2dce8c14b 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -55,6 +55,72 @@ const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { }; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; +// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce +// it somewhere. + +/// # Safety +pub unsafe fn ifft_lower( + values: *mut i32, + vecwise_twiddle_dbl: Option<&[Vec]>, + twiddle_dbl: &[Vec], + n_total_bits: usize, + n_fft_bits: usize, +) { + assert!(n_fft_bits >= 1); + if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { + assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 3)); + assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 2)); + assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << (n_fft_bits + 1)); + assert_eq!(vecwise_twiddle_dbl[3].len(), 1 << n_fft_bits); + } + for h in 0..(1 << (n_total_bits - n_fft_bits)) { + // TODO(spapini): + if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { + for l in 0..(1 << (n_fft_bits - 1)) { + // TODO(spapini): modulo for twiddles on the iters. + let index = (h << (n_fft_bits - 1)) + l; + let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 16 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[3].get_unchecked(index * 2 + i)), + ); + _mm512_store_epi32(values.add(index * 32), val0); + _mm512_store_epi32(values.add(index * 32 + 16), val1); + // TODO(spapini): do a fifth layer here. + } + } + for bit_i in (0..n_fft_bits).step_by(3) { + if bit_i + 3 > n_fft_bits { + todo!(); + } + for m in 0..(1 << (n_fft_bits - 3 - bit_i)) { + let twid_index = (h << (n_fft_bits - 3 - bit_i)) + m; + for l in 0..(1 << bit_i) { + ifft3( + values, + (h << n_fft_bits) + (m << (bit_i + 3)) + l, + bit_i, + std::array::from_fn(|i| { + *twiddle_dbl[bit_i].get_unchecked(twid_index * 4 + i) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 1].get_unchecked(twid_index * 2 + i) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 2].get_unchecked(twid_index + i) + }), + ); + } + } + } + } +} + /// Computes the butterfly operation for packed M31 elements. /// val0 + t val1, val0 - t val1. /// val0, val1 are packed M31 elements. 16 M31 words at each. @@ -121,7 +187,7 @@ pub unsafe fn avx_ibutterfly( let r0 = add_mod_p(val0, val1); let r1 = sub_mod_p(val0, val1); - // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. + // Extract the even and odd parts of r1 and twiddle_m_e_dbldbl, and spread as 8 64bit values. let r1_e = r1; let r1_o = _mm512_srli_epi64(r1, 32); let twiddle_dbl_e = twiddle_dbl; @@ -302,8 +368,8 @@ pub unsafe fn vecwise_ibutterflies( /// Parameters: /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. -/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that -/// need to be transformed. For layer i this is i-4. +/// log_step - The log of the distance in the array, in AVX vectors, between each pair of +/// values that need to be transformed. For layer i this is i - 4. /// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. /// Each layer has 4/2/1 twiddles. /// @@ -311,21 +377,20 @@ pub unsafe fn vecwise_ibutterflies( pub unsafe fn ifft3( values: *mut i32, offset: usize, - step_in_vecs: usize, - twiddles_dbl0: &[i32; 4], - twiddles_dbl1: &[i32; 2], - twiddles_dbl2: &[i32; 1], + log_step: usize, + twiddles_dbl0: [i32; 4], + twiddles_dbl1: [i32; 2], + twiddles_dbl2: [i32; 1], ) { - let step_in_u32s = step_in_vecs + 4; // Load the 8 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const()); - let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const()); - let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const()); - let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const()); - let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const()); + let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_step)) << 4).cast_const()); + let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_step)) << 4).cast_const()); + let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_step)) << 4).cast_const()); + let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_step)) << 4).cast_const()); + let mut val4 = _mm512_load_epi32(values.add((offset + (4 << log_step)) << 4).cast_const()); + let mut val5 = _mm512_load_epi32(values.add((offset + (5 << log_step)) << 4).cast_const()); + let mut val6 = _mm512_load_epi32(values.add((offset + (6 << log_step)) << 4).cast_const()); + let mut val7 = _mm512_load_epi32(values.add((offset + (7 << log_step)) << 4).cast_const()); // Apply the first layer of butterflies. (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); @@ -346,14 +411,14 @@ pub unsafe fn ifft3( (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); // Store the 8 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0); - _mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1); - _mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2); - _mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3); - _mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4); - _mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5); - _mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6); - _mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7); + _mm512_store_epi32(values.add((offset + (0 << log_step)) << 4), val0); + _mm512_store_epi32(values.add((offset + (1 << log_step)) << 4), val1); + _mm512_store_epi32(values.add((offset + (2 << log_step)) << 4), val2); + _mm512_store_epi32(values.add((offset + (3 << log_step)) << 4), val3); + _mm512_store_epi32(values.add((offset + (4 << log_step)) << 4), val4); + _mm512_store_epi32(values.add((offset + (5 << log_step)) << 4), val5); + _mm512_store_epi32(values.add((offset + (6 << log_step)) << 4), val6); + _mm512_store_epi32(values.add((offset + (7 << log_step)) << 4), val7); } // TODO(spapini): Move these to M31 AVX. @@ -390,8 +455,13 @@ mod tests { use super::*; use crate::core::backend::avx512::m31::PackedBaseField; + use crate::core::backend::avx512::BaseFieldVec; + use crate::core::backend::CPUBackend; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; + use crate::core::fields::{Column, Field}; + use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; + use crate::core::utils::bit_reverse; #[test] fn test_butterfly() { @@ -426,12 +496,12 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, ); let twiddle = _mm512_setr_epi32( - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, ); let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); @@ -642,9 +712,9 @@ mod tests { std::mem::transmute(values.as_mut_ptr()), 0, 0, - &twiddles0_dbl, - &twiddles1_dbl, - &twiddles2_dbl, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, ); let expected: [u32; 8] = std::array::from_fn(|i| i as u32); @@ -684,4 +754,99 @@ mod tests { } } } + + fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push( + coset + .iter() + .map(|p| (p.y.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res + } + + fn ref_ifft(domain: CircleDomain, mut values: Vec) -> Vec { + bit_reverse(&mut values); + let eval = CircleEvaluation::::new(domain, values); + let mut expected_coeffs = eval.interpolate().coeffs; + for x in expected_coeffs.iter_mut() { + *x *= BaseField::from_u32_unchecked(domain.size() as u32); + } + bit_reverse(&mut expected_coeffs); + expected_coeffs + } + + #[test] + fn test_vecwise_ibutterflies_real() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain); + assert_eq!(twiddle_dbls.len(), 5); + let values0: [i32; 16] = std::array::from_fn(|i| i as i32); + let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); + let result: [BaseField; 32] = unsafe { + let (val0, val1) = vecwise_ibutterflies( + std::mem::transmute(values0), + std::mem::transmute(values1), + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + twiddle_dbls[3].clone().try_into().unwrap(), + ); + let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[4][0])); + std::mem::transmute([val0, val1]) + }; + + // ref. + let mut values = values0.to_vec(); + values.extend_from_slice(&values1); + let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect()); + + // Compare. + for i in 0..32 { + assert_eq!(result[i], expected[i]); + } + } + + #[test] + fn test_ifft_lower() { + let log_size = 4 + 3 + 3; + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft_lower( + std::mem::transmute(values.data.as_mut_ptr()), + Some(&twiddle_dbls[..4]), + &twiddle_dbls[4..], + (log_size - 4) as usize, + (log_size - 4) as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } + } } diff --git a/src/core/fields/m31.rs b/src/core/fields/m31.rs index 3f34afa73..b6b49b3d8 100644 --- a/src/core/fields/m31.rs +++ b/src/core/fields/m31.rs @@ -14,7 +14,7 @@ pub const P: u32 = 2147483647; // 2 ** 31 - 1 #[repr(transparent)] #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)] -pub struct M31(u32); +pub struct M31(pub u32); pub type BaseField = M31; impl_field!(M31, P); diff --git a/src/core/utils.rs b/src/core/utils.rs index c781e1ae3..8a6483e36 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -10,6 +10,9 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator { impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { + if log_size == 0 { + return i; + } i.reverse_bits() >> (usize::BITS - log_size) }