Skip to content

Commit

Permalink
Save twiddles
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 3, 2024
1 parent 83b4507 commit 7df206b
Showing 1 changed file with 74 additions and 108 deletions.
182 changes: 74 additions & 108 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_permutexvar_epi32,
_mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
_mm512_xor_epi32,
};

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
Expand Down Expand Up @@ -68,10 +69,9 @@ pub unsafe fn ifft_lower(
) {
assert!(n_fft_bits >= 1);
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 3));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 2));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << (n_fft_bits + 1));
assert_eq!(vecwise_twiddle_dbl[3].len(), 1 << n_fft_bits);
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 2));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 1));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << n_fft_bits);
}
for h in 0..(1 << (n_total_bits - n_fft_bits)) {
// TODO(spapini):
Expand All @@ -84,10 +84,9 @@ pub unsafe fn ifft_lower(
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 16 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[3].get_unchecked(index * 2 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 2 + i)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
Expand Down Expand Up @@ -126,7 +125,7 @@ pub unsafe fn ifft_lower(
/// val0, val1 are packed M31 elements. 16 M31 words at each.
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// Returned values are in unreduced form, [0, P] including P.
/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form.
/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form.
/// # Safety
/// This function is safe.
pub unsafe fn avx_butterfly(
Expand Down Expand Up @@ -176,7 +175,7 @@ pub unsafe fn avx_butterfly(
/// val0 + val1, t (val0 - val1).
/// val0, val1 are packed M31 elements. 16 M31 words at each.
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in reduced form.
/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form.
/// # Safety
/// This function is safe.
pub unsafe fn avx_ibutterfly(
Expand Down Expand Up @@ -280,20 +279,21 @@ pub unsafe fn vecwise_butterflies(
/// This amounts to 4 butterfly layers, each with 16 butterflies.
/// Each of the vectors represents a bit reversed evaluation.
/// Each value in a vectors is in unreduced form: [0, P] including P.
/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle.
/// The first layer (lower bit of the index) takes 16 twiddles.
/// The second layer takes 8 twiddles.
/// etc.
/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the
/// corresponding twiddle.
/// The first layer's twiddles (lower bit of the index) are computed from the second layer's
/// twiddles. The second layer takes 8 twiddles.
/// The third layer takes 4 twiddles.
/// The fourth layer takes 2 twiddles.
/// # Safety
/// This function is safe.
pub unsafe fn vecwise_ibutterflies(
mut val0: __m512i,
mut val1: __m512i,
twiddle0_dbl: [i32; 16],
twiddle1_dbl: [i32; 8],
twiddle2_dbl: [i32; 4],
twiddle3_dbl: [i32; 2],
) -> (__m512i, __m512i) {
// TODO(spapini): Compute twiddle0 from twiddle1.
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.

// Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element.
Expand All @@ -313,9 +313,38 @@ pub unsafe fn vecwise_ibutterflies(
// ifft on a
// i:abcd

// The twiddles for layer 0 are packed like:
// 0 1 2 3 4 5 6 7 8 9 a b c d e f
let t: __m512i = std::mem::transmute(twiddle0_dbl);
// Start by loading the twiddles for the second layer (layer 1):
// The twiddles for layer 1 are packed like:
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl));

// The twiddles for layer 0 can be computed from the twiddles for layer 1:
// A circle coset of size 4 in bit reversed order looks like this:
// [(x, y), (-x, -y), (y, -x), (-y, x)]
// Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation
// is (0,-1) and not (0,1). This would cause another relation.
// The twiddles for layer 0 are the y coordinates:
// [y, -y, -x, x]
// The twiddles for layer 1 in bit reversed order are the x coordinates:
// [x, y]
// Works also for inverse of the twiddles.

// The twiddles for layer 0 are computed like this:
// t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]]
const INDICES_FROM_T1: __m512i = unsafe {
core::mem::transmute([
0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100,
0b0100, 0b0111, 0b0111, 0b0110, 0b0110,
])
};
// Xoring a double twiddle with 2^32-2 transforms it to the double of it negation.
// Note that this keeps the values as a double of a value in the range [0, P].
const NEGATION_MASK: __m512i = unsafe {
core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0])
};
let t = _mm512_permutexvar_epi32(INDICES_FROM_T1, t1);
let t = _mm512_xor_epi32(t, NEGATION_MASK);

// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
Expand All @@ -325,8 +354,8 @@ pub unsafe fn vecwise_ibutterflies(

// The twiddles for layer 1 are packed like:
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
let t = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl));
// Apply the permutation, resulting in indexing c:diab.
let t = t1;
(val0, val1) = (
_mm512_permutex2var_epi32(val0, EVENS_CONCAT_EVENS, val1),
_mm512_permutex2var_epi32(val0, ODDS_CONCAT_ODDS, val1),
Expand Down Expand Up @@ -592,90 +621,6 @@ mod tests {
}
}

#[test]
fn test_vecwise_ibutterflies() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddles0 = [
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
];
let twiddles1 = [48, 49, 50, 51, 52, 53, 54, 55];
let twiddles2 = [56, 57, 58, 59];
let twiddles3 = [60, 61];
let twiddle0_dbl = std::array::from_fn(|i| twiddles0[i] * 2);
let twiddle1_dbl = std::array::from_fn(|i| twiddles1[i] * 2);
let twiddle2_dbl = std::array::from_fn(|i| twiddles2[i] * 2);
let twiddle3_dbl = std::array::from_fn(|i| twiddles3[i] * 2);

let (r0, r1) = vecwise_ibutterflies(
val0,
val1,
twiddle0_dbl,
twiddle1_dbl,
twiddle2_dbl,
twiddle3_dbl,
);

let mut val0: [BaseField; 16] = std::mem::transmute(val0);
let mut val1: [BaseField; 16] = std::mem::transmute(val1);
let r0: [BaseField; 16] = std::mem::transmute(r0);
let r1: [BaseField; 16] = std::mem::transmute(r1);
let twiddles0: [BaseField; 16] = std::mem::transmute(twiddles0);
let twiddles1: [BaseField; 8] = std::mem::transmute(twiddles1);
let twiddles2: [BaseField; 4] = std::mem::transmute(twiddles2);
let twiddles3: [BaseField; 2] = std::mem::transmute(twiddles3);

for i in 0..16 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles0[i / 2]);
ibutterfly(&mut v10, &mut v11, twiddles0[8 + i / 2]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles1[i / 4]);
ibutterfly(&mut v10, &mut v11, twiddles1[4 + i / 4]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles2[i / 8]);
ibutterfly(&mut v10, &mut v11, twiddles2[2 + i / 8]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
for i in 0..16 {
let j = i ^ 8;
if i > j {
continue;
}
let (mut v00, mut v01, mut v10, mut v11) = (val0[i], val0[j], val1[i], val1[j]);
ibutterfly(&mut v00, &mut v01, twiddles3[0]);
ibutterfly(&mut v10, &mut v11, twiddles3[1]);
(val0[i], val0[j], val1[i], val1[j]) = (v00, v01, v10, v11);
}
// Compare
for i in 0..16 {
assert_eq!(val0[i], r0[i]);
assert_eq!(val1[i], r1[i]);
}
}
}

#[test]
fn test_ifft3() {
unsafe {
Expand Down Expand Up @@ -760,6 +705,28 @@ mod tests {
res
}

#[test]
fn test_twiddle_relation() {
let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain());
let t0 = ts[0]
.iter()
.copied()
.map(|x| BaseField::from_u32_unchecked((x as u32) / 2))
.collect::<Vec<_>>();
let t1 = ts[1]
.iter()
.copied()
.map(|x| BaseField::from_u32_unchecked((x as u32) / 2))
.collect::<Vec<_>>();

for i in 0..t0.len() / 4 {
assert_eq!(t0[i * 4], t1[i * 2 + 1]);
assert_eq!(t0[i * 4 + 1], -t1[i * 2 + 1]);
assert_eq!(t0[i * 4 + 2], -t1[i * 2]);
assert_eq!(t0[i * 4 + 3], t1[i * 2]);
}
}

fn ref_ifft(domain: CircleDomain, mut values: Vec<BaseField>) -> Vec<BaseField> {
bit_reverse(&mut values);
let eval = CircleEvaluation::<CPUBackend, _>::new(domain, values);
Expand All @@ -782,7 +749,6 @@ mod tests {
let (val0, val1) = vecwise_ibutterflies(
std::mem::transmute(values0),
std::mem::transmute(values1),
twiddle_dbls[0].clone().try_into().unwrap(),
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
twiddle_dbls[3].clone().try_into().unwrap(),
Expand Down Expand Up @@ -818,7 +784,7 @@ mod tests {
unsafe {
ifft_lower(
std::mem::transmute(values.data.as_mut_ptr()),
Some(&twiddle_dbls[..4]),
Some(&twiddle_dbls[1..4]),
&twiddle_dbls[4..],
(log_size - 4) as usize,
(log_size - 4) as usize,
Expand Down

0 comments on commit 7df206b

Please sign in to comment.