Skip to content

Commit

Permalink
Looks like a sweetspot
Browse files Browse the repository at this point in the history
  • Loading branch information
ogxd committed Nov 12, 2024
1 parent b16db6c commit 32cd659
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 103 deletions.
104 changes: 4 additions & 100 deletions src/gxhash/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn gxhash(input: &[u8], seed: State) -> State {

// gxhash v4 flow:
// - Start with state = seed
// - Read partial first, regardless of input length. This way we don't duplicate the code for partial read, and skip the 0 len check.
// - Incorporate the partial read into the state
// - Then loop over the input in blocks of 128 bits
// - Finalize the hash with aes instructions for better diffusion and avalanche effect

let mut ptr = input.as_ptr() as *const State; // Do we need to check if valid slice?

let len = input.len();
Expand All @@ -97,22 +90,13 @@ pub(crate) unsafe fn gxhash(input: &[u8], seed: State) -> State {
break 'p0;
} else if lzcnt >= 60 {
break 'p1;
} else if lzcnt >= 56 {
} else if lzcnt >= 55 {
break 'p2;
}

// If we have less than 56 leading zeroes, it means length is at least 256 bits, or two vectors
let end_address = ptr.add((whole_vector_count / 2) * 2) as usize;
let mut lane1 = state;
let mut lane2 = state;
while (ptr as usize) < end_address {
crate::gxhash::load_unaligned!(ptr, v0, v1);
lane1 = aes_encrypt(lane1, v0);
lane2 = aes_encrypt(lane2, v1);
}
// Merge lanes
state = aes_encrypt(lane1, lane2);
whole_vector_count = whole_vector_count % 2;
state = compress_8(ptr, whole_vector_count, state, len);

whole_vector_count %= 8;
}

let end_address = ptr.add(whole_vector_count) as usize;
Expand All @@ -131,86 +115,6 @@ pub(crate) unsafe fn gxhash(input: &[u8], seed: State) -> State {
return finalize(state);
}

#[inline(always)]
pub(crate) unsafe fn compress_all(input: &[u8]) -> State {

let len = input.len();

if len == 0 {
return create_empty();
}

let mut ptr = input.as_ptr() as *const State;

if len <= VECTOR_SIZE {
// Input fits on a single SIMD vector, however we might read beyond the input message
// Thus we need this safe method that checks if it can safely read beyond or must copy
return get_partial(ptr, len);
}

let mut hash_vector: State;
let end = ptr as usize + len;

let extra_bytes_count = len % VECTOR_SIZE;
if extra_bytes_count == 0 {
load_unaligned!(ptr, v0);
hash_vector = v0;
} else {
// If the input length does not match the length of a whole number of SIMD vectors,
// it means we'll need to read a partial vector. We can start with the partial vector first,
// so that we can safely read beyond since we expect the following bytes to still be part of
// the input
hash_vector = get_partial_unsafe(ptr, extra_bytes_count);
ptr = ptr.cast::<u8>().add(extra_bytes_count).cast();
}

load_unaligned!(ptr, v0);

if len > VECTOR_SIZE * 2 {
// Fast path when input length > 32 and <= 48
load_unaligned!(ptr, v);
v0 = aes_encrypt(v0, v);

if len > VECTOR_SIZE * 3 {
// Fast path when input length > 48 and <= 64
load_unaligned!(ptr, v);
v0 = aes_encrypt(v0, v);

if len > VECTOR_SIZE * 4 {
// Input message is large and we can use the high ILP loop
hash_vector = compress_many(ptr, end, hash_vector, len);
}
}
}

return aes_encrypt_last(hash_vector,
aes_encrypt(aes_encrypt(v0, ld(KEYS.as_ptr())), ld(KEYS.as_ptr().offset(4))));
}

#[inline(always)]
unsafe fn compress_many(mut ptr: *const State, end: usize, hash_vector: State, len: usize) -> State {

const UNROLL_FACTOR: usize = 8;

let remaining_bytes = end - ptr as usize;

let unrollable_blocks_count: usize = remaining_bytes / (VECTOR_SIZE * UNROLL_FACTOR) * UNROLL_FACTOR;

let remaining_bytes = remaining_bytes - unrollable_blocks_count * VECTOR_SIZE;
let end_address = ptr.add(remaining_bytes / VECTOR_SIZE) as usize;

// Process first individual blocks until we have a whole number of 8 blocks
let mut hash_vector = hash_vector;
while (ptr as usize) < end_address {
load_unaligned!(ptr, v0);
hash_vector = aes_encrypt(hash_vector, v0);
}

// Process the remaining n * 8 blocks
// This part may use 128-bit or 256-bit
compress_8(ptr, end, hash_vector, len)
}

#[cfg(test)]
mod tests {

Expand Down
7 changes: 5 additions & 2 deletions src/gxhash/platform/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ pub unsafe fn ld(array: *const u32) -> State {
}

#[cfg(not(feature = "hybrid"))]
#[inline(always)]
pub unsafe fn compress_8(mut ptr: *const State, end_address: usize, hash_vector: State, len: usize) -> State {
#[inline(never)]
pub unsafe fn compress_8(mut ptr: *const State, whole_vector_count: usize, hash_vector: State, len: usize) -> State {

let end_address = ptr.add((whole_vector_count / 8) * 8) as usize;

// Disambiguation vectors
let mut t1: State = create_empty();
Expand Down Expand Up @@ -105,6 +107,7 @@ pub unsafe fn compress_8(mut ptr: *const State, end_address: usize, hash_vector:
let len_vec = _mm_set1_epi32(len as i32);
lane1 = _mm_add_epi8(lane1, len_vec);
lane2 = _mm_add_epi8(lane2, len_vec);

// Merge lanes
aes_encrypt(lane1, lane2)
}
Expand Down
2 changes: 1 addition & 1 deletion src/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl Hasher for GxHasher {
#[inline]
fn write(&mut self, bytes: &[u8]) {
// Improvement: only compress at this stage and finalize in finish
self.state = unsafe { aes_encrypt_last(compress_all(bytes), aes_encrypt(self.state, ld(KEYS.as_ptr()))) };
// self.state = unsafe { aes_encrypt_last(compress_all(bytes), aes_encrypt(self.state, ld(KEYS.as_ptr()))) };
}

write!(write_u8, u8, load_u8);
Expand Down

0 comments on commit 32cd659

Please sign in to comment.