diff --git a/src/core/backend/avx512/fft/ifft.rs b/src/core/backend/avx512/fft/ifft.rs index 7142f5a84..ea26cd88b 100644 --- a/src/core/backend/avx512/fft/ifft.rs +++ b/src/core/backend/avx512/fft/ifft.rs @@ -1,17 +1,13 @@ //! Inverse fft. use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_load_epi32, _mm512_mul_epu32, - _mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, - _mm512_store_epi32, + __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, + _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, }; -use super::{ - add_mod_p, compute_first_twiddles, sub_mod_p, EVENS_CONCAT_EVENS, EVENS_INTERLEAVE_EVENS, - ODDS_CONCAT_ODDS, ODDS_INTERLEAVE_ODDS, -}; -use crate::core::backend::avx512::fft::transpose_vecs; -use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE}; +use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; +use crate::core::backend::avx512::fft::{transpose_vecs, MIN_FFT_LOG_SIZE}; +use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CircleDomain; use crate::core::utils::bit_reverse; @@ -160,8 +156,8 @@ unsafe fn ifft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); + let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); (val0, val1) = vecwise_ibutterflies( val0, val1, @@ -174,8 +170,8 @@ unsafe fn ifft_vecwise_loop( val1, _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), ); - _mm512_store_epi32(values.add(index * 32), val0); - _mm512_store_epi32(values.add(index * 32 + 16), val1); + val0.store(values.add(index * 32)); + val1.store(values.add(index * 32 + 16)); } } @@ -272,16 +268,16 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind /// # Safety /// This function is safe. pub unsafe fn avx_ibutterfly( - val0: __m512i, - val1: __m512i, + val0: PackedBaseField, + val1: PackedBaseField, twiddle_dbl: __m512i, -) -> (__m512i, __m512i) { - let r0 = add_mod_p(val0, val1); - let r1 = sub_mod_p(val0, val1); +) -> (PackedBaseField, PackedBaseField) { + let r0 = val0 + val1; + let r1 = val0 - val1; // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. - let r1_e = r1; - let r1_o = _mm512_srli_epi64(r1, 32); + let r1_e = r1.0; + let r1_o = _mm512_srli_epi64(r1.0, 32); let twiddle_dbl_e = twiddle_dbl; let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); @@ -308,7 +304,7 @@ pub unsafe fn avx_ibutterfly( let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); // prod_hs - |0|prod_o_h|0|prod_e_h| - let prod = add_mod_p(prod_ls, prod_hs); + let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); (r0, prod) } @@ -326,12 +322,12 @@ pub unsafe fn avx_ibutterfly( /// # Safety /// This function is safe. pub unsafe fn vecwise_ibutterflies( - mut val0: __m512i, - mut val1: __m512i, + mut val0: PackedBaseField, + mut val1: PackedBaseField, twiddle1_dbl: [i32; 8], twiddle2_dbl: [i32; 4], twiddle3_dbl: [i32; 2], -) -> (__m512i, __m512i) { +) -> (PackedBaseField, PackedBaseField) { // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. @@ -354,44 +350,29 @@ pub unsafe fn vecwise_ibutterflies( let (t0, t1) = compute_first_twiddles(twiddle1_dbl); // Apply the permutation, resulting in indexing d:iabc. - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), - _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), - ); + (val0, val1) = val0.deinterleave_with(val1); (val0, val1) = avx_ibutterfly(val0, val1, t0); // Apply the permutation, resulting in indexing c:diab. - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), - _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), - ); + (val0, val1) = val0.deinterleave_with(val1); (val0, val1) = avx_ibutterfly(val0, val1, t1); // The twiddles for layer 2 are replicated in the following pattern: // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); // Apply the permutation, resulting in indexing b:cdia. - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), - _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), - ); + (val0, val1) = val0.deinterleave_with(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); // The twiddles for layer 3 are replicated in the following pattern: // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); // Apply the permutation, resulting in indexing a:bcid. - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), - _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), - ); + (val0, val1) = val0.deinterleave_with(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); // Apply the permutation, resulting in indexing i:abcd. - ( - _mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1), - _mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1), - ) + val0.deinterleave_with(val1) } pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec> { @@ -442,14 +423,14 @@ pub unsafe fn ifft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); // Apply the first layer of ibutterflies. (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); @@ -470,14 +451,14 @@ pub unsafe fn ifft3( (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); // Store the 8 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); - _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); - _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); - _mm512_store_epi32(values.add(offset + (4 << log_step)), val4); - _mm512_store_epi32(values.add(offset + (5 << log_step)), val5); - _mm512_store_epi32(values.add(offset + (6 << log_step)), val6); - _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); + val4.store(values.add(offset + (4 << log_step))); + val5.store(values.add(offset + (5 << log_step))); + val6.store(values.add(offset + (6 << log_step))); + val7.store(values.add(offset + (7 << log_step))); } /// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. @@ -501,10 +482,10 @@ pub unsafe fn ifft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); // Apply the first layer of butterflies. (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); @@ -515,10 +496,10 @@ pub unsafe fn ifft2( (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); // Store the 4 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); - _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); - _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); } /// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. @@ -532,14 +513,14 @@ pub unsafe fn ifft2( /// # Safety pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { // Load the 2 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); // Store the 2 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); } #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] @@ -559,10 +540,12 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let val1 = _mm512_setr_epi32( + let val0 = PackedBaseField(_mm512_setr_epi32( + 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + )); + let val1 = PackedBaseField(_mm512_setr_epi32( 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ); + )); let twiddle = _mm512_setr_epi32( 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, ); diff --git a/src/core/backend/avx512/fft/mod.rs b/src/core/backend/avx512/fft/mod.rs index 5482713b7..32a7863be 100644 --- a/src/core/backend/avx512/fft/mod.rs +++ b/src/core/backend/avx512/fft/mod.rs @@ -1,6 +1,6 @@ use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_min_epu32, - _mm512_permutexvar_epi32, _mm512_store_epi32, _mm512_sub_epi32, _mm512_xor_epi32, + __m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32, + _mm512_store_epi32, _mm512_xor_epi32, }; pub mod ifft; @@ -23,39 +23,8 @@ const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { ]) }; -/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a -/// with the even words of b. -const EVENS_CONCAT_EVENS: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010, - 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a -/// with the odd words of b. -const ODDS_CONCAT_ODDS: __m512i = unsafe { - core::mem::transmute([ - 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011, - 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the low half of a -/// with the low half of b. -const LHALF_INTERLEAVE_LHALF: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, - 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the high half of a -/// with the high half of b. -const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { - core::mem::transmute([ - 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, - 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, - ]) -}; -const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; +pub const CACHED_FFT_LOG_SIZE: usize = 16; +pub const MIN_FFT_LOG_SIZE: usize = 5; // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. @@ -128,34 +97,6 @@ unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { (t0, t1) } -// TODO(spapini): Move these to M31 AVX. - -/// Adds two packed M31 elements, and reduces the result to the range [0,P]. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// # Safety -/// This function is safe. -pub unsafe fn add_mod_p(a: __m512i, b: __m512i) -> __m512i { - // Add word by word. Each word is in the range [0, 2P]. - let c = _mm512_add_epi32(a, b); - // Apply min(c, c-P) to each word. - // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. - // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. - _mm512_min_epu32(c, _mm512_sub_epi32(c, P)) -} - -/// Subtracts two packed M31 elements, and reduces the result to the range [0,P]. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// # Safety -/// This function is safe. -pub unsafe fn sub_mod_p(a: __m512i, b: __m512i) -> __m512i { - // Subtract word by word. Each word is in the range [-P, P]. - let c = _mm512_sub_epi32(a, b); - // Apply min(c, c+P) to each word. - // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. - // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than [2^32-P,2^32-1]. - _mm512_min_epu32(_mm512_add_epi32(c, P), c) -} - #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { diff --git a/src/core/backend/avx512/fft/rfft.rs b/src/core/backend/avx512/fft/rfft.rs index 985ec3280..d4a9ac12a 100644 --- a/src/core/backend/avx512/fft/rfft.rs +++ b/src/core/backend/avx512/fft/rfft.rs @@ -1,17 +1,13 @@ //! Regular (forward) fft. use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_load_epi32, _mm512_mul_epu32, - _mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, - _mm512_store_epi32, + __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, + _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, }; -use super::{ - add_mod_p, compute_first_twiddles, sub_mod_p, EVENS_INTERLEAVE_EVENS, HHALF_INTERLEAVE_HHALF, - LHALF_INTERLEAVE_LHALF, ODDS_INTERLEAVE_ODDS, -}; -use crate::core::backend::avx512::fft::transpose_vecs; -use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE}; +use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; +use crate::core::backend::avx512::fft::{transpose_vecs, MIN_FFT_LOG_SIZE}; +use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; use crate::core::poly::circle::CircleDomain; use crate::core::utils::bit_reverse; @@ -159,8 +155,8 @@ unsafe fn fft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); + let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); (val0, val1) = avx_butterfly( val0, val1, @@ -173,8 +169,8 @@ unsafe fn fft_vecwise_loop( std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); - _mm512_store_epi32(values.add(index * 32), val0); - _mm512_store_epi32(values.add(index * 32 + 16), val1); + val0.store(values.add(index * 32)); + val1.store(values.add(index * 32 + 16)); } } @@ -272,14 +268,14 @@ unsafe fn fft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, inde /// # Safety /// This function is safe. pub unsafe fn avx_butterfly( - val0: __m512i, - val1: __m512i, + val0: PackedBaseField, + val1: PackedBaseField, twiddle_dbl: __m512i, -) -> (__m512i, __m512i) { +) -> (PackedBaseField, PackedBaseField) { // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of val0. - let val1_e = val1; + let val1_e = val1.0; // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of val0. - let val1_o = _mm512_srli_epi64(val1, 32); + let val1_o = _mm512_srli_epi64(val1.0, 32); let twiddle_dbl_e = twiddle_dbl; let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); @@ -306,10 +302,10 @@ pub unsafe fn avx_butterfly( let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); // prod_hs - |0|prod_o_h|0|prod_e_h| - let prod = add_mod_p(prod_ls, prod_hs); + let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); - let r0 = add_mod_p(val0, prod); - let r1 = sub_mod_p(val0, prod); + let r0 = val0 + prod; + let r1 = val0 - prod; (r0, r1) } @@ -324,47 +320,32 @@ pub unsafe fn avx_butterfly( /// etc. /// # Safety pub unsafe fn vecwise_butterflies( - mut val0: __m512i, - mut val1: __m512i, + mut val0: PackedBaseField, + mut val1: PackedBaseField, twiddle1_dbl: [i32; 8], twiddle2_dbl: [i32; 4], twiddle3_dbl: [i32; 2], -) -> (__m512i, __m512i) { +) -> (PackedBaseField, PackedBaseField) { // TODO(spapini): Compute twiddle0 from twiddle1. // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // The implementation is the exact reverse of vecwise_ibutterflies(). // See the comments in its body for more info. let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), - _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), - ); + (val0, val1) = val0.interleave_with(val1); (val0, val1) = avx_butterfly(val0, val1, t); let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), - _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), - ); + (val0, val1) = val0.interleave_with(val1); (val0, val1) = avx_butterfly(val0, val1, t); let (t0, t1) = compute_first_twiddles(twiddle1_dbl); - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), - _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), - ); + (val0, val1) = val0.interleave_with(val1); (val0, val1) = avx_butterfly(val0, val1, t1); - (val0, val1) = ( - _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), - _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), - ); + (val0, val1) = val0.interleave_with(val1); (val0, val1) = avx_butterfly(val0, val1, t0); - ( - _mm512_permutex2var_epi32(val0, LHALF_INTERLEAVE_LHALF, val1), - _mm512_permutex2var_epi32(val0, HHALF_INTERLEAVE_HHALF, val1), - ) + val0.interleave_with(val1) } pub fn get_twiddle_dbls(domain: CircleDomain) -> Vec> { @@ -410,14 +391,14 @@ pub unsafe fn fft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); // Apply the third layer of butterflies. (val0, val4) = avx_butterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); @@ -438,14 +419,14 @@ pub unsafe fn fft3( (val6, val7) = avx_butterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); // Store the 8 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); - _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); - _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); - _mm512_store_epi32(values.add(offset + (4 << log_step)), val4); - _mm512_store_epi32(values.add(offset + (5 << log_step)), val5); - _mm512_store_epi32(values.add(offset + (6 << log_step)), val6); - _mm512_store_epi32(values.add(offset + (7 << log_step)), val7); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); + val4.store(values.add(offset + (4 << log_step))); + val5.store(values.add(offset + (5 << log_step))); + val6.store(values.add(offset + (6 << log_step))); + val7.store(values.add(offset + (7 << log_step))); } /// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. @@ -469,10 +450,10 @@ pub unsafe fn fft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); // Apply the second layer of butterflies. (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); @@ -483,10 +464,10 @@ pub unsafe fn fft2( (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); // Store the 4 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); - _mm512_store_epi32(values.add(offset + (2 << log_step)), val2); - _mm512_store_epi32(values.add(offset + (3 << log_step)), val3); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); } /// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. @@ -500,14 +481,14 @@ pub unsafe fn fft2( /// # Safety pub unsafe fn fft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { // Load the 2 AVX vectors from the array. - let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const()); + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); // Store the 2 AVX vectors back to the array. - _mm512_store_epi32(values.add(offset + (0 << log_step)), val0); - _mm512_store_epi32(values.add(offset + (1 << log_step)), val1); + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); } #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] @@ -526,10 +507,12 @@ mod tests { #[test] fn test_butterfly() { unsafe { - let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - let val1 = _mm512_setr_epi32( + let val0 = PackedBaseField(_mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + )); + let val1 = PackedBaseField(_mm512_setr_epi32( 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ); + )); let twiddle = _mm512_setr_epi32( 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, ); diff --git a/src/core/backend/avx512/m31.rs b/src/core/backend/avx512/m31.rs index ba8a36874..cbb1a13b7 100644 --- a/src/core/backend/avx512/m31.rs +++ b/src/core/backend/avx512/m31.rs @@ -2,7 +2,7 @@ use core::arch::x86_64::{ __m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_srli_epi64, _mm512_sub_epi32, }; -use std::arch::x86_64::_mm512_permutex2var_epi32; +use std::arch::x86_64::{_mm512_load_epi32, _mm512_permutex2var_epi32, _mm512_store_epi32}; use std::fmt::Display; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -14,11 +14,45 @@ use crate::core::fields::FieldExpOps; pub const K_BLOCK_SIZE: usize = 16; pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) }; +/// An input to _mm512_permutex2var_epi32, and is used to interleave the low half of a +/// with the low half of b. +pub const LHALF_INTERLEAVE_LHALF: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, + 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) +}; +/// An input to _mm512_permutex2var_epi32, and is used to interleave the high half of a +/// with the high half of b. +pub const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, + 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) +}; + +/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a +/// with the even words of b. +pub const EVENS_CONCAT_EVENS: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010, + 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, + ]) +}; +/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a +/// with the odd words of b. +pub const ODDS_CONCAT_ODDS: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011, + 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, + ]) +}; + /// AVX512 implementation of M31. /// Stores 16 M31 elements in a single 512-bit register. /// Each M31 element is unreduced in the range [0, P]. #[derive(Copy, Clone, Debug)] -pub struct PackedBaseField(__m512i); +pub struct PackedBaseField(pub __m512i); impl PackedBaseField { pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> PackedBaseField { @@ -37,6 +71,42 @@ impl PackedBaseField { pub fn reduce(self) -> PackedBaseField { Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) }) } + + /// Interleaves self with other. + /// Returns the result as two packed M31 elements. + pub fn interleave_with(self, other: Self) -> (Self, Self) { + ( + Self(unsafe { _mm512_permutex2var_epi32(self.0, LHALF_INTERLEAVE_LHALF, other.0) }), + Self(unsafe { _mm512_permutex2var_epi32(self.0, HHALF_INTERLEAVE_HHALF, other.0) }), + ) + } + + /// Deinterleaves self with other. + /// Done by concatenating the even words of self with the even words of other, and the odd words + /// The inverse of [Self::interleave_with]. + /// Returns the result as two packed M31 elements. + pub fn deinterleave_with(self, other: Self) -> (Self, Self) { + ( + Self(unsafe { _mm512_permutex2var_epi32(self.0, EVENS_CONCAT_EVENS, other.0) }), + Self(unsafe { _mm512_permutex2var_epi32(self.0, ODDS_CONCAT_ODDS, other.0) }), + ) + } + + /// # Safety + /// + /// This function is unsafe because it performs a load from a raw pointer. The pointer must be + /// valid and aligned to 64 bytes. + pub unsafe fn load(ptr: *const i32) -> Self { + Self(_mm512_load_epi32(ptr)) + } + + /// # Safety + /// + /// This function is unsafe because it performs a load from a raw pointer. The pointer must be + /// valid and aligned to 64 bytes. + pub unsafe fn store(self, ptr: *mut i32) { + _mm512_store_epi32(ptr, self.0); + } } impl Display for PackedBaseField {