Skip to content

Commit

Permalink
clean up fft
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 22, 2024
1 parent 59c04aa commit e29d282
Showing 1 changed file with 212 additions and 85 deletions.
297 changes: 212 additions & 85 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,21 @@ pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec<i32>], log_n_elements: u
assert!(log_n_elements >= 4);
if log_n_elements <= 1 {
// 16 {
ifft_lower(
values,
Some(&twiddle_dbl[..3]),
&twiddle_dbl[3..],
log_n_elements - 4,
log_n_elements - 4,
);
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements - 4, log_n_elements - 4);
return;
}
let log_n_vecs = log_n_elements - 4;
let log_n_fft_vecs0 = log_n_vecs / 2;
let log_n_fft_vecs1 = (log_n_vecs + 1) / 2;
ifft_lower(
ifft_lower_with_vecwise(
values,
Some(&twiddle_dbl[..3]),
&twiddle_dbl[3..(3 + log_n_fft_vecs1)],
&twiddle_dbl[..(3 + log_n_fft_vecs1)],
log_n_elements - 4,
log_n_fft_vecs1,
);
// TODO(spapini): better transpose.
transpose_vecs(values, log_n_elements - 4);
ifft_lower(
ifft_lower_without_vecwise(
values,
None,
&twiddle_dbl[(3 + log_n_fft_vecs1)..],
log_n_elements - 4,
log_n_fft_vecs0,
Expand Down Expand Up @@ -109,72 +100,166 @@ pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) {
}

/// # Safety
pub unsafe fn ifft_lower(
pub unsafe fn ifft_lower_with_vecwise(
values: *mut i32,
vecwise_twiddle_dbl: Option<&[Vec<i32>]>,
twiddle_dbl: &[Vec<i32>],
log_n_vecs: usize,
fft_bits: usize,
) {
assert!(fft_bits >= 1);
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (log_n_vecs + 2));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (log_n_vecs + 1));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << log_n_vecs);
}
assert_eq!(twiddle_dbl[0].len(), 1 << (log_n_vecs + 2));
for h in 0..(1 << (log_n_vecs - fft_bits)) {
// TODO(spapini):
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
for l in 0..(1 << (fft_bits - 1)) {
// TODO(spapini): modulo for twiddles on the iters.
let index = (h << (fft_bits - 1)) + l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 2 + i)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
// TODO(spapini): do a fifth layer here.
ifft_vecwise_loop(values, twiddle_dbl, fft_bits, h);
for bit_i in (1..fft_bits).step_by(3) {
match fft_bits - bit_i {
1 => {
ifft1_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h);
}
2 => {
ifft2_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h);
}
_ => {
ifft3_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h);
}
}
}
}
}

/// # Safety
pub unsafe fn ifft_lower_without_vecwise(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
log_n_vecs: usize,
fft_bits: usize,
) {
assert!(fft_bits >= 1);
for h in 0..(1 << (log_n_vecs - fft_bits)) {
for bit_i in (0..fft_bits).step_by(3) {
if bit_i + 3 > fft_bits {
todo!();
}
for m in 0..(1 << (fft_bits - 3 - bit_i)) {
let twid_index = (h << (fft_bits - 3 - bit_i)) + m;
for l in 0..(1 << bit_i) {
ifft3(
values,
(h << fft_bits) + (m << (bit_i + 3)) + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i].get_unchecked(
(twid_index * 4 + i) & (twiddle_dbl[bit_i].len() - 1),
)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 1].get_unchecked(
(twid_index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1),
)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 2].get_unchecked(
(twid_index + i) & (twiddle_dbl[bit_i + 2].len() - 1),
)
}),
);
match fft_bits - bit_i {
1 => {
ifft1_loop(values, twiddle_dbl, fft_bits, bit_i, h);
}
2 => {
ifft2_loop(values, twiddle_dbl, fft_bits, bit_i, h);
}
_ => {
ifft3_loop(values, twiddle_dbl, fft_bits, bit_i, h);
}
}
}
}
}

