From 20a6d22936b7082bd2015a41323c7443f0ad4496 Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Mon, 4 Mar 2024 11:35:23 +0000 Subject: [PATCH] Save twiddles (#380) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change is [Reviewable](https://reviewable.io/reviews/starkware-libs/stwo/380) --- src/core/backend/avx512/fft.rs | 335 ++++++++++++++------------------- 1 file changed, 141 insertions(+), 194 deletions(-) diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index b995bccb8..16d32edc8 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,7 +1,8 @@ use std::arch::x86_64::{ __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32, - _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32, - _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, + _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_permutexvar_epi32, + _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, + _mm512_xor_epi32, }; use crate::core::backend::avx512::VECS_LOG_SIZE; @@ -80,7 +81,7 @@ pub unsafe fn ifft_lower_with_vecwise( const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; assert!(log_size >= VECWISE_FFT_BITS); - assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 1)); + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); for index_h in 0..(1 << (log_size - fft_layers)) { ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); @@ -88,7 +89,7 @@ pub unsafe fn ifft_lower_with_vecwise( while fft_layers - layer >= 3 { ifft3_loop( values, - &twiddle_dbl[layer..], + &twiddle_dbl[(layer - 1)..], fft_layers - layer - 3, layer, index_h, @@ -121,15 +122,14 @@ unsafe fn ifft_vecwise_loop( (val0, val1) = vecwise_ibutterflies( val0, val1, - std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 16 + i)), - std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *twiddle_dbl[3].get_unchecked(index * 2 + i)), + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); (val0, val1) = avx_ibutterfly( val0, val1, - _mm512_set1_epi32(*twiddle_dbl[4].get_unchecked(index)), + _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), ); _mm512_store_epi32(values.add(index * 32), val0); _mm512_store_epi32(values.add(index * 32 + 16), val1); @@ -179,7 +179,7 @@ unsafe fn ifft3_loop( /// val0, val1 are packed M31 elements. 16 M31 words at each. /// Each value is assumed to be in unreduced form, [0, P] including P. /// Returned values are in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form. +/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. /// # Safety /// This function is safe. pub unsafe fn avx_butterfly( @@ -229,7 +229,7 @@ pub unsafe fn avx_butterfly( /// val0 + val1, t (val0 - val1). /// val0, val1 are packed M31 elements. 16 M31 words at each. /// Each value is assumed to be in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form. +/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. /// # Safety /// This function is safe. pub unsafe fn avx_ibutterfly( @@ -286,7 +286,6 @@ pub unsafe fn avx_ibutterfly( pub unsafe fn vecwise_butterflies( mut val0: __m512i, mut val1: __m512i, - twiddle0_dbl: [i32; 16], twiddle1_dbl: [i32; 8], twiddle2_dbl: [i32; 4], twiddle3_dbl: [i32; 2], @@ -309,19 +308,18 @@ pub unsafe fn vecwise_butterflies( ); (val0, val1) = avx_butterfly(val0, val1, t); - let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + let (t0, t1) = compute_first_twiddles(twiddle1_dbl); (val0, val1) = ( _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), ); - (val0, val1) = avx_butterfly(val0, val1, t); + (val0, val1) = avx_butterfly(val0, val1, t1); - let t: __m512i = std::mem::transmute(twiddle0_dbl); (val0, val1) = ( _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), ); - (val0, val1) = avx_butterfly(val0, val1, t); + (val0, val1) = avx_butterfly(val0, val1, t0); ( _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), @@ -333,20 +331,21 @@ pub unsafe fn vecwise_butterflies( /// This amounts to 4 butterfly layers, each with 16 butterflies. /// Each of the vectors represents a bit reversed evaluation. /// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. -/// The first layer (lower bit of the index) takes 16 twiddles. -/// The second layer takes 8 twiddles. -/// etc. +/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the +/// corresponding twiddle. +/// The first layer's twiddles (lower bit of the index) are computed from the second layer's +/// twiddles. The second layer takes 8 twiddles. +/// The third layer takes 4 twiddles. +/// The fourth layer takes 2 twiddles. /// # Safety +/// This function is safe. pub unsafe fn vecwise_ibutterflies( mut val0: __m512i, mut val1: __m512i, - twiddle0_dbl: [i32; 16], twiddle1_dbl: [i32; 8], twiddle2_dbl: [i32; 4], twiddle3_dbl: [i32; 2], ) -> (__m512i, __m512i) { - // TODO(spapini): Compute twiddle0 from twiddle1. // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. @@ -366,25 +365,21 @@ pub unsafe fn vecwise_ibutterflies( // ifft on a // i:abcd - // The twiddles for layer 0 are unique and arranged as follows: - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - let t: __m512i = std::mem::transmute(twiddle0_dbl); + let (t0, t1) = compute_first_twiddles(twiddle1_dbl); + // Apply the permutation, resulting in indexing d:iabc. (val0, val1) = ( _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), ); - (val0, val1) = avx_ibutterfly(val0, val1, t); + (val0, val1) = avx_ibutterfly(val0, val1, t0); - // The twiddles for layer 1 are replicated in the following pattern: - // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); // Apply the permutation, resulting in indexing c:diab. (val0, val1) = ( _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), ); - (val0, val1) = avx_ibutterfly(val0, val1, t); + (val0, val1) = avx_ibutterfly(val0, val1, t1); // The twiddles for layer 2 are replicated in the following pattern: // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 @@ -413,6 +408,43 @@ pub unsafe fn vecwise_ibutterflies( ) } +/// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. +/// Returns the twiddles for the first layer and the twiddles for the second layer. +/// # Safety +unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { + // Start by loading the twiddles for the second layer (layer 1): + // The twiddles for layer 1 are replicated in the following pattern: + // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + + // The twiddles for layer 0 can be computed from the twiddles for layer 1: + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation + // is (0,-1) and not (0,1). This would cause another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates: + // [x, y] + // Works also for inverse of the twiddles. + + // The twiddles for layer 0 are computed like this: + // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] + const INDICES_FROM_T1: __m512i = unsafe { + core::mem::transmute([ + 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, + 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, + ]) + }; + // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. + // Note that this keeps the values as a double of a value in the range [0, P]. + const NEGATION_MASK: __m512i = unsafe { + core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0]) + }; + let t0 = _mm512_xor_epi32(_mm512_permutexvar_epi32(INDICES_FROM_T1, t1), NEGATION_MASK); + (t0, t1) +} + /// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. /// Vectorized over the 16 elements of the vectors. /// Used for radix-8 ifft. @@ -508,11 +540,11 @@ mod tests { use super::*; use crate::core::backend::avx512::m31::PackedBaseField; use crate::core::backend::avx512::BaseFieldVec; - use crate::core::backend::CPUBackend; + use crate::core::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly}; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; use crate::core::fields::{Column, Field}; - use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; + use crate::core::poly::circle::{CanonicCoset, CircleDomain}; use crate::core::utils::bit_reverse; #[test] @@ -576,171 +608,36 @@ mod tests { } #[test] - fn test_vecwise_butterflies() { - unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ); - let twiddles0 = [ - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ]; - let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; - let twiddles2 = [56, 57, 58, 59]; - let twiddles3 = [60, 61]; - let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); - let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); - - let (r0, r1) = vecwise_butterflies( - val0, - val1, - twiddle0_dbl, - twiddle1_dbl, - twiddle2_dbl, - twiddle3_dbl, - ); - - let mut val0: [BaseField; 16] = std::mem::transmute(val0); - let mut val1: [BaseField; 16] = std::mem::transmute(val1); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); - let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); - - for i in 0..16 { - let j = i ^ 8; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - butterfly(&mut v00, &mut v01, twiddles3[0]); - butterfly(&mut v10, &mut v11, twiddles3[1]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - butterfly(&mut v00, &mut v01, twiddles2[i / 8]); - butterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - butterfly(&mut v00, &mut v01, twiddles1[i / 4]); - butterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - butterfly(&mut v00, &mut v01, twiddles0[i / 2]); - butterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - - // Compare - for i in 0..16 { - assert_eq!(val0[i], r0[i]); - assert_eq!(val1[i], r1[i]); - } - } - } - - #[test] - fn test_vecwise_ibutterflies() { - unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let val1 = _mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + fn test_vecwise_butterflies_real() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_twiddle_dbls(domain); + assert_eq!(twiddle_dbls.len(), 5); + let values0: [i32; 16] = std::array::from_fn(|i| i as i32); + let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); + let result: [BaseField; 32] = unsafe { + let (val0, val1) = avx_butterfly( + std::mem::transmute(values0), + std::mem::transmute(values1), + _mm512_set1_epi32(twiddle_dbls[4][0]), ); - let twiddles0 = [ - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ]; - let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55]; - let twiddles2 = [56, 57, 58, 59]; - let twiddles3 = [60, 61]; - let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); - let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2); - - let (r0, r1) = vecwise_ibutterflies( + let (val0, val1) = vecwise_butterflies( val0, val1, - twiddle0_dbl, - twiddle1_dbl, - twiddle2_dbl, - twiddle3_dbl, + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + twiddle_dbls[3].clone().try_into().unwrap(), ); + std::mem::transmute([val0, val1]) + }; - let mut val0: [BaseField; 16] = std::mem::transmute(val0); - let mut val1: [BaseField; 16] = std::mem::transmute(val1); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2); - let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3); + // ref. + let mut values = values0.to_vec(); + values.extend_from_slice(&values1); + let expected = ref_fft(domain, values.into_iter().map(BaseField::from).collect()); - for i in 0..16 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]); - ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]); - ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]); - ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - for i in 0..16 { - let j = i ^ 8; - if i > j { - continue; - } - let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]); - ibutterfly(&mut v00, &mut v01, twiddles3[0]); - ibutterfly(&mut v10, &mut v11, twiddles3[1]); - (val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11); - } - // Compare - for i in 0..16 { - assert_eq!(val0[i], r0[i]); - assert_eq!(val1[i], r1[i]); - } + // Compare. + for i in 0..32 { + assert_eq!(result[i], expected[i]); } } @@ -807,6 +704,27 @@ mod tests { } } + fn get_twiddle_dbls(domain: CircleDomain) -> Vec> { + let mut coset = domain.half_coset; + + let mut res = vec![]; + res.push(coset.iter().map(|p| (p.y.0 * 2) as i32).collect::>()); + bit_reverse(res.last_mut().unwrap()); + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| (p.x.0 * 2) as i32) + .collect::>(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res + } + fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { let mut coset = domain.half_coset; @@ -833,9 +751,39 @@ mod tests { res } + #[test] + fn test_twiddle_relation() { + let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain()); + let t0 = ts[0] + .iter() + .copied() + .map(|x| BaseField::from_u32_unchecked((x as u32) / 2)) + .collect::>(); + let t1 = ts[1] + .iter() + .copied() + .map(|x| BaseField::from_u32_unchecked((x as u32) / 2)) + .collect::>(); + + for i in 0..t0.len() / 4 { + assert_eq!(t0[i * 4], t1[i * 2 + 1]); + assert_eq!(t0[i * 4 + 1], -t1[i * 2 + 1]); + assert_eq!(t0[i * 4 + 2], -t1[i * 2]); + assert_eq!(t0[i * 4 + 3], t1[i * 2]); + } + } + + fn ref_fft(domain: CircleDomain, mut values: Vec) -> Vec { + bit_reverse(&mut values); + let poly = CPUCirclePoly::new(values); + let mut expected_values = poly.evaluate(domain).values; + bit_reverse(&mut expected_values); + expected_values + } + fn ref_ifft(domain: CircleDomain, mut values: Vec) -> Vec { bit_reverse(&mut values); - let eval = CircleEvaluation::::new(domain, values); + let eval = CPUCircleEvaluation::new(domain, values); let mut expected_coeffs = eval.interpolate().coeffs; for x in expected_coeffs.iter_mut() { *x *= BaseField::from_u32_unchecked(domain.size() as u32); @@ -855,7 +803,6 @@ mod tests { let (val0, val1) = vecwise_ibutterflies( std::mem::transmute(values0), std::mem::transmute(values1), - twiddle_dbls[0].clone().try_into().unwrap(), twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), twiddle_dbls[3].clone().try_into().unwrap(), @@ -891,7 +838,7 @@ mod tests { unsafe { ifft_lower_with_vecwise( std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls, + &twiddle_dbls[1..], log_size as usize, log_size as usize, );