Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avx vecwise ibutterfly #368

Merged
merged 1 commit into from
Mar 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 216 additions & 30 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32,
_mm512_srli_epi64, _mm512_sub_epi32,
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_sub_epi32,
};

/// L is an input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
const L: __m512i = unsafe {
const EVENS_INTERLEAVE_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000,
0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110,
])
};
/// H is an input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a
/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a
/// with the odd words of b.
const H: __m512i = unsafe {
const ODDS_INTERLEAVE_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001,
0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111,
])
};

/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a
/// with the even words of b.
const EVENS_CONCAT_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010,
0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a
/// with the odd words of b.
const ODDS_CONCAT_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011,
0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111,
])
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

/// Computes the butterfly operation for packed M31 elements.
Expand Down Expand Up @@ -53,15 +71,15 @@ pub unsafe fn avx_butterfly(
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl);
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = _mm512_srli_epi64(prod_ls, 1);
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl);
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);
Expand Down Expand Up @@ -105,22 +123,106 @@ pub unsafe fn avx_ibutterfly(
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl);
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = _mm512_srli_epi64(prod_ls, 1);
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl);
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);

(r0, prod)
}

/// Runs ifft on 2 vectors of 16 M31 elements.
/// This amounts to 4 butterfly layers, each with 16 butterflies.
/// Each of the vectors represents a bit reversed evaluation.
/// Each value in a vectors is in unreduced form: [0, P] including P.
/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle.
/// The first layer (lower bit of the index) takes 16 twiddles.
/// The second layer takes 8 twiddles.
/// etc.
/// # Safety
pub unsafe fn vecwise_ibutterflies(
mut val0: __m512i,
mut val1: __m512i,
twiddle0_dbl: [i32; 16],
twiddle1_dbl: [i32; 8],
twiddle2_dbl: [i32; 4],
twiddle3_dbl: [i32; 2],
) -> (__m512i, __m512i) {
// TODO(spapini): Compute twiddle0 from twiddle1.
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.

// Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element.
// We need to permute the 512-bit registers to get the right order for the butterflies.
// Denote the index of the 16 M31 elements in register i as i:abcd.
// At each layer we apply the following permutation to the index:
// i:abcd => d:iabc
// This is how it looks like at each iteration.
// i:abcd
// d:iabc
// ifft on d
// c:diab
// ifft on c
// b:cdia
// ifft on b
// a:bcid
// ifft on a
// i:abcd

// The twiddles for layer 0 are unique and arranged as follows:
// 0 1 2 3 4 5 6 7 8 9 a b c d e f
let t: __m512i = std::mem::transmute(twiddle0_dbl);
// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 1 are replicated in the following pattern:
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl));
// Apply the permutation, resulting in indexing c:diab.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 2 are replicated in the following pattern:
// 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl));
// Apply the permutation, resulting in indexing b:cdia.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 3 are replicated in the following pattern:
// 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl));
// Apply the permutation, resulting in indexing a:bcid.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// Apply the permutation, resulting in indexing i:abcd.
(
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
)
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -170,19 +272,19 @@ mod tests {
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl);

let val0: [u32; 16] = std::mem::transmute(val0);
let val1: [u32; 16] = std::mem::transmute(val1);
let twiddle: [u32; 16] = std::mem::transmute(twiddle);
let r0: [u32; 16] = std::mem::transmute(r0);
let r1: [u32; 16] = std::mem::transmute(r1);
let val0: [BaseField; 16] = std::mem::transmute(val0);
let val1: [BaseField; 16] = std::mem::transmute(val1);
let twiddle: [BaseField; 16] = std::mem::transmute(twiddle);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);

for i in 0..16 {
let mut x = BaseField::from_u32_unchecked(val0[i]);
let mut y = BaseField::from_u32_unchecked(val1[i]);
let twiddle = BaseField::from_u32_unchecked(twiddle[i]);
let mut x = val0[i];
let mut y = val1[i];
let twiddle = twiddle[i];
butterfly(&mut x, &mut y, twiddle);
assert_eq!(x, BaseField::from_u32_unchecked(r0[i]));
assert_eq!(y, BaseField::from_u32_unchecked(r1[i]));
assert_eq!(x, r0[i]);
assert_eq!(y, r1[i]);
}
}
}
Expand All @@ -200,19 +302,103 @@ mod tests {
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);

let val0: [u32; 16] = std::mem::transmute(val0);
let val1: [u32; 16] = std::mem::transmute(val1);
let twiddle: [u32; 16] = std::mem::transmute(twiddle);
let r0: [u32; 16] = std::mem::transmute(r0);
let r1: [u32; 16] = std::mem::transmute(r1);
let val0: [BaseField; 16] = std::mem::transmute(val0);
let val1: [BaseField; 16] = std::mem::transmute(val1);
let twiddle: [BaseField; 16] = std::mem::transmute(twiddle);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);

for i in 0..16 {
let mut x = BaseField::from_u32_unchecked(val0[i]);
let mut y = BaseField::from_u32_unchecked(val1[i]);
let twiddle = BaseField::from_u32_unchecked(twiddle[i]);
let mut x = val0[i];
let mut y = val1[i];
let twiddle = twiddle[i];
ibutterfly(&mut x, &mut y, twiddle);
assert_eq!(x, BaseField::from_u32_unchecked(r0[i]));
assert_eq!(y, BaseField::from_u32_unchecked(r1[i]));
assert_eq!(x, r0[i]);
assert_eq!(y, r1[i]);
}
}
}

#[test]
fn test_vecwise_ibutterflies() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddles0 = [
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
];
let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55];
let twiddles2 = [56, 57, 58, 59];
let twiddles3 = [60, 61];
let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2);

let (r0, r1) = vecwise_ibutterflies(
val0,
val1,
twiddle0_dbl,
twiddle1_dbl,
twiddle2_dbl,
twiddle3_dbl,
);

let mut val0: [BaseField; 16] = std::mem::transmute(val0);
let mut val1: [BaseField; 16] = std::mem::transmute(val1);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);
let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2);
let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3);

for i in 0..16 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]);
ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]);
ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]);
ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 8;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles3[0]);
ibutterfly(&mut v10, &mut v11, twiddles3[1]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
// Compare
for i in 0..16 {
assert_eq!(val0[i], r0[i]);
assert_eq!(val1[i], r1[i]);
}
}
}
Expand Down