From 6250403d014dabadcdd36a999567b23bb7d33426 Mon Sep 17 00:00:00 2001 From: Alon Haramati <91828241+alonh5@users.noreply.github.com> Date: Tue, 1 Oct 2024 10:25:34 +0300 Subject: [PATCH] Parallelize bit reverse m31. (#855) * Parallelize bit reverse m31. * Move UnsafeMut to simd utils Co-Authored-By: Shahar Samocha --- .../src/core/backend/simd/bit_reverse.rs | 16 +++++++---- crates/prover/src/core/backend/simd/circle.rs | 1 - .../prover/src/core/backend/simd/fft/ifft.rs | 6 ++--- .../prover/src/core/backend/simd/fft/mod.rs | 24 ++--------------- .../prover/src/core/backend/simd/fft/rfft.rs | 10 +++---- crates/prover/src/core/backend/simd/utils.rs | 27 +++++++++++++++++++ crates/prover/src/lib.rs | 1 + 7 files changed, 49 insertions(+), 36 deletions(-) diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index d6062d381..1c2418d7d 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -1,12 +1,17 @@ use std::array; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + use super::column::{BaseColumn, SecureColumn}; use super::m31::PackedBaseField; use super::SimdBackend; +use crate::core::backend::simd::utils::UnsafeMut; use crate::core::backend::ColumnOps; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; +use crate::parallel_iter; const VEC_BITS: u32 = 4; @@ -51,9 +56,10 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { let log_size = data.len().ilog2(); let a_bits = log_size - 2 * W_BITS - VEC_BITS; + let data = UnsafeMut(data); - // TODO(AlonH): when doing multithreading, do it over a. - for a in 0u32..1 << a_bits { + parallel_iter!(0u32..(1 << a_bits)).for_each(|a| { + let data = unsafe { data.get() }; for w_l in 0u32..1 << W_BITS { let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS); for w_h in 0..w_l_rev + 1 { @@ -68,7 +74,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { // Read first chunk. // TODO(andrew): Think about optimizing a_bits. What does this mean? let chunk0 = array::from_fn(|i| unsafe { - *data.get_unchecked(idx + (i << (2 * W_BITS + a_bits))) + *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) }); let values0 = bit_reverse16(chunk0); @@ -86,7 +92,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { // Read bit reversed chunk. let chunk1 = array::from_fn(|i| unsafe { - *data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits))) + *data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) }); let values1 = bit_reverse16(chunk1); @@ -99,7 +105,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { } } } - } + }) } /// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index d8e803f4c..a20721a4f 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -229,7 +229,6 @@ impl PolyOps for SimdBackend { domain: CircleDomain, twiddles: &TwiddleTree, ) -> CircleEvaluation { - // TODO(AlonH): Handle small cases. let log_size = domain.log_size(); let fft_log_size = poly.log_size(); assert!( diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index aeddfd35a..77b096d9c 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -9,8 +9,8 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; -use crate::core::backend::simd::fft::UnsafeMutI32; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::utils::UnsafeMut; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; use crate::core::utils::bit_reverse; @@ -86,7 +86,7 @@ pub unsafe fn ifft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - let values = UnsafeMutI32(values); + let values = UnsafeMut(values); parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| { let values = values.get(); ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); @@ -138,7 +138,7 @@ pub unsafe fn ifft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - let values = UnsafeMutI32(values); + let values = UnsafeMut(values); parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| { let values = values.get(); for layer in (0..fft_layers).step_by(3) { diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ea2c149af..ba091b145 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -4,6 +4,7 @@ use std::simd::{simd_swizzle, u32x16, u32x8}; use rayon::prelude::*; use super::m31::PackedBaseField; +use super::utils::UnsafeMut; use crate::core::fields::m31::P; use crate::parallel_iter; @@ -14,27 +15,6 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16; pub const MIN_FFT_LOG_SIZE: u32 = 5; -// TODO(andrew): Examine usage of unsafe in SIMD FFT. -pub struct UnsafeMutI32(pub *mut u32); -impl UnsafeMutI32 { - pub fn get(&self) -> *mut u32 { - self.0 - } -} - -unsafe impl Send for UnsafeMutI32 {} -unsafe impl Sync for UnsafeMutI32 {} - -pub struct UnsafeConstI32(pub *const u32); -impl UnsafeConstI32 { - pub fn get(&self) -> *const u32 { - self.0 - } -} - -unsafe impl Send for UnsafeConstI32 {} -unsafe impl Sync for UnsafeConstI32 {} - // TODO(andrew): FFTs return a redundant representation, that can get the value P. need to deal with // it. Either: reduce before commitment or regenerate proof with new seed if redundant value // decommitted. @@ -56,7 +36,7 @@ unsafe impl Sync for UnsafeConstI32 {} pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { let half = log_n_vecs / 2; - let values = UnsafeMutI32(values); + let values = UnsafeMut(values); parallel_iter!(0..1 << half).for_each(|a| { let values = values.get(); for b in 0..1 << (log_n_vecs & 1) { diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 6ea1cd0ae..d28c8a00d 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -10,8 +10,8 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; -use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32}; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::utils::{UnsafeConst, UnsafeMut}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; use crate::parallel_iter; @@ -90,8 +90,8 @@ pub unsafe fn fft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - let src = UnsafeConstI32(src); - let dst = UnsafeMutI32(dst); + let src = UnsafeConst(src); + let dst = UnsafeMut(dst); parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| { let mut src = src.get(); let dst = dst.get(); @@ -154,8 +154,8 @@ pub unsafe fn fft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - let src = UnsafeConstI32(src); - let dst = UnsafeMutI32(dst); + let src = UnsafeConst(src); + let dst = UnsafeMut(dst); parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| { let mut src = src.get(); let dst = dst.get(); diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index 87dfd2246..b5cb9e986 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -24,6 +24,33 @@ const fn parity_interleave(odd: bool) -> [usize; N] { res } +// TODO(andrew): Examine usage of unsafe in SIMD FFT. +pub struct UnsafeMut(pub *mut T); +impl UnsafeMut { + /// # Safety + /// + /// Returns a raw mutable pointer. + pub unsafe fn get(&self) -> *mut T { + self.0 + } +} + +unsafe impl Send for UnsafeMut {} +unsafe impl Sync for UnsafeMut {} + +pub struct UnsafeConst(pub *const T); +impl UnsafeConst { + /// # Safety + /// + /// Returns a raw constant pointer. + pub unsafe fn get(&self) -> *const T { + self.0 + } +} + +unsafe impl Send for UnsafeConst {} +unsafe impl Sync for UnsafeConst {} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 1e9c3be74..49adff0f3 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -15,6 +15,7 @@ slice_first_last_chunk, slice_flatten, slice_group_by, + slice_ptr_get, stdsimd )] pub mod constraint_framework;