Skip to content

Commit

Permalink
ifft_lower
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 3, 2024
1 parent e031117 commit 54fd21a
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 32 deletions.
284 changes: 253 additions & 31 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::arch::x86_64::{
_mm512_set1_epi64, _mm512_srli_epi64, _mm512_store_epi32, _mm512_sub_epi32,
};

use crate::core::backend::avx512::VECS_LOG_SIZE;

/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
const EVENS_INTERLEAVE_EVENS: __m512i = unsafe {
Expand Down Expand Up @@ -55,6 +57,128 @@ const HHALF_INTERLEAVE_HHALF: __m512i = unsafe {
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

/// Computes partial ifft on `2^log_size` M31 elements.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft.
/// layer i holds 2^(log_size - 1 - i) twiddles.
/// log_size - The log of the number of number of M31 elements in the array.
/// fft_layers - The number of ifft layers to apply, out of log_size.
/// # Safety
/// `values` must be aligned to 64 bytes.
/// `log_size` must be at least 5.
/// `fft_layers` must be at least 5.
pub unsafe fn ifft_lower_with_vecwise(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
log_size: usize,
fft_layers: usize,
) {
const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1;
assert!(log_size >= VECWISE_FFT_BITS);

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 1));

for index_h in 0..(1 << (log_size - fft_layers)) {
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
1 => {
todo!()
}
2 => {
todo!()
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
layer,
index_h,
);
}
}
}
}
}

/// Runs the 5 first ifft layers across the entire array.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 ifft layers.
/// high_bits - The number of bits this loops needs to run on.
/// index_h - The higher part of the index, iterated by the caller.
/// # Safety
unsafe fn ifft_vecwise_loop(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
loop_bits: usize,
index_h: usize,
) {
for index_l in 0..(1 << loop_bits) {
let index = (index_h << loop_bits) + index_l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 16 + i)),
std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *twiddle_dbl[3].get_unchecked(index * 2 + i)),
);
(val0, val1) = avx_ibutterfly(
val0,
val1,
_mm512_set1_epi32(*twiddle_dbl[4].get_unchecked(index)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
}
}

/// Runs 3 ifft layers across the entire array.
/// Parameters:
/// values - Pointer to the entire value array, aligned to 64 bytes.
/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 ifft layers.
/// loop_bits - The number of bits this loops needs to run on.
/// layer - The layer number of the first ifft layer to apply.
/// The layers `layer`, `layer + 1`, `layer + 2` are applied.
/// index_h - The higher part of the index, iterated by the caller.
/// # Safety
unsafe fn ifft3_loop(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
loop_bits: usize,
layer: usize,
index_h: usize,
) {
for m in 0..(1 << loop_bits) {
let index = (index_h << loop_bits) + m;
let offset = index << (layer + 3);
for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) {
ifft3(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1))
}),
);
}
}
}

