Skip to content

Commit

Permalink
ifft_lower
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 22, 2024
1 parent b5b378e commit 70d7f2d
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 34 deletions.
232 changes: 200 additions & 32 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_broadcast_i32x4, _mm512_broadcast_i64x4, _mm512_load_epi32,
_mm512_min_epu32, _mm512_mul_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_min_epu32, _mm512_mul_epu32, _mm512_permutex2var_epi32, _mm512_set1_epi32,
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
};

Expand Down Expand Up @@ -44,6 +44,72 @@ const H2: __m512i = unsafe {
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

/// # Safety
pub unsafe fn ifft_lower(
values: *mut i32,
vecwise_twiddle_dbl: Option<&[Vec<i32>]>,
twiddle_dbl: &[Vec<i32>],
n_total_bits: usize,
n_fft_bits: usize,
) {
assert!(n_fft_bits >= 1);
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 3));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 2));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << (n_fft_bits + 1));
assert_eq!(vecwise_twiddle_dbl[3].len(), 1 << n_fft_bits);
}
for h in 0..(1 << (n_total_bits - n_fft_bits)) {
// TODO(spapini):
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
for l in 0..(1 << (n_fft_bits - 1)) {
// TODO(spapini): modulo for twiddles on the iters.
let index = (h << (n_fft_bits - 1)) + l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 16 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[3].get_unchecked(index * 2 + i)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
// TODO(spapini): do a fifth layer here.
}
}
for bit_i in (0..n_fft_bits).step_by(3) {
if bit_i + 3 > n_fft_bits {
todo!();
}
for m in 0..(1 << (n_fft_bits - 3 - bit_i)) {
let twid_index = (h << (n_fft_bits - 3 - bit_i)) + m;
for l in 0..(1 << bit_i) {
ifft3(
values,
(h << n_fft_bits) + (m << (bit_i + 3)) + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i].get_unchecked(twid_index * 4 + i)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 1].get_unchecked(twid_index * 2 + i)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 2].get_unchecked(twid_index + i)
}),
);
}
}
}
}
}

/// # Safety
pub unsafe fn avx_butterfly(
val0: __m512i,
Expand All @@ -54,8 +120,8 @@ pub unsafe fn avx_butterfly(
let twiddle_dbl_e = twiddle_dbl;
let val1_o = _mm512_srli_epi64(val1, 32);
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);
let m_e_dbl = _mm512_mul_epi32(val1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epi32(val1_o, twiddle_dbl_o);
let m_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o);

let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1);
let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl);
Expand Down Expand Up @@ -94,8 +160,8 @@ pub unsafe fn avx_ibutterfly(
let twiddle_dbl_e = twiddle_dbl;
let r1_o = _mm512_srli_epi64(r1, 32);
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);
let m_e_dbl = _mm512_mul_epi32(r1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epi32(r1_o, twiddle_dbl_o);
let m_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e);
let m_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o);

let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1);
let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl);
Expand Down Expand Up @@ -199,21 +265,21 @@ pub unsafe fn vecwise_ibutterflies(
pub unsafe fn ifft3(
values: *mut i32,
offset: usize,
step: usize,
twiddles_dbl0: &[i32; 4],
twiddles_dbl1: &[i32; 2],
twiddles_dbl2: &[i32; 1],
log_step: usize,
twiddles_dbl0: [i32; 4],
twiddles_dbl1: [i32; 2],
twiddles_dbl2: [i32; 1],
) {
let u32_step = step + 4;
let log_u32_step = log_step;
// load
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << u32_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << u32_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << u32_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << u32_step)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << u32_step)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << u32_step)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << u32_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << u32_step)).cast_const());
let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const());
let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const());
let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_u32_step)) << 4).cast_const());
let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_u32_step)) << 4).cast_const());
let mut val4 = _mm512_load_epi32(values.add((offset + (4 << log_u32_step)) << 4).cast_const());
let mut val5 = _mm512_load_epi32(values.add((offset + (5 << log_u32_step)) << 4).cast_const());
let mut val6 = _mm512_load_epi32(values.add((offset + (6 << log_u32_step)) << 4).cast_const());
let mut val7 = _mm512_load_epi32(values.add((offset + (7 << log_u32_step)) << 4).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
Expand All @@ -231,23 +297,28 @@ pub unsafe fn ifft3(
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// store
_mm512_store_epi32(values.add(offset + (0 << u32_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << u32_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << u32_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << u32_step)), val3);
_mm512_store_epi32(values.add(offset + (4 << u32_step)), val4);
_mm512_store_epi32(values.add(offset + (5 << u32_step)), val5);
_mm512_store_epi32(values.add(offset + (6 << u32_step)), val6);
_mm512_store_epi32(values.add(offset + (7 << u32_step)), val7);
_mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0);
_mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1);
_mm512_store_epi32(values.add((offset + (2 << log_u32_step)) << 4), val2);
_mm512_store_epi32(values.add((offset + (3 << log_u32_step)) << 4), val3);
_mm512_store_epi32(values.add((offset + (4 << log_u32_step)) << 4), val4);
_mm512_store_epi32(values.add((offset + (5 << log_u32_step)) << 4), val5);
_mm512_store_epi32(values.add((offset + (6 << log_u32_step)) << 4), val6);
_mm512_store_epi32(values.add((offset + (7 << log_u32_step)) << 4), val7);
}

