Skip to content

Commit

Permalink
avx merkle
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Apr 2, 2024
1 parent 5f2a9e6 commit e335956
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 24 deletions.
23 changes: 21 additions & 2 deletions benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
use itertools::Itertools;
use num_traits::Zero;
use stwo::commitment_scheme::ops::MerkleOps;
use stwo::core::backend::CPUBackend;
use stwo::core::backend::avx512::AVX512Backend;
use stwo::core::backend::{CPUBackend, Col};
use stwo::core::fields::m31::BaseField;
use stwo::platform;

const N_COLS: usize = 1 << 8;
const LOG_SIZE: u32 = 20;
const LOG_SIZE: u32 = 16;
let cols = (0..N_COLS)
.map(|_| {
(0..(1 << LOG_SIZE))
Expand All @@ -30,6 +32,23 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
CPUBackend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
})
});

if !platform::avx512_detected() {
return;
}
let cols = (0..N_COLS)
.map(|_| {
(0..(1 << LOG_SIZE))
.map(|_| BaseField::zero())
.collect::<Col<AVX512Backend, BaseField>>()
})
.collect::<Vec<_>>();

group.bench_function("avx merkle", |b| {
b.iter(|| {
AVX512Backend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
})
});
}

#[cfg(target_arch = "x86_64")]
Expand Down
21 changes: 1 addition & 20 deletions src/commitment_scheme/blake2_merkle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use itertools::Itertools;
use num_traits::Zero;

use super::blake2_hash::Blake2sHash;
use super::blake2s_ref::compress;
use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use super::ops::MerkleHasher;
use crate::core::fields::m31::BaseField;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
Expand Down Expand Up @@ -39,23 +37,6 @@ impl MerkleHasher for Blake2sMerkleHasher {
}
}

impl MerkleOps<Blake2sMerkleHasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
Expand Down
88 changes: 88 additions & 0 deletions src/core/backend/avx512/blake2s.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::arch::x86_64::__m512i;

use itertools::Itertools;

use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states};
use super::{AVX512Backend, VECS_LOG_SIZE};
use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::ops::MerkleOps;
use crate::core::backend::{Col, ColumnOps};
use crate::core::fields::m31::BaseField;

impl ColumnOps<Blake2sHash> for AVX512Backend {
type Column = Vec<Blake2sHash>;

fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}

impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Col<AVX512Backend, BaseField>],
) -> Vec<Blake2sHash> {
// Pad prev_layer if too small.
let mut padded_buffer = vec![];
let prev_layer = if log_size < 4 {
prev_layer.map(|prev_layer| {
padded_buffer = prev_layer
.iter()
.copied()
.chain(std::iter::repeat(Blake2sHash::default()))
.collect_vec();
&padded_buffer
})
} else {
prev_layer
};

// Commit to columns.
let mut res = Vec::with_capacity(1 << log_size);
for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) {
let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() };
// Hash prev_layer.
if let Some(prev_layer) = prev_layer {
let ptr = prev_layer[(i << 5)..(i << 5) + 32].as_ptr() as *const __m512i;
let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) });
state = unsafe {
compress16(
state,
transpose_msgs(msgs),
set1(0),
set1(0),
set1(0),
set1(0),
)
};
}

// Hash columns in chunks of 16.
let mut col_chunk_iter = columns.array_chunks();
for col_chunk in &mut col_chunk_iter {
let msgs = col_chunk.map(|column| column.data[i].0);
state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
}

// Hash remaining columns.
let remainder = col_chunk_iter.remainder();
if !remainder.is_empty() {
let msgs = remainder
.iter()
.map(|column| column.data[i].0)
.chain(std::iter::repeat(unsafe { set1(0) }))
.take(16)
.collect_vec()
.try_into()
.unwrap();
state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
}
let state: [Blake2sHash; 16] =
unsafe { std::mem::transmute(untranspose_states(state)) };
res.extend_from_slice(&state);
}
res
}
}
5 changes: 4 additions & 1 deletion src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ impl PolyOps for AVX512Backend {
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(eval.domain, Col::<AVX512Backend, _>::from_iter(eval.values))
CircleEvaluation::new(
eval.domain,
Col::<AVX512Backend, BaseField>::from_iter(eval.values),
)
}

fn interpolate(
Expand Down
3 changes: 2 additions & 1 deletion src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod bit_reverse;
mod blake2s;
pub mod blake2s_avx;
pub mod circle;
pub mod cm31;
Expand Down Expand Up @@ -202,7 +203,7 @@ mod tests {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
<B as ColumnOps<BaseField>>::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
Expand Down
24 changes: 24 additions & 0 deletions src/core/backend/cpu/blake2s.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use itertools::Itertools;

use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

impl MerkleOps<Blake2sMerkleHasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}
1 change: 1 addition & 0 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod blake2s;
mod circle;
mod fri;
pub mod quotients;
Expand Down

0 comments on commit e335956

Please sign in to comment.