Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avx ifft3 #378

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_sub_epi32,
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
};

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
Expand Down Expand Up @@ -294,6 +294,67 @@ pub unsafe fn vecwise_ibutterflies(
)
}

/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
/// Each butterfly layer, has 3 AVX butterflies.
/// Total of 12 AVX butterflies.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that
/// need to be transformed. For layer i this is i-4.
/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies.
/// Each layer has 4/2/1 twiddles.
/// # Safety
pub unsafe fn ifft3(
values: *mut i32,
offset: usize,
step_in_vecs: usize,
twiddles_dbl0: &[i32; 4],
twiddles_dbl1: &[i32; 2],
twiddles_dbl2: &[i32; 1],
) {
let step_in_u32s = step_in_vecs + 4;
// Load the 8 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2]));
(val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3]));

// Apply the second layer of butterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1]));
(val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1]));

// Apply the third layer of butterflies.
(val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0]));
(val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0]));
(val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0]));
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// Store the 8 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0);
_mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1);
_mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2);
_mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3);
_mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4);
_mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5);
_mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6);
_mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7);
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -327,6 +388,7 @@ mod tests {
use std::arch::x86_64::_mm512_setr_epi32;

use super::*;
use crate::core::backend::avx512::m31::PackedBaseField;
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;

Expand Down Expand Up @@ -558,4 +620,67 @@ mod tests {
}
}
}

#[test]
fn test_ifft3() {
unsafe {
let mut values: Vec<PackedBaseField> = (0..8)
.map(|i| {
PackedBaseField::from_array(std::array::from_fn(|_| {
BaseField::from_u32_unchecked(i)
}))
})
.collect();
let twiddles0 = [32, 33, 34, 35];
let twiddles1 = [36, 37];
let twiddles2 = [38];
let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
ifft3(
std::mem::transmute(values.as_mut_ptr()),
0,
0,
&twiddles0_dbl,
&twiddles1_dbl,
&twiddles2_dbl,
);

let expected: [u32; 8] = std::array::from_fn(|i| i as u32);
let mut expected: [BaseField; 8] = std::mem::transmute(expected);
let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2);
for i in 0..8 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ibutterfly(&mut v0, &mut v1, twiddles2[0]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
assert_eq!(values[i].to_array()[0], expected[i]);
}
}
}
}