diff --git a/benches/merkle.rs b/benches/merkle.rs
index c465f4c93..c2a82e619 100644
--- a/benches/merkle.rs
+++ b/benches/merkle.rs
@@ -7,11 +7,13 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
use itertools::Itertools;
use num_traits::Zero;
use stwo::commitment_scheme::ops::MerkleOps;
- use stwo::core::backend::CPUBackend;
+ use stwo::core::backend::avx512::AVX512Backend;
+ use stwo::core::backend::{CPUBackend, Col};
use stwo::core::fields::m31::BaseField;
+ use stwo::platform;
const N_COLS: usize = 1 << 8;
- const LOG_SIZE: u32 = 20;
+ const LOG_SIZE: u32 = 16;
let cols = (0..N_COLS)
.map(|_| {
(0..(1 << LOG_SIZE))
@@ -30,6 +32,23 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
CPUBackend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
})
});
+
+ if !platform::avx512_detected() {
+ return;
+ }
+ let cols = (0..N_COLS)
+ .map(|_| {
+ (0..(1 << LOG_SIZE))
+ .map(|_| BaseField::zero())
+ .collect::
>()
+ })
+ .collect::>();
+
+ group.bench_function("avx merkle", |b| {
+ b.iter(|| {
+ AVX512Backend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
+ })
+ });
}
#[cfg(target_arch = "x86_64")]
diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs
index 4185bf08e..e4b1e48df 100644
--- a/src/commitment_scheme/blake2_merkle.rs
+++ b/src/commitment_scheme/blake2_merkle.rs
@@ -1,10 +1,8 @@
-use itertools::Itertools;
use num_traits::Zero;
use super::blake2_hash::Blake2sHash;
use super::blake2s_ref::compress;
-use super::ops::{MerkleHasher, MerkleOps};
-use crate::core::backend::CPUBackend;
+use super::ops::MerkleHasher;
use crate::core::fields::m31::BaseField;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
@@ -39,23 +37,6 @@ impl MerkleHasher for Blake2sMerkleHasher {
}
}
-impl MerkleOps for CPUBackend {
- fn commit_on_layer(
- log_size: u32,
- prev_layer: Option<&Vec>,
- columns: &[&Vec],
- ) -> Vec {
- (0..(1 << log_size))
- .map(|i| {
- Blake2sMerkleHasher::hash_node(
- prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
- &columns.iter().map(|column| column[i]).collect_vec(),
- )
- })
- .collect()
- }
-}
-
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
diff --git a/src/core/backend/avx512/blake2s.rs b/src/core/backend/avx512/blake2s.rs
new file mode 100644
index 000000000..3a8a4a8b9
--- /dev/null
+++ b/src/core/backend/avx512/blake2s.rs
@@ -0,0 +1,88 @@
+use std::arch::x86_64::__m512i;
+
+use itertools::Itertools;
+
+use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states};
+use super::{AVX512Backend, VECS_LOG_SIZE};
+use crate::commitment_scheme::blake2_hash::Blake2sHash;
+use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
+use crate::commitment_scheme::ops::MerkleOps;
+use crate::core::backend::{Col, ColumnOps};
+use crate::core::fields::m31::BaseField;
+
+impl ColumnOps for AVX512Backend {
+ type Column = Vec;
+
+ fn bit_reverse_column(_column: &mut Self::Column) {
+ unimplemented!()
+ }
+}
+
+impl MerkleOps for AVX512Backend {
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Vec>,
+ columns: &[&Col],
+ ) -> Vec {
+ // Pad prev_layer if too small.
+ let mut padded_buffer = vec![];
+ let prev_layer = if log_size < VECS_LOG_SIZE as u32 {
+ prev_layer.map(|prev_layer| {
+ padded_buffer = prev_layer
+ .iter()
+ .copied()
+ .chain(std::iter::repeat(Blake2sHash::default()))
+ .collect_vec();
+ &padded_buffer
+ })
+ } else {
+ prev_layer
+ };
+
+ // Commit to columns.
+ let mut res = Vec::with_capacity(1 << log_size);
+ for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) {
+ let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() };
+ // Hash prev_layer, if exists.
+ if let Some(prev_layer) = prev_layer {
+ let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const __m512i;
+ let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) });
+ state = unsafe {
+ compress16(
+ state,
+ transpose_msgs(msgs),
+ set1(0),
+ set1(0),
+ set1(0),
+ set1(0),
+ )
+ };
+ }
+
+ // Hash columns in chunks of 16.
+ let mut col_chunk_iter = columns.array_chunks();
+ for col_chunk in &mut col_chunk_iter {
+ let msgs = col_chunk.map(|column| column.data[i].0);
+ state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
+ }
+
+ // Hash remaining columns.
+ let remainder = col_chunk_iter.remainder();
+ if !remainder.is_empty() {
+ let msgs = remainder
+ .iter()
+ .map(|column| column.data[i].0)
+ .chain(std::iter::repeat(unsafe { set1(0) }))
+ .take(16)
+ .collect_vec()
+ .try_into()
+ .unwrap();
+ state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
+ }
+ let state: [Blake2sHash; 16] =
+ unsafe { std::mem::transmute(untranspose_states(state)) };
+ res.extend_from_slice(&state);
+ }
+ res
+ }
+}
diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs
index 62fefbbc5..1b62962b7 100644
--- a/src/core/backend/avx512/mod.rs
+++ b/src/core/backend/avx512/mod.rs
@@ -1,4 +1,5 @@
pub mod bit_reverse;
+mod blake2s;
pub mod blake2s_avx;
pub mod circle;
pub mod cm31;
diff --git a/src/core/backend/cpu/blake2s.rs b/src/core/backend/cpu/blake2s.rs
new file mode 100644
index 000000000..b1fb44c1c
--- /dev/null
+++ b/src/core/backend/cpu/blake2s.rs
@@ -0,0 +1,24 @@
+use itertools::Itertools;
+
+use crate::commitment_scheme::blake2_hash::Blake2sHash;
+use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
+use crate::commitment_scheme::ops::{MerkleHasher, MerkleOps};
+use crate::core::backend::CPUBackend;
+use crate::core::fields::m31::BaseField;
+
+impl MerkleOps for CPUBackend {
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Vec>,
+ columns: &[&Vec],
+ ) -> Vec {
+ (0..(1 << log_size))
+ .map(|i| {
+ Blake2sMerkleHasher::hash_node(
+ prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
+ &columns.iter().map(|column| column[i]).collect_vec(),
+ )
+ })
+ .collect()
+ }
+}
diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs
index 5927165d0..094e4bb54 100644
--- a/src/core/backend/cpu/mod.rs
+++ b/src/core/backend/cpu/mod.rs
@@ -1,3 +1,4 @@
+mod blake2s;
mod circle;
mod fri;
pub mod quotients;