diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index b356b9736..f4890990f 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,7 +1,7 @@ use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32, - _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64, - _mm512_sub_epi32, + __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32, + _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32, + _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, }; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a @@ -294,6 +294,68 @@ pub unsafe fn vecwise_ibutterflies( ) } +/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 AVX butterflies. +/// Total of 12 AVX butterflies. +/// Parameters: +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that +/// need to be transformed. For layer i this is i-4. +/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. +/// Each layer has 4/2/1 twiddles. +/// +/// # Safety +pub unsafe fn ifft3( + values: *mut i32, + offset: usize, + step_in_vecs: usize, + twiddles_dbl0: &[i32; 4], + twiddles_dbl1: &[i32; 2], + twiddles_dbl2: &[i32; 1], +) { + let step_in_u32s = step_in_vecs + 4; + // Load the 8 AVX vectors from the array. + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const()); + let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const()); + let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const()); + let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const()); + let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const()); + + // Apply the first layer of butterflies. + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); + (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + + // Apply the second layer of butterflies. + (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); + (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + + // Apply the third layer of butterflies. + (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); + (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); + (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); + (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + + // Store the 8 AVX vectors back to the array. + _mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0); + _mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1); + _mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2); + _mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3); + _mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4); + _mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5); + _mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6); + _mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7); +} + // TODO(spapini): Move these to M31 AVX. /// Adds two packed M31 elements, and reduces the result to the range [0,P]. @@ -327,6 +389,7 @@ mod tests { use std::arch::x86_64::_mm512_setr_epi32; use super::*; + use crate::core::backend::avx512::m31::PackedBaseField; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; @@ -558,4 +621,67 @@ mod tests { } } } + + #[test] + fn test_ifft3() { + unsafe { + let mut values: Vec = (0..8) + .map(|i| { + PackedBaseField::from_array(std::array::from_fn(|_| { + BaseField::from_u32_unchecked(i) + })) + }) + .collect(); + let twiddles0 = [32, 33, 34, 35]; + let twiddles1 = [36, 37]; + let twiddles2 = [38]; + let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + ifft3( + std::mem::transmute(values.as_mut_ptr()), + 0, + 0, + &twiddles0_dbl, + &twiddles1_dbl, + &twiddles2_dbl, + ); + + let expected: [u32; 8] = std::array::from_fn(|i| i as u32); + let mut expected: [BaseField; 8] = std::mem::transmute(expected); + let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!(values[i].to_array()[0], expected[i]); + } + } + } }