From 7df206bd54781d19c133ce75b1927b78c34e5344 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Wed, 21 Feb 2024 11:28:43 +0200 Subject: [PATCH] Save twiddles --- src/core/backend/avx512/fft.rs | 182 ++++++++++++++------------------- 1 file changed, 74 insertions(+), 108 deletions(-) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 635ec0219..e5a4a9432 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,7 +1,8 @@ use std::arch::x86_64::{ __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32, - _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32, - _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, + _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_permutexvar_epi32, + _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, + _mm512_xor_epi32, }; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a @@ -68,10 +69,9 @@ pub unsafe fn ifft_lower( ) { assert!(n_fft_bits >= 1); if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { - assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 3)); - assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 2)); - assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << (n_fft_bits + 1)); - assert_eq!(vecwise_twiddle_dbl[3].len(), 1 << n_fft_bits); + assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 2)); + assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 1)); + assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << n_fft_bits); } for h in 0..(1 << (n_total_bits - n_fft_bits)) { // TODO(spapini): @@ -84,10 +84,9 @@ pub unsafe fn ifft_lower( (val0, val1) = vecwise_ibutterflies( val0, val1, - std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 16 + i)), - std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *vecwise_twiddle_dbl[3].get_unchecked(index * 2 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 2 + i)), ); _mm512_store_epi32(values.add(index * 32), val0); _mm512_store_epi32(values.add(index * 32 + 16), val1); @@ -126,7 +125,7 @@ pub unsafe fn ifft_lower( /// val0, val1 are packed M31 elements. 16 M31 words at each. /// Each value is assumed to be in unreduced form, [0, P] including P. /// Returned values are in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form. +/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. /// # Safety /// This function is safe. pub unsafe fn avx_butterfly( @@ -176,7 +175,7 @@ pub unsafe fn avx_butterfly( /// val0 + val1, t (val0 - val1). /// val0, val1 are packed M31 elements. 16 M31 words at each. /// Each value is assumed to be in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form. +/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. /// # Safety /// This function is safe. pub unsafe fn avx_ibutterfly( @@ -280,20 +279,21 @@ pub unsafe fn vecwise_butterflies( /// This amounts to 4 butterfly layers, each with 16 butterflies. /// Each of the vectors represents a bit reversed evaluation. /// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. -/// The first layer (lower bit of the index) takes 16 twiddles. -/// The second layer takes 8 twiddles. -/// etc. +/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the +/// corresponding twiddle. +/// The first layer's twiddles (lower bit of the index) are computed from the second layer's +/// twiddles. The second layer takes 8 twiddles. +/// The third layer takes 4 twiddles. +/// The fourth layer takes 2 twiddles. /// # Safety +/// This function is safe. pub unsafe fn vecwise_ibutterflies( mut val0: __m512i, mut val1: __m512i, - twiddle0_dbl: [i32; 16], twiddle1_dbl: [i32; 8], twiddle2_dbl: [i32; 4], twiddle3_dbl: [i32; 2], ) -> (__m512i, __m512i) { - // TODO(spapini): Compute twiddle0 from twiddle1. // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. @@ -313,9 +313,38 @@ pub unsafe fn vecwise_ibutterflies( // ifft on a // i:abcd - // The twiddles for layer 0 are packed like: - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - let t: __m512i = std::mem::transmute(twiddle0_dbl); + // Start by loading the twiddles for the second layer (layer 1): + // The twiddles for layer 1 are packed like: + // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + + // The twiddles for layer 0 can be computed from the twiddles for layer 1: + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation + // is (0,-1) and not (0,1). This would cause another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates: + // [x, y] + // Works also for inverse of the twiddles. + + // The twiddles for layer 0 are computed like this: + // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] + const INDICES_FROM_T1: __m512i = unsafe { + core::mem::transmute([ + 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, + 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, + ]) + }; + // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. + // Note that this keeps the values as a double of a value in the range [0, P]. + const NEGATION_MASK: __m512i = unsafe { + core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0]) + }; + let t = _mm512_permutexvar_epi32(INDICES_FROM_T1, t1); + let t = _mm512_xor_epi32(t, NEGATION_MASK); + // Apply the permutation, resulting in indexing d:iabc. (val0, val1) = ( _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), @@ -325,8 +354,8 @@ pub unsafe fn vecwise_ibutterflies( // The twiddles for layer 1 are packed like: // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); // Apply the permutation, resulting in indexing c:diab. + let t = t1; (val0, val1) = ( _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), @@ -592,90 +621,6 @@ mod tests { } } - #[test] - fn test_vecwise_ibutterflies() { - unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ); - let twiddles0 = [ - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ]; - let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; - let twiddles2 = [56, 57, 58, 59]; - let twiddles3 = [60, 61]; - let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); - let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); - - let (r0, r1) = vecwise_ibutterflies( - val0, - val1, - twiddle0_dbl, - twiddle1_dbl, - twiddle2_dbl, - twiddle3_dbl, - ); - - let mut val0: [BaseField; 16] = std::mem::transmute(val0); - let mut val1: [BaseField; 16] = std::mem::transmute(val1); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); - let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); - - for i in 0..16 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]); - ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]); - ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]); - ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 8; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles3[0]); - ibutterfly(&mut v10, &mut v11, twiddles3[1]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - // Compare - for i in 0..16 { - assert_eq!(val0[i], r0[i]); - assert_eq!(val1[i], r1[i]); - } - } - } - #[test] fn test_ifft3() { unsafe { @@ -760,6 +705,28 @@ mod tests { res } + #[test] + fn test_twiddle_relation() { + let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain()); + let t0 = ts[0] + .iter() + .copied() + .map(|x| BaseField::from_u32_unchecked((x as u32) / 2)) + .collect::>(); + let t1 = ts[1] + .iter() + .copied() + .map(|x| BaseField::from_u32_unchecked((x as u32) / 2)) + .collect::>(); + + for i in 0..t0.len() / 4 { + assert_eq!(t0[i * 4], t1[i * 2 + 1]); + assert_eq!(t0[i * 4 + 1], -t1[i * 2 + 1]); + assert_eq!(t0[i * 4 + 2], -t1[i * 2]); + assert_eq!(t0[i * 4 + 3], t1[i * 2]); + } + } + fn ref_ifft(domain: CircleDomain, mut values: Vec) -> Vec { bit_reverse(&mut values); let eval = CircleEvaluation::::new(domain, values); @@ -782,7 +749,6 @@ mod tests { let (val0, val1) = vecwise_ibutterflies( std::mem::transmute(values0), std::mem::transmute(values1), - twiddle_dbls[0].clone().try_into().unwrap(), twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), twiddle_dbls[3].clone().try_into().unwrap(), @@ -818,7 +784,7 @@ mod tests { unsafe { ifft_lower( std::mem::transmute(values.data.as_mut_ptr()), - Some(&twiddle_dbls[..4]), + Some(&twiddle_dbls[1..4]), &twiddle_dbls[4..], (log_size - 4) as usize, (log_size - 4) as usize,