Skip to content

Commit

Permalink
clean up fft
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 4, 2024
1 parent ba639e3 commit a89c843
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,48 @@ pub unsafe fn ifft3(
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
}

/// # Safety
pub unsafe fn ifft2(
values: *mut i32,
offset: usize,
log_step: usize,
twiddles_dbl0: [i32; 2],
twiddles_dbl1: [i32; 1],
) {
let log_u32_step = log_step;
// load
let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const());
let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const());
let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_u32_step)) << 4).cast_const());
let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_u32_step)) << 4).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));

(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));

// store
_mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0);
_mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1);
_mm512_store_epi32(values.add((offset + (2 << log_u32_step)) << 4), val2);
_mm512_store_epi32(values.add((offset + (3 << log_u32_step)) << 4), val3);
}

/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
let log_u32_step = log_step;
// load
let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const());
let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));

// store
_mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0);
_mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1);
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down

0 comments on commit a89c843

Please sign in to comment.