diff --git a/src/core/backend/avx512/fft/ifft.rs b/src/core/backend/avx512/fft/ifft.rs index 9ff691180..60ee0a8f5 100644 --- a/src/core/backend/avx512/fft/ifft.rs +++ b/src/core/backend/avx512/fft/ifft.rs @@ -79,25 +79,24 @@ pub unsafe fn ifft_lower_with_vecwise( for index_h in 0..(1 << (log_size - fft_layers)) { ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); - let mut layer = VECWISE_FFT_BITS; - while fft_layers - layer >= 3 { - ifft3_loop( - values, - &twiddle_dbl[(layer - 1)..], - fft_layers - layer - 3, - layer, - index_h, - ); - layer += 3; - } - match fft_layers - layer { - 2 => { - ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - 1 => { - ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } } - _ => {} } } } @@ -123,26 +122,25 @@ pub unsafe fn ifft_lower_without_vecwise( assert!(log_size >= VECS_LOG_SIZE); for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { - let mut layer = 0; - while fft_layers - layer >= 3 { - ifft3_loop( - values, - &twiddle_dbl[layer..], - fft_layers - layer - 3, - layer + VECS_LOG_SIZE, - index_h, - ); - layer += 3; - } - let fixed_layer = layer + VECS_LOG_SIZE; - match fft_layers - layer { - 2 => { - ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); - } - 1 => { - ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + for layer in (0..fft_layers).step_by(3) { + let fixed_layer = layer + VECS_LOG_SIZE; + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } } - _ => {} } } } @@ -422,7 +420,7 @@ pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { res } -/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. /// Vectorized over the 16 elements of the vectors. /// Used for radix-8 ifft. /// Each butterfly layer, has 3 AVX butterflies. @@ -432,7 +430,7 @@ pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { /// offset - The offset of the first value in the array. /// log_step - The log of the distance in the array, in M31 elements, between each pair of /// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. +/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of ibutterflies. /// Each layer has 4/2/1 twiddles. /// # Safety pub unsafe fn ifft3( @@ -453,19 +451,19 @@ pub unsafe fn ifft3( let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const()); let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const()); - // Apply the first layer of butterflies. + // Apply the first layer of ibutterflies. (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); - // Apply the second layer of butterflies. + // Apply the second layer of ibutterflies. (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); - // Apply the third layer of butterflies. + // Apply the third layer of ibutterflies. (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); @@ -482,17 +480,17 @@ pub unsafe fn ifft3( _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); } -/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. /// Vectorized over the 16 elements of the vectors. /// Used for radix-4 ifft. -/// Each butterfly layer, has 2 AVX butterflies. +/// Each ibutterfly layer, has 2 AVX butterflies. /// Total of 4 AVX butterflies. /// Parameters: /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. /// log_step - The log of the distance in the array, in M31 elements, between each pair of /// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies. +/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of ibutterflies. /// Each layer has 2/1 twiddles. /// # Safety pub unsafe fn ifft2( @@ -523,14 +521,14 @@ pub unsafe fn ifft2( _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); } -/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. /// Vectorized over the 16 elements of the vectors. /// Parameters: /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. /// log_step - The log of the distance in the array, in M31 elements, between each pair of /// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0 - The double of the twiddles for the butterfly layer. +/// twiddles_dbl0 - The double of the twiddles for the ibutterfly layer. /// # Safety pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { // Load the 2 AVX vectors from the array. @@ -695,27 +693,28 @@ mod tests { #[test] fn test_ifft_lower_with_vecwise() { - let log_size = 5 + 3 + 3; - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain); - - unsafe { - ifft_lower_with_vecwise( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls[1..], - log_size as usize, - log_size as usize, - ); - - // Compare. - assert_eq!(values.to_vec(), expected_coeffs); + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft_lower_with_vecwise( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + log_size as usize, + log_size as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } } } @@ -748,7 +747,7 @@ mod tests { #[test] fn test_ifft_full() { - for i in 5..=5 + 3 + 3 { + for i in 5..12 { run_ifft_full_test(i); } } diff --git a/src/core/backend/avx512/fft/rfft.rs b/src/core/backend/avx512/fft/rfft.rs index 6d7703f5d..b1970147b 100644 --- a/src/core/backend/avx512/fft/rfft.rs +++ b/src/core/backend/avx512/fft/rfft.rs @@ -1,14 +1,267 @@ //! Regular (forward) fft. use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi64, _mm512_srli_epi64, + __m512i, _mm512_broadcast_i32x4, _mm512_load_epi32, _mm512_mul_epu32, + _mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, + _mm512_store_epi32, }; use super::{ add_mod_p, compute_first_twiddles, sub_mod_p, EVENS_INTERLEAVE_EVENS, HHALF_INTERLEAVE_HHALF, LHALF_INTERLEAVE_LHALF, ODDS_INTERLEAVE_ODDS, }; +use crate::core::backend::avx512::fft::transpose_vecs; +use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE}; +use crate::core::poly::circle::CircleDomain; +use crate::core::utils::bit_reverse; + +/// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. +/// +/// # Safety +/// This function is unsafe because it takes a raw pointer to i32 values. +/// `values` must be aligned to 64 bytes. +/// +/// # Arguments +/// * `values`: A mutable pointer to the values on which the CFFT is to be performed. +/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// * `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// # Panics +/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. +pub unsafe fn fft(values: *mut i32, twiddle_dbl: &[Vec], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE); + let log_n_vecs = log_n_elements - VECS_LOG_SIZE; + // TODO(spapini): Use CACHED_FFT_LOG_SIZE instead. + if log_n_elements <= 1 { + fft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + + let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); + let fft_layers_post_transpose = log_n_vecs / 2; + fft_lower_without_vecwise( + values, + &twiddle_dbl[(3 + fft_layers_pre_transpose)..], + log_n_elements, + fft_layers_post_transpose, + ); + transpose_vecs(values, log_n_vecs); + fft_lower_with_vecwise( + values, + &twiddle_dbl[..(3 + fft_layers_pre_transpose)], + log_n_elements, + fft_layers_pre_transpose + VECS_LOG_SIZE, + ); +} + +/// Computes partial fft on `2^log_size` M31 elements. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. +/// layer i holds 2^(log_size - 1 - i) twiddles. +/// log_size - The log of the number of number of M31 elements in the array. +/// fft_layers - The number of fft layers to apply, out of log_size. +/// # Safety +/// `values` must be aligned to 64 bytes. +/// `log_size` must be at least 5. +/// `fft_layers` must be at least 5. +pub unsafe fn fft_lower_with_vecwise( + values: *mut i32, + twiddle_dbl: &[Vec], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); + + for index_h in 0..(1 << (log_size - fft_layers)) { + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { + match fft_layers - layer { + 1 => { + fft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + fft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + fft3_loop( + values, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + } + fft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); + } +} + +/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits +/// of the index). +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. +/// log_size - The log of the number of number of M31 elements in the array. +/// fft_layers - The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. +/// +/// # Safety +/// `values` must be aligned to 64 bytes. +/// `log_size` must be at least 4. +/// `fft_layers` must be at least 4. +pub unsafe fn fft_lower_without_vecwise( + values: *mut i32, + twiddle_dbl: &[Vec], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= VECS_LOG_SIZE); + + for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { + for layer in (0..fft_layers).step_by(3).rev() { + let fixed_layer = layer + VECS_LOG_SIZE; + match fft_layers - layer { + 1 => { + fft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + fft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + fft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } + } + } + } +} + +/// Runs the last 5 fft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 fft layers. +/// high_bits - The number of bits this loops needs to run on. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn fft_vecwise_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + (val0, val1) = avx_butterfly( + val0, + val1, + _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + ); + (val0, val1) = vecwise_butterflies( + val0, + val1, + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + _mm512_store_epi32(values.add(index * 32), val0); + _mm512_store_epi32(values.add(index * 32 + 16), val1); + } +} + +/// Runs 3 fft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 fft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first fft layer to apply. +/// The layers `layer`, `layer + 1`, `layer + 2` are applied. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn fft3_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let offset = index << (layer + 3); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + fft3( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + +/// Runs 2 fft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 fft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first fft layer to apply. +/// The layers `layer`, `layer + 1` are applied. +/// index - The index, iterated by the caller. +/// # Safety +unsafe fn fft2_loop(values: *mut i32, twiddle_dbl: &[Vec], layer: usize, index: usize) { + let offset = index << (layer + 2); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + fft2( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 fft layer across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for the fft layer. +/// layer - The layer number of the fft layer to apply. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn fft1_loop(values: *mut i32, twiddle_dbl: &[Vec], layer: usize, index: usize) { + let offset = index << (layer + 1); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + fft1( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} /// Computes the butterfly operation for packed M31 elements. /// val0 + t val1, val0 - t val1. @@ -114,15 +367,160 @@ pub unsafe fn vecwise_butterflies( ) } +pub fn get_twiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push(coset.iter().map(|p| (p.y.0 * 2) as i32).collect::>()); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + +/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 AVX butterflies. +/// Total of 12 AVX butterflies. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. +/// Each layer has 4/2/1 twiddles. +/// # Safety +pub unsafe fn fft3( + values: *mut i32, + offset: usize, + log_step: usize, + twiddles_dbl0: [i32; 4], + twiddles_dbl1: [i32; 2], + twiddles_dbl2: [i32; 1], +) { + // Load the 8 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const()); + + // Apply the third layer of butterflies. + (val0, val4) = avx_butterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); + (val1, val5) = avx_butterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); + (val2, val6) = avx_butterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); + (val3, val7) = avx_butterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + + // Apply the second layer of butterflies. + (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val4, val6) = avx_butterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); + (val5, val7) = avx_butterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + + // Apply the first layer of butterflies. + (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val4, val5) = avx_butterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); + (val6, val7) = avx_butterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + + // Store the 8 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); + _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); + _mm512_store_epi32(values.add(offset + (4 << log_step)), val4); + _mm512_store_epi32(values.add(offset + (5 << log_step)), val5); + _mm512_store_epi32(values.add(offset + (6 << log_step)), val6); + _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); +} + +/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 fft. +/// Each butterfly layer, has 2 AVX butterflies. +/// Total of 4 AVX butterflies. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies. +/// Each layer has 2/1 twiddles. +/// # Safety +pub unsafe fn fft2( + values: *mut i32, + offset: usize, + log_step: usize, + twiddles_dbl0: [i32; 2], + twiddles_dbl1: [i32; 1], +) { + // Load the 4 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + + // Apply the second layer of butterflies. + (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + + // Apply the first layer of butterflies. + (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + + // Store the 4 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); + _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); +} + +/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0 - The double of the twiddles for the butterfly layer. +/// # Safety +pub unsafe fn fft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { + // Load the 2 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + + (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + + // Store the 2 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); +} + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { use std::arch::x86_64::{_mm512_add_epi32, _mm512_set1_epi32, _mm512_setr_epi32}; use super::*; + use crate::core::backend::avx512::{BaseFieldVec, PackedBaseField}; use crate::core::backend::cpu::CPUCirclePoly; use crate::core::fft::butterfly; use crate::core::fields::m31::BaseField; + use crate::core::fields::Column; use crate::core::poly::circle::{CanonicCoset, CircleDomain}; use crate::core::utils::bit_reverse; @@ -156,6 +554,77 @@ mod tests { } } + #[test] + fn test_fft3() { + unsafe { + let mut values: Vec = (0..8) + .map(|i| { + PackedBaseField::from_array(std::array::from_fn(|_| { + BaseField::from_u32_unchecked(i) + })) + }) + .collect(); + let twiddles0 = [32, 33, 34, 35]; + let twiddles1 = [36, 37]; + let twiddles2 = [38]; + let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + fft3( + std::mem::transmute(values.as_mut_ptr()), + 0, + VECS_LOG_SIZE, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, + ); + + let expected: [u32; 8] = std::array::from_fn(|i| i as u32); + let mut expected: [BaseField; 8] = std::mem::transmute(expected); + let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + butterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + butterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + butterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!(values[i].to_array()[0], expected[i]); + } + } + } + + fn ref_fft(domain: CircleDomain, mut values: Vec) -> Vec { + bit_reverse(&mut values); + let poly = CPUCirclePoly::new(values); + let mut expected_values = poly.evaluate(domain).values; + bit_reverse(&mut expected_values); + expected_values + } + #[test] fn test_vecwise_butterflies() { let domain = CanonicCoset::new(5).circle_domain(); @@ -190,32 +659,64 @@ mod tests { } } - fn get_twiddle_dbls(domain: CircleDomain) -> Vec> { - let mut coset = domain.half_coset; + #[test] + fn test_fft_lower() { + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_fft(domain, values.clone()); - let mut res = vec![]; - res.push(coset.iter().map(|p| (p.y.0 * 2) as i32).collect::>()); - bit_reverse(res.last_mut().unwrap()); - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| (p.x.0 * 2) as i32) - .collect::>(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_twiddle_dbls(domain); + + unsafe { + fft_lower_with_vecwise( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + log_size as usize, + log_size as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } } + } + + fn run_fft_full_test(log_size: u32) { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_fft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_twiddle_dbls(domain); - res + unsafe { + transpose_vecs( + std::mem::transmute(values.data.as_mut_ptr()), + (log_size - 4) as usize, + ); + fft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + log_size as usize, + ); + + // Compare. + assert_eq!(values.to_vec(), expected_coeffs); + } } - fn ref_fft(domain: CircleDomain, mut values: Vec) -> Vec { - bit_reverse(&mut values); - let poly = CPUCirclePoly::new(values); - let mut expected_values = poly.evaluate(domain).values; - bit_reverse(&mut expected_values); - expected_values + #[test] + fn test_fft_full() { + for i in 5..12 { + run_fft_full_test(i); + } } }