From 3a23730f18a00594223a052ede5a636c1e252774 Mon Sep 17 00:00:00 2001
From: Riccardo Mazzarini <me@noib3.dev>
Date: Thu, 4 Jul 2024 21:08:34 +0800
Subject: [PATCH 1/2] encode integers using leb128

---
 Cargo.toml    |   3 +-
 src/encode.rs | 224 +++++++-------------------------------------------
 2 files changed, 32 insertions(+), 195 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 61a318f..c4305e0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -17,12 +17,13 @@ features = ["serde"]
 rustdoc-args = ["--cfg", "docsrs"]
 
 [features]
-encode = ["dep:sha2"]
+encode = ["dep:sha2", "dep:varint-simd"]
 serde = ["encode", "dep:serde"]
 
 [dependencies]
 serde = { version = "1.0", optional = true }
 sha2 = { version = "0.10", optional = true }
+varint-simd = { version = "0.4", optional = true }
 
 [dev-dependencies]
 bincode = "1.3"
diff --git a/src/encode.rs b/src/encode.rs
index ffc2dff..92eb489 100644
--- a/src/encode.rs
+++ b/src/encode.rs
@@ -79,77 +79,6 @@ impl core::fmt::Display for BoolDecodeError {
     }
 }
 
-// When encoding integers we use a variable-length encoding scheme that aims to
-// minimize the number of bytes used to encode small integers, where small is
-// approx < 2 ^ 16.
-//
-// There are 3 separate branches based on how big the integer is:
-//
-// - if the integer is between 0 and 63, we can encode it with 6 bits. In this
-//   case we use the lower 6 bits of the first byte, and set the 2 highest bits
-//   to `01`;
-//
-// - if the integer is between 64 and 2 ^ 15 - 1, we can encode it with 15
-//   bits. In this case we use the lower 7 bits of the first byte and all the
-//   bits of the second byte. The highest bit of the first byte is set to `1`;
-//
-// - if the integer is greater than 2 ^ 15 - 1, we use the first byte to encode
-//   the number of bytes used to encode the integer, and then we encode the
-//   integer itself in little endian, throwing away any trailing zeros.
-//
-// With this scheme we can encode integers up to 63 with 1 byte, integers up
-// to 2 ^ 15 - 1 with 2 bytes, and integers greater than 2 ^ 15 - 1 with 3
-// or more bytes.
-
-const ENCODE_ONE_BYTE_MASK: u8 = 0b0100_0000;
-
-const ENCODE_TWO_BYTES_MASK: u8 = 0b1000_0000;
-
-const LAST_BIT_MASK: u8 = 0b1000_0000;
-
-#[inline(always)]
-fn encode_one_byte(int: u8) -> u8 {
-    debug_assert!(int < 1 << 6);
-    int | ENCODE_ONE_BYTE_MASK
-}
-
-#[inline(always)]
-fn decode_one_byte(int: u8) -> u8 {
-    int & !ENCODE_ONE_BYTE_MASK
-}
-
-#[inline(always)]
-fn encode_two_bytes(int: u16) -> (u8, u8) {
-    debug_assert!((1 << 6..1 << 15).contains(&int));
-
-    let [mut lo, mut hi] = int.to_le_bytes();
-
-    // Move the last bit of the low byte to the last bit of the high byte.
-    //
-    // We know this doesn't lose any information because the int is less than
-    // 2 ^ 15, so the last bit of the high byte is 0.
-    hi |= lo & LAST_BIT_MASK;
-
-    // Set the last bit of the low byte to 1 to indicate that this number is
-    // encoded with 2 bytes.
-    lo |= ENCODE_TWO_BYTES_MASK;
-
-    (lo, hi)
-}
-
-#[inline(always)]
-fn decode_two_bytes(mut lo: u8, mut hi: u8) -> u16 {
-    lo &= !ENCODE_TWO_BYTES_MASK;
-
-    // Move the last bit of the high byte to the last bit of the low byte.
-    lo |= hi & LAST_BIT_MASK;
-
-    // Reset the last bit of the high byte to 0.
-    hi &= !LAST_BIT_MASK;
-
-    u16::from_le_bytes([lo, hi])
-}
-
 impl_int_encode!(u16);
 impl_int_encode!(u32);
 impl_int_encode!(u64);
