diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 20e8b79dc..114c75204 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -41,7 +41,7 @@ pub unsafe fn avx_butterfly( let twiddle_dbl_e = twiddle_dbl; let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - // To compute prod = val1 * twiddle start by multipling + // To compute prod = val1 * twiddle start by multiplying // val1_e/o by twiddle_dbl_e/o. let prod_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e); let prod_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o); @@ -72,6 +72,55 @@ pub unsafe fn avx_butterfly( (r0, r1) } +/// Computes the ibutterfly operation for packed M31 elements. +/// val0 + val1, t (val0 - val1). +/// val0, val1 are packed M31 elements. 16 M31 words at each. +/// Each value is assumed to be in unreduced form, [0, P] including P. +/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form. +/// # Safety +/// This function is safe. +pub unsafe fn avx_ibutterfly( + val0: __m512i, + val1: __m512i, + twiddle_dbl: __m512i, +) -> (__m512i, __m512i) { + let r0 = add_mod_p(val0, val1); + let r1 = sub_mod_p(val0, val1); + + // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. + let r1_e = r1; + let r1_o = _mm512_srli_epi64(r1, 32); + let twiddle_dbl_e = twiddle_dbl; + let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); + + // To compute prod = r1 * twiddle start by multiplying + // r1_e/o by twiddle_dbl_e/o. + let prod_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e); + let prod_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o); + + // The result of a multiplication holds r1*twiddle_dbl in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + + // Divide by 2: + let prod_ls = _mm512_srli_epi64(prod_ls, 1); + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl); + // prod_hs - |0|prod_o_h|0|prod_e_h| + + let prod = add_mod_p(prod_ls, prod_hs); + + (r0, prod) +} + // TODO(spapini): Move these to M31 AVX. /// Adds two packed M31 elements, and reduces the result to the range [0,P]. @@ -105,7 +154,7 @@ mod tests { use std::arch::x86_64::_mm512_setr_epi32; use super::*; - use crate::core::fft::butterfly; + use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; #[test] @@ -137,4 +186,34 @@ mod tests { } } } + + #[test] + fn test_ibutterfly() { + unsafe { + let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + let val1 = _mm512_setr_epi32( + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ); + let twiddle = _mm512_setr_epi32( + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + ); + let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); + let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); + + let val0: [u32; 16] = std::mem::transmute(val0); + let val1: [u32; 16] = std::mem::transmute(val1); + let twiddle: [u32; 16] = std::mem::transmute(twiddle); + let r0: [u32; 16] = std::mem::transmute(r0); + let r1: [u32; 16] = std::mem::transmute(r1); + + for i in 0..16 { + let mut x = BaseField::from_u32_unchecked(val0[i]); + let mut y = BaseField::from_u32_unchecked(val1[i]); + let twiddle = BaseField::from_u32_unchecked(twiddle[i]); + ibutterfly(&mut x, &mut y, twiddle); + assert_eq!(x, BaseField::from_u32_unchecked(r0[i])); + assert_eq!(y, BaseField::from_u32_unchecked(r1[i])); + } + } + } }