diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index f4890990f..60535cfde 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -4,6 +4,8 @@ use std::arch::x86_64::{ _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, }; +use crate::core::backend::avx512::VECS_LOG_SIZE; + /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { @@ -55,6 +57,128 @@ const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { }; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; +// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce +// it somewhere. + +/// Computes partial ifft on `2^log_size` M31 elements. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. +/// layer i holds 2^(log_size - 1 - i) twiddles. +/// log_size - The log of the number of number of M31 elements in the array. +/// fft_layers - The number of ifft layers to apply, out of log_size. +/// # Safety +/// `values` must be aligned to 64 bytes. +/// `log_size` must be at least 5. +/// `fft_layers` must be at least 5. +pub unsafe fn ifft_lower_with_vecwise( + values: *mut i32, + twiddle_dbl: &[Vec], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 1)); + + for index_h in 0..(1 << (log_size - fft_layers)) { + ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { + match fft_layers - layer { + 1 => { + todo!() + } + 2 => { + todo!() + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + } + } +} + +/// Runs the 5 first ifft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 ifft layers. +/// high_bits - The number of bits this loops needs to run on. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn ifft_vecwise_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 16 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[3].get_unchecked(index * 2 + i)), + ); + (val0, val1) = avx_ibutterfly( + val0, + val1, + _mm512_set1_epi32(*twiddle_dbl[4].get_unchecked(index)), + ); + _mm512_store_epi32(values.add(index * 32), val0); + _mm512_store_epi32(values.add(index * 32 + 16), val1); + } +} + +/// Runs 3 ifft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 ifft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first ifft layer to apply. +/// The layers `layer`, `layer + 1`, `layer + 2` are applied. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn ifft3_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for m in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + m; + let offset = index << (layer + 3); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + ifft3( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + /// Computes the butterfly operation for packed M31 elements. /// val0 + t val1, val0 - t val1. /// val0, val1 are packed M31 elements. 16 M31 words at each. @@ -121,7 +245,7 @@ pub unsafe fn avx_ibutterfly( let r0 = add_mod_p(val0, val1); let r1 = sub_mod_p(val0, val1); - // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. + // Extract the even and odd parts of r1 and twiddle_m_e_dbldbl, and spread as 8 64bit values. let r1_e = r1; let r1_o = _mm512_srli_epi64(r1, 32); let twiddle_dbl_e = twiddle_dbl; @@ -302,8 +426,8 @@ pub unsafe fn vecwise_ibutterflies( /// Parameters: /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. -/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that -/// need to be transformed. For layer i this is i-4. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. /// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. /// Each layer has 4/2/1 twiddles. /// @@ -311,21 +435,20 @@ pub unsafe fn vecwise_ibutterflies( pub unsafe fn ifft3( values: *mut i32, offset: usize, - step_in_vecs: usize, - twiddles_dbl0: &[i32; 4], - twiddles_dbl1: &[i32; 2], - twiddles_dbl2: &[i32; 1], + log_step: usize, + twiddles_dbl0: [i32; 4], + twiddles_dbl1: [i32; 2], + twiddles_dbl2: [i32; 1], ) { - let step_in_u32s = step_in_vecs + 4; // Load the 8 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const()); - let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const()); - let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const()); - let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const()); - let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const()); + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const()); // Apply the first layer of butterflies. (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); @@ -346,14 +469,14 @@ pub unsafe fn ifft3( (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); // Store the 8 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0); - _mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1); - _mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2); - _mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3); - _mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4); - _mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5); - _mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6); - _mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7); + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); + _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); + _mm512_store_epi32(values.add(offset + (4 << log_step)), val4); + _mm512_store_epi32(values.add(offset + (5 << log_step)), val5); + _mm512_store_epi32(values.add(offset + (6 << log_step)), val6); + _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); } // TODO(spapini): Move these to M31 AVX. @@ -390,8 +513,13 @@ mod tests { use super::*; use crate::core::backend::avx512::m31::PackedBaseField; + use crate::core::backend::avx512::BaseFieldVec; + use crate::core::backend::CPUBackend; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; + use crate::core::fields::{Column, Field}; + use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; + use crate::core::utils::bit_reverse; #[test] fn test_butterfly() { @@ -426,12 +554,12 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, ); let twiddle = _mm512_setr_epi32( - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, ); let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); @@ -641,10 +769,10 @@ mod tests { ifft3( std::mem::transmute(values.as_mut_ptr()), 0, - 0, - &twiddles0_dbl, - &twiddles1_dbl, - &twiddles2_dbl, + VECS_LOG_SIZE, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, ); let expected: [u32; 8] = std::array::from_fn(|i| i as u32); @@ -684,4 +812,98 @@ mod tests { } } } + + fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push( + coset + .iter() + .map(|p| (p.y.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.inverse().0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res + } + + fn ref_ifft(domain: CircleDomain, mut values: Vec) -> Vec { + bit_reverse(&mut values); + let eval = CircleEvaluation::::new(domain, values); + let mut expected_coeffs = eval.interpolate().coeffs; + for x in expected_coeffs.iter_mut() { + *x *= BaseField::from_u32_unchecked(domain.size() as u32); + } + bit_reverse(&mut expected_coeffs); + expected_coeffs + } + + #[test] + fn test_vecwise_ibutterflies_real() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain); + assert_eq!(twiddle_dbls.len(), 5); + let values0: [i32; 16] = std::array::from_fn(|i| i as i32); + let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); + let result: [BaseField; 32] = unsafe { + let (val0, val1) = vecwise_ibutterflies( + std::mem::transmute(values0), + std::mem::transmute(values1), + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + twiddle_dbls[3].clone().try_into().unwrap(), + ); + let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[4][0])); + std::mem::transmute([val0, val1]) + }; + + // ref. + let mut values = values0.to_vec(); + values.extend_from_slice(&values1); + let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect()); + + // Compare. + for i in 0..32 { + assert_eq!(result[i], expected[i]); + } + } + + #[test] + fn test_ifft_lower_with_vecwise() { + let log_size = 5 + 3 + 3; + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft_lower_with_vecwise( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls, + log_size as usize, + log_size as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } + } } diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index b3e808b31..327afb6c1 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -14,6 +14,8 @@ use crate::core::fields::m31::BaseField; use crate::core::fields::{Column, FieldOps}; use crate::core::utils; +const VECS_LOG_SIZE: usize = 4; + #[derive(Copy, Clone, Debug)] pub struct AVX512Backend; diff --git a/src/core/fields/m31.rs b/src/core/fields/m31.rs index 3f34afa73..b6b49b3d8 100644 --- a/src/core/fields/m31.rs +++ b/src/core/fields/m31.rs @@ -14,7 +14,7 @@ pub const P: u32 = 2147483647; // 2 ** 31 - 1 #[repr(transparent)] #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)] -pub struct M31(u32); +pub struct M31(pub u32); pub type BaseField = M31; impl_field!(M31, P); diff --git a/src/core/utils.rs b/src/core/utils.rs index c781e1ae3..8a6483e36 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -10,6 +10,9 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator { impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { + if log_size == 0 { + return i; + } i.reverse_bits() >> (usize::BITS - log_size) }