diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 09fbfb751..512bf2acb 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -169,8 +169,14 @@ pub unsafe fn ifft_lower_with_vecwise( ); layer += 3; } - if fft_layers - layer != 0 { - todo!() + match fft_layers - layer { + 2 => { + ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 1 => { + ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => {} } } } @@ -207,8 +213,15 @@ pub unsafe fn ifft_lower_without_vecwise( ); layer += 3; } - if fft_layers - layer != 0 { - todo!() + let fixed_layer = layer + VECS_LOG_SIZE; + match fft_layers - layer { + 2 => { + ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 1 => { + ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => {} } } } @@ -285,6 +298,53 @@ unsafe fn ifft3_loop( } } +/// Runs 2 ifft layers across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 ifft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first ifft layer to apply. +/// The layers `layer`, `layer + 1` are applied. +/// index - The index, iterated by the caller. +/// # Safety +unsafe fn ifft2_loop(values: *mut i32, twiddle_dbl: &[Vec], layer: usize, index: usize) { + let offset = index << (layer + 2); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + ifft2( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 ifft layer across the entire array. +/// Parameters: +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for the ifft layer. +/// layer - The layer number of the ifft layer to apply. +/// index_h - The higher part of the index, iterated by the caller. +/// # Safety +unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[Vec], layer: usize, index: usize) { + let offset = index << (layer + 1); + for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + ifft1( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} + /// Computes the butterfly operation for packed M31 elements. /// val0 + t val1, val0 - t val1. /// val0, val1 are packed M31 elements. 16 M31 words at each. @@ -642,6 +702,68 @@ pub unsafe fn ifft3( _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); } +/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 ifft. +/// Each butterfly layer, has 2 AVX butterflies. +/// Total of 4 AVX butterflies. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies. +/// Each layer has 2/1 twiddles. +/// # Safety +pub unsafe fn ifft2( + values: *mut i32, + offset: usize, + log_step: usize, + twiddles_dbl0: [i32; 2], + twiddles_dbl1: [i32; 1], +) { + // Load the 4 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + + // Apply the first layer of butterflies. + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + + // Apply the second layer of butterflies. + (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + + // Store the 4 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); + _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); +} + +/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0 - The double of the twiddles for the butterfly layer. +/// # Safety +pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { + // Load the 2 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + + // Store the 2 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); +} + // TODO(spapini): Move these to M31 AVX. /// Adds two packed M31 elements, and reduces the result to the range [0,P]. @@ -987,7 +1109,9 @@ mod tests { } #[test] - fn test_ifft_full_11() { - run_ifft_full_test(3 + 3 + 5); + fn test_ifft_full() { + for i in 5..=5 + 3 + 3 { + run_ifft_full_test(i); + } } }