/// # Safety
unsafe fn ifft_vecwise_loop(values: *mut i32, twiddle_dbl: &[Vec<i32>], fft_bits: usize, h: usize) {
for l in 0..(1 << (fft_bits - 1)) {
let index = (h << (fft_bits - 1)) + l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)),
);
(val0, val1) = avx_ibutterfly(
val0,
val1,
_mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)),
);
_mm512_store_epi32(values.add(index * 32), val0);
_mm512_store_epi32(values.add(index * 32 + 16), val1);
}
}

/// # Safety
unsafe fn ifft3_loop(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
fft_bits: usize,
bit_i: usize,
index: usize,
) {
for m in 0..(1 << (fft_bits - 3 - bit_i)) {
let index = (index << (fft_bits - bit_i - 3)) + m;
let offset = index << (bit_i + 3);
for l in 0..(1 << bit_i) {
ifft3(
values,
offset + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i]
.get_unchecked((index * 4 + i) & (twiddle_dbl[bit_i].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 1]
.get_unchecked((index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 2]
.get_unchecked((index + i) & (twiddle_dbl[bit_i + 2].len() - 1))
}),
);
}
}
}

/// # Safety
unsafe fn ifft2_loop(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
fft_bits: usize,
bit_i: usize,
index: usize,
) {
for m in 0..(1 << (fft_bits - 2 - bit_i)) {
let index = (index << (fft_bits - bit_i - 2)) + m;
let offset = index << (bit_i + 2);
for l in 0..(1 << bit_i) {
ifft2(
values,
offset + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i]
.get_unchecked((index * 2 + i) & (twiddle_dbl[bit_i].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 1]
.get_unchecked((index + i) & (twiddle_dbl[bit_i + 1].len() - 1))
}),
);
}
}
}

/// # Safety
unsafe fn ifft1_loop(
values: *mut i32,
twiddle_dbl: &[Vec<i32>],
fft_bits: usize,
bit_i: usize,
index: usize,
) {
for m in 0..(1 << (fft_bits - 1 - bit_i)) {
let index = (index << (fft_bits - bit_i - 1)) + m;
let offset = index << (bit_i + 1);
for l in 0..(1 << bit_i) {
ifft1(
values,
offset + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i].get_unchecked((index + i) & (twiddle_dbl[bit_i].len() - 1))
}),
);
}
}
}

/// # Safety
pub unsafe fn avx_butterfly(
val0: __m512i,
Expand Down Expand Up @@ -408,6 +493,48 @@ pub unsafe fn ifft3(
_mm512_store_epi32(values.add((offset + (7 << log_u32_step)) << 4), val7);
}

/// # Safety
pub unsafe fn ifft2(
values: *mut i32,
offset: usize,
log_step: usize,
twiddles_dbl0: [i32; 2],
twiddles_dbl1: [i32; 1],
) {
let log_u32_step = log_step;
// load
let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const());
let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const());
let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_u32_step)) << 4).cast_const());
let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_u32_step)) << 4).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));

(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));

// store
_mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0);
_mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1);
_mm512_store_epi32(values.add((offset + (2 << log_u32_step)) << 4), val2);
_mm512_store_epi32(values.add((offset + (3 << log_u32_step)) << 4), val3);
}

/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
let log_u32_step = log_step;
// load
let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const());
let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const());

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));

// store
_mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0);
_mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1);
}

#[cfg(test)]
mod tests {
use std::arch::x86_64::_mm512_setr_epi32;
Expand Down Expand Up @@ -688,29 +815,29 @@ mod tests {

#[test]
fn test_ifft_lower() {
let log_size = 4 + 3 + 3;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower(
std::mem::transmute(values.data.as_mut_ptr()),
Some(&twiddle_dbls[1..4]),
&twiddle_dbls[4..],
(log_size - 4) as usize,
(log_size - 4) as usize,
);
for log_size in 5..=10 {
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower_with_vecwise(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..],
(log_size - 4) as usize,
(log_size - 4) as usize,
);

// Compare.
for i in 0..expected_coeffs.len() {
assert_eq!(values[i], expected_coeffs[i]);
// Compare.
for i in 0..expected_coeffs.len() {
assert_eq!(values[i], expected_coeffs[i]);
}
}
}
}
Expand Down

0 comments on commit e29d282

Please sign in to comment.