/// Computes the butterfly operation for packed M31 elements.
/// val0 + t val1, val0 - t val1.
/// val0, val1 are packed M31 elements. 16 M31 words at each.
Expand Down Expand Up @@ -121,7 +245,7 @@ pub unsafe fn avx_ibutterfly(
let r0 = add_mod_p(val0, val1);
let r1 = sub_mod_p(val0, val1);

// Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values.
// Extract the even and odd parts of r1 and twiddle_m_e_dbldbl, and spread as 8 64bit values.
let r1_e = r1;
let r1_o = _mm512_srli_epi64(r1, 32);
let twiddle_dbl_e = twiddle_dbl;
Expand Down Expand Up @@ -302,30 +426,29 @@ pub unsafe fn vecwise_ibutterflies(
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// step_in_vecs - The distance in the array, in AVX vectors, between each pair of values that
/// need to be transformed. For layer i this is i-4.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies.
/// Each layer has 4/2/1 twiddles.
///
/// # Safety
pub unsafe fn ifft3(
values: *mut i32,
offset: usize,
step_in_vecs: usize,
twiddles_dbl0: &[i32; 4],
twiddles_dbl1: &[i32; 2],
twiddles_dbl2: &[i32; 1],
log_step: usize,
twiddles_dbl0: [i32; 4],
twiddles_dbl1: [i32; 2],
twiddles_dbl2: [i32; 1],
) {
let step_in_u32s = step_in_vecs + 4;
// Load the 8 AVX vectors from the array.
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << step_in_u32s)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << step_in_u32s)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << step_in_u32s)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << step_in_u32s)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << step_in_u32s)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << step_in_u32s)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << step_in_u32s)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << step_in_u32s)).cast_const());
let mut val0 = _mm512_load_epi32(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = _mm512_load_epi32(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = _mm512_load_epi32(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = _mm512_load_epi32(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = _mm512_load_epi32(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = _mm512_load_epi32(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const());

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
Expand All @@ -346,14 +469,14 @@ pub unsafe fn ifft3(
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));

// Store the 8 AVX vectors back to the array.
_mm512_store_epi32(values.add(offset + (0 << step_in_u32s)), val0);
_mm512_store_epi32(values.add(offset + (1 << step_in_u32s)), val1);
_mm512_store_epi32(values.add(offset + (2 << step_in_u32s)), val2);
_mm512_store_epi32(values.add(offset + (3 << step_in_u32s)), val3);
_mm512_store_epi32(values.add(offset + (4 << step_in_u32s)), val4);
_mm512_store_epi32(values.add(offset + (5 << step_in_u32s)), val5);
_mm512_store_epi32(values.add(offset + (6 << step_in_u32s)), val6);
_mm512_store_epi32(values.add(offset + (7 << step_in_u32s)), val7);
_mm512_store_epi32(values.add(offset + (0 << log_step)), val0);
_mm512_store_epi32(values.add(offset + (1 << log_step)), val1);
_mm512_store_epi32(values.add(offset + (2 << log_step)), val2);
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
_mm512_store_epi32(values.add(offset + (4 << log_step)), val4);
_mm512_store_epi32(values.add(offset + (5 << log_step)), val5);
_mm512_store_epi32(values.add(offset + (6 << log_step)), val6);
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
}

// TODO(spapini): Move these to M31 AVX.
Expand Down Expand Up @@ -390,8 +513,13 @@ mod tests {

use super::*;
use crate::core::backend::avx512::m31::PackedBaseField;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::CPUBackend;
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, Field};
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use crate::core::utils::bit_reverse;

#[test]
fn test_butterfly() {
Expand Down Expand Up @@ -426,12 +554,12 @@ mod tests {
#[test]
fn test_ibutterfly() {
unsafe {
let val0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val0 = _mm512_setr_epi32(2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let val1 = _mm512_setr_epi32(
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
);
let twiddle = _mm512_setr_epi32(
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);
Expand Down Expand Up @@ -641,10 +769,10 @@ mod tests {
ifft3(
std::mem::transmute(values.as_mut_ptr()),
0,
0,
&twiddles0_dbl,
&twiddles1_dbl,
&twiddles2_dbl,
VECS_LOG_SIZE,
twiddles0_dbl,
twiddles1_dbl,
twiddles2_dbl,
);

let expected: [u32; 8] = std::array::from_fn(|i| i as u32);
Expand Down Expand Up @@ -684,4 +812,98 @@ mod tests {
}
}
}

fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

let mut res = vec![];
res.push(
coset
.iter()
.map(|p| (p.y.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| (p.x.inverse().0 * 2) as i32)
.collect::<Vec<_>>(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}

res
}

fn ref_ifft(domain: CircleDomain, mut values: Vec<BaseField>) -> Vec<BaseField> {
bit_reverse(&mut values);
let eval = CircleEvaluation::<CPUBackend, _>::new(domain, values);
let mut expected_coeffs = eval.interpolate().coeffs;
for x in expected_coeffs.iter_mut() {
*x *= BaseField::from_u32_unchecked(domain.size() as u32);
}
bit_reverse(&mut expected_coeffs);
expected_coeffs
}

#[test]
fn test_vecwise_ibutterflies_real() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain);
assert_eq!(twiddle_dbls.len(), 5);
let values0: [i32; 16] = std::array::from_fn(|i| i as i32);
let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32);
let result: [BaseField; 32] = unsafe {
let (val0, val1) = vecwise_ibutterflies(
std::mem::transmute(values0),
std::mem::transmute(values1),
twiddle_dbls[0].clone().try_into().unwrap(),
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
twiddle_dbls[3].clone().try_into().unwrap(),
);
let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[4][0]));
std::mem::transmute([val0, val1])
};

// ref.
let mut values = values0.to_vec();
values.extend_from_slice(&values1);
let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect());

// Compare.
for i in 0..32 {
assert_eq!(result[i], expected[i]);
}
}

#[test]
fn test_ifft_lower_with_vecwise() {
let log_size = 5 + 3 + 3;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower_with_vecwise(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls,
log_size as usize,
log_size as usize,
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
}
}
}
2 changes: 2 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, FieldOps};
use crate::core::utils;

const VECS_LOG_SIZE: usize = 4;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

Expand Down
2 changes: 1 addition & 1 deletion src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub const P: u32 = 2147483647; // 2 ** 31 - 1

#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)]
pub struct M31(u32);
pub struct M31(pub u32);
pub type BaseField = M31;

impl_field!(M31, P);
Expand Down
3 changes: 3 additions & 0 deletions src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
impl<'a, T: 'a, I: Iterator<Item = &'a mut T>> IteratorMutExt<'a, T> for I {}

pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
if log_size == 0 {
return i;
}
i.reverse_bits() >> (usize::BITS - log_size)
}

Expand Down

0 comments on commit 54fd21a

Please sign in to comment.