Skip to content

Commit

Permalink
Support ifft for intermediate sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 4, 2024
1 parent 4c0af3a commit 95260d6
Showing 1 changed file with 130 additions and 6 deletions.
136 changes: 130 additions & 6 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ pub unsafe fn ifft_lower_with_vecwise(
);
layer += 3;
}
if fft_layers - layer != 0 {
todo!()
match fft_layers - layer {
2 => {
ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
1 => {
ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
_ => {}
}
}
}
Expand Down Expand Up @@ -208,8 +214,15 @@ pub unsafe fn ifft_lower_without_vecwise(
);
layer += 3;
}
if fft_layers - layer != 0 {
todo!()
let fixed_layer = layer + VECS_LOG_SIZE;
match fft_layers - layer {
2 => {
ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
1 => {
ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
_ => {}
}
}
}
Expand Down Expand Up @@ -286,6 +299,53 @@ unsafe fn ifft3_loop(
}
}

/// Runs 2 ifft layers across the entire array.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 ifft layers.
/// loop_bits - The number of bits this loops needs to run on.
/// layer - The layer number of the first ifft layer to apply.
/// The layers `layer`, `layer + 1` are applied.
/// index - The index, iterated by the caller.
/// # Safety
unsafe fn ifft2_loop(values: *mut i32, twiddle_dbl: &[Vec<i32>], layer: usize, index: usize) {
let offset = index << (layer + 2);
for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) {
ifft2(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1))
}),
);
}
}

/// Runs 1 ifft layer across the entire array.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
/// twiddle_dbl - The doubles of the twiddle factors for the ifft layer.
/// layer - The layer number of the ifft layer to apply.
/// index_h - The higher part of the index, iterated by the caller.
/// # Safety
unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[Vec<i32>], layer: usize, index: usize) {
let offset = index << (layer + 1);
for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) {
ifft1(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1))
}),
);
}
}

/// Computes the butterfly operation for packed M31 elements.
/// val0 + t val1, val0 - t val1.
/// val0, val1 are packed M31 elements. 16 M31 words at each.
Expand Down Expand Up @@ -643,6 +703,68 @@ pub unsafe fn ifft3(
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
}

/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-4 ifft.
/// Each butterfly layer, has 2 AVX butterflies.
/// Total of 4 AVX butterflies.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies.
/// Each layer has 2/1 twiddles.
/// # Safety
pub unsafe fn ifft2(
values: *mut i32,
offset: usize,
log_step: usize,
twiddles_dbl0: [i32; 2],
twiddles_dbl1: [i32; 1],
) {
// Load the 4 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));

// Apply the second layer of butterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));

// Store the 4 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
}

/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0 - The double of the twiddles for the butterfly layer.
/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
// Load the 2 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));

// Store the 2 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -1014,7 +1136,9 @@ mod tests {
}

#[test]
fn test_ifft_full_11() {
run_ifft_full_test(3 + 3 + 5);
fn test_ifft_full() {
for i in 5..=5 + 3 + 3 {
run_ifft_full_test(i);
}
}
}

0 comments on commit 95260d6

Please sign in to comment.