Skip to content

Commit

Permalink
Use PackedBaseField in avx fft
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 11, 2024
1 parent 41f6f0a commit 1768412
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 157 deletions.
125 changes: 63 additions & 62 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
//! Inverse fft.
use std::arch::x86_64::{
__m512i, _mm512_broadcast_i32x4, _mm512_load_epi32, _mm512_mul_epu32,
_mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_store_epi32,
__m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32,
_mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
};

use super::{
add_mod_p, compute_first_twiddles, sub_mod_p, EVENS_CONCAT_EVENS, EVENS_INTERLEAVE_EVENS,
ODDS_CONCAT_ODDS, ODDS_INTERLEAVE_ODDS,
compute_first_twiddles, EVENS_CONCAT_EVENS, EVENS_INTERLEAVE_EVENS, ODDS_CONCAT_ODDS,
ODDS_INTERLEAVE_ODDS,
};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::VECS_LOG_SIZE;
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
Expand Down Expand Up @@ -159,8 +158,8 @@ unsafe fn ifft_vecwise_loop(
) {
for index_l in 0..(1 << loop_bits) {
let index = (index_h << loop_bits) + index_l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const());
let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
Expand All @@ -173,8 +172,8 @@ unsafe fn ifft_vecwise_loop(
val1,
_mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
val0.store(values.add(index * 32));
val1.store(values.add(index * 32 + 16));
}
}

Expand Down Expand Up @@ -271,16 +270,16 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind
/// # Safety
/// This function is safe.
pub unsafe fn avx_ibutterfly(
val0: __m512i,
val1: __m512i,
val0: PackedBaseField,
val1: PackedBaseField,
twiddle_dbl: __m512i,
) -> (__m512i, __m512i) {
let r0 = add_mod_p(val0, val1);
let r1 = sub_mod_p(val0, val1);
) -> (PackedBaseField, PackedBaseField) {
let r0 = val0 + val1;
let r1 = val0 - val1;

// Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values.
let r1_e = r1;
let r1_o = _mm512_srli_epi64(r1, 32);
let r1_e = r1.0;
let r1_o = _mm512_srli_epi64(r1.0, 32);
let twiddle_dbl_e = twiddle_dbl;
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);

Expand All @@ -307,7 +306,7 @@ pub unsafe fn avx_ibutterfly(
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);
let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs);

(r0, prod)
}
Expand All @@ -325,12 +324,12 @@ pub unsafe fn avx_ibutterfly(
/// # Safety
/// This function is safe.
pub unsafe fn vecwise_ibutterflies(
mut val0: __m512i,
mut val1: __m512i,
mut val0: PackedBaseField,
mut val1: PackedBaseField,
twiddle1_dbl: [i32; 8],
twiddle2_dbl: [i32; 4],
twiddle3_dbl: [i32; 2],
) -> (__m512i, __m512i) {
) -> (PackedBaseField, PackedBaseField) {
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.

// Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element.
Expand All @@ -354,15 +353,15 @@ pub unsafe fn vecwise_ibutterflies(

// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
val0.permute_with(EVENS_CONCAT_EVENS, val1),
val0.permute_with(ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t0);

// Apply the permutation, resulting in indexing c:diab.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
val0.permute_with(EVENS_CONCAT_EVENS, val1),
val0.permute_with(ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t1);

Expand All @@ -371,8 +370,8 @@ pub unsafe fn vecwise_ibutterflies(
let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl));
// Apply the permutation, resulting in indexing b:cdia.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
val0.permute_with(EVENS_CONCAT_EVENS, val1),
val0.permute_with(ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

Expand All @@ -381,15 +380,15 @@ pub unsafe fn vecwise_ibutterflies(
let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl));
// Apply the permutation, resulting in indexing a:bcid.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
val0.permute_with(EVENS_CONCAT_EVENS, val1),
val0.permute_with(ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// Apply the permutation, resulting in indexing i:abcd.
(
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
val0.permute_with(EVENS_CONCAT_EVENS, val1),
val0.permute_with(ODDS_CONCAT_ODDS, val1),
)
}

Expand Down Expand Up @@ -439,14 +438,14 @@ pub unsafe fn ifft3(
twiddles_dbl2: [i32; 1],
) {
// Load the 8 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const());

// Apply the first layer of ibutterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
Expand All @@ -467,14 +466,14 @@ pub unsafe fn ifft3(
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// Store the 8 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
_mm512_store_epi32(values.add(offset + (4 << log_step)), val4);
_mm512_store_epi32(values.add(offset + (5 << log_step)), val5);
_mm512_store_epi32(values.add(offset + (6 << log_step)), val6);
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
val4.store(values.add(offset + (4 << log_step)));
val5.store(values.add(offset + (5 << log_step)));
val6.store(values.add(offset + (6 << log_step)));
val7.store(values.add(offset + (7 << log_step)));
}

/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
Expand All @@ -498,10 +497,10 @@ pub unsafe fn ifft2(
twiddles_dbl1: [i32; 1],
) {
// Load the 4 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
Expand All @@ -512,10 +511,10 @@ pub unsafe fn ifft2(
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));

// Store the 4 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
}

/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
Expand All @@ -529,14 +528,14 @@ pub unsafe fn ifft2(
/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
// Load the 2 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));

// Store the 2 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand All @@ -556,10 +555,12 @@ mod tests {
#[test]
fn test_ibutterfly() {
unsafe {
let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
let val0 = PackedBaseField(_mm512_setr_epi32(
2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
));
let val1 = PackedBaseField(_mm512_setr_epi32(
3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
));
let twiddle = _mm512_setr_epi32(
1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
Expand Down
33 changes: 2 additions & 31 deletions src/core/backend/avx512/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_min_epu32,
_mm512_permutexvar_epi32, _mm512_store_epi32, _mm512_sub_epi32, _mm512_xor_epi32,
__m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32,
_mm512_store_epi32, _mm512_xor_epi32,
};

pub mod ifft;
Expand Down Expand Up @@ -55,7 +55,6 @@ const HHALF_INTERLEAVE_HHALF: __m512i = unsafe {
0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111,
])
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

pub const CACHED_FFT_LOG_SIZE: usize = 16;
pub const MIN_FFT_LOG_SIZE: usize = 5;
Expand Down Expand Up @@ -131,34 +130,6 @@ unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) {
(t0, t1)
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// # Safety
/// This function is safe.
pub unsafe fn add_mod_p(a: __m512i, b: __m512i) -> __m512i {
// Add word by word. Each word is in the range [0, 2P].
let c = _mm512_add_epi32(a, b);
// Apply min(c, c-P) to each word.
// When c in [P,2P], then c-P in [0,P] which is always less than [P,2P].
// When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1].
_mm512_min_epu32(c, _mm512_sub_epi32(c, P))
}

/// Subtracts two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// # Safety
/// This function is safe.
pub unsafe fn sub_mod_p(a: __m512i, b: __m512i) -> __m512i {
// Subtract word by word. Each word is in the range [-P, P].
let c = _mm512_sub_epi32(a, b);
// Apply min(c, c+P) to each word.
// When c in [0,P], then c+P in [P,2P] which is always greater than [0,P].
// When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than [2^32-P,2^32-1].
_mm512_min_epu32(_mm512_add_epi32(c, P), c)
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
Expand Down
Loading

0 comments on commit 1768412

Please sign in to comment.