Skip to content

Commit

Permalink
ifft_full
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 3, 2024
1 parent 19779ce commit ded50d7
Showing 1 changed file with 113 additions and 17 deletions.
130 changes: 113 additions & 17 deletions src/core/backend/avx512/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,82 @@ const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };
// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

/// # Safety
pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec<i32>], log_n_elements: usize) {
assert!(log_n_elements >= 4);
if log_n_elements <= 1 {
// 16 {
ifft_lower(
values,
Some(&twiddle_dbl[..3]),
&twiddle_dbl[3..],
log_n_elements - 4,
log_n_elements - 4,
);
return;
}
let log_n_vecs = log_n_elements - 4;
let log_n_fft_vecs0 = log_n_vecs / 2;
let log_n_fft_vecs1 = (log_n_vecs + 1) / 2;
ifft_lower(
values,
Some(&twiddle_dbl[..3]),
&twiddle_dbl[3..(3 + log_n_fft_vecs1)],
log_n_elements - 4,
log_n_fft_vecs1,
);
// TODO(spapini): better transpose.
transpose_vecs(values, log_n_elements - 4);
ifft_lower(
values,
None,
&twiddle_dbl[(3 + log_n_fft_vecs1)..],
log_n_elements - 4,
log_n_fft_vecs0,
);
}

// TODO(spapini): This is inefficient. Optimize.
/// # Safety
pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..(1 << (log_n_vecs & 1)) {
for a in 0..(1 << half) {
for c in 0..(1 << half) {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
if i >= j {
continue;
}
let val0 = _mm512_load_epi32(values.add(i << 4).cast_const());
let val1 = _mm512_load_epi32(values.add(j << 4).cast_const());
_mm512_store_epi32(values.add(i << 4), val1);
_mm512_store_epi32(values.add(j << 4), val0);
}
}
}
}

/// # Safety
pub unsafe fn ifft_lower(
values: *mut i32,
vecwise_twiddle_dbl: Option<&[Vec<i32>]>,
twiddle_dbl: &[Vec<i32>],
n_total_bits: usize,
n_fft_bits: usize,
log_n_vecs: usize,
fft_bits: usize,
) {
assert!(n_fft_bits >= 1);
assert!(fft_bits >= 1);
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (n_fft_bits + 2));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (n_fft_bits + 1));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << n_fft_bits);
assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (log_n_vecs + 2));
assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (log_n_vecs + 1));
assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << log_n_vecs);
}
for h in 0..(1 << (n_total_bits - n_fft_bits)) {
for h in 0..(1 << (log_n_vecs - fft_bits)) {
// TODO(spapini):
if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl {
for l in 0..(1 << (n_fft_bits - 1)) {
for l in 0..(1 << (fft_bits - 1)) {
// TODO(spapini): modulo for twiddles on the iters.
let index = (h << (n_fft_bits - 1)) + l;
let index = (h << (fft_bits - 1)) + l;
let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const());
let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
Expand All @@ -93,25 +149,31 @@ pub unsafe fn ifft_lower(
// TODO(spapini): do a fifth layer here.
}
}
for bit_i in (0..n_fft_bits).step_by(3) {
if bit_i + 3 > n_fft_bits {
for bit_i in (0..fft_bits).step_by(3) {
if bit_i + 3 > fft_bits {
todo!();
}
for m in 0..(1 << (n_fft_bits - 3 - bit_i)) {
let twid_index = (h << (n_fft_bits - 3 - bit_i)) + m;
for m in 0..(1 << (fft_bits - 3 - bit_i)) {
let twid_index = (h << (fft_bits - 3 - bit_i)) + m;
for l in 0..(1 << bit_i) {
ifft3(
values,
(h << n_fft_bits) + (m << (bit_i + 3)) + l,
(h << fft_bits) + (m << (bit_i + 3)) + l,
bit_i,
std::array::from_fn(|i| {
*twiddle_dbl[bit_i].get_unchecked(twid_index * 4 + i)
*twiddle_dbl[bit_i].get_unchecked(
(twid_index * 4 + i) & (twiddle_dbl[bit_i].len() - 1),
)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 1].get_unchecked(twid_index * 2 + i)
*twiddle_dbl[bit_i + 1].get_unchecked(
(twid_index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1),
)
}),
std::array::from_fn(|i| {
*twiddle_dbl[bit_i + 2].get_unchecked(twid_index + i)
*twiddle_dbl[bit_i + 2].get_unchecked(
(twid_index + i) & (twiddle_dbl[bit_i + 2].len() - 1),
)
}),
);
}
Expand Down Expand Up @@ -782,4 +844,38 @@ mod tests {
}
}
}

fn run_ifft_full_test(log_size: u32) {
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..],
log_size as usize,
);
transpose_vecs(
std::mem::transmute(values.data.as_mut_ptr()),
(log_size - 4) as usize,
);

// Compare.
for i in 0..expected_coeffs.len() {
assert_eq!(values[i], expected_coeffs[i]);
}
}
}

#[test]
fn test_ifft_full_10() {
run_ifft_full_test(3 + 3 + 4);
}
}

0 comments on commit ded50d7

Please sign in to comment.