diff --git a/src/core/backend/avx512/fft.rs b/src/core/backend/avx512/fft.rs index 0ec75e671..394020771 100644 --- a/src/core/backend/avx512/fft.rs +++ b/src/core/backend/avx512/fft.rs @@ -57,30 +57,21 @@ pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[Vec], log_n_elements: u assert!(log_n_elements >= 4); if log_n_elements <= 1 { // 16 { - ifft_lower( - values, - Some(&twiddle_dbl[..3]), - &twiddle_dbl[3..], - log_n_elements - 4, - log_n_elements - 4, - ); + ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements - 4, log_n_elements - 4); return; } let log_n_vecs = log_n_elements - 4; let log_n_fft_vecs0 = log_n_vecs / 2; let log_n_fft_vecs1 = (log_n_vecs + 1) / 2; - ifft_lower( + ifft_lower_with_vecwise( values, - Some(&twiddle_dbl[..3]), - &twiddle_dbl[3..(3 + log_n_fft_vecs1)], + &twiddle_dbl[..(3 + log_n_fft_vecs1)], log_n_elements - 4, log_n_fft_vecs1, ); - // TODO(spapini): better transpose. transpose_vecs(values, log_n_elements - 4); - ifft_lower( + ifft_lower_without_vecwise( values, - None, &twiddle_dbl[(3 + log_n_fft_vecs1)..], log_n_elements - 4, log_n_fft_vecs0, @@ -109,72 +100,166 @@ pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { } /// # Safety -pub unsafe fn ifft_lower( +pub unsafe fn ifft_lower_with_vecwise( values: *mut i32, - vecwise_twiddle_dbl: Option<&[Vec]>, twiddle_dbl: &[Vec], log_n_vecs: usize, fft_bits: usize, ) { assert!(fft_bits >= 1); - if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { - assert_eq!(vecwise_twiddle_dbl[0].len(), 1 << (log_n_vecs + 2)); - assert_eq!(vecwise_twiddle_dbl[1].len(), 1 << (log_n_vecs + 1)); - assert_eq!(vecwise_twiddle_dbl[2].len(), 1 << log_n_vecs); - } + assert_eq!(twiddle_dbl[0].len(), 1 << (log_n_vecs + 2)); for h in 0..(1 << (log_n_vecs - fft_bits)) { - // TODO(spapini): - if let Some(vecwise_twiddle_dbl) = vecwise_twiddle_dbl { - for l in 0..(1 << (fft_bits - 1)) { - // TODO(spapini): modulo for twiddles on the iters. - let index = (h << (fft_bits - 1)) + l; - let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); - let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); - (val0, val1) = vecwise_ibutterflies( - val0, - val1, - std::array::from_fn(|i| *vecwise_twiddle_dbl[0].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *vecwise_twiddle_dbl[1].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *vecwise_twiddle_dbl[2].get_unchecked(index * 2 + i)), - ); - _mm512_store_epi32(values.add(index * 32), val0); - _mm512_store_epi32(values.add(index * 32 + 16), val1); - // TODO(spapini): do a fifth layer here. + ifft_vecwise_loop(values, twiddle_dbl, fft_bits, h); + for bit_i in (1..fft_bits).step_by(3) { + match fft_bits - bit_i { + 1 => { + ifft1_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h); + } + _ => { + ifft3_loop(values, &twiddle_dbl[3..], fft_bits, bit_i, h); + } } } + } +} + +/// # Safety +pub unsafe fn ifft_lower_without_vecwise( + values: *mut i32, + twiddle_dbl: &[Vec], + log_n_vecs: usize, + fft_bits: usize, +) { + assert!(fft_bits >= 1); + for h in 0..(1 << (log_n_vecs - fft_bits)) { for bit_i in (0..fft_bits).step_by(3) { - if bit_i + 3 > fft_bits { - todo!(); - } - for m in 0..(1 << (fft_bits - 3 - bit_i)) { - let twid_index = (h << (fft_bits - 3 - bit_i)) + m; - for l in 0..(1 << bit_i) { - ifft3( - values, - (h << fft_bits) + (m << (bit_i + 3)) + l, - bit_i, - std::array::from_fn(|i| { - *twiddle_dbl[bit_i].get_unchecked( - (twid_index * 4 + i) & (twiddle_dbl[bit_i].len() - 1), - ) - }), - std::array::from_fn(|i| { - *twiddle_dbl[bit_i + 1].get_unchecked( - (twid_index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1), - ) - }), - std::array::from_fn(|i| { - *twiddle_dbl[bit_i + 2].get_unchecked( - (twid_index + i) & (twiddle_dbl[bit_i + 2].len() - 1), - ) - }), - ); + match fft_bits - bit_i { + 1 => { + ifft1_loop(values, twiddle_dbl, fft_bits, bit_i, h); + } + 2 => { + ifft2_loop(values, twiddle_dbl, fft_bits, bit_i, h); + } + _ => { + ifft3_loop(values, twiddle_dbl, fft_bits, bit_i, h); } } } } } +/// # Safety +unsafe fn ifft_vecwise_loop(values: *mut i32, twiddle_dbl: &[Vec], fft_bits: usize, h: usize) { + for l in 0..(1 << (fft_bits - 1)) { + let index = (h << (fft_bits - 1)) + l; + let mut val0 = _mm512_load_epi32(values.add(index * 32).cast_const()); + let mut val1 = _mm512_load_epi32(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + (val0, val1) = avx_ibutterfly( + val0, + val1, + _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + ); + _mm512_store_epi32(values.add(index * 32), val0); + _mm512_store_epi32(values.add(index * 32 + 16), val1); + } +} + +/// # Safety +unsafe fn ifft3_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + fft_bits: usize, + bit_i: usize, + index: usize, +) { + for m in 0..(1 << (fft_bits - 3 - bit_i)) { + let index = (index << (fft_bits - bit_i - 3)) + m; + let offset = index << (bit_i + 3); + for l in 0..(1 << bit_i) { + ifft3( + values, + offset + l, + bit_i, + std::array::from_fn(|i| { + *twiddle_dbl[bit_i] + .get_unchecked((index * 4 + i) & (twiddle_dbl[bit_i].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 1] + .get_unchecked((index * 2 + i) & (twiddle_dbl[bit_i + 1].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 2] + .get_unchecked((index + i) & (twiddle_dbl[bit_i + 2].len() - 1)) + }), + ); + } + } +} + +/// # Safety +unsafe fn ifft2_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + fft_bits: usize, + bit_i: usize, + index: usize, +) { + for m in 0..(1 << (fft_bits - 2 - bit_i)) { + let index = (index << (fft_bits - bit_i - 2)) + m; + let offset = index << (bit_i + 2); + for l in 0..(1 << bit_i) { + ifft2( + values, + offset + l, + bit_i, + std::array::from_fn(|i| { + *twiddle_dbl[bit_i] + .get_unchecked((index * 2 + i) & (twiddle_dbl[bit_i].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[bit_i + 1] + .get_unchecked((index + i) & (twiddle_dbl[bit_i + 1].len() - 1)) + }), + ); + } + } +} + +/// # Safety +unsafe fn ifft1_loop( + values: *mut i32, + twiddle_dbl: &[Vec], + fft_bits: usize, + bit_i: usize, + index: usize, +) { + for m in 0..(1 << (fft_bits - 1 - bit_i)) { + let index = (index << (fft_bits - bit_i - 1)) + m; + let offset = index << (bit_i + 1); + for l in 0..(1 << bit_i) { + ifft1( + values, + offset + l, + bit_i, + std::array::from_fn(|i| { + *twiddle_dbl[bit_i].get_unchecked((index + i) & (twiddle_dbl[bit_i].len() - 1)) + }), + ); + } + } +} + /// # Safety pub unsafe fn avx_butterfly( val0: __m512i, @@ -408,6 +493,48 @@ pub unsafe fn ifft3( _mm512_store_epi32(values.add((offset + (7 << log_u32_step)) << 4), val7); } +/// # Safety +pub unsafe fn ifft2( + values: *mut i32, + offset: usize, + log_step: usize, + twiddles_dbl0: [i32; 2], + twiddles_dbl1: [i32; 1], +) { + let log_u32_step = log_step; + // load + let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const()); + let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const()); + let mut val2 = _mm512_load_epi32(values.add((offset + (2 << log_u32_step)) << 4).cast_const()); + let mut val3 = _mm512_load_epi32(values.add((offset + (3 << log_u32_step)) << 4).cast_const()); + + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + + (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); + (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + + // store + _mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0); + _mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1); + _mm512_store_epi32(values.add((offset + (2 << log_u32_step)) << 4), val2); + _mm512_store_epi32(values.add((offset + (3 << log_u32_step)) << 4), val3); +} + +/// # Safety +pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { + let log_u32_step = log_step; + // load + let mut val0 = _mm512_load_epi32(values.add((offset + (0 << log_u32_step)) << 4).cast_const()); + let mut val1 = _mm512_load_epi32(values.add((offset + (1 << log_u32_step)) << 4).cast_const()); + + (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + + // store + _mm512_store_epi32(values.add((offset + (0 << log_u32_step)) << 4), val0); + _mm512_store_epi32(values.add((offset + (1 << log_u32_step)) << 4), val1); +} + #[cfg(test)] mod tests { use std::arch::x86_64::_mm512_setr_epi32; @@ -688,29 +815,29 @@ mod tests { #[test] fn test_ifft_lower() { - let log_size = 4 + 3 + 3; - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain); - - unsafe { - ifft_lower( - std::mem::transmute(values.data.as_mut_ptr()), - Some(&twiddle_dbls[1..4]), - &twiddle_dbls[4..], - (log_size - 4) as usize, - (log_size - 4) as usize, - ); + for log_size in 5..=10 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..domain.size()) + .map(|i| BaseField::from_u32_unchecked(i as u32)) + .collect::>(); + let expected_coeffs = ref_ifft(domain, values.clone()); + + // Compute. + let mut values = BaseFieldVec::from_iter(values); + let twiddle_dbls = get_itwiddle_dbls(domain); + + unsafe { + ifft_lower_with_vecwise( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddle_dbls[1..], + (log_size - 4) as usize, + (log_size - 4) as usize, + ); - // Compare. - for i in 0..expected_coeffs.len() { - assert_eq!(values[i], expected_coeffs[i]); + // Compare. + for i in 0..expected_coeffs.len() { + assert_eq!(values[i], expected_coeffs[i]); + } } } }