Skip to content

Commit

Permalink
avx ibutterfly
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 22, 2024
1 parent 140f9ea commit 131e95e
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ pub unsafe fn avx_butterfly(
twiddle_dbl: __m512i,
) -> (__m512i, __m512i) {
let val1_e = val1;
let twiddle_dbl_e = twiddle_dbl;
let val1_o = _mm512_srli_epi64(val1, 32);
let m_e_dbl = _mm512_mul_epi32(val1_e, twiddle_dbl);
let m_o_dbl = _mm512_mul_epi32(val1_o, _mm512_srli_epi64(twiddle_dbl, 32));
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);
let m_e_dbl = _mm512_mul_epi32(val1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epi32(val1_o, twiddle_dbl_o);

let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1);
let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl);
Expand All @@ -46,12 +48,44 @@ pub unsafe fn avx_butterfly(
(r0, r1)
}

/// # Safety
pub unsafe fn avx_ibutterfly(
val0: __m512i,
val1: __m512i,
twiddle_dbl: __m512i,
) -> (__m512i, __m512i) {
let a0 = _mm512_add_epi32(val0, val1);
let a0_m_p = _mm512_sub_epi32(a0, P);
let r0 = _mm512_min_epu32(a0, a0_m_p);

let a1 = _mm512_sub_epi32(val0, val1);
let a1_p_p = _mm512_add_epi32(a1, P);
let r1 = _mm512_min_epu32(a1_p_p, a1);

// mul
let r1_e = r1;
let twiddle_dbl_e = twiddle_dbl;
let r1_o = _mm512_srli_epi64(r1, 32);
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);
let m_e_dbl = _mm512_mul_epi32(r1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epi32(r1_o, twiddle_dbl_o);

let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1);
let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl);

let rm = _mm512_add_epi32(rm_l, rm_h);
let rm_m_p = _mm512_sub_epi32(rm, P);
let rrm = _mm512_min_epu32(rm, rm_m_p);

(r0, rrm)
}

#[cfg(test)]
mod tests {
use std::arch::x86_64::_mm512_setr_epi32;

use super::*;
use crate::core::fft::butterfly;
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;

#[test]
Expand Down Expand Up @@ -83,4 +117,34 @@ mod tests {
}
}
}

#[test]
fn test_ibutterfly() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddle = _mm512_setr_epi32(
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);

let val0: [u32; 16] = std::mem::transmute(val0);
let val1: [u32; 16] = std::mem::transmute(val1);
let twiddle: [u32; 16] = std::mem::transmute(twiddle);
let r0: [u32; 16] = std::mem::transmute(r0);
let r1: [u32; 16] = std::mem::transmute(r1);

for i in 0..16 {
let mut x = BaseField::from_u32_unchecked(val0[i]);
let mut y = BaseField::from_u32_unchecked(val1[i]);
let twiddle = BaseField::from_u32_unchecked(twiddle[i]);
ibutterfly(&mut x, &mut y, twiddle);
assert_eq!(x, BaseField::from_u32_unchecked(r0[i]));
assert_eq!(y, BaseField::from_u32_unchecked(r1[i]));
}
}
}
}

0 comments on commit 131e95e

Please sign in to comment.