diff --git a/benches/merkle.rs b/benches/merkle.rs
index c465f4c93..c2a82e619 100644
--- a/benches/merkle.rs
+++ b/benches/merkle.rs
@@ -7,11 +7,13 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
use itertools::Itertools;
use num_traits::Zero;
use stwo::commitment_scheme::ops::MerkleOps;
- use stwo::core::backend::CPUBackend;
+ use stwo::core::backend::avx512::AVX512Backend;
+ use stwo::core::backend::{CPUBackend, Col};
use stwo::core::fields::m31::BaseField;
+ use stwo::platform;
const N_COLS: usize = 1 << 8;
- const LOG_SIZE: u32 = 20;
+ const LOG_SIZE: u32 = 16;
let cols = (0..N_COLS)
.map(|_| {
(0..(1 << LOG_SIZE))
@@ -30,6 +32,23 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) {
CPUBackend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
})
});
+
+ if !platform::avx512_detected() {
+ return;
+ }
+ let cols = (0..N_COLS)
+ .map(|_| {
+ (0..(1 << LOG_SIZE))
+ .map(|_| BaseField::zero())
+ .collect::
>()
+ })
+ .collect::>();
+
+ group.bench_function("avx merkle", |b| {
+ b.iter(|| {
+ AVX512Backend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec());
+ })
+ });
}
#[cfg(target_arch = "x86_64")]
diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs
index f12ba4479..693c762df 100644
--- a/src/commitment_scheme/blake2_merkle.rs
+++ b/src/commitment_scheme/blake2_merkle.rs
@@ -1,10 +1,8 @@
-use itertools::Itertools;
use num_traits::Zero;
-use super::blake2_hash::{Blake2sHash, Blake2sHasher};
+use super::blake2_hash::Blake2sHasher;
use super::blake2s_ref::compress;
-use super::ops::{MerkleHasher, MerkleOps};
-use crate::core::backend::CPUBackend;
+use super::ops::MerkleHasher;
use crate::core::fields::m31::BaseField;
impl MerkleHasher for Blake2sHasher {
@@ -35,23 +33,6 @@ impl MerkleHasher for Blake2sHasher {
}
}
-impl MerkleOps for CPUBackend {
- fn commit_on_layer(
- log_size: u32,
- prev_layer: Option<&Vec>,
- columns: &[&Vec],
- ) -> Vec {
- (0..(1 << log_size))
- .map(|i| {
- Blake2sHasher::hash_node(
- prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
- &columns.iter().map(|column| column[i]).collect_vec(),
- )
- })
- .collect()
- }
-}
-
#[cfg(test)]
mod tests {
use itertools::Itertools;
diff --git a/src/commitment_scheme/mod.rs b/src/commitment_scheme/mod.rs
index 689295051..3f952f491 100644
--- a/src/commitment_scheme/mod.rs
+++ b/src/commitment_scheme/mod.rs
@@ -1,6 +1,5 @@
pub mod blake2_hash;
pub mod blake2_merkle;
-#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
pub mod blake2s_avx;
pub mod blake2s_ref;
pub mod blake3_hash;
diff --git a/src/core/backend/avx512/blake2s.rs b/src/core/backend/avx512/blake2s.rs
new file mode 100644
index 000000000..398634a2d
--- /dev/null
+++ b/src/core/backend/avx512/blake2s.rs
@@ -0,0 +1,87 @@
+use std::arch::x86_64::__m512i;
+
+use itertools::Itertools;
+
+use super::{AVX512Backend, VECS_LOG_SIZE};
+use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
+use crate::commitment_scheme::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states};
+use crate::commitment_scheme::ops::MerkleOps;
+use crate::core::backend::{Col, ColumnOps};
+use crate::core::fields::m31::BaseField;
+
+impl ColumnOps for AVX512Backend {
+ type Column = Vec;
+
+ fn bit_reverse_column(_column: &mut Self::Column) {
+ unimplemented!()
+ }
+}
+
+impl MerkleOps for AVX512Backend {
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Vec>,
+ columns: &[&Col],
+ ) -> Vec {
+ // Pad prev_layer if too small.
+ let mut padded_buffer = vec![];
+ let prev_layer = if log_size < 4 {
+ prev_layer.map(|prev_layer| {
+ padded_buffer = prev_layer
+ .iter()
+ .copied()
+ .chain(std::iter::repeat(Blake2sHash::default()))
+ .collect_vec();
+ &padded_buffer
+ })
+ } else {
+ prev_layer
+ };
+
+ // Commit to columns.
+ let mut res = Vec::with_capacity(1 << log_size);
+ for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) {
+ let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() };
+ // Hash prev_layer.
+ if let Some(prev_layer) = prev_layer {
+ let ptr = prev_layer[(i << 5)..(i << 5) + 32].as_ptr() as *const __m512i;
+ let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) });
+ state = unsafe {
+ compress16(
+ state,
+ transpose_msgs(msgs),
+ set1(0),
+ set1(0),
+ set1(0),
+ set1(0),
+ )
+ };
+ }
+
+ // Hash columns in chunks of 16.
+ let mut col_chunk_iter = columns.array_chunks();
+ for col_chunk in &mut col_chunk_iter {
+ let msgs = col_chunk.map(|column| column.data[i].0);
+ state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
+ }
+
+ // Hash remaining columns.
+ let remainder = col_chunk_iter.remainder();
+ if !remainder.is_empty() {
+ let msgs = remainder
+ .iter()
+ .map(|column| column.data[i].0)
+ .chain(std::iter::repeat(unsafe { set1(0) }))
+ .take(16)
+ .collect_vec()
+ .try_into()
+ .unwrap();
+ state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
+ }
+ let state: [Blake2sHash; 16] =
+ unsafe { std::mem::transmute(untranspose_states(state)) };
+ res.extend_from_slice(&state);
+ }
+ res
+ }
+}
diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs
index d1f9f3385..262d02b80 100644
--- a/src/core/backend/avx512/circle.rs
+++ b/src/core/backend/avx512/circle.rs
@@ -132,7 +132,10 @@ impl PolyOps for AVX512Backend {
) -> CircleEvaluation {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
- CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values))
+ CircleEvaluation::new(
+ eval.domain,
+ Col::::from_iter(eval.values),
+ )
}
fn interpolate(
diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs
index e84a2c02d..df0565296 100644
--- a/src/core/backend/avx512/mod.rs
+++ b/src/core/backend/avx512/mod.rs
@@ -1,4 +1,5 @@
pub mod bit_reverse;
+mod blake2s;
pub mod circle;
pub mod cm31;
pub mod fft;
@@ -157,7 +158,7 @@ mod tests {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::::from_iter((0..len).map(BaseField::from));
- B::bit_reverse_column(&mut col);
+ >::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
diff --git a/src/core/backend/cpu/blake2s.rs b/src/core/backend/cpu/blake2s.rs
new file mode 100644
index 000000000..5becdd55d
--- /dev/null
+++ b/src/core/backend/cpu/blake2s.rs
@@ -0,0 +1,23 @@
+use itertools::Itertools;
+
+use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
+use crate::commitment_scheme::ops::{MerkleHasher, MerkleOps};
+use crate::core::backend::CPUBackend;
+use crate::core::fields::m31::BaseField;
+
+impl MerkleOps for CPUBackend {
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Vec>,
+ columns: &[&Vec],
+ ) -> Vec {
+ (0..(1 << log_size))
+ .map(|i| {
+ Blake2sHasher::hash_node(
+ prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
+ &columns.iter().map(|column| column[i]).collect_vec(),
+ )
+ })
+ .collect()
+ }
+}
diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs
index 5927165d0..094e4bb54 100644
--- a/src/core/backend/cpu/mod.rs
+++ b/src/core/backend/cpu/mod.rs
@@ -1,3 +1,4 @@
+mod blake2s;
mod circle;
mod fri;
pub mod quotients;