Skip to content

Commit

Permalink
avx ifft3
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 4, 2024
1 parent 9709cd8 commit 9afa4c2
Showing 1 changed file with 128 additions and 3 deletions.
131 changes: 128 additions & 3 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_sub_epi32,
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
};

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
Expand Down Expand Up @@ -294,6 +294,67 @@ pub unsafe fn vecwise_ibutterflies(
)
}

/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
/// Each butterfly layer, has 3 AVX butterflies.
/// Total of 12 AVX butterflies.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that
/// need to be transformed. For layer i this is i-4.
/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies.
/// Each layer has 4/2/1 twiddles.
/// # Safety
pub unsafe fn ifft3(
values: *mut i32,
offset: usize,
step_in_vecs: usize,
twiddles_dbl0: &[i32; 4],
twiddles_dbl1: &[i32; 2],
twiddles_dbl2: &[i32; 1],
) {
let step_in_u32s = step_in_vecs + 4;
// Load the 8 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2]));
(val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3]));

// Apply the second layer of butterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1]));
(val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1]));

// Apply the third layer of butterflies.
(val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0]));
(val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0]));
(val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0]));
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// Store the 8 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0);
_mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1);
_mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2);
_mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3);
_mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4);
_mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5);
_mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6);
_mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7);
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -327,6 +388,7 @@ mod tests {
use std::arch::x86_64::_mm512_setr_epi32;

use super::*;
use crate::core::backend::avx512::m31::PackedBaseField;
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;

Expand Down Expand Up @@ -558,4 +620,67 @@ mod tests {
}
}
}

#[test]
fn test_ifft3() {
unsafe {
let mut values: Vec<PackedBaseField> = (0..8)
.map(|i| {
PackedBaseField::from_array(std::array::from_fn(|_| {
BaseField::from_u32_unchecked(i)
}))
})
.collect();
let twiddles0 = [32, 33, 34, 35];
let twiddles1 = [36, 37];
let twiddles2 = [38];
let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
ifft3(
std::mem::transmute(values.as_mut_ptr()),
0,
0,
&twiddles0_dbl,
&twiddles1_dbl,
&twiddles2_dbl,
);

let expected: [u32; 8] = std::array::from_fn(|i| i as u32);
let mut expected: [BaseField; 8] = std::mem::transmute(expected);
let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2);
for i in 0..8 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles2[0]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
assert_eq!(values[i].to_array()[0], expected[i]);
}
}
}
}

0 comments on commit 9afa4c2

Please sign in to comment.