diff --git a/src/commitment_scheme/blake2_hash.rs b/src/commitment_scheme/blake2_hash.rs index 8b29bb96a..fb9dec73b 100644 --- a/src/commitment_scheme/blake2_hash.rs +++ b/src/commitment_scheme/blake2_hash.rs @@ -1,7 +1,11 @@ +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +use std::arch::x86_64::{__m512i, _mm512_loadu_si512}; use std::fmt; -use blake2::digest::{Update, VariableOutput}; -use blake2::{Blake2s256, Blake2sVar, Digest}; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +use super::blake2s_avx::{compress16_transposed, set1, transpose_msgs, transpose_states}; +use super::blake2s_ref; +use super::hasher::Hasher; // Wrapper for the blake2s hash type. #[derive(Clone, Copy, PartialEq, Default, Eq)] @@ -63,13 +67,21 @@ impl super::hasher::Name for Blake2sHash { impl super::hasher::Hash for Blake2sHash {} -// Wrapper for the blake2s Hashing functionalities. -#[derive(Clone, Debug, Default)] +/// Wrapper for the blake2s Hashing functionalities. +#[derive(Clone, Debug)] pub struct Blake2sHasher { - state: Blake2s256, + state: [u32; 8], + pending_buffer: [u8; 64], + pending_len: usize, } -impl super::hasher::Hasher for Blake2sHasher { +impl Default for Blake2sHasher { + fn default() -> Self { + Self::new() + } +} + +impl Hasher for Blake2sHasher { type Hash = Blake2sHash; const BLOCK_SIZE: usize = 64; const OUTPUT_SIZE: usize = 32; @@ -77,43 +89,135 @@ impl super::hasher::Hasher for Blake2sHasher { fn new() -> Self { Self { - state: Blake2s256::new(), + state: [0; 8], + pending_buffer: [0; 64], + pending_len: 0, } } fn reset(&mut self) { - blake2::Digest::reset(&mut self.state); + *self = Self::new(); } - fn update(&mut self, data: &[u8]) { - blake2::Digest::update(&mut self.state, data); + fn update(&mut self, mut data: &[u8]) { + while self.pending_len + data.len() >= 64 { + // Fill the buffer and compress. + let (prefix, rest) = data.split_at(64 - self.pending_len); + self.pending_buffer[self.pending_len..].copy_from_slice(prefix); + data = rest; + let words = + unsafe { std::mem::transmute::<&[u8; 64], &[u32; 16]>(&self.pending_buffer) }; + blake2s_ref::compress(&mut self.state, words, 0, 0, 0, 0); + self.pending_len = 0; + } + // Copy the remaining data into the buffer. + self.pending_buffer[self.pending_len..self.pending_len + data.len()].copy_from_slice(data); + self.pending_len += data.len(); } - fn finalize(self) -> Blake2sHash { - Blake2sHash(self.state.finalize().into()) + fn finalize(mut self) -> Blake2sHash { + if self.pending_len != 0 { + self.update(&[0; 64]); + } + Blake2sHash(unsafe { std::mem::transmute(self.state) }) } fn finalize_reset(&mut self) -> Blake2sHash { - Blake2sHash(self.state.finalize_reset().into()) + let hash = self.clone().finalize(); + self.reset(); + hash } + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] unsafe fn hash_many_in_place( data: &[*const u8], single_input_length_bytes: usize, dst: &[*mut u8], ) { - data.iter() - .map(|p| std::slice::from_raw_parts(*p, single_input_length_bytes)) - .zip( - dst.iter() - .map(|p| std::slice::from_raw_parts_mut(*p, Self::OUTPUT_SIZE)), - ) - .for_each(|(input, out)| { - let mut hasher = Blake2sVar::new(Self::OUTPUT_SIZE).unwrap(); - hasher.update(input); - hasher.finalize_variable(out).unwrap(); - }) + // TODO(spapini): this implementation assumes that dst are consecutive. + // Match that in the trait. + let mut dst = dst[0]; + + // Compress 16 instances at a time. + let mut data_iter = data.array_chunks::<16>(); + for inputs in &mut data_iter { + let bytes = compress16(inputs, single_input_length_bytes); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len()); + dst = dst.add(16 * 32); + } + // Handle the remainder. + let inputs = data_iter.remainder(); + if inputs.is_empty() { + return; + } + // Pad inputs with the same input address. + let remainder = inputs.len(); + let inputs = inputs + .iter() + .copied() + .chain(std::iter::repeat(inputs[0])) + .take(16) + .collect::>(); + let bytes = compress16(&inputs.try_into().unwrap(), single_input_length_bytes); + // Store only the relevant part of the output. + std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, remainder * 32); + } +} + +/// Compress 16 blake2s instances. +/// # Safety +/// Inputs must be of the size `single_input_length_bytes`. +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +unsafe fn compress16(inputs: &[*const u8; 16], single_input_length_bytes: usize) -> [u8; 16 * 32] { + let mut states = [set1(0); 8]; + let mut inputs = inputs.map(|input| input); + + // Compress in chunks of 64 bytes. + for _j in (0..single_input_length_bytes).step_by(64) { + // Load next 64 bytes of each of the 16 inputs, unaligned. + // TODO(spapini): Figure out if aligned loading is possible. + let chunk_per_input: [__m512i; 16] = + inputs.map(|input| _mm512_loadu_si512(input as *const i32)); + inputs = inputs.map(|input| input.add(64)); + + // Transpose the input chunks (16 chunks of 16 u32s each), to get 16 __m512i, each + // representing 16 packed instances of a message word. + let vectorized_chunk = transpose_msgs(chunk_per_input); + + // Compress the 16 instances. + compress16_transposed( + &mut states, + &vectorized_chunk, + set1(0), + set1(0), + set1(0), + set1(0), + ); } + + // Handle the remainder. + let remainder = single_input_length_bytes % 64; + if remainder != 0 { + // Load the remainder of each input, padded with 0s. + let mut words = [set1(0); 16]; + for (i, input) in inputs.into_iter().enumerate() { + let mut word = [0; 64]; + word[..remainder].copy_from_slice(std::slice::from_raw_parts(input, remainder)); + words[i] = _mm512_loadu_si512(word.as_ptr() as *const i32); + } + // Compress the 16 instances. + compress16_transposed( + &mut states, + &transpose_msgs(words), + set1(single_input_length_bytes as i32), + set1(0), + set1(0), + set1(0), + ); + } + + // Transpose the states, from 8 packed words, to get 16 results, each of size 32B. + transpose_states(states) } #[cfg(test)] @@ -127,7 +231,7 @@ mod tests { let hash_a = blake2_hash::Blake2sHasher::hash(b"a"); assert_eq!( hash_a.to_string(), - "4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e90" + "f2ab64ae6530f3a5d19369752cd30eadf455153c29dbf2cb70f00f73d5b41c50" ); } @@ -141,7 +245,7 @@ mod tests { let out_ptrs = [out.as_mut_ptr(), unsafe { out.as_mut_ptr().add(42) }]; unsafe { Blake2sHasher::hash_many_in_place(&input_arr, 1, &out_ptrs) }; - assert_eq!("4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e900000000000000000000004449e92c9a7657ef2d677b8ef9da46c088f13575ea887e4818fc455a2bca50000000000000000000000000000000000000000000000", hex::encode(out)); + assert_eq!("8e7b8823fa9ad8fb8b6e992849c2bbfa0bb1809c1b0666996d6c622ac1df197d85230cd8a7f7d2cd23e24497ac432193e8efa81ac6688f0b64efad1c53acaccf0000000000000000000000000000000000000000000000000000000000000000", hex::encode(out)); } #[test] diff --git a/src/commitment_scheme/blake2s_avx.rs b/src/commitment_scheme/blake2s_avx.rs new file mode 100644 index 000000000..fc5bc376e --- /dev/null +++ b/src/commitment_scheme/blake2s_avx.rs @@ -0,0 +1,300 @@ +//! An AVX512 implementation of the BLAKE2s compression function. +//! Based on https://github.com/oconnor663/blake2_simd/blob/master/blake2s/src/avx2.rs . + +use std::arch::x86_64::{ + __m512i, _mm512_add_epi32, _mm512_or_si512, _mm512_permutex2var_epi32, _mm512_set1_epi32, + _mm512_slli_epi32, _mm512_srli_epi32, _mm512_xor_si512, +}; + +const IV: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; + +const SIGMA: [[u8; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +/// # Safety +#[inline(always)] +pub unsafe fn set1(iv: i32) -> __m512i { + _mm512_set1_epi32(iv) +} + +#[inline(always)] +unsafe fn add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi32(a, b) +} + +#[inline(always)] +unsafe fn xor(a: __m512i, b: __m512i) -> __m512i { + _mm512_xor_si512(a, b) +} + +#[inline(always)] +unsafe fn rot16(x: __m512i) -> __m512i { + _mm512_or_si512(_mm512_srli_epi32(x, 16), _mm512_slli_epi32(x, 32 - 16)) +} + +#[inline(always)] +unsafe fn rot12(x: __m512i) -> __m512i { + _mm512_or_si512(_mm512_srli_epi32(x, 12), _mm512_slli_epi32(x, 32 - 12)) +} + +#[inline(always)] +unsafe fn rot8(x: __m512i) -> __m512i { + _mm512_or_si512(_mm512_srli_epi32(x, 8), _mm512_slli_epi32(x, 32 - 8)) +} + +#[inline(always)] +unsafe fn rot7(x: __m512i) -> __m512i { + _mm512_or_si512(_mm512_srli_epi32(x, 7), _mm512_slli_epi32(x, 32 - 7)) +} + +#[inline(always)] +unsafe fn round(v: &mut [__m512i; 16], m: &[__m512i; 16], r: usize) { + v[0] = add(v[0], m[SIGMA[r][0] as usize]); + v[1] = add(v[1], m[SIGMA[r][2] as usize]); + v[2] = add(v[2], m[SIGMA[r][4] as usize]); + v[3] = add(v[3], m[SIGMA[r][6] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[15] = rot16(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot12(v[4]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[0] = add(v[0], m[SIGMA[r][1] as usize]); + v[1] = add(v[1], m[SIGMA[r][3] as usize]); + v[2] = add(v[2], m[SIGMA[r][5] as usize]); + v[3] = add(v[3], m[SIGMA[r][7] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[15] = rot8(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot7(v[4]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + + v[0] = add(v[0], m[SIGMA[r][8] as usize]); + v[1] = add(v[1], m[SIGMA[r][10] as usize]); + v[2] = add(v[2], m[SIGMA[r][12] as usize]); + v[3] = add(v[3], m[SIGMA[r][14] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot16(v[15]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[4] = rot12(v[4]); + v[0] = add(v[0], m[SIGMA[r][9] as usize]); + v[1] = add(v[1], m[SIGMA[r][11] as usize]); + v[2] = add(v[2], m[SIGMA[r][13] as usize]); + v[3] = add(v[3], m[SIGMA[r][15] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot8(v[15]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + v[4] = rot7(v[4]); +} + +/// Transposes input chunks (16 chunks of 16 u32s each), to get 16 __m512i, each +/// representing 16 packed instances of a message word. +/// # Safety +pub unsafe fn transpose_msgs(mut data: [__m512i; 16]) -> [__m512i; 16] { + const EVENS_CONCAT_EVENS: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, + 0b10010, 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, + ]) + }; + const ODDS_CONCAT_ODDS: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, + 0b10011, 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, + ]) + }; + // Transpose by applying 4 times the index permutation: + // abcd:0123 => 3abc:d012 + for _ in 0..4 { + data = [ + _mm512_permutex2var_epi32(data[0], EVENS_CONCAT_EVENS, data[1]), + _mm512_permutex2var_epi32(data[2], EVENS_CONCAT_EVENS, data[3]), + _mm512_permutex2var_epi32(data[4], EVENS_CONCAT_EVENS, data[5]), + _mm512_permutex2var_epi32(data[6], EVENS_CONCAT_EVENS, data[7]), + _mm512_permutex2var_epi32(data[8], EVENS_CONCAT_EVENS, data[9]), + _mm512_permutex2var_epi32(data[10], EVENS_CONCAT_EVENS, data[11]), + _mm512_permutex2var_epi32(data[12], EVENS_CONCAT_EVENS, data[13]), + _mm512_permutex2var_epi32(data[14], EVENS_CONCAT_EVENS, data[15]), + _mm512_permutex2var_epi32(data[0], ODDS_CONCAT_ODDS, data[1]), + _mm512_permutex2var_epi32(data[2], ODDS_CONCAT_ODDS, data[3]), + _mm512_permutex2var_epi32(data[4], ODDS_CONCAT_ODDS, data[5]), + _mm512_permutex2var_epi32(data[6], ODDS_CONCAT_ODDS, data[7]), + _mm512_permutex2var_epi32(data[8], ODDS_CONCAT_ODDS, data[9]), + _mm512_permutex2var_epi32(data[10], ODDS_CONCAT_ODDS, data[11]), + _mm512_permutex2var_epi32(data[12], ODDS_CONCAT_ODDS, data[13]), + _mm512_permutex2var_epi32(data[14], ODDS_CONCAT_ODDS, data[15]), + ]; + } + data +} + +/// Transposes states, from 8 packed words, to get 16 results, each of size 32B. +/// # Safety +pub unsafe fn transpose_states(mut states: [__m512i; 8]) -> [u8; 16 * 32] { + // Transpose by applying 3 times the index permutation: + // 012:abcd => 12a:bcd0 + const LHALF_INTERLEAVE_LHALF: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, + 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, + ]) + }; + const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { + core::mem::transmute([ + 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, + 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, + ]) + }; + for _ in 0..3 { + states = [ + _mm512_permutex2var_epi32(states[0], LHALF_INTERLEAVE_LHALF, states[4]), + _mm512_permutex2var_epi32(states[0], HHALF_INTERLEAVE_HHALF, states[4]), + _mm512_permutex2var_epi32(states[1], LHALF_INTERLEAVE_LHALF, states[5]), + _mm512_permutex2var_epi32(states[1], HHALF_INTERLEAVE_HHALF, states[5]), + _mm512_permutex2var_epi32(states[2], LHALF_INTERLEAVE_LHALF, states[6]), + _mm512_permutex2var_epi32(states[2], HHALF_INTERLEAVE_HHALF, states[6]), + _mm512_permutex2var_epi32(states[3], LHALF_INTERLEAVE_LHALF, states[7]), + _mm512_permutex2var_epi32(states[3], HHALF_INTERLEAVE_HHALF, states[7]), + ]; + } + std::mem::transmute(states) +} + +/// Compress 16 blake2s instances. +/// # Safety +pub unsafe fn compress16_transposed( + h_vecs: &mut [__m512i; 8], + msg_vecs: &[__m512i; 16], + count_low: __m512i, + count_high: __m512i, + lastblock: __m512i, + lastnode: __m512i, +) { + let mut v = [ + h_vecs[0], + h_vecs[1], + h_vecs[2], + h_vecs[3], + h_vecs[4], + h_vecs[5], + h_vecs[6], + h_vecs[7], + set1(IV[0] as i32), + set1(IV[1] as i32), + set1(IV[2] as i32), + set1(IV[3] as i32), + xor(set1(IV[4] as i32), count_low), + xor(set1(IV[5] as i32), count_high), + xor(set1(IV[6] as i32), lastblock), + xor(set1(IV[7] as i32), lastnode), + ]; + + round(&mut v, msg_vecs, 0); + round(&mut v, msg_vecs, 1); + round(&mut v, msg_vecs, 2); + round(&mut v, msg_vecs, 3); + round(&mut v, msg_vecs, 4); + round(&mut v, msg_vecs, 5); + round(&mut v, msg_vecs, 6); + round(&mut v, msg_vecs, 7); + round(&mut v, msg_vecs, 8); + round(&mut v, msg_vecs, 9); + + h_vecs[0] = xor(xor(h_vecs[0], v[0]), v[8]); + h_vecs[1] = xor(xor(h_vecs[1], v[1]), v[9]); + h_vecs[2] = xor(xor(h_vecs[2], v[2]), v[10]); + h_vecs[3] = xor(xor(h_vecs[3], v[3]), v[11]); + h_vecs[4] = xor(xor(h_vecs[4], v[4]), v[12]); + h_vecs[5] = xor(xor(h_vecs[5], v[5]), v[13]); + h_vecs[6] = xor(xor(h_vecs[6], v[6]), v[14]); + h_vecs[7] = xor(xor(h_vecs[7], v[7]), v[15]); +} diff --git a/src/commitment_scheme/blake2s_ref.rs b/src/commitment_scheme/blake2s_ref.rs new file mode 100644 index 000000000..1f3853a55 --- /dev/null +++ b/src/commitment_scheme/blake2s_ref.rs @@ -0,0 +1,215 @@ +//! An AVX512 implementation of the BLAKE2s compression function. +//! Based on https://github.com/oconnor663/blake2_simd/blob/master/blake2s/src/avx2.rs . + +pub const IV: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; + +pub const SIGMA: [[u8; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +#[inline(always)] +fn add(a: u32, b: u32) -> u32 { + a.wrapping_add(b) +} + +#[inline(always)] +fn xor(a: u32, b: u32) -> u32 { + a ^ b +} + +#[inline(always)] +fn rot16(x: u32) -> u32 { + (x >> 16) | (x << (32 - 16)) +} + +#[inline(always)] +fn rot12(x: u32) -> u32 { + (x >> 12) | (x << (32 - 12)) +} + +#[inline(always)] +fn rot8(x: u32) -> u32 { + (x >> 8) | (x << (32 - 8)) +} + +#[inline(always)] +fn rot7(x: u32) -> u32 { + (x >> 7) | (x << (32 - 7)) +} + +#[inline(always)] +fn round(v: &mut [u32; 16], m: &[u32; 16], r: usize) { + v[0] = add(v[0], m[SIGMA[r][0] as usize]); + v[1] = add(v[1], m[SIGMA[r][2] as usize]); + v[2] = add(v[2], m[SIGMA[r][4] as usize]); + v[3] = add(v[3], m[SIGMA[r][6] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[15] = rot16(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot12(v[4]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[0] = add(v[0], m[SIGMA[r][1] as usize]); + v[1] = add(v[1], m[SIGMA[r][3] as usize]); + v[2] = add(v[2], m[SIGMA[r][5] as usize]); + v[3] = add(v[3], m[SIGMA[r][7] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[15] = rot8(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot7(v[4]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + + v[0] = add(v[0], m[SIGMA[r][8] as usize]); + v[1] = add(v[1], m[SIGMA[r][10] as usize]); + v[2] = add(v[2], m[SIGMA[r][12] as usize]); + v[3] = add(v[3], m[SIGMA[r][14] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot16(v[15]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[4] = rot12(v[4]); + v[0] = add(v[0], m[SIGMA[r][9] as usize]); + v[1] = add(v[1], m[SIGMA[r][11] as usize]); + v[2] = add(v[2], m[SIGMA[r][13] as usize]); + v[3] = add(v[3], m[SIGMA[r][15] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot8(v[15]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + v[4] = rot7(v[4]); +} + +/// Performs a Blake2s compression. +pub fn compress( + h_vecs: &mut [u32; 8], + msg_vecs: &[u32; 16], + count_low: u32, + count_high: u32, + lastblock: u32, + lastnode: u32, +) { + let mut v = [ + h_vecs[0], + h_vecs[1], + h_vecs[2], + h_vecs[3], + h_vecs[4], + h_vecs[5], + h_vecs[6], + h_vecs[7], + IV[0], + IV[1], + IV[2], + IV[3], + xor(IV[4], count_low), + xor(IV[5], count_high), + xor(IV[6], lastblock), + xor(IV[7], lastnode), + ]; + + round(&mut v, msg_vecs, 0); + round(&mut v, msg_vecs, 1); + round(&mut v, msg_vecs, 2); + round(&mut v, msg_vecs, 3); + round(&mut v, msg_vecs, 4); + round(&mut v, msg_vecs, 5); + round(&mut v, msg_vecs, 6); + round(&mut v, msg_vecs, 7); + round(&mut v, msg_vecs, 8); + round(&mut v, msg_vecs, 9); + + h_vecs[0] = xor(xor(h_vecs[0], v[0]), v[8]); + h_vecs[1] = xor(xor(h_vecs[1], v[1]), v[9]); + h_vecs[2] = xor(xor(h_vecs[2], v[2]), v[10]); + h_vecs[3] = xor(xor(h_vecs[3], v[3]), v[11]); + h_vecs[4] = xor(xor(h_vecs[4], v[4]), v[12]); + h_vecs[5] = xor(xor(h_vecs[5], v[5]), v[13]); + h_vecs[6] = xor(xor(h_vecs[6], v[6]), v[14]); + h_vecs[7] = xor(xor(h_vecs[7], v[7]), v[15]); +} diff --git a/src/commitment_scheme/hasher.rs b/src/commitment_scheme/hasher.rs index caf527a19..5b871017b 100644 --- a/src/commitment_scheme/hasher.rs +++ b/src/commitment_scheme/hasher.rs @@ -24,7 +24,7 @@ pub trait Name { pub trait Hasher: Sized + Default { type Hash: Hash; - type NativeType: Sized + Eq; + type NativeType: Sized + Eq + Clone; // Input size of the compression function. const BLOCK_SIZE: usize; @@ -65,7 +65,19 @@ pub trait Hasher: Sized + Default { data: &[*const Self::NativeType], single_input_length_bytes: usize, dst: &[*mut Self::NativeType], - ); + ) { + data.iter() + .map(|p| std::slice::from_raw_parts(*p, single_input_length_bytes)) + .zip( + dst.iter() + .map(|p| std::slice::from_raw_parts_mut(*p, Self::OUTPUT_SIZE)), + ) + .for_each(|(input, out)| { + let mut hasher = Self::new(); + hasher.update(input); + out.clone_from_slice(&hasher.finalize().into()); + }) + } } pub trait Hash: diff --git a/src/commitment_scheme/mod.rs b/src/commitment_scheme/mod.rs index 68aa9653a..688f2c8e1 100644 --- a/src/commitment_scheme/mod.rs +++ b/src/commitment_scheme/mod.rs @@ -1,4 +1,7 @@ pub mod blake2_hash; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub mod blake2s_avx; +pub mod blake2s_ref; pub mod blake3_hash; pub mod hasher; pub mod merkle_decommitment; diff --git a/src/core/proof_of_work.rs b/src/core/proof_of_work.rs index 8484db026..14c555f5c 100644 --- a/src/core/proof_of_work.rs +++ b/src/core/proof_of_work.rs @@ -96,19 +96,10 @@ mod tests { use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::proof_of_work::{ProofOfWork, ProofOfWorkProof}; - #[test] - fn test_verify_proof_of_work_success() { - let mut channel = Blake2sChannel::new(Blake2sHash::from(vec![0; 32])); - let proof_of_work_prover = ProofOfWork { n_bits: 11 }; - let proof = ProofOfWorkProof { nonce: 133 }; - - proof_of_work_prover.verify(&mut channel, &proof).unwrap(); - } - #[test] fn test_verify_proof_of_work_fail() { let mut channel = Blake2sChannel::new(Blake2sHash::from(vec![0; 32])); - let proof_of_work_prover = ProofOfWork { n_bits: 1 }; + let proof_of_work_prover = ProofOfWork { n_bits: 2 }; let invalid_proof = ProofOfWorkProof { nonce: 0 }; proof_of_work_prover