Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PackedBaseField in avx fft #455

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 58 additions & 75 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
//! Inverse fft.

use std::arch::x86_64::{
__m512i, _mm512_broadcast_i32x4, _mm512_load_epi32, _mm512_mul_epu32,
_mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_store_epi32,
__m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32,
_mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
};

use super::{
add_mod_p, compute_first_twiddles, sub_mod_p, EVENS_CONCAT_EVENS, EVENS_INTERLEAVE_EVENS,
ODDS_CONCAT_ODDS, ODDS_INTERLEAVE_ODDS,
};
use crate::core::backend::avx512::fft::transpose_vecs;
use crate::core::backend::avx512::{MIN_FFT_LOG_SIZE, VECS_LOG_SIZE};
use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CircleDomain;
use crate::core::utils::bit_reverse;
Expand Down Expand Up @@ -160,8 +156,8 @@ unsafe fn ifft_vecwise_loop(
) {
for index_l in 0..(1 << loop_bits) {
let index = (index_h << loop_bits) + index_l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const());
let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
Expand All @@ -174,8 +170,8 @@ unsafe fn ifft_vecwise_loop(
val1,
_mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
val0.store(values.add(index * 32));
val1.store(values.add(index * 32 + 16));
}
}

Expand Down Expand Up @@ -272,16 +268,16 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind
/// # Safety
/// This function is safe.
pub unsafe fn avx_ibutterfly(
val0: __m512i,
val1: __m512i,
val0: PackedBaseField,
val1: PackedBaseField,
twiddle_dbl: __m512i,
) -> (__m512i, __m512i) {
let r0 = add_mod_p(val0, val1);
let r1 = sub_mod_p(val0, val1);
) -> (PackedBaseField, PackedBaseField) {
let r0 = val0 + val1;
let r1 = val0 - val1;

// Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values.
let r1_e = r1;
let r1_o = _mm512_srli_epi64(r1, 32);
let r1_e = r1.0;
let r1_o = _mm512_srli_epi64(r1.0, 32);
let twiddle_dbl_e = twiddle_dbl;
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);

Expand All @@ -308,7 +304,7 @@ pub unsafe fn avx_ibutterfly(
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);
let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs);

(r0, prod)
}
Expand All @@ -326,12 +322,12 @@ pub unsafe fn avx_ibutterfly(
/// # Safety
/// This function is safe.
pub unsafe fn vecwise_ibutterflies(
mut val0: __m512i,
mut val1: __m512i,
mut val0: PackedBaseField,
mut val1: PackedBaseField,
twiddle1_dbl: [i32; 8],
twiddle2_dbl: [i32; 4],
twiddle3_dbl: [i32; 2],
) -> (__m512i, __m512i) {
) -> (PackedBaseField, PackedBaseField) {
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.

// Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element.
Expand All @@ -354,44 +350,29 @@ pub unsafe fn vecwise_ibutterflies(
let (t0, t1) = compute_first_twiddles(twiddle1_dbl);

// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t0);

// Apply the permutation, resulting in indexing c:diab.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t1);

// The twiddles for layer 2 are replicated in the following pattern:
// 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl));
// Apply the permutation, resulting in indexing b:cdia.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 3 are replicated in the following pattern:
// 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl));
// Apply the permutation, resulting in indexing a:bcid.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// Apply the permutation, resulting in indexing i:abcd.
(
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
)
val0.deinterleave_with(val1)
}

pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
Expand Down Expand Up @@ -442,14 +423,14 @@ pub unsafe fn ifft3(
twiddles_dbl2: [i32; 1],
) {
// Load the 8 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const());

// Apply the first layer of ibutterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
Expand All @@ -470,14 +451,14 @@ pub unsafe fn ifft3(
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// Store the 8 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
_mm512_store_epi32(values.add(offset + (4 << log_step)), val4);
_mm512_store_epi32(values.add(offset + (5 << log_step)), val5);
_mm512_store_epi32(values.add(offset + (6 << log_step)), val6);
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
val4.store(values.add(offset + (4 << log_step)));
val5.store(values.add(offset + (5 << log_step)));
val6.store(values.add(offset + (6 << log_step)));
val7.store(values.add(offset + (7 << log_step)));
}

/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
Expand All @@ -501,10 +482,10 @@ pub unsafe fn ifft2(
twiddles_dbl1: [i32; 1],
) {
// Load the 4 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
Expand All @@ -515,10 +496,10 @@ pub unsafe fn ifft2(
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));

// Store the 4 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
}

/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
Expand All @@ -532,14 +513,14 @@ pub unsafe fn ifft2(
/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
// Load the 2 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));

// Store the 2 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand All @@ -559,10 +540,12 @@ mod tests {
#[test]
fn test_ibutterfly() {
unsafe {
let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
let val0 = PackedBaseField(_mm512_setr_epi32(
2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
));
let val1 = PackedBaseField(_mm512_setr_epi32(
3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
));
let twiddle = _mm512_setr_epi32(
1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
Expand Down
67 changes: 4 additions & 63 deletions src/core/backend/avx512/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_min_epu32,
_mm512_permutexvar_epi32, _mm512_store_epi32, _mm512_sub_epi32, _mm512_xor_epi32,
__m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32,
_mm512_store_epi32, _mm512_xor_epi32,
};

pub mod ifft;
Expand All @@ -23,39 +23,8 @@ const ODDS_INTERLEAVE_ODDS: __m512i = unsafe {
])
};

/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a
/// with the even words of b.
const EVENS_CONCAT_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010,
0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a
/// with the odd words of b.
const ODDS_CONCAT_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011,
0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to interleave the low half of a
/// with the low half of b.
const LHALF_INTERLEAVE_LHALF: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100,
0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to interleave the high half of a
/// with the high half of b.
const HHALF_INTERLEAVE_HHALF: __m512i = unsafe {
core::mem::transmute([
0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100,
0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111,
])
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };
pub const CACHED_FFT_LOG_SIZE: usize = 16;
pub const MIN_FFT_LOG_SIZE: usize = 5;

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.
Expand Down Expand Up @@ -128,34 +97,6 @@ unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) {
(t0, t1)
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// # Safety
/// This function is safe.
pub unsafe fn add_mod_p(a: __m512i, b: __m512i) -> __m512i {
// Add word by word. Each word is in the range [0, 2P].
let c = _mm512_add_epi32(a, b);
// Apply min(c, c-P) to each word.
// When c in [P,2P], then c-P in [0,P] which is always less than [P,2P].
// When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1].
_mm512_min_epu32(c, _mm512_sub_epi32(c, P))
}

/// Subtracts two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// # Safety
/// This function is safe.
pub unsafe fn sub_mod_p(a: __m512i, b: __m512i) -> __m512i {
// Subtract word by word. Each word is in the range [-P, P].
let c = _mm512_sub_epi32(a, b);
// Apply min(c, c+P) to each word.
// When c in [0,P], then c+P in [P,2P] which is always greater than [0,P].
// When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than [2^32-P,2^32-1].
_mm512_min_epu32(_mm512_add_epi32(c, P), c)
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
Expand Down
Loading
Loading