From 02997ad14aba656d473e1e670323a03c0c402337 Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Mon, 4 Mar 2024 09:52:46 +0000 Subject: [PATCH] avx vecwise butterflies (#377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change is [Reviewable](https://reviewable.io/reviews/starkware-libs/stwo/377) --- src/core/backend/avx512/fft.rs | 156 +++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index cf9f419b7..b356b9736 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -37,6 +37,22 @@ const ODDS_CONCAT_ODDS: __m512i = unsafe { 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, ]) }; +/// An input to _mm512_permutex2var_epi32, and is used to interleave the low half of a +/// with the low half of b. +const LHALF_INTERLEAVE_LHALF: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, + 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) +}; +/// An input to _mm512_permutex2var_epi32, and is used to interleave the high half of a +/// with the high half of b. +const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, + 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) +}; const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; /// Computes the butterfly operation for packed M31 elements. @@ -139,6 +155,61 @@ pub unsafe fn avx_ibutterfly( (r0, prod) } +/// Runs fft on 2 vectors of 16 M31 elements. +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents natural ordered polynomial coefficeint. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. +/// The first layer (higher bit of the index) takes 2 twiddles. +/// The second layer takes 4 twiddles. +/// etc. +/// # Safety +pub unsafe fn vecwise_butterflies( + mut val0: __m512i, + mut val1: __m512i, + twiddle0_dbl: [i32; 16], + twiddle1_dbl: [i32; 8], + twiddle2_dbl: [i32; 4], + twiddle3_dbl: [i32; 2], +) -> (__m512i, __m512i) { + // TODO(spapini): Compute twiddle0 from twiddle1. + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + // The implementation is the exact reverse of vecwise_ibutterflies(). + // See the comments in its body for more info. + let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), + _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), + _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), + _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + let t: __m512i = std::mem::transmute(twiddle0_dbl); + (val0, val1) = ( + _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), + _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), + ); + (val0, val1) = avx_butterfly(val0, val1, t); + + ( + _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), + _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), + ) +} + /// Runs ifft on 2 vectors of 16 M31 elements. /// This amounts to 4 butterfly layers, each with 16 butterflies. /// Each of the vectors represents a bit reversed evaluation. @@ -319,6 +390,91 @@ mod tests { } } + #[test] + fn test_vecwise_butterflies() { + unsafe { + let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val1 = _mm512_setr_epi32( + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ); + let twiddles0 = [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + ]; + let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; + let twiddles2 = [56, 57, 58, 59]; + let twiddles3 = [60, 61]; + let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); + + let (r0, r1) = vecwise_butterflies( + val0, + val1, + twiddle0_dbl, + twiddle1_dbl, + twiddle2_dbl, + twiddle3_dbl, + ); + + let mut val0: [BaseField; 16] = std::mem::transmute(val0); + let mut val1: [BaseField; 16] = std::mem::transmute(val1); + let r0: [BaseField; 16] = std::mem::transmute(r0); + let r1: [BaseField; 16] = std::mem::transmute(r1); + let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); + let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); + + for i in 0..16 { + let j = i ^ 8; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles3[0]); + butterfly(&mut v10, &mut v11, twiddles3[1]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles2[i / 8]); + butterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles1[i / 4]); + butterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + for i in 0..16 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); + butterfly(&mut v00, &mut v01, twiddles0[i / 2]); + butterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); + (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); + } + + // Compare + for i in 0..16 { + assert_eq!(val0[i], r0[i]); + assert_eq!(val1[i], r1[i]); + } + } + } + #[test] fn test_vecwise_ibutterflies() { unsafe {