diff --git a/benches/merkle.rs b/benches/merkle.rs index c465f4c93..c2a82e619 100644 --- a/benches/merkle.rs +++ b/benches/merkle.rs @@ -7,11 +7,13 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) { use itertools::Itertools; use num_traits::Zero; use stwo::commitment_scheme::ops::MerkleOps; - use stwo::core::backend::CPUBackend; + use stwo::core::backend::avx512::AVX512Backend; + use stwo::core::backend::{CPUBackend, Col}; use stwo::core::fields::m31::BaseField; + use stwo::platform; const N_COLS: usize = 1 << 8; - const LOG_SIZE: u32 = 20; + const LOG_SIZE: u32 = 16; let cols = (0..N_COLS) .map(|_| { (0..(1 << LOG_SIZE)) @@ -30,6 +32,23 @@ pub fn cpu_merkle(c: &mut criterion::Criterion) { CPUBackend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec()); }) }); + + if !platform::avx512_detected() { + return; + } + let cols = (0..N_COLS) + .map(|_| { + (0..(1 << LOG_SIZE)) + .map(|_| BaseField::zero()) + .collect::>() + }) + .collect::>(); + + group.bench_function("avx merkle", |b| { + b.iter(|| { + AVX512Backend::commit_on_layer(LOG_SIZE, None, &cols.iter().collect_vec()); + }) + }); } #[cfg(target_arch = "x86_64")] diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs index 4185bf08e..e4b1e48df 100644 --- a/src/commitment_scheme/blake2_merkle.rs +++ b/src/commitment_scheme/blake2_merkle.rs @@ -1,10 +1,8 @@ -use itertools::Itertools; use num_traits::Zero; use super::blake2_hash::Blake2sHash; use super::blake2s_ref::compress; -use super::ops::{MerkleHasher, MerkleOps}; -use crate::core::backend::CPUBackend; +use super::ops::MerkleHasher; use crate::core::fields::m31::BaseField; #[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] @@ -39,23 +37,6 @@ impl MerkleHasher for Blake2sMerkleHasher { } } -impl MerkleOps for CPUBackend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Vec], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - Blake2sMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column[i]).collect_vec(), - ) - }) - .collect() - } -} - #[cfg(test)] mod tests { use std::collections::BTreeMap; diff --git a/src/core/backend/avx512/blake2s.rs b/src/core/backend/avx512/blake2s.rs new file mode 100644 index 000000000..23329c52c --- /dev/null +++ b/src/core/backend/avx512/blake2s.rs @@ -0,0 +1,88 @@ +use std::arch::x86_64::__m512i; + +use itertools::Itertools; + +use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states}; +use super::{AVX512Backend, VECS_LOG_SIZE}; +use crate::commitment_scheme::blake2_hash::Blake2sHash; +use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher; +use crate::commitment_scheme::ops::MerkleOps; +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; + +impl ColumnOps for AVX512Backend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for AVX512Backend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + // Pad prev_layer if too small. + let mut padded_buffer = vec![]; + let prev_layer = if log_size < 4 { + prev_layer.map(|prev_layer| { + padded_buffer = prev_layer + .iter() + .copied() + .chain(std::iter::repeat(Blake2sHash::default())) + .collect_vec(); + &padded_buffer + }) + } else { + prev_layer + }; + + // Commit to columns. + let mut res = Vec::with_capacity(1 << log_size); + for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) { + let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() }; + // Hash prev_layer. + if let Some(prev_layer) = prev_layer { + let ptr = prev_layer[(i << 5)..(i << 5) + 32].as_ptr() as *const __m512i; + let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) }); + state = unsafe { + compress16( + state, + transpose_msgs(msgs), + set1(0), + set1(0), + set1(0), + set1(0), + ) + }; + } + + // Hash columns in chunks of 16. + let mut col_chunk_iter = columns.array_chunks(); + for col_chunk in &mut col_chunk_iter { + let msgs = col_chunk.map(|column| column.data[i].0); + state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; + } + + // Hash remaining columns. + let remainder = col_chunk_iter.remainder(); + if !remainder.is_empty() { + let msgs = remainder + .iter() + .map(|column| column.data[i].0) + .chain(std::iter::repeat(unsafe { set1(0) })) + .take(16) + .collect_vec() + .try_into() + .unwrap(); + state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; + } + let state: [Blake2sHash; 16] = + unsafe { std::mem::transmute(untranspose_states(state)) }; + res.extend_from_slice(&state); + } + res + } +} diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs index b67ba2e60..d9dbc8d23 100644 --- a/src/core/backend/avx512/circle.rs +++ b/src/core/backend/avx512/circle.rs @@ -132,7 +132,10 @@ impl PolyOps for AVX512Backend { ) -> CircleEvaluation { // TODO(spapini): Optimize. let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values)); - CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values)) + CircleEvaluation::new( + eval.domain, + Col::::from_iter(eval.values), + ) } fn interpolate( diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 4e1ce749b..016bae5ac 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1,4 +1,5 @@ pub mod bit_reverse; +mod blake2s; pub mod blake2s_avx; pub mod circle; pub mod cm31; @@ -202,7 +203,7 @@ mod tests { for i in 1..16 { let len = 1 << i; let mut col = Col::::from_iter((0..len).map(BaseField::from)); - B::bit_reverse_column(&mut col); + >::bit_reverse_column(&mut col); assert_eq!( col.to_vec(), (0..len) diff --git a/src/core/backend/cpu/blake2s.rs b/src/core/backend/cpu/blake2s.rs new file mode 100644 index 000000000..b1fb44c1c --- /dev/null +++ b/src/core/backend/cpu/blake2s.rs @@ -0,0 +1,24 @@ +use itertools::Itertools; + +use crate::commitment_scheme::blake2_hash::Blake2sHash; +use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher; +use crate::commitment_scheme::ops::{MerkleHasher, MerkleOps}; +use crate::core::backend::CPUBackend; +use crate::core::fields::m31::BaseField; + +impl MerkleOps for CPUBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Vec], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + Blake2sMerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs index 5927165d0..094e4bb54 100644 --- a/src/core/backend/cpu/mod.rs +++ b/src/core/backend/cpu/mod.rs @@ -1,3 +1,4 @@ +mod blake2s; mod circle; mod fri; pub mod quotients;