Skip to content

Commit

Permalink
avx ifft3
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 3, 2024
1 parent 971615d commit ddccc6a
Showing 1 changed file with 107 additions and 3 deletions.
110 changes: 107 additions & 3 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_sub_epi32,
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
};

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
Expand Down Expand Up @@ -294,6 +294,52 @@ pub unsafe fn vecwise_ibutterflies(
)
}

/// # Safety
pub unsafe fn ifft3(
values: *mut i32,
offset: usize,
step: usize,
twiddles_dbl0: &[i32; 4],
twiddles_dbl1: &[i32; 2],
twiddles_dbl2: &[i32; 1],
) {
let u32_step = step + 4;
// load
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << u32_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << u32_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << u32_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << u32_step)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << u32_step)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << u32_step)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << u32_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << u32_step)).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2]));
(val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3]));

(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1]));
(val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1]));

(val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0]));
(val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0]));
(val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0]));
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// store
_mm512_store_epi32(values.add(offset + (0 << u32_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << u32_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << u32_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << u32_step)), val3);
_mm512_store_epi32(values.add(offset + (4 << u32_step)), val4);
_mm512_store_epi32(values.add(offset + (5 << u32_step)), val5);
_mm512_store_epi32(values.add(offset + (6 << u32_step)), val6);
_mm512_store_epi32(values.add(offset + (7 << u32_step)), val7);
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -558,4 +604,62 @@ mod tests {
}
}
}

#[test]
fn test_ifft3() {
unsafe {
let mut values: Vec<[i32; 16]> = (0..8).map(|i| std::array::from_fn(|_| i)).collect();
let twiddles0 = [32, 33, 34, 35];
let twiddles1 = [36, 37];
let twiddles2 = [38];
let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
ifft3(
std::mem::transmute(values.as_mut_ptr()),
0,
0,
&twiddles0_dbl,
&twiddles1_dbl,
&twiddles2_dbl,
);

let actual: Vec<[BaseField; 16]> = std::mem::transmute(values);
let expected: [u32; 8] = std::array::from_fn(|i| i as u32);
let mut expected: [BaseField; 8] = std::mem::transmute(expected);
let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2);
for i in 0..8 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles2[0]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
assert_eq!(actual[i][0], expected[i]);
}
}
}
}

0 comments on commit ddccc6a

Please sign in to comment.