From 01261a4d6a7094742a0bee345b9fffa5ccf014ef Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Fri, 16 Feb 2024 20:13:28 +0200 Subject: [PATCH] AVX bit reverse --- Cargo.toml | 4 + benches/bit_rev.rs | 48 +++++++++ benches/fri.rs | 11 +- src/core/backend/avx512/bit_reverse.rs | 138 +++++++++++++++++++++++++ src/core/backend/avx512/mod.rs | 1 + src/core/backend/mod.rs | 1 + src/core/utils.rs | 13 ++- src/lib.rs | 1 + 8 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 benches/bit_rev.rs create mode 100644 src/core/backend/avx512/bit_reverse.rs create mode 100644 src/core/backend/avx512/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 60c1437e8..ddcbf49b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,10 @@ avx512 = [] codegen-units = 1 lto = true +[[bench]] +name = "bit_rev" +harness = false + [[bench]] harness = false name = "field" diff --git a/benches/bit_rev.rs b/benches/bit_rev.rs new file mode 100644 index 000000000..3ee16c0ca --- /dev/null +++ b/benches/bit_rev.rs @@ -0,0 +1,48 @@ +#![feature(iter_array_chunks)] + +use criterion::Criterion; + +#[cfg(target_arch = "x86_64")] +pub fn cpu_bit_rev(c: &mut criterion::Criterion) { + use prover_research::core::fields::m31::BaseField; + + const SIZE: usize = 1 << 28; + let mut data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + + c.bench_function("cpu bit_rev", |b| { + b.iter(|| { + data = prover_research::core::utils::bit_reverse(std::mem::take(&mut data)); + }) + }); +} + +#[cfg(target_arch = "x86_64")] +pub fn avx512_bit_rev(c: &mut criterion::Criterion) { + use prover_research::core::backend::avx512::bit_reverse::bit_reverse_m31; + use prover_research::core::fields::m31::BaseField; + use prover_research::platform; + if !platform::avx512_detected() { + return; + } + + const SIZE: usize = 1 << 28; + let data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect(); + + c.bench_function("avx bit_rev", |b| { + b.iter(|| { + bit_reverse_m31(&mut data); + }) + }); +} + +#[cfg(target_arch = "x86_64")] +criterion::criterion_group!( + name=avx_bit_rev; + config = Criterion::default().sample_size(10); + targets=avx512_bit_rev, cpu_bit_rev); +criterion::criterion_main!(avx_bit_rev); diff --git a/benches/fri.rs b/benches/fri.rs index eeaab6274..e47839e75 100644 --- a/benches/fri.rs +++ b/benches/fri.rs @@ -1,20 +1,21 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use prover_research::core::backend::CPUBackend; use prover_research::core::fields::m31::BaseField; -use prover_research::core::fri::fold_line; +use prover_research::core::fri::FriOps; use prover_research::core::poly::circle::CanonicCoset; use prover_research::core::poly::line::{LineDomain, LineEvaluation}; fn folding_benchmark(c: &mut Criterion) { const LOG_SIZE: u32 = 12; - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE).coset()); + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); let evals = LineEvaluation::new( domain, - vec![BaseField::from_u32_unchecked(712837213); 1 << LOG_SIZE], + vec![BaseField::from_u32_unchecked(712837213).into(); 1 << LOG_SIZE], ); - let alpha = BaseField::from_u32_unchecked(12389); + let alpha = BaseField::from_u32_unchecked(12389).into(); c.bench_function("fold_line", |b| { b.iter(|| { - black_box(fold_line(black_box(&evals), black_box(alpha))); + black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha))); }) }); } diff --git a/src/core/backend/avx512/bit_reverse.rs b/src/core/backend/avx512/bit_reverse.rs new file mode 100644 index 000000000..4bc567eb6 --- /dev/null +++ b/src/core/backend/avx512/bit_reverse.rs @@ -0,0 +1,138 @@ +use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32}; + +use crate::core::fields::m31::BaseField; +use crate::core::utils::{bit_reverse_index, IteratorMutExt}; + +const VEC_BITS: u32 = 4; +const W_BITS: u32 = 3; +const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; + +// TODO(spapini): Use PackedBaseField type. +/// Bit reverses packed M31 values. +/// Given an array A[0..2^n), computes B[i] = A[bit_reverse(i)]. +pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) { + assert!(data.len().is_power_of_two()); + assert!(data.len().ilog2() >= MIN_LOG_SIZE); + + // Indices in the array are of the form v_h w_h a w_l v_l, with + // |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS. + // The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at: + // * w_h a w_l * <-> * rev(w_h a w_l) *. + // These are 1 or 2 chunks of 2^W_BITS contiguous AVX512 vectors. + + let log_size = data.len().ilog2(); + let a_bits = log_size - 2 * W_BITS - VEC_BITS; + + // TODO(spapini): when doing multithreading, do it over a. + for a in 0u32..(1 << a_bits) { + for w_l in 0u32..(1 << W_BITS) { + for w_h in 0u32..(1 << W_BITS) { + let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize; + let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS); + + // In order to not swap twice, only swap if idx <= idx_rev. + if idx > idx_rev { + continue; + } + + // Read first chunk. + let chunk0 = std::array::from_fn(|i| data[idx + (i << (2 * W_BITS + a_bits))]); + let values0 = bit_reverse16(chunk0); + + if idx == idx_rev { + // Palindrome index. Write into the same chunk. + data[idx..] + .iter_mut() + .step_by(1 << (2 * W_BITS + a_bits)) + .assign(values0); + continue; + } + + // Read bit reversed chunk. + let chunk1 = std::array::from_fn(|i| data[idx_rev + (i << (2 * W_BITS + a_bits))]); + let values1 = bit_reverse16(chunk1); + + data[idx..] + .iter_mut() + .step_by(1 << (2 * W_BITS + a_bits)) + .assign(values1); + data[idx_rev..] + .iter_mut() + .step_by(1 << (2 * W_BITS + a_bits)) + .assign(values0); + } + } + } +} + +fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] { + let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; + // abcd0123 => 0abc123d + const L: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, + 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) + }; + const H: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, + 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) + }; + for _ in 0..4 { + unsafe { + data = [ + _mm512_permutex2var_epi32(data[0], L, data[1]), + _mm512_permutex2var_epi32(data[2], L, data[3]), + _mm512_permutex2var_epi32(data[4], L, data[5]), + _mm512_permutex2var_epi32(data[6], L, data[7]), + _mm512_permutex2var_epi32(data[8], L, data[9]), + _mm512_permutex2var_epi32(data[10], L, data[11]), + _mm512_permutex2var_epi32(data[12], L, data[13]), + _mm512_permutex2var_epi32(data[14], L, data[15]), + _mm512_permutex2var_epi32(data[0], H, data[1]), + _mm512_permutex2var_epi32(data[2], H, data[3]), + _mm512_permutex2var_epi32(data[4], H, data[5]), + _mm512_permutex2var_epi32(data[6], H, data[7]), + _mm512_permutex2var_epi32(data[8], H, data[9]), + _mm512_permutex2var_epi32(data[10], H, data[11]), + _mm512_permutex2var_epi32(data[12], H, data[13]), + _mm512_permutex2var_epi32(data[14], H, data[15]), + ]; + } + } + unsafe { std::mem::transmute(data) } +} + +#[cfg(test)] +mod tests { + use super::bit_reverse16; + use crate::core::backend::avx512::bit_reverse::bit_reverse_m31; + use crate::core::fields::m31::BaseField; + use crate::core::utils::bit_reverse; + + #[test] + fn test_bit_reverse16() { + let data: [u32; 256] = std::array::from_fn(|i| i as u32); + let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24); + unsafe { + let data = bit_reverse16(std::mem::transmute(data)); + assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected); + } + } + + #[test] + fn test_bit_reverse() { + const SIZE: usize = 1 << 15; + let data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + let expected = bit_reverse(data.clone()); + let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect(); + let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect(); + + bit_reverse_m31(&mut data); + assert_eq!(data, expected); + } +} diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs new file mode 100644 index 000000000..0e2440cee --- /dev/null +++ b/src/core/backend/avx512/mod.rs @@ -0,0 +1 @@ +pub mod bit_reverse; diff --git a/src/core/backend/mod.rs b/src/core/backend/mod.rs index d0d4183ab..f4331d88f 100644 --- a/src/core/backend/mod.rs +++ b/src/core/backend/mod.rs @@ -8,6 +8,7 @@ use super::fields::qm31::SecureField; use super::fields::Field; use super::poly::circle::PolyOps; +pub mod avx512; pub mod cpu; pub trait Backend: diff --git a/src/core/utils.rs b/src/core/utils.rs index 369a018d6..b0b03ef4e 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -1,3 +1,14 @@ +pub trait IteratorMutExt<'a, T: 'a>: Iterator { + fn assign(self, other: impl IntoIterator) + where + Self: Sized, + { + self.zip(other).for_each(|(a, b)| *a = b); + } +} + +impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} + pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { i.reverse_bits() >> (usize::BITS - log_size) } @@ -9,7 +20,7 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { /// Panics if the length of the slice is not a power of two. // TODO(AlonH): Consider benchmarking this function. // TODO: Implement cache friendly implementation. -pub(crate) fn bit_reverse>(mut v: U) -> U { +pub fn bit_reverse>(mut v: U) -> U { let n = v.as_mut().len(); assert!(n.is_power_of_two()); let log_n = n.ilog2(); diff --git a/src/lib.rs b/src/lib.rs index 8929b091c..1c860b95b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature( array_chunks, + iter_array_chunks, exact_size_is_empty, is_sorted, new_uninit,