diff --git a/Cargo.toml b/Cargo.toml index 60c1437e8..ddcbf49b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,10 @@ avx512 = [] codegen-units = 1 lto = true +[[bench]] +name = "bit_rev" +harness = false + [[bench]] harness = false name = "field" diff --git a/benches/bit_rev.rs b/benches/bit_rev.rs new file mode 100644 index 000000000..3ee16c0ca --- /dev/null +++ b/benches/bit_rev.rs @@ -0,0 +1,48 @@ +#![feature(iter_array_chunks)] + +use criterion::Criterion; + +#[cfg(target_arch = "x86_64")] +pub fn cpu_bit_rev(c: &mut criterion::Criterion) { + use prover_research::core::fields::m31::BaseField; + + const SIZE: usize = 1 << 28; + let mut data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + + c.bench_function("cpu bit_rev", |b| { + b.iter(|| { + data = prover_research::core::utils::bit_reverse(std::mem::take(&mut data)); + }) + }); +} + +#[cfg(target_arch = "x86_64")] +pub fn avx512_bit_rev(c: &mut criterion::Criterion) { + use prover_research::core::backend::avx512::bit_reverse::bit_reverse_m31; + use prover_research::core::fields::m31::BaseField; + use prover_research::platform; + if !platform::avx512_detected() { + return; + } + + const SIZE: usize = 1 << 28; + let data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect(); + + c.bench_function("avx bit_rev", |b| { + b.iter(|| { + bit_reverse_m31(&mut data); + }) + }); +} + +#[cfg(target_arch = "x86_64")] +criterion::criterion_group!( + name=avx_bit_rev; + config = Criterion::default().sample_size(10); + targets=avx512_bit_rev, cpu_bit_rev); +criterion::criterion_main!(avx_bit_rev); diff --git a/benches/fri.rs b/benches/fri.rs index eeaab6274..e47839e75 100644 --- a/benches/fri.rs +++ b/benches/fri.rs @@ -1,20 +1,21 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use prover_research::core::backend::CPUBackend; use prover_research::core::fields::m31::BaseField; -use prover_research::core::fri::fold_line; +use prover_research::core::fri::FriOps; use prover_research::core::poly::circle::CanonicCoset; use prover_research::core::poly::line::{LineDomain, LineEvaluation}; fn folding_benchmark(c: &mut Criterion) { const LOG_SIZE: u32 = 12; - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE).coset()); + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); let evals = LineEvaluation::new( domain, - vec![BaseField::from_u32_unchecked(712837213); 1 << LOG_SIZE], + vec![BaseField::from_u32_unchecked(712837213).into(); 1 << LOG_SIZE], ); - let alpha = BaseField::from_u32_unchecked(12389); + let alpha = BaseField::from_u32_unchecked(12389).into(); c.bench_function("fold_line", |b| { b.iter(|| { - black_box(fold_line(black_box(&evals), black_box(alpha))); + black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha))); }) }); } diff --git a/src/core/backend/avx512/bit_reverse.rs b/src/core/backend/avx512/bit_reverse.rs new file mode 100644 index 000000000..56b8e411b --- /dev/null +++ b/src/core/backend/avx512/bit_reverse.rs @@ -0,0 +1,130 @@ +use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32}; + +use crate::core::fields::m31::BaseField; + +const VEC_BITS: u32 = 4; +const W_BITS: u32 = 3; +const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; + +pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) { + assert!(data.len().is_power_of_two()); + assert!(data.len().ilog2() >= MIN_LOG_SIZE); + + // V W1 A W0 [V] + + let data_bits = data.len().ilog2(); + let a_bits = data_bits - 2 * W_BITS - VEC_BITS; + // TODO: if threading, over a. + // TODO: Go over a in an L2/L3 cache friendly way. + + // Total needed cache size: 2*2^(W_BITS+VEC_BITS) = 2^15 B = 32KB. + for a in 0u32..(1 << a_bits) { + for w0 in 0u32..(1 << W_BITS) { + for w1 in 0u32..(1 << W_BITS) { + let idx = (((w1 << a_bits) | a) << W_BITS) | w0; + let idxr = idx.reverse_bits() >> (32 - (data_bits - VEC_BITS)); + if idx > idxr { + continue; + } + + let values0 = std::array::from_fn(|i| { + data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize] + }); + let values0 = bit_reverse16(values0); + + if idx == idxr { + // Palindrome. + for i in 0..16 { + data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize] = + values0[i as usize]; + } + continue; + } + let values1 = std::array::from_fn(|i| { + data[(idxr + ((i as u32) << (2 * W_BITS + a_bits))) as usize] + }); + let values1 = bit_reverse16(values1); + + for i in 0..16 { + data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize] = + values1[i as usize]; + data[(idxr + ((i as u32) << (2 * W_BITS + a_bits))) as usize] = + values0[i as usize]; + } + } + } + } +} + +#[allow(dead_code)] +fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] { + let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; + // abcd0123 => 0abc123d + const L: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, + 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) + }; + const H: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, + 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) + }; + for _ in 0..4 { + unsafe { + data = [ + _mm512_permutex2var_epi32(data[0], L, data[1]), + _mm512_permutex2var_epi32(data[2], L, data[3]), + _mm512_permutex2var_epi32(data[4], L, data[5]), + _mm512_permutex2var_epi32(data[6], L, data[7]), + _mm512_permutex2var_epi32(data[8], L, data[9]), + _mm512_permutex2var_epi32(data[10], L, data[11]), + _mm512_permutex2var_epi32(data[12], L, data[13]), + _mm512_permutex2var_epi32(data[14], L, data[15]), + _mm512_permutex2var_epi32(data[0], H, data[1]), + _mm512_permutex2var_epi32(data[2], H, data[3]), + _mm512_permutex2var_epi32(data[4], H, data[5]), + _mm512_permutex2var_epi32(data[6], H, data[7]), + _mm512_permutex2var_epi32(data[8], H, data[9]), + _mm512_permutex2var_epi32(data[10], H, data[11]), + _mm512_permutex2var_epi32(data[12], H, data[13]), + _mm512_permutex2var_epi32(data[14], H, data[15]), + ]; + } + } + unsafe { std::mem::transmute(data) } +} + +#[cfg(test)] +mod tests { + use super::bit_reverse16; + use crate::core::backend::avx512::bit_reverse::bit_reverse_m31; + use crate::core::fields::m31::BaseField; + use crate::core::utils::bit_reverse; + + #[test] + fn test_bit_reverse16() { + let data: [u32; 256] = std::array::from_fn(|i| i as u32); + let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24); + unsafe { + let data = bit_reverse16(std::mem::transmute(data)); + assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected); + } + } + + #[test] + fn test_bit_reverse() { + const SIZE: usize = 1 << 15; + let data: Vec<_> = (0..SIZE as u32) + .map(BaseField::from_u32_unchecked) + .collect(); + let expected = bit_reverse(data.clone()); + let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect(); + let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect(); + + bit_reverse_m31(&mut data); + assert_eq!(data, expected); + } +} diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs new file mode 100644 index 000000000..0e2440cee --- /dev/null +++ b/src/core/backend/avx512/mod.rs @@ -0,0 +1 @@ +pub mod bit_reverse; diff --git a/src/core/backend/mod.rs b/src/core/backend/mod.rs index e55029a95..ff28ebf42 100644 --- a/src/core/backend/mod.rs +++ b/src/core/backend/mod.rs @@ -8,6 +8,7 @@ use super::fields::qm31::SecureField; use super::fields::Field; use super::poly::circle::PolyOps; +pub mod avx512; pub mod cpu; pub trait Backend: diff --git a/src/core/utils.rs b/src/core/utils.rs index 369a018d6..0888557c5 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -9,7 +9,7 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { /// Panics if the length of the slice is not a power of two. // TODO(AlonH): Consider benchmarking this function. // TODO: Implement cache friendly implementation. -pub(crate) fn bit_reverse>(mut v: U) -> U { +pub fn bit_reverse>(mut v: U) -> U { let n = v.as_mut().len(); assert!(n.is_power_of_two()); let log_n = n.ilog2(); diff --git a/src/lib.rs b/src/lib.rs index 8929b091c..1c860b95b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature( array_chunks, + iter_array_chunks, exact_size_is_empty, is_sorted, new_uninit,