Skip to content

Commit

Permalink
Parallelize bit reverse m31. (#855)
Browse files Browse the repository at this point in the history
* Parallelize bit reverse m31.

* Move UnsafeMut to simd utils

Co-Authored-By: Shahar Samocha <[email protected]>
  • Loading branch information
alonh5 and shaharsamocha7 authored Oct 1, 2024
1 parent 443ae49 commit 6250403
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 36 deletions.
16 changes: 11 additions & 5 deletions crates/prover/src/core/backend/simd/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use std::array;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::column::{BaseColumn, SecureColumn};
use super::m31::PackedBaseField;
use super::SimdBackend;
use crate::core::backend::simd::utils::UnsafeMut;
use crate::core::backend::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};
use crate::parallel_iter;

const VEC_BITS: u32 = 4;

Expand Down Expand Up @@ -51,9 +56,10 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {

let log_size = data.len().ilog2();
let a_bits = log_size - 2 * W_BITS - VEC_BITS;
let data = UnsafeMut(data);

// TODO(AlonH): when doing multithreading, do it over a.
for a in 0u32..1 << a_bits {
parallel_iter!(0u32..(1 << a_bits)).for_each(|a| {
let data = unsafe { data.get() };
for w_l in 0u32..1 << W_BITS {
let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS);
for w_h in 0..w_l_rev + 1 {
Expand All @@ -68,7 +74,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
// Read first chunk.
// TODO(andrew): Think about optimizing a_bits. What does this mean?
let chunk0 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx + (i << (2 * W_BITS + a_bits)))
*data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits)))
});
let values0 = bit_reverse16(chunk0);

Expand All @@ -86,7 +92,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {

// Read bit reversed chunk.
let chunk1 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits)))
*data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits)))
});
let values1 = bit_reverse16(chunk1);

Expand All @@ -99,7 +105,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
}
}
}
}
})
}

/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
Expand Down
1 change: 0 additions & 1 deletion crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ impl PolyOps for SimdBackend {
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(AlonH): Handle small cases.
let log_size = domain.log_size();
let fft_log_size = poly.log_size();
assert!(
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use rayon::prelude::*;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::utils::UnsafeMut;
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
Expand Down Expand Up @@ -86,7 +86,7 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

let values = UnsafeMutI32(values);
let values = UnsafeMut(values);
parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
Expand Down Expand Up @@ -138,7 +138,7 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

let values = UnsafeMutI32(values);
let values = UnsafeMut(values);
parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
Expand Down
24 changes: 2 additions & 22 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::simd::{simd_swizzle, u32x16, u32x8};
use rayon::prelude::*;

use super::m31::PackedBaseField;
use super::utils::UnsafeMut;
use crate::core::fields::m31::P;
use crate::parallel_iter;

Expand All @@ -14,27 +15,6 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

// TODO(andrew): Examine usage of unsafe in SIMD FFT.
pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(andrew): FFTs return a redundant representation, that can get the value P. need to deal with
// it. Either: reduce before commitment or regenerate proof with new seed if redundant value
// decommitted.
Expand All @@ -56,7 +36,7 @@ unsafe impl Sync for UnsafeConstI32 {}
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;

let values = UnsafeMutI32(values);
let values = UnsafeMut(values);
parallel_iter!(0..1 << half).for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use rayon::prelude::*;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::utils::{UnsafeConst, UnsafeMut};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;
Expand Down Expand Up @@ -90,8 +90,8 @@ pub unsafe fn fft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
let src = UnsafeConst(src);
let dst = UnsafeMut(dst);
parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
Expand Down Expand Up @@ -154,8 +154,8 @@ pub unsafe fn fft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
let src = UnsafeConst(src);
let dst = UnsafeMut(dst);
parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
Expand Down
27 changes: 27 additions & 0 deletions crates/prover/src/core/backend/simd/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@ const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
res
}

// TODO(andrew): Examine usage of unsafe in SIMD FFT.
pub struct UnsafeMut<T: ?Sized>(pub *mut T);
impl<T: ?Sized> UnsafeMut<T> {
/// # Safety
///
/// Returns a raw mutable pointer.
pub unsafe fn get(&self) -> *mut T {
self.0
}
}

unsafe impl<T: ?Sized> Send for UnsafeMut<T> {}
unsafe impl<T: ?Sized> Sync for UnsafeMut<T> {}

pub struct UnsafeConst<T>(pub *const T);
impl<T> UnsafeConst<T> {
/// # Safety
///
/// Returns a raw constant pointer.
pub unsafe fn get(&self) -> *const T {
self.0
}
}

unsafe impl<T> Send for UnsafeConst<T> {}
unsafe impl<T> Sync for UnsafeConst<T> {}

#[cfg(test)]
mod tests {
use std::simd::{u32x4, Swizzle};
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
slice_first_last_chunk,
slice_flatten,
slice_group_by,
slice_ptr_get,
stdsimd
)]
pub mod constraint_framework;
Expand Down

0 comments on commit 6250403

Please sign in to comment.