Skip to content

Commit

Permalink
AVX bit reverse
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 16, 2024
1 parent aa3fa2f commit 16a6c91
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 6 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ avx512 = []
codegen-units = 1
lto = true

[[bench]]
name = "bit_rev"
harness = false

[[bench]]
harness = false
name = "field"
Expand Down
48 changes: 48 additions & 0 deletions benches/bit_rev.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#![feature(iter_array_chunks)]

use criterion::Criterion;

#[cfg(target_arch = "x86_64")]
pub fn cpu_bit_rev(c: &mut criterion::Criterion) {
use prover_research::core::fields::m31::BaseField;

const SIZE: usize = 1 << 28;
let mut data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();

c.bench_function("cpu bit_rev", |b| {
b.iter(|| {
data = prover_research::core::utils::bit_reverse(std::mem::take(&mut data));
})
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_bit_rev(c: &mut criterion::Criterion) {
use prover_research::core::backend::avx512::bit_reverse::bit_reverse_m31;
use prover_research::core::fields::m31::BaseField;
use prover_research::platform;
if !platform::avx512_detected() {
return;
}

const SIZE: usize = 1 << 28;
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect();

c.bench_function("avx bit_rev", |b| {
b.iter(|| {
bit_reverse_m31(&mut data);
})
});
}

#[cfg(target_arch = "x86_64")]
criterion::criterion_group!(
name=avx_bit_rev;
config = Criterion::default().sample_size(10);
targets=avx512_bit_rev, cpu_bit_rev);
criterion::criterion_main!(avx_bit_rev);
11 changes: 6 additions & 5 deletions benches/fri.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use prover_research::core::backend::CPUBackend;
use prover_research::core::fields::m31::BaseField;
use prover_research::core::fri::fold_line;
use prover_research::core::fri::FriOps;
use prover_research::core::poly::circle::CanonicCoset;
use prover_research::core::poly::line::{LineDomain, LineEvaluation};

fn folding_benchmark(c: &mut Criterion) {
const LOG_SIZE: u32 = 12;
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE).coset());
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let evals = LineEvaluation::new(
domain,
vec![BaseField::from_u32_unchecked(712837213); 1 << LOG_SIZE],
vec![BaseField::from_u32_unchecked(712837213).into(); 1 << LOG_SIZE],
);
let alpha = BaseField::from_u32_unchecked(12389);
let alpha = BaseField::from_u32_unchecked(12389).into();
c.bench_function("fold_line", |b| {
b.iter(|| {
black_box(fold_line(black_box(&evals), black_box(alpha)));
black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha)));
})
});
}
Expand Down
130 changes: 130 additions & 0 deletions src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32};

use crate::core::fields::m31::BaseField;

const VEC_BITS: u32 = 4;
const W_BITS: u32 = 3;
const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
assert!(data.len().is_power_of_two());
assert!(data.len().ilog2() >= MIN_LOG_SIZE);

// V W1 A W0 [V]

let data_bits = data.len().ilog2();
let a_bits = data_bits - 2 * W_BITS - VEC_BITS;
// TODO: if threading, over a.
// TODO: Go over a in an L2/L3 cache friendly way.

// Total needed cache size: 2*2^(W_BITS+VEC_BITS) = 2^15 B = 32KB.
for a in 0u32..(1 << a_bits) {
for w0 in 0u32..(1 << W_BITS) {
for w1 in 0u32..(1 << W_BITS) {
let idx = (((w1 << a_bits) | a) << W_BITS) | w0;
let idxr = idx.reverse_bits() >> (32 - (data_bits - VEC_BITS));
if idx > idxr {
continue;
}

let values0 = std::array::from_fn(|i| {
data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize]
});
let values0 = bit_reverse16(values0);

if idx == idxr {
// Palindrome.
for i in 0..16 {
data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize] =
values0[i as usize];
}
continue;
}
let values1 = std::array::from_fn(|i| {
data[(idxr + ((i as u32) << (2 * W_BITS + a_bits))) as usize]
});
let values1 = bit_reverse16(values1);

for i in 0..16 {
data[(idx + ((i as u32) << (2 * W_BITS + a_bits))) as usize] =
values1[i as usize];
data[(idxr + ((i as u32) << (2 * W_BITS + a_bits))) as usize] =
values0[i as usize];
}
}
}
}
}

#[allow(dead_code)]
fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) };
// abcd0123 => 0abc123d
const L: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100,
0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111,
])
};
const H: __m512i = unsafe {
core::mem::transmute([
0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100,
0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111,
])
};
for _ in 0..4 {
unsafe {
data = [
_mm512_permutex2var_epi32(data[0], L, data[1]),
_mm512_permutex2var_epi32(data[2], L, data[3]),
_mm512_permutex2var_epi32(data[4], L, data[5]),
_mm512_permutex2var_epi32(data[6], L, data[7]),
_mm512_permutex2var_epi32(data[8], L, data[9]),
_mm512_permutex2var_epi32(data[10], L, data[11]),
_mm512_permutex2var_epi32(data[12], L, data[13]),
_mm512_permutex2var_epi32(data[14], L, data[15]),
_mm512_permutex2var_epi32(data[0], H, data[1]),
_mm512_permutex2var_epi32(data[2], H, data[3]),
_mm512_permutex2var_epi32(data[4], H, data[5]),
_mm512_permutex2var_epi32(data[6], H, data[7]),
_mm512_permutex2var_epi32(data[8], H, data[9]),
_mm512_permutex2var_epi32(data[10], H, data[11]),
_mm512_permutex2var_epi32(data[12], H, data[13]),
_mm512_permutex2var_epi32(data[14], H, data[15]),
];
}
}
unsafe { std::mem::transmute(data) }
}

#[cfg(test)]
mod tests {
use super::bit_reverse16;
use crate::core::backend::avx512::bit_reverse::bit_reverse_m31;
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse;

#[test]
fn test_bit_reverse16() {
let data: [u32; 256] = std::array::from_fn(|i| i as u32);
let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24);
unsafe {
let data = bit_reverse16(std::mem::transmute(data));
assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected);
}
}

#[test]
fn test_bit_reverse() {
const SIZE: usize = 1 << 15;
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let expected = bit_reverse(data.clone());
let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect();
let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect();

bit_reverse_m31(&mut data);
assert_eq!(data, expected);
}
}
1 change: 1 addition & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod bit_reverse;
2 changes: 2 additions & 0 deletions src/core/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod avx512;
pub mod cpu;

use std::ops::Index;

pub use cpu::CPUBackend;
Expand Down
2 changes: 1 addition & 1 deletion src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
/// Panics if the length of the slice is not a power of two.
// TODO(AlonH): Consider benchmarking this function.
// TODO: Implement cache friendly implementation.
pub(crate) fn bit_reverse<T, U: AsMut<[T]>>(mut v: U) -> U {
pub fn bit_reverse<T, U: AsMut<[T]>>(mut v: U) -> U {
let n = v.as_mut().len();
assert!(n.is_power_of_two());
let log_n = n.ilog2();
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(
array_chunks,
iter_array_chunks,
exact_size_is_empty,
is_sorted,
new_uninit,
Expand Down

0 comments on commit 16a6c91

Please sign in to comment.