diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 17eed4e6e..02e610d9e 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,6 +1,7 @@ use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epi32, _mm512_permutex2var_epi32, - _mm512_srli_epi64, _mm512_sub_epi32, + __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32, + _mm512_mul_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64, + _mm512_sub_epi32, }; const L: __m512i = unsafe { @@ -15,6 +16,19 @@ const H: __m512i = unsafe { 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, ]) }; + +const L1: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010, + 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, + ]) +}; +const H1: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011, + 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, + ]) +}; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; /// # Safety @@ -80,6 +94,50 @@ pub unsafe fn avx_ibutterfly( (r0, rrm) } +/// # Safety +pub unsafe fn vecwise_ibutterflies( + mut val0: __m512i, + mut val1: __m512i, + twiddle0_dbl: [i32; 16], + twiddle1_dbl: [i32; 8], + twiddle2_dbl: [i32; 4], + twiddle3_dbl: [i32; 2], +) -> (__m512i, __m512i) { + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + let t: __m512i = std::mem::transmute(twiddle0_dbl); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L1, val1), + _mm512_permutex2var_epi32(val0, H1, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L1, val1), + _mm512_permutex2var_epi32(val0, H1, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L1, val1), + _mm512_permutex2var_epi32(val0, H1, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, L1, val1), + _mm512_permutex2var_epi32(val0, H1, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + ( + _mm512_permutex2var_epi32(val0, L1, val1), + _mm512_permutex2var_epi32(val0, H1, val1), + ) +} + #[cfg(test)] mod tests { use std::arch::x86_64::_mm512_setr_epi32; @@ -101,19 +159,19 @@ mod tests { let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); - let val0: [u32; 16] = std::mem::transmute(val0); - let val1: [u32; 16] = std::mem::transmute(val1); - let twiddle: [u32; 16] = std::mem::transmute(twiddle); - let r0: [u32; 16] = std::mem::transmute(r0); - let r1: [u32; 16] = std::mem::transmute(r1); + let val0: [BaseField; 16] = std::mem::transmute(val0); + let val1: [BaseField; 16] = std::mem::transmute(val1); + let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); for i in 0..16 { - let mut x = BaseField::from_u32_unchecked(val0[i]); - let mut y = BaseField::from_u32_unchecked(val1[i]); - let twiddle = BaseField::from_u32_unchecked(twiddle[i]); + let mut x = val0[i]; + let mut y = val1[i]; + let twiddle = twiddle[i]; butterfly(&mut x, &mut y, twiddle); - assert_eq!(x, BaseField::from_u32_unchecked(r0[i])); - assert_eq!(y, BaseField::from_u32_unchecked(r1[i])); + assert_eq!(x, r0[i]); + assert_eq!(y, r1[i]); } } } @@ -131,19 +189,103 @@ mod tests { let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); - let val0: [u32; 16] = std::mem::transmute(val0); - let val1: [u32; 16] = std::mem::transmute(val1); - let twiddle: [u32; 16] = std::mem::transmute(twiddle); - let r0: [u32; 16] = std::mem::transmute(r0); - let r1: [u32; 16] = std::mem::transmute(r1); + let val0: [BaseField; 16] = std::mem::transmute(val0); + let val1: [BaseField; 16] = std::mem::transmute(val1); + let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); for i in 0..16 { - let mut x = BaseField::from_u32_unchecked(val0[i]); - let mut y = BaseField::from_u32_unchecked(val1[i]); - let twiddle = BaseField::from_u32_unchecked(twiddle[i]); + let mut x = val0[i]; + let mut y = val1[i]; + let twiddle = twiddle[i]; ibutterfly(&mut x, &mut y, twiddle); - assert_eq!(x, BaseField::from_u32_unchecked(r0[i])); - assert_eq!(y, BaseField::from_u32_unchecked(r1[i])); + assert_eq!(x, r0[i]); + assert_eq!(y, r1[i]); + } + } + } + + #[test] + fn test_vecwise_ibutterflies() { + unsafe { + let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val1 = _mm512_setr_epi32( + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ); + let twiddles0 = [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + ]; + let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; + let twiddles2 = [56, 57, 58, 59]; + let twiddles3 = [60, 61]; + let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); + + let (r0, r1) = vecwise_ibutterflies( + val0, + val1, + twiddle0_dbl, + twiddle1_dbl, + twiddle2_dbl, + twiddle3_dbl, + ); + + let mut val0: [BaseField; 16] = std::mem::transmute(val0); + let mut val1: [BaseField; 16] = std::mem::transmute(val1); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); + let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); + let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); + + for i in 0..16 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]); + ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]); + ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]); + ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 8; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles3[0]); + ibutterfly(&mut v10, &mut v11, twiddles3[1]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + // Compare + for i in 0..16 { + assert_eq!(val0[i], r0[i]); + assert_eq!(val1[i], r1[i]); } } }