diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index ba52b5e69..194a2bb0f 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,6 +1,6 @@ use std::arch::x86_64::{ __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32, - _mm512_min_epu32, _mm512_mul_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi32, + _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, }; @@ -44,6 +44,72 @@ const H2: __m512i = unsafe { }; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; +// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce +// it somewhere. + +/// # Safety +pub unsafe fn ifft_lower( + values: *mut i32, + vecwise_twiddle_dbl: Option<&[Vec]>, + twiddle_dbl: &[Vec], + n_total_bits: usize, + n_fft_bits: usize, +) { + assert!(n_fft_bits >= 1); + if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { + assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 3)); + assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 2)); + assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << (n_fft_bits + 1)); + assert_eq!(vecwise_twiddle_dbl[3].len(), 1 << n_fft_bits); + } + for h in 0..(1 << (n_total_bits - n_fft_bits)) { + // TODO(spapini): + if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { + for l in 0..(1 << (n_fft_bits - 1)) { + // TODO(spapini): modulo for twiddles on the iters. + let index = (h << (n_fft_bits - 1)) + l; + let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 16 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[3].get_unchecked(index * 2 + i)), + ); + _mm512_store_epi32(values.add(index * 32), val0); + _mm512_store_epi32(values.add(index * 32 + 16), val1); + // TODO(spapini): do a fifth layer here. + } + } + for bit_i in (0..n_fft_bits).step_by(3) { + if bit_i + 3 > n_fft_bits { + todo!(); + } + for m in 0..(1 << (n_fft_bits - 3 - bit_i)) { + let twid_index = (h << (n_fft_bits - 3 - bit_i)) + m; + for l in 0..(1 << bit_i) { + ifft3( + values, + (h << n_fft_bits) + (m << (bit_i + 3)) + l, + bit_i, + std::array::from_fn(|i| { + *twiddle_dbl[bit_i].get_unchecked(twid_index * 4 + i) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 1].get_unchecked(twid_index * 2 + i) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 2].get_unchecked(twid_index + i) + }), + ); + } + } + } + } +} + /// # Safety pub unsafe fn avx_butterfly( val0: __m512i, @@ -54,8 +120,8 @@ pub unsafe fn avx_butterfly( let twiddle_dbl_e = twiddle_dbl; let val1_o = _mm512_srli_epi64(val1, 32); let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - let m_e_dbl = _mm512_mul_epi32(val1_e, twiddle_dbl_e); - let m_o_dbl = _mm512_mul_epi32(val1_o, twiddle_dbl_o); + let m_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e); + let m_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o); let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1); let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl); @@ -94,8 +160,8 @@ pub unsafe fn avx_ibutterfly( let twiddle_dbl_e = twiddle_dbl; let r1_o = _mm512_srli_epi64(r1, 32); let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - let m_e_dbl = _mm512_mul_epi32(r1_e, twiddle_dbl_e); - let m_o_dbl = _mm512_mul_epi32(r1_o, twiddle_dbl_o); + let m_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e); + let m_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o); let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1); let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl); @@ -199,21 +265,21 @@ pub unsafe fn vecwise_ibutterflies( pub unsafe fn ifft3( values: *mut i32, offset: usize, - step: usize, - twiddles_dbl0: &[i32; 4], - twiddles_dbl1: &[i32; 2], - twiddles_dbl2: &[i32; 1], + log_step: usize, + twiddles_dbl0: [i32; 4], + twiddles_dbl1: [i32; 2], + twiddles_dbl2: [i32; 1], ) { - let u32_step = step + 4; + let log_u32_step = log_step; // load - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << u32_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << u32_step)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << u32_step)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << u32_step)).cast_const()); - let mut val4 = _mm512_load_epi32(values.add(offset + (4 << u32_step)).cast_const()); - let mut val5 = _mm512_load_epi32(values.add(offset + (5 << u32_step)).cast_const()); - let mut val6 = _mm512_load_epi32(values.add(offset + (6 << u32_step)).cast_const()); - let mut val7 = _mm512_load_epi32(values.add(offset + (7 << u32_step)).cast_const()); + let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const()); + let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const()); + let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_u32_step)) << 4).cast_const()); + let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_u32_step)) << 4).cast_const()); + let mut val4 = _mm512_load_epi32(values.add((offset + (4 << log_u32_step)) << 4).cast_const()); + let mut val5 = _mm512_load_epi32(values.add((offset + (5 << log_u32_step)) << 4).cast_const()); + let mut val6 = _mm512_load_epi32(values.add((offset + (6 << log_u32_step)) << 4).cast_const()); + let mut val7 = _mm512_load_epi32(values.add((offset + (7 << log_u32_step)) << 4).cast_const()); (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); @@ -231,14 +297,14 @@ pub unsafe fn ifft3( (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); // store - _mm512_store_epi32(values.add(offset + (0 << u32_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << u32_step)), val1); - _mm512_store_epi32(values.add(offset + (2 << u32_step)), val2); - _mm512_store_epi32(values.add(offset + (3 << u32_step)), val3); - _mm512_store_epi32(values.add(offset + (4 << u32_step)), val4); - _mm512_store_epi32(values.add(offset + (5 << u32_step)), val5); - _mm512_store_epi32(values.add(offset + (6 << u32_step)), val6); - _mm512_store_epi32(values.add(offset + (7 << u32_step)), val7); + _mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0); + _mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1); + _mm512_store_epi32(values.add((offset + (2 << log_u32_step)) << 4), val2); + _mm512_store_epi32(values.add((offset + (3 << log_u32_step)) << 4), val3); + _mm512_store_epi32(values.add((offset + (4 << log_u32_step)) << 4), val4); + _mm512_store_epi32(values.add((offset + (5 << log_u32_step)) << 4), val5); + _mm512_store_epi32(values.add((offset + (6 << log_u32_step)) << 4), val6); + _mm512_store_epi32(values.add((offset + (7 << log_u32_step)) << 4), val7); } #[cfg(test)] @@ -246,8 +312,13 @@ mod tests { use std::arch::x86_64::_mm512_setr_epi32; use super::*; + use crate::core::backend::avx512::BaseFieldVec; + use crate::core::backend::{CPUBackend, ColumnTrait}; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; + use crate::core::fields::Field; + use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; + use crate::core::utils::bit_reverse; #[test] fn test_butterfly() { @@ -282,12 +353,12 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, ); let twiddle = _mm512_setr_epi32( - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, ); let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); @@ -492,9 +563,9 @@ mod tests { std::mem::transmute(values.as_mut_ptr()), 0, 0, - &twiddles0_dbl, - &twiddles1_dbl, - &twiddles2_dbl, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, ); let actual: Vec<[BaseField; 16]> = std::mem::transmute(values); @@ -535,4 +606,101 @@ mod tests { } } } + + fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push( + coset + .iter() + .map(|p| (p.y.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res + } + + fn ref_ifft(domain: CircleDomain, mut values: Vec) -> Vec { + bit_reverse(&mut values); + let eval = CircleEvaluation::::new(domain, values); + let mut expected_coeffs = eval.interpolate().coeffs; + for x in expected_coeffs.iter_mut() { + *x *= BaseField::from_u32_unchecked(domain.size() as u32); + } + bit_reverse(&mut expected_coeffs); + expected_coeffs + } + + #[test] + fn test_vecwise_ibutterflies_real() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain); + assert_eq!(twiddle_dbls.len(), 5); + let values0: [i32; 16] = std::array::from_fn(|i| i as i32); + let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); + let result: [BaseField; 32] = unsafe { + let (val0, val1) = vecwise_ibutterflies( + std::mem::transmute(values0), + std::mem::transmute(values1), + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + twiddle_dbls[3].clone().try_into().unwrap(), + ); + let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[4][0])); + std::mem::transmute([val0, val1]) + }; + + // ref. + let mut values = values0.to_vec(); + values.extend_from_slice(&values1); + let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect()); + + // Compare. + for i in 0..32 { + assert_eq!(result[i], expected[i]); + } + } + + #[test] + fn test_ifft_lower() { + let log_size = 4 + 3 + 3; + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_vec(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft_lower( + std::mem::transmute(values.data.as_mut_ptr()), + Some(&twiddle_dbls[..4]), + &twiddle_dbls[4..], + (log_size - 4) as usize, + (log_size - 4) as usize, + ); + + // Compare. + for i in 0..expected_coeffs.len() { + assert_eq!(values[i], expected_coeffs[i]); + } + } + } } diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 1dfc2f9e2..9cea2e3f1 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -69,7 +69,7 @@ fn as_cpu_vec(values: BaseFieldVec) -> Vec { impl Index for BaseFieldVec { type Output = BaseField; fn index(&self, index: usize) -> &Self::Output { - &self.data[index / 8][index % 8] + &self.data[index / 16][index % 16] } } diff --git a/src/core/fields/m31.rs b/src/core/fields/m31.rs index 3f34afa73..b6b49b3d8 100644 --- a/src/core/fields/m31.rs +++ b/src/core/fields/m31.rs @@ -14,7 +14,7 @@ pub const P: u32 = 2147483647; // 2 ** 31 - 1 #[repr(transparent)] #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)] -pub struct M31(u32); +pub struct M31(pub u32); pub type BaseField = M31; impl_field!(M31, P); diff --git a/src/core/utils.rs b/src/core/utils.rs index bfd0a5246..18b4b7073 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -10,6 +10,9 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator { impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { + if log_size == 0 { + return i; + } i.reverse_bits() >> (usize::BITS - log_size) }