diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 13c5f32cf..ba52b5e69 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -1,7 +1,7 @@ use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32, - _mm512_mul_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64, - _mm512_sub_epi32, + __m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32, + _mm512_min_epu32, _mm512_mul_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi32, + _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32, }; const L: __m512i = unsafe { @@ -195,6 +195,52 @@ pub unsafe fn vecwise_ibutterflies( ) } +/// # Safety +pub unsafe fn ifft3( + values: *mut i32, + offset: usize, + step: usize, + twiddles_dbl0: &[i32; 4], + twiddles_dbl1: &[i32; 2], + twiddles_dbl2: &[i32; 1], +) { + let u32_step = step + 4; + // load + let mut val0 = _mm512_load_epi32(values.add(offset + (0 << u32_step)).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(offset + (1 << u32_step)).cast_const()); + let mut val2 = _mm512_load_epi32(values.add(offset + (2 << u32_step)).cast_const()); + let mut val3 = _mm512_load_epi32(values.add(offset + (3 << u32_step)).cast_const()); + let mut val4 = _mm512_load_epi32(values.add(offset + (4 << u32_step)).cast_const()); + let mut val5 = _mm512_load_epi32(values.add(offset + (5 << u32_step)).cast_const()); + let mut val6 = _mm512_load_epi32(values.add(offset + (6 << u32_step)).cast_const()); + let mut val7 = _mm512_load_epi32(values.add(offset + (7 << u32_step)).cast_const()); + + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); + (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + + (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); + (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + + (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); + (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); + (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); + (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + + // store + _mm512_store_epi32(values.add(offset + (0 << u32_step)), val0); + _mm512_store_epi32(values.add(offset + (1 << u32_step)), val1); + _mm512_store_epi32(values.add(offset + (2 << u32_step)), val2); + _mm512_store_epi32(values.add(offset + (3 << u32_step)), val3); + _mm512_store_epi32(values.add(offset + (4 << u32_step)), val4); + _mm512_store_epi32(values.add(offset + (5 << u32_step)), val5); + _mm512_store_epi32(values.add(offset + (6 << u32_step)), val6); + _mm512_store_epi32(values.add(offset + (7 << u32_step)), val7); +} + #[cfg(test)] mod tests { use std::arch::x86_64::_mm512_setr_epi32; @@ -431,4 +477,62 @@ mod tests { } } } + + #[test] + fn test_ifft3() { + unsafe { + let mut values: Vec<[i32; 16]> = (0..8).map(|i| std::array::from_fn(|_| i)).collect(); + let twiddles0 = [32, 33, 34, 35]; + let twiddles1 = [36, 37]; + let twiddles2 = [38]; + let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); + let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); + let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); + ifft3( + std::mem::transmute(values.as_mut_ptr()), + 0, + 0, + &twiddles0_dbl, + &twiddles1_dbl, + &twiddles2_dbl, + ); + + let actual: Vec<[BaseField; 16]> = std::mem::transmute(values); + let expected: [u32; 8] = std::array::from_fn(|i| i as u32); + let mut expected: [BaseField; 8] = std::mem::transmute(expected); + let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); + let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); + let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ibutterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!(actual[i][0], expected[i]); + } + } + } }