Skip to content

Commit

Permalink
encode integers using leb128
Browse files Browse the repository at this point in the history
  • Loading branch information
noib3 committed Jul 4, 2024
1 parent 35ca41b commit 80136e1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 195 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ features = ["serde"]
rustdoc-args = ["--cfg", "docsrs"]

[features]
encode = ["dep:sha2"]
encode = ["dep:sha2", "dep:varint-simd"]
serde = ["encode", "dep:serde"]

[dependencies]
serde = { version = "1.0", optional = true }
sha2 = { version = "0.10", optional = true }
varint-simd = { version = "0.4", optional = true }

[dev-dependencies]
bincode = "1.3"
Expand Down
224 changes: 30 additions & 194 deletions src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,77 +79,6 @@ impl core::fmt::Display for BoolDecodeError {
}
}

// When encoding integers we use a variable-length encoding scheme that aims to
// minimize the number of bytes used to encode small integers, where small is
// approx < 2 ^ 16.
//
// There are 3 separate branches based on how big the integer is:
//
// - if the integer is between 0 and 63, we can encode it with 6 bits. In this
// case we use the lower 6 bits of the first byte, and set the 2 highest bits
// to `01`;
//
// - if the integer is between 64 and 2 ^ 15 - 1, we can encode it with 15
// bits. In this case we use the lower 7 bits of the first byte and all the
// bits of the second byte. The highest bit of the first byte is set to `1`;
//
// - if the integer is greater than 2 ^ 15 - 1, we use the first byte to encode
// the number of bytes used to encode the integer, and then we encode the
// integer itself in little endian, throwing away any trailing zeros.
//
// With this scheme we can encode integers up to 63 with 1 byte, integers up
// to 2 ^ 15 - 1 with 2 bytes, and integers greater than 2 ^ 15 - 1 with 3
// or more bytes.

const ENCODE_ONE_BYTE_MASK: u8 = 0b0100_0000;

const ENCODE_TWO_BYTES_MASK: u8 = 0b1000_0000;

const LAST_BIT_MASK: u8 = 0b1000_0000;

#[inline(always)]
fn encode_one_byte(int: u8) -> u8 {
debug_assert!(int < 1 << 6);
int | ENCODE_ONE_BYTE_MASK
}

#[inline(always)]
fn decode_one_byte(int: u8) -> u8 {
int & !ENCODE_ONE_BYTE_MASK
}

#[inline(always)]
fn encode_two_bytes(int: u16) -> (u8, u8) {
debug_assert!((1 << 6..1 << 15).contains(&int));

let [mut lo, mut hi] = int.to_le_bytes();

// Move the last bit of the low byte to the last bit of the high byte.
//
// We know this doesn't lose any information because the int is less than
// 2 ^ 15, so the last bit of the high byte is 0.
hi |= lo & LAST_BIT_MASK;

// Set the last bit of the low byte to 1 to indicate that this number is
// encoded with 2 bytes.
lo |= ENCODE_TWO_BYTES_MASK;

(lo, hi)
}

#[inline(always)]
fn decode_two_bytes(mut lo: u8, mut hi: u8) -> u16 {
lo &= !ENCODE_TWO_BYTES_MASK;

// Move the last bit of the high byte to the last bit of the low byte.
lo |= hi & LAST_BIT_MASK;

// Reset the last bit of the high byte to 0.
hi &= !LAST_BIT_MASK;

u16::from_le_bytes([lo, hi])
}

impl_int_encode!(u16);
impl_int_encode!(u32);
impl_int_encode!(u64);
Expand Down Expand Up @@ -181,35 +110,8 @@ macro_rules! impl_int_encode {
impl Encode for $ty {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
let int = *self;

if int < 1 << 6 {
buf.push(encode_one_byte(int as u8));
return;
} else if int < 1 << 15 {
let (first, second) = encode_two_bytes(int as u16);
buf.push(first);
buf.push(second);
return;
}

let array = int.to_le_bytes();

let num_trailing_zeros = array
.iter()
.rev()
.copied()
.take_while(|&byte| byte == 0)
.count();

let len = array.len() - num_trailing_zeros;

// Make sure that the first 2 bits are 0.
debug_assert_eq!(len & 0b1100_0000, 0);

buf.push(len as u8);

buf.extend_from_slice(&array[..len]);
let (array, len) = varint_simd::encode(*self);
buf.extend_from_slice(&array[..len as usize]);
}
}
};
Expand All @@ -226,45 +128,18 @@ macro_rules! impl_int_decode {

#[inline]
fn decode(buf: &[u8]) -> Result<($ty, &[u8]), Self::Error> {
let (&first, buf) =
buf.split_first().ok_or(IntDecodeError::EmptyBuffer)?;

if first & ENCODE_TWO_BYTES_MASK != 0 {
let lo = first;

let (&hi, buf) = buf.split_first().ok_or(
IntDecodeError::LengthLessThanPrefix {
prefix: 2,
actual: 1,
},
)?;

let int = decode_two_bytes(lo, hi) as $ty;

return Ok((int, buf));
} else if first & ENCODE_ONE_BYTE_MASK != 0 {
let int = decode_one_byte(first) as $ty;
return Ok((int, buf));
}

let len = first;

if len as usize > buf.len() {
return Err(IntDecodeError::LengthLessThanPrefix {
prefix: len,
actual: buf.len() as u8,
});
}

let mut array = [0u8; ::core::mem::size_of::<$ty>()];

let (bytes, buf) = buf.split_at(len as usize);

array[..bytes.len()].copy_from_slice(bytes);

let int = <$ty>::from_le_bytes(array);

Ok((int, buf))
let (decoded, len) = varint_simd::decode::<Self>(buf)
.map_err(IntDecodeError)?;

// TODO: this check shouldn't be necessary, `decode` should
// fail. Open an issue.
let Some(rest) = buf.get(len as usize..) else {
return Err(IntDecodeError(
varint_simd::VarIntDecodeError::NotEnoughBytes,
));
};

Ok((decoded, rest))
}
}
};
Expand All @@ -273,32 +148,12 @@ macro_rules! impl_int_decode {
use impl_int_decode;

/// An error that can occur when decoding an [`Int`].
#[cfg_attr(test, derive(PartialEq, Eq))]
pub(crate) enum IntDecodeError {
/// The buffer passed to `Int::decode` is empty. This is always an error,
/// even if the integer being decoded is zero.
EmptyBuffer,

/// The actual byte length of the buffer is less than what was specified
/// in the prefix.
LengthLessThanPrefix { prefix: u8, actual: u8 },
}
pub(crate) struct IntDecodeError(varint_simd::VarIntDecodeError);

impl Display for IntDecodeError {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::EmptyBuffer => f.write_str(
"Int couldn't be decoded because the buffer is empty",
),
Self::LengthLessThanPrefix { prefix, actual } => {
write!(
f,
"Int couldn't be decoded because the buffer's length is \
{actual}, but the prefix specified a length of {prefix}",
)
},
}
Display::fmt(&self.0, f)
}
}

Expand Down Expand Up @@ -396,8 +251,17 @@ mod serde {
mod tests {
use super::*;

impl PartialEq for IntDecodeError {
fn eq(&self, other: &Self) -> bool {
use varint_simd::VarIntDecodeError::*;
matches!(
(&self.0, &other.0),
(Overflow, Overflow) | (NotEnoughBytes, NotEnoughBytes)
);
}
}

impl core::fmt::Debug for IntDecodeError {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
Expand All @@ -418,31 +282,9 @@ mod tests {
}
}

/// Tests that integers are encoded using the correct number of bytes.
/// Tests the encoding-decoding roundtrip on a number of inputs.
#[test]
fn encode_int_num_bytes() {
fn expected_len(int: u64) -> u8 {
if int < 1 << 6 {
1
} else if int < 1 << 15 {
2
} else if int < 1 << 16 {
3
} else if int < 1 << 24 {
4
} else if int < 1 << 32 {
5
} else if int < 1 << 40 {
6
} else if int < 1 << 48 {
7
} else if int < 1 << 56 {
8
} else {
9
}
}

fn encode_int_roundtrip() {
let ints = (1..=8).chain([
0,
(1 << 6) - 1,
Expand All @@ -463,15 +305,9 @@ mod tests {

for int in ints {
int.encode(&mut buf);

assert_eq!(buf.len(), expected_len(int) as usize);

let (decoded, rest) = u64::decode(&buf).unwrap();

assert_eq!(int, decoded);

assert!(rest.is_empty());

buf.clear();
}
}
Expand All @@ -487,7 +323,7 @@ mod tests {

assert_eq!(
u32::decode(&buf).unwrap_err(),
IntDecodeError::EmptyBuffer
IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes),
);
}

Expand All @@ -503,7 +339,7 @@ mod tests {

assert_eq!(
u32::decode(&buf).unwrap_err(),
IntDecodeError::LengthLessThanPrefix { prefix: 2, actual: 1 }
IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes),
);
}
}

0 comments on commit 80136e1

Please sign in to comment.