Skip to content

Commit

Permalink
avx blake
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 21, 2024
1 parent 3524e6b commit ad09eca
Show file tree
Hide file tree
Showing 6 changed files with 663 additions and 38 deletions.
156 changes: 130 additions & 26 deletions src/commitment_scheme/blake2_hash.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
use std::arch::x86_64::{__m512i, _mm512_loadu_si512};
use std::fmt;

use blake2::digest::{Update, VariableOutput};
use blake2::{Blake2s256, Blake2sVar, Digest};
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
use super::blake2s_avx::{compress16_transposed, set1, transpose_msgs, transpose_states};
use super::blake2s_ref;
use super::hasher::Hasher;

// Wrapper for the blake2s hash type.
#[derive(Clone, Copy, PartialEq, Default, Eq)]
Expand Down Expand Up @@ -63,57 +67,157 @@ impl super::hasher::Name for Blake2sHash {

impl super::hasher::Hash<u8> for Blake2sHash {}

// Wrapper for the blake2s Hashing functionalities.
#[derive(Clone, Debug, Default)]
/// Wrapper for the blake2s Hashing functionalities.
#[derive(Clone, Debug)]
pub struct Blake2sHasher {
state: Blake2s256,
state: [u32; 8],
pending_buffer: [u8; 64],
pending_len: usize,
}

impl super::hasher::Hasher for Blake2sHasher {
impl Default for Blake2sHasher {
fn default() -> Self {
Self::new()
}
}

impl Hasher for Blake2sHasher {
type Hash = Blake2sHash;
const BLOCK_SIZE: usize = 64;
const OUTPUT_SIZE: usize = 32;
type NativeType = u8;

fn new() -> Self {
Self {
state: Blake2s256::new(),
state: [0; 8],
pending_buffer: [0; 64],
pending_len: 0,
}
}

fn reset(&mut self) {
blake2::Digest::reset(&mut self.state);
*self = Self::new();
}

fn update(&mut self, data: &[u8]) {
blake2::Digest::update(&mut self.state, data);
fn update(&mut self, mut data: &[u8]) {
while self.pending_len + data.len() >= 64 {
// Fill the buffer and compress.
let (prefix, rest) = data.split_at(64 - self.pending_len);
self.pending_buffer[self.pending_len..].copy_from_slice(prefix);
data = rest;
let words =
unsafe { std::mem::transmute::<&[u8; 64], &[u32; 16]>(&self.pending_buffer) };
blake2s_ref::compress(&mut self.state, words, 0, 0, 0, 0);
self.pending_len = 0;
}
// Copy the remaining data into the buffer.
self.pending_buffer[self.pending_len..self.pending_len + data.len()].copy_from_slice(data);
self.pending_len += data.len();
}

fn finalize(self) -> Blake2sHash {
Blake2sHash(self.state.finalize().into())
fn finalize(mut self) -> Blake2sHash {
if self.pending_len != 0 {
self.update(&[0; 64]);
}
Blake2sHash(unsafe { std::mem::transmute(self.state) })
}

fn finalize_reset(&mut self) -> Blake2sHash {
Blake2sHash(self.state.finalize_reset().into())
let hash = self.clone().finalize();
self.reset();
hash
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
unsafe fn hash_many_in_place(
data: &[*const u8],
single_input_length_bytes: usize,
dst: &[*mut u8],
) {
data.iter()
.map(|p| std::slice::from_raw_parts(*p, single_input_length_bytes))
.zip(
dst.iter()
.map(|p| std::slice::from_raw_parts_mut(*p, Self::OUTPUT_SIZE)),
)
.for_each(|(input, out)| {
let mut hasher = Blake2sVar::new(Self::OUTPUT_SIZE).unwrap();
hasher.update(input);
hasher.finalize_variable(out).unwrap();
})
// TODO(spapini): this implementation assumes that dst are consecutive.
// Match that in the trait.
let mut dst = dst[0];

// Compress 16 instances at a time.
let mut data_iter = data.array_chunks::<16>();
for inputs in &mut data_iter {
let bytes = compress16(inputs, single_input_length_bytes);
std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len());
dst = dst.add(16 * 32);
}
// Handle the remainder.
let inputs = data_iter.remainder();
if inputs.is_empty() {
return;
}
// Pad inputs with the same input address.
let remainder = inputs.len();
let inputs = inputs
.iter()
.copied()
.chain(std::iter::repeat(inputs[0]))
.take(16)
.collect::<Vec<_>>();
let bytes = compress16(&inputs.try_into().unwrap(), single_input_length_bytes);
// Store only the relevant part of the output.
std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, remainder * 32);
}
}

/// Compress 16 blake2s instances.
/// # Safety
/// Inputs must be of the size `single_input_length_bytes`.
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
unsafe fn compress16(inputs: &[*const u8; 16], single_input_length_bytes: usize) -> [u8; 16 * 32] {
let mut states = [set1(0); 8];
let mut inputs = inputs.map(|input| input);

// Compress in chunks of 64 bytes.
for _j in (0..single_input_length_bytes).step_by(64) {
// Load next 64 bytes of each of the 16 inputs, unaligned.
// TODO(spapini): Figure out if aligned loading is possible.
let chunk_per_input: [__m512i; 16] =
inputs.map(|input| _mm512_loadu_si512(input as *const i32));
inputs = inputs.map(|input| input.add(64));

// Transpose the input chunks (16 chunks of 16 u32s each), to get 16 __m512i, each
// representing 16 packed instances of a message word.
let vectorized_chunk = transpose_msgs(chunk_per_input);

// Compress the 16 instances.
compress16_transposed(
&mut states,
&vectorized_chunk,
set1(0),
set1(0),
set1(0),
set1(0),
);
}

// Handle the remainder.
let remainder = single_input_length_bytes % 64;
if remainder != 0 {
// Load the remainder of each input, padded with 0s.
let mut words = [set1(0); 16];
for (i, input) in inputs.into_iter().enumerate() {
let mut word = [0; 64];
word[..remainder].copy_from_slice(std::slice::from_raw_parts(input, remainder));
words[i] = _mm512_loadu_si512(word.as_ptr() as *const i32);
}
// Compress the 16 instances.
compress16_transposed(
&mut states,
&transpose_msgs(words),
set1(single_input_length_bytes as i32),
set1(0),
set1(0),
set1(0),
);
}

// Transpose the states, from 8 packed words, to get 16 results, each of size 32B.
transpose_states(states)
}

#[cfg(test)]
Expand All @@ -127,7 +231,7 @@ mod tests {
let hash_a = blake2_hash::Blake2sHasher::hash(b"a");
assert_eq!(
hash_a.to_string(),
"4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e90"
"f2ab64ae6530f3a5d19369752cd30eadf455153c29dbf2cb70f00f73d5b41c50"
);
}

Expand All @@ -141,7 +245,7 @@ mod tests {
let out_ptrs = [out.as_mut_ptr(), unsafe { out.as_mut_ptr().add(42) }];
unsafe { Blake2sHasher::hash_many_in_place(&input_arr, 1, &out_ptrs) };

assert_eq!("4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e900000000000000000000004449e92c9a7657ef2d677b8ef9da46c088f13575ea887e4818fc455a2bca50000000000000000000000000000000000000000000000", hex::encode(out));
assert_eq!("8e7b8823fa9ad8fb8b6e992849c2bbfa0bb1809c1b0666996d6c622ac1df197d85230cd8a7f7d2cd23e24497ac432193e8efa81ac6688f0b64efad1c53acaccf0000000000000000000000000000000000000000000000000000000000000000", hex::encode(out));
}

#[test]
Expand Down
Loading

0 comments on commit ad09eca

Please sign in to comment.