From 4997f653095ce535d1da46b2ebcf569cc9f6f544 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Tue, 20 Feb 2024 13:40:07 +0200 Subject: [PATCH] avx vecwise butterflies --- src/core/backend/avx512/fft.rs | 142 +++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 02e610d9e..13c5f32cf 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -29,6 +29,19 @@ const H1: __m512i = unsafe { 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, ]) }; + +const L2: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, + 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) +}; +const H2: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, + 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) +}; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; /// # Safety @@ -94,6 +107,50 @@ pub unsafe fn avx_ibutterfly( (r0, rrm) } +/// # Safety +pub unsafe fn vecwise_butterflies( + mut val0: __m512i, + mut val1: __m512i, + twiddle0_dbl: [i32; 16], + twiddle1_dbl: [i32; 8], + twiddle2_dbl: [i32; 4], + twiddle3_dbl: [i32; 2], +) -> (__m512i, __m512i) { + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L2, val1), + _mm512_permutex2var_epi32(val0, H2, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L2, val1), + _mm512_permutex2var_epi32(val0, H2, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L2, val1), + _mm512_permutex2var_epi32(val0, H2, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t: __m512i = std::mem::transmute(twiddle0_dbl); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L2, val1), + _mm512_permutex2var_epi32(val0, H2, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + ( + _mm512_permutex2var_epi32(val0, L2, val1), + _mm512_permutex2var_epi32(val0, H2, val1), + ) +} + /// # Safety pub unsafe fn vecwise_ibutterflies( mut val0: __m512i, @@ -206,6 +263,91 @@ mod tests { } } + #[test] + fn test_vecwise_butterflies() { + unsafe { + let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val1 = _mm512_setr_epi32( + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ); + let twiddles0 = [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + ]; + let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; + let twiddles2 = [56, 57, 58, 59]; + let twiddles3 = [60, 61]; + let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); + + let (r0, r1) = vecwise_butterflies( + val0, + val1, + twiddle0_dbl, + twiddle1_dbl, + twiddle2_dbl, + twiddle3_dbl, + ); + + let mut val0: [BaseField; 16] = std::mem::transmute(val0); + let mut val1: [BaseField; 16] = std::mem::transmute(val1); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); + let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); + let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); + + for i in 0..16 { + let j = i ^ 8; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles3[0]); + butterfly(&mut v10, &mut v11, twiddles3[1]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles2[i / 8]); + butterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles1[i / 4]); + butterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles0[i / 2]); + butterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + + // Compare + for i in 0..16 { + assert_eq!(val0[i], r0[i]); + assert_eq!(val1[i], r1[i]); + } + } + } + #[test] fn test_vecwise_ibutterflies() { unsafe {