From c6ac5cc4d52141aaebca8a5e1aadfb5a6c28088d Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Tue, 20 Feb 2024 13:09:48 +0200 Subject: [PATCH] avx vecwise ibutterfly --- src/core/backend/avx512/fft.rs | 246 +++++++++++++++++++++++++++++---- 1 file changed, 216 insertions(+), 30 deletions(-) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 114c75204..cf9f419b7 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,24 +1,42 @@ use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_srli_epi64, _mm512_sub_epi32, + __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32, + _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64, + _mm512_sub_epi32, }; -/// L is an input to _mm512_permutex2var_epi32, and is used to interleave the even words of a +/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. -const L: __m512i = unsafe { +const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { core::mem::transmute([ 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, ]) }; -/// H is an input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a +/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a /// with the odd words of b. -const H: __m512i = unsafe { +const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { core::mem::transmute([ 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, ]) }; + +/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a +/// with the even words of b. +const EVENS_CONCAT_EVENS: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010, + 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, + ]) +}; +/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a +/// with the odd words of b. +const ODDS_CONCAT_ODDS: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011, + 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, + ]) +}; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; /// Computes the butterfly operation for packed M31 elements. @@ -53,7 +71,7 @@ pub unsafe fn avx_butterfly( // prod_o_dbl - |0|prod_o_h|prod_o_l|0| // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl); + let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); // prod_ls - |prod_o_l|0|prod_e_l|0| // Divide by 2: @@ -61,7 +79,7 @@ pub unsafe fn avx_butterfly( // prod_ls - |0|prod_o_l|0|prod_e_l| // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl); + let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); // prod_hs - |0|prod_o_h|0|prod_e_h| let prod = add_mod_p(prod_ls, prod_hs); @@ -105,7 +123,7 @@ pub unsafe fn avx_ibutterfly( // prod_o_dbl - |0|prod_o_h|prod_o_l|0| // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl); + let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); // prod_ls - |prod_o_l|0|prod_e_l|0| // Divide by 2: @@ -113,7 +131,7 @@ pub unsafe fn avx_ibutterfly( // prod_ls - |0|prod_o_l|0|prod_e_l| // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl); + let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); // prod_hs - |0|prod_o_h|0|prod_e_h| let prod = add_mod_p(prod_ls, prod_hs); @@ -121,6 +139,90 @@ pub unsafe fn avx_ibutterfly( (r0, prod) } +/// Runs ifft on 2 vectors of 16 M31 elements. +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents a bit reversed evaluation. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. +/// The first layer (lower bit of the index) takes 16 twiddles. +/// The second layer takes 8 twiddles. +/// etc. +/// # Safety +pub unsafe fn vecwise_ibutterflies( + mut val0: __m512i, + mut val1: __m512i, + twiddle0_dbl: [i32; 16], + twiddle1_dbl: [i32; 8], + twiddle2_dbl: [i32; 4], + twiddle3_dbl: [i32; 2], +) -> (__m512i, __m512i) { + // TODO(spapini): Compute twiddle0 from twiddle1. + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + + // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. + // We need to permute the 512-bit registers to get the right order for the butterflies. + // Denote the index of the 16 M31 elements in register i as i:abcd. + // At each layer we apply the following permutation to the index: + // i:abcd => d:iabc + // This is how it looks like at each iteration. + // i:abcd + // d:iabc + // ifft on d + // c:diab + // ifft on c + // b:cdia + // ifft on b + // a:bcid + // ifft on a + // i:abcd + + // The twiddles for layer 0 are unique and arranged as follows: + // 0 1 2 3 4 5 6 7 8 9 a b c d e f + let t: __m512i = std::mem::transmute(twiddle0_dbl); + // Apply the permutation, resulting in indexing d:iabc. + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), + _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + // The twiddles for layer 1 are replicated in the following pattern: + // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + // Apply the permutation, resulting in indexing c:diab. + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), + _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + // The twiddles for layer 2 are replicated in the following pattern: + // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + // Apply the permutation, resulting in indexing b:cdia. + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), + _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + // The twiddles for layer 3 are replicated in the following pattern: + // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 + let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + // Apply the permutation, resulting in indexing a:bcid. + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), + _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), + ); + (val0, val1) = avx_ibutterfly(val0, val1, t); + + // Apply the permutation, resulting in indexing i:abcd. + ( + _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), + _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), + ) +} + // TODO(spapini): Move these to M31 AVX. /// Adds two packed M31 elements, and reduces the result to the range [0,P]. @@ -170,19 +272,19 @@ mod tests { let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); - let val0: [u32; 16] = std::mem::transmute(val0); - let val1: [u32; 16] = std::mem::transmute(val1); - let twiddle: [u32; 16] = std::mem::transmute(twiddle); - let r0: [u32; 16] = std::mem::transmute(r0); - let r1: [u32; 16] = std::mem::transmute(r1); + let val0: [BaseField; 16] = std::mem::transmute(val0); + let val1: [BaseField; 16] = std::mem::transmute(val1); + let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); for i in 0..16 { - let mut x = BaseField::from_u32_unchecked(val0[i]); - let mut y = BaseField::from_u32_unchecked(val1[i]); - let twiddle = BaseField::from_u32_unchecked(twiddle[i]); + let mut x = val0[i]; + let mut y = val1[i]; + let twiddle = twiddle[i]; butterfly(&mut x, &mut y, twiddle); - assert_eq!(x, BaseField::from_u32_unchecked(r0[i])); - assert_eq!(y, BaseField::from_u32_unchecked(r1[i])); + assert_eq!(x, r0[i]); + assert_eq!(y, r1[i]); } } } @@ -200,19 +302,103 @@ mod tests { let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); - let val0: [u32; 16] = std::mem::transmute(val0); - let val1: [u32; 16] = std::mem::transmute(val1); - let twiddle: [u32; 16] = std::mem::transmute(twiddle); - let r0: [u32; 16] = std::mem::transmute(r0); - let r1: [u32; 16] = std::mem::transmute(r1); + let val0: [BaseField; 16] = std::mem::transmute(val0); + let val1: [BaseField; 16] = std::mem::transmute(val1); + let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); for i in 0..16 { - let mut x = BaseField::from_u32_unchecked(val0[i]); - let mut y = BaseField::from_u32_unchecked(val1[i]); - let twiddle = BaseField::from_u32_unchecked(twiddle[i]); + let mut x = val0[i]; + let mut y = val1[i]; + let twiddle = twiddle[i]; ibutterfly(&mut x, &mut y, twiddle); - assert_eq!(x, BaseField::from_u32_unchecked(r0[i])); - assert_eq!(y, BaseField::from_u32_unchecked(r1[i])); + assert_eq!(x, r0[i]); + assert_eq!(y, r1[i]); + } + } + } + + #[test] + fn test_vecwise_ibutterflies() { + unsafe { + let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val1 = _mm512_setr_epi32( + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ); + let twiddles0 = [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + ]; + let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; + let twiddles2 = [56, 57, 58, 59]; + let twiddles3 = [60, 61]; + let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); + + let (r0, r1) = vecwise_ibutterflies( + val0, + val1, + twiddle0_dbl, + twiddle1_dbl, + twiddle2_dbl, + twiddle3_dbl, + ); + + let mut val0: [BaseField; 16] = std::mem::transmute(val0); + let mut val1: [BaseField; 16] = std::mem::transmute(val1); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); + let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); + let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); + + for i in 0..16 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]); + ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]); + ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]); + ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 8; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + ibutterfly(&mut v00, &mut v01, twiddles3[0]); + ibutterfly(&mut v10, &mut v11, twiddles3[1]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + // Compare + for i in 0..16 { + assert_eq!(val0[i], r0[i]); + assert_eq!(val1[i], r1[i]); } } }