From 36e62a134dccc2e2fe17452a41a84d467965640a Mon Sep 17 00:00:00 2001 From: Vivek Pandya Date: Fri, 26 Jan 2024 20:40:45 +0530 Subject: [PATCH] Use usize::BITS and wrapping_shr in reverse_index_bits_in_place_small (#1478) --- util/src/lib.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/util/src/lib.rs b/util/src/lib.rs index 6c8b2ed586..613bf6ef5a 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -110,9 +110,12 @@ fn reverse_index_bits_large(arr: &[T], n_power: usize) -> Vec { unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { if lb_n <= 6 { // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses. - let dst_shr_amt = 6 - lb_n; + let dst_shr_amt = 6 - lb_n as u32; for src in 0..arr.len() { - let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt); if src < dst { swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); } @@ -121,11 +124,14 @@ unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { // LLVM does not know that it does not need to reverse src at each iteration (which is // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high // bits of dst are dependent only on the low bits of src. - let dst_lo_shr_amt = 64 - (lb_n - 6); + let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32; let dst_hi_shl_amt = lb_n - 6; for src_chunk in 0..(arr.len() >> 6) { let src_hi = src_chunk << 6; - let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt); for src_lo in 0..(1 << 6) { let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; let src = src_hi + src_lo;