#[cfg(test)]
mod tests {
use std::arch::x86_64::_mm512_setr_epi32;

use super::*;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::{CPUBackend, ColumnTrait};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::Field;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use crate::core::utils::bit_reverse;

#[test]
fn test_butterfly() {
Expand Down Expand Up @@ -282,12 +353,12 @@ mod tests {
#[test]
fn test_ibutterfly() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddle = _mm512_setr_epi32(
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);
Expand Down Expand Up @@ -492,9 +563,9 @@ mod tests {
std::mem::transmute(values.as_mut_ptr()),
0,
0,
&twiddles0_dbl,
&twiddles1_dbl,
&twiddles2_dbl,
twiddles0_dbl,
twiddles1_dbl,
twiddles2_dbl,
);

let actual: Vec<[BaseField; 16]> = std::mem::transmute(values);
Expand Down Expand Up @@ -535,4 +606,101 @@ mod tests {
}
}
}

fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

let mut res = vec![];
res.push(
coset
.iter()
.map(|p| (p.y.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| (p.x.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}

res
}

fn ref_ifft(domain: CircleDomain, mut values: Vec<BaseField>) -> Vec<BaseField> {
bit_reverse(&mut values);
let eval = CircleEvaluation::<CPUBackend, _>::new(domain, values);
let mut expected_coeffs = eval.interpolate().coeffs;
for x in expected_coeffs.iter_mut() {
*x *= BaseField::from_u32_unchecked(domain.size() as u32);
}
bit_reverse(&mut expected_coeffs);
expected_coeffs
}

#[test]
fn test_vecwise_ibutterflies_real() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain);
assert_eq!(twiddle_dbls.len(), 5);
let values0: [i32; 16] = std::array::from_fn(|i| i as i32);
let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32);
let result: [BaseField; 32] = unsafe {
let (val0, val1) = vecwise_ibutterflies(
std::mem::transmute(values0),
std::mem::transmute(values1),
twiddle_dbls[0].clone().try_into().unwrap(),
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
twiddle_dbls[3].clone().try_into().unwrap(),
);
let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[4][0]));
std::mem::transmute([val0, val1])
};

// ref.
let mut values = values0.to_vec();
values.extend_from_slice(&values1);
let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect());

// Compare.
for i in 0..32 {
assert_eq!(result[i], expected[i]);
}
}

#[test]
fn test_ifft_lower() {
let log_size = 4 + 3 + 3;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_vec(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower(
std::mem::transmute(values.data.as_mut_ptr()),
Some(&twiddle_dbls[..4]),
&twiddle_dbls[4..],
(log_size - 4) as usize,
(log_size - 4) as usize,
);

// Compare.
for i in 0..expected_coeffs.len() {
assert_eq!(values[i], expected_coeffs[i]);
}
}
}
}
2 changes: 1 addition & 1 deletion src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn as_cpu_vec(values: BaseFieldVec) -> Vec<BaseField> {
impl Index<usize> for BaseFieldVec {
type Output = BaseField;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index / 8][index % 8]
&self.data[index / 16][index % 16]
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub const P: u32 = 2147483647; // 2 ** 31 - 1

#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)]
pub struct M31(u32);
pub struct M31(pub u32);
pub type BaseField = M31;

impl_field!(M31, P);
Expand Down
3 changes: 3 additions & 0 deletions src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
impl<'a, T: 'a, I: Iterator<Item = &'a mut T>> IteratorMutExt<'a, T> for I {}

pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
if log_size == 0 {
return i;
}
i.reverse_bits() >> (usize::BITS - log_size)
}

Expand Down

0 comments on commit 70d7f2d

Please sign in to comment.