@@ -181,35 +110,8 @@ macro_rules! impl_int_encode {
         impl Encode for $ty {
             #[inline]
             fn encode(&self, buf: &mut Vec<u8>) {
-                let int = *self;
-
-                if int < 1 << 6 {
-                    buf.push(encode_one_byte(int as u8));
-                    return;
-                } else if int < 1 << 15 {
-                    let (first, second) = encode_two_bytes(int as u16);
-                    buf.push(first);
-                    buf.push(second);
-                    return;
-                }
-
-                let array = int.to_le_bytes();
-
-                let num_trailing_zeros = array
-                    .iter()
-                    .rev()
-                    .copied()
-                    .take_while(|&byte| byte == 0)
-                    .count();
-
-                let len = array.len() - num_trailing_zeros;
-
-                // Make sure that the first 2 bits are 0.
-                debug_assert_eq!(len & 0b1100_0000, 0);
-
-                buf.push(len as u8);
-
-                buf.extend_from_slice(&array[..len]);
+                let (array, len) = varint_simd::encode(*self);
+                buf.extend_from_slice(&array[..len as usize]);
             }
         }
     };
@@ -226,45 +128,18 @@ macro_rules! impl_int_decode {
 
             #[inline]
             fn decode(buf: &[u8]) -> Result<($ty, &[u8]), Self::Error> {
-                let (&first, buf) =
-                    buf.split_first().ok_or(IntDecodeError::EmptyBuffer)?;
-
-                if first & ENCODE_TWO_BYTES_MASK != 0 {
-                    let lo = first;
-
-                    let (&hi, buf) = buf.split_first().ok_or(
-                        IntDecodeError::LengthLessThanPrefix {
-                            prefix: 2,
-                            actual: 1,
-                        },
-                    )?;
-
-                    let int = decode_two_bytes(lo, hi) as $ty;
-
-                    return Ok((int, buf));
-                } else if first & ENCODE_ONE_BYTE_MASK != 0 {
-                    let int = decode_one_byte(first) as $ty;
-                    return Ok((int, buf));
-                }
-
-                let len = first;
-
-                if len as usize > buf.len() {
-                    return Err(IntDecodeError::LengthLessThanPrefix {
-                        prefix: len,
-                        actual: buf.len() as u8,
-                    });
-                }
-
-                let mut array = [0u8; ::core::mem::size_of::<$ty>()];
-
-                let (bytes, buf) = buf.split_at(len as usize);
-
-                array[..bytes.len()].copy_from_slice(bytes);
-
-                let int = <$ty>::from_le_bytes(array);
-
-                Ok((int, buf))
+                let (decoded, len) = varint_simd::decode::<Self>(buf)
+                    .map_err(IntDecodeError)?;
+
+                // TODO: this check shouldn't be necessary, `decode` should
+                // fail. Open an issue.
+                let Some(rest) = buf.get(len as usize..) else {
+                    return Err(IntDecodeError(
+                        varint_simd::VarIntDecodeError::NotEnoughBytes,
+                    ));
+                };
+
+                Ok((decoded, rest))
             }
         }
     };
@@ -273,32 +148,12 @@ macro_rules! impl_int_decode {
 use impl_int_decode;
 
 /// An error that can occur when decoding an [`Int`].
-#[cfg_attr(test, derive(PartialEq, Eq))]
-pub(crate) enum IntDecodeError {
-    /// The buffer passed to `Int::decode` is empty. This is always an error,
-    /// even if the integer being decoded is zero.
-    EmptyBuffer,
-
-    /// The actual byte length of the buffer is less than what was specified
-    /// in the prefix.
-    LengthLessThanPrefix { prefix: u8, actual: u8 },
-}
+pub(crate) struct IntDecodeError(varint_simd::VarIntDecodeError);
 
 impl Display for IntDecodeError {
     #[inline]
     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
-        match self {
-            Self::EmptyBuffer => f.write_str(
-                "Int couldn't be decoded because the buffer is empty",
-            ),
-            Self::LengthLessThanPrefix { prefix, actual } => {
-                write!(
-                    f,
-                    "Int couldn't be decoded because the buffer's length is \
-                     {actual}, but the prefix specified a length of {prefix}",
-                )
-            },
-        }
+        Display::fmt(&self.0, f)
     }
 }
 
