Skip to content

Commit

Permalink
avx vecwise ibutterfly
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 26, 2024
1 parent 5191419 commit eaab17b
Showing 1 changed file with 216 additions and 30 deletions.
246 changes: 216 additions & 30 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32,
_mm512_srli_epi64, _mm512_sub_epi32,
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
_mm512_sub_epi32,
};

/// L is an input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
const L: __m512i = unsafe {
const EVENS_INTERLEAVE_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000,
0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110,
])
};
/// H is an input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a
/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a
/// with the odd words of b.
const H: __m512i = unsafe {
const ODDS_INTERLEAVE_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001,
0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111,
])
};

/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a
/// with the even words of b.
const EVENS_CONCAT_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010,
0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a
/// with the odd words of b.
const ODDS_CONCAT_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011,
0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111,
])
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

/// Computes the butterfly operation for packed M31 elements.
Expand Down Expand Up @@ -53,15 +71,15 @@ pub unsafe fn avx_butterfly(
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl);
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = _mm512_srli_epi64(prod_ls, 1);
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl);
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);
Expand Down Expand Up @@ -105,22 +123,106 @@ pub unsafe fn avx_ibutterfly(
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, L, prod_o_dbl);
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = _mm512_srli_epi64(prod_ls, 1);
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, H, prod_o_dbl);
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = add_mod_p(prod_ls, prod_hs);

(r0, prod)
}

/// Runs ifft on 2 vectors of 16 M31 elements.
/// This amounts to 4 butterfly layers, each with 16 butterflies.
/// Each of the vectors represents a bit reversed evaluation.
/// Each value in a vectors is in unreduced form: [0, P] including P.
/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle.
/// The first layer (lower bit of the index) takes 16 twiddles.
/// The second layer takes 8 twiddles.
/// etc.
/// # Safety
pub unsafe fn vecwise_ibutterflies(
mut val0: __m512i,
mut val1: __m512i,
twiddle0_dbl: [i32; 16],
twiddle1_dbl: [i32; 8],
twiddle2_dbl: [i32; 4],
twiddle3_dbl: [i32; 2],
) -> (__m512i, __m512i) {
// TODO(spapini): Compute twiddle0 from twiddle1.
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.

// Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element.
// We need to permute the 512-bit registers to get the right order for the butterflies.
// Denote the index of the 16 M31 elements in register i as i:abcd.
// At each layer we apply the following permutation to the index:
// i:abcd => d:iabc
// This is how it looks like at each iteration.
// i:abcd
// d:iabc
// ifft on d
// c:diab
// ifft on c
// b:cdia
// ifft on b
// a:bcid
// ifft on a
// i:abcd

// The twiddles for layer 0 are packed like:
// 0 1 2 3 4 5 6 7 8 9 a b c d e f
let t: __m512i = std::mem::transmute(twiddle0_dbl);
// Apply i:abcd => d:iabc
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 1 are packed like:
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl));
// Apply i:abcd => d:iabc
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 2 are packed like:
// 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl));
// Apply i:abcd => d:iabc
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 3 are packed like:
// 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl));
// Apply i:abcd => d:iabc
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// Apply i:abcd => d:iabc
(
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
)
}

// TODO(spapini): Move these to M31 AVX.

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
Expand Down Expand Up @@ -170,19 +272,19 @@ mod tests {
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl);

let val0: [u32; 16] = std::mem::transmute(val0);
let val1: [u32; 16] = std::mem::transmute(val1);
let twiddle: [u32; 16] = std::mem::transmute(twiddle);
let r0: [u32; 16] = std::mem::transmute(r0);
let r1: [u32; 16] = std::mem::transmute(r1);
let val0: [BaseField; 16] = std::mem::transmute(val0);
let val1: [BaseField; 16] = std::mem::transmute(val1);
let twiddle: [BaseField; 16] = std::mem::transmute(twiddle);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);

for i in 0..16 {
let mut x = BaseField::from_u32_unchecked(val0[i]);
let mut y = BaseField::from_u32_unchecked(val1[i]);
let twiddle = BaseField::from_u32_unchecked(twiddle[i]);
let mut x = val0[i];
let mut y = val1[i];
let twiddle = twiddle[i];
butterfly(&mut x, &mut y, twiddle);
assert_eq!(x, BaseField::from_u32_unchecked(r0[i]));
assert_eq!(y, BaseField::from_u32_unchecked(r1[i]));
assert_eq!(x, r0[i]);
assert_eq!(y, r1[i]);
}
}
}
Expand All @@ -200,19 +302,103 @@ mod tests {
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);

let val0: [u32; 16] = std::mem::transmute(val0);
let val1: [u32; 16] = std::mem::transmute(val1);
let twiddle: [u32; 16] = std::mem::transmute(twiddle);
let r0: [u32; 16] = std::mem::transmute(r0);
let r1: [u32; 16] = std::mem::transmute(r1);
let val0: [BaseField; 16] = std::mem::transmute(val0);
let val1: [BaseField; 16] = std::mem::transmute(val1);
let twiddle: [BaseField; 16] = std::mem::transmute(twiddle);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);

for i in 0..16 {
let mut x = BaseField::from_u32_unchecked(val0[i]);
let mut y = BaseField::from_u32_unchecked(val1[i]);
let twiddle = BaseField::from_u32_unchecked(twiddle[i]);
let mut x = val0[i];
let mut y = val1[i];
let twiddle = twiddle[i];
ibutterfly(&mut x, &mut y, twiddle);
assert_eq!(x, BaseField::from_u32_unchecked(r0[i]));
assert_eq!(y, BaseField::from_u32_unchecked(r1[i]));
assert_eq!(x, r0[i]);
assert_eq!(y, r1[i]);
}
}
}

#[test]
fn test_vecwise_ibutterflies() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddles0 = [
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
];
let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55];
let twiddles2 = [56, 57, 58, 59];
let twiddles3 = [60, 61];
let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2);

let (r0, r1) = vecwise_ibutterflies(
val0,
val1,
twiddle0_dbl,
twiddle1_dbl,
twiddle2_dbl,
twiddle3_dbl,
);

let mut val0: [BaseField; 16] = std::mem::transmute(val0);
let mut val1: [BaseField; 16] = std::mem::transmute(val1);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);
let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2);
let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3);

for i in 0..16 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]);
ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]);
ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]);
ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 8;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles3[0]);
ibutterfly(&mut v10, &mut v11, twiddles3[1]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
// Compare
for i in 0..16 {
assert_eq!(val0[i], r0[i]);
assert_eq!(val1[i], r1[i]);
}
}
}
Expand Down

0 comments on commit eaab17b

Please sign in to comment.