@@ -396,8 +251,17 @@ mod serde {
 mod tests {
     use super::*;
 
+    impl PartialEq for IntDecodeError {
+        fn eq(&self, other: &Self) -> bool {
+            use varint_simd::VarIntDecodeError::*;
+            matches!(
+                (&self.0, &other.0),
+                (Overflow, Overflow) | (NotEnoughBytes, NotEnoughBytes)
+            )
+        }
+    }
+
     impl core::fmt::Debug for IntDecodeError {
-        #[inline]
         fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
             core::fmt::Display::fmt(self, f)
         }
@@ -418,31 +282,9 @@ mod tests {
         }
     }
 
-    /// Tests that integers are encoded using the correct number of bytes.
+    /// Tests the encoding-decoding roundtrip on a number of inputs.
     #[test]
-    fn encode_int_num_bytes() {
-        fn expected_len(int: u64) -> u8 {
-            if int < 1 << 6 {
-                1
-            } else if int < 1 << 15 {
-                2
-            } else if int < 1 << 16 {
-                3
-            } else if int < 1 << 24 {
-                4
-            } else if int < 1 << 32 {
-                5
-            } else if int < 1 << 40 {
-                6
-            } else if int < 1 << 48 {
-                7
-            } else if int < 1 << 56 {
-                8
-            } else {
-                9
-            }
-        }
-
+    fn encode_int_roundtrip() {
         let ints = (1..=8).chain([
             0,
             (1 << 6) - 1,
@@ -463,15 +305,9 @@ mod tests {
 
         for int in ints {
             int.encode(&mut buf);
-
-            assert_eq!(buf.len(), expected_len(int) as usize);
-
             let (decoded, rest) = u64::decode(&buf).unwrap();
-
             assert_eq!(int, decoded);
-
             assert!(rest.is_empty());
-
             buf.clear();
         }
     }
@@ -487,7 +323,7 @@ mod tests {
 
         assert_eq!(
             u32::decode(&buf).unwrap_err(),
-            IntDecodeError::EmptyBuffer
+            IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes),
         );
     }
 
@@ -503,7 +339,7 @@ mod tests {
 
         assert_eq!(
             u32::decode(&buf).unwrap_err(),
-            IntDecodeError::LengthLessThanPrefix { prefix: 2, actual: 1 }
+            IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes),
         );
     }
 }

From 150358f1d5172f06c41731b65a15ada68a342a21 Mon Sep 17 00:00:00 2001
From: Riccardo Mazzarini <me@noib3.dev>
Date: Fri, 5 Jul 2024 20:49:31 +0800
Subject: [PATCH 2/2] write serde sizes to stdout in tests

---
 tests/serde.rs | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/tests/serde.rs b/tests/serde.rs
index 631488b..3102c0d 100644
--- a/tests/serde.rs
+++ b/tests/serde.rs
@@ -2,6 +2,8 @@ mod common;
 
 #[cfg(feature = "serde")]
 mod serde {
+    use std::io::{self, Write};
+
     use serde::de::DeserializeOwned;
     use serde::ser::Serialize;
     use traces::{ConcurrentTraceInfos, Crdt, Edit, SequentialTrace};
@@ -141,14 +143,22 @@ mod serde {
             }
         };
 
+        let mut stdout = io::stdout();
+
         let replica_size = E::encode(&replica.encode()).len();
 
-        println!("{} | Replica: {}", E::name(), printed_size(replica_size));
+        let _ = writeln!(
+            &mut stdout,
+            "{} | Replica: {}",
+            E::name(),
+            printed_size(replica_size)
+        );
 
         let total_insertions_size =
             insertions.iter().map(Vec::len).sum::<usize>();
 
-        println!(
+        let _ = writeln!(
+            &mut stdout,
             "{} | Total insertions: {}",
             E::name(),
             printed_size(total_insertions_size)
@@ -157,7 +167,8 @@ mod serde {
         let total_deletions_size =
             deletions.iter().map(Vec::len).sum::<usize>();
 
-        println!(
+        let _ = writeln!(
+            &mut stdout,
             "{} | Total deletions: {}",
             E::name(),
             printed_size(total_deletions_size)