Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods on EncodedReplica to convert it to and from a byte slice #13

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## [Unreleased]

### Added

- a `as_bytes()` method on `EncodedReplica` to get its underlying bytes;
- a `from_bytes()` method on `EncodedReplica` to create one from a byte slice;

### Changed

- the `EncodedReplica` struct now has a lifetime parameter tied to the
underlying buffer;

## [0.4.6] - Oct 31 2024

### Added
Expand Down
249 changes: 126 additions & 123 deletions src/encoded_replica.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use core::fmt;
use core::ops::Deref;

use sha2::{Digest, Sha256};

use crate::encode::{Decode, Encode, IntDecodeError};
use crate::encode::{Decode, Encode};
use crate::*;

/// We use this instead of a `Vec<u8>` because it's 1/3 the size on the stack.
pub(crate) type Checksum = Box<ChecksumArray>;

pub(crate) type ChecksumArray = [u8; 32];
type Checksum = [u8; 32];

const CHECKSUM_LEN: usize = core::mem::size_of::<ChecksumArray>();
const CHECKSUM_LEN: usize = core::mem::size_of::<Checksum>();

/// A [`Replica`] encoded into a compact binary format suitable for
/// transmission over the network.
Expand All @@ -18,153 +18,163 @@ const CHECKSUM_LEN: usize = core::mem::size_of::<ChecksumArray>();
/// [`decode`](Replica::decode). See the documentation of those methods for
/// more information.
#[cfg_attr(docsrs, doc(cfg(feature = "encode")))]
#[derive(Clone, PartialEq, Eq)]
pub struct EncodedReplica {
protocol_version: ProtocolVersion,
checksum: Checksum,
bytes: Box<[u8]>,
}

impl core::fmt::Debug for EncodedReplica {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
struct HexSlice<'a>(&'a [u8]);

impl core::fmt::Debug for HexSlice<'_> {
fn fmt(
&self,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
for byte in self.0 {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct EncodedReplica<'buf> {
bytes: Bytes<'buf>,
}

f.debug_struct("EncodedReplica")
.field("protocol_version", &self.protocol_version)
.field("checksum", &HexSlice(self.checksum()))
.finish_non_exhaustive()
}
#[derive(Clone)]
enum Bytes<'a> {
Owned(Box<[u8]>),
Borrowed(&'a [u8]),
}

impl EncodedReplica {
impl<'buf> EncodedReplica<'buf> {
/// Returns the raw bytes of the encoded replica.
#[inline]
pub(crate) fn bytes(&self) -> &[u8] {
&*self.bytes
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}

/// Creates an `EncodedReplica` from the given bytes.
#[inline]
pub(crate) fn checksum(&self) -> &[u8] {
&*self.checksum
pub fn from_bytes(bytes: &'buf [u8]) -> Self {
bytes.into()
}

/// Copies the underlying bytes into a new `EncodedReplica` with a static
/// lifetime.
#[inline]
pub(crate) fn new(
protocol_version: ProtocolVersion,
checksum: Checksum,
bytes: Box<[u8]>,
) -> Self {
Self { protocol_version, checksum, bytes }
pub fn to_static(&self) -> EncodedReplica<'static> {
EncodedReplica {
bytes: match &self.bytes {
Bytes::Owned(bytes) => Bytes::Owned(bytes.clone()),
Bytes::Borrowed(bytes) => Bytes::Owned((*bytes).into()),
},
}
}

#[inline]
pub(crate) fn protocol_version(&self) -> ProtocolVersion {
self.protocol_version
pub(crate) fn to_replica(
&self,
) -> Result<<Replica as Decode>::Value, DecodeError> {
let bytes = &*self.bytes;

let (protocol_version, buf) = ProtocolVersion::decode(bytes)
.map_err(|_| DecodeError::InvalidData)?;

if protocol_version != crate::PROTOCOL_VERSION {
return Err(DecodeError::DifferentProtocol {
encoded_on: protocol_version,
decoding_on: crate::PROTOCOL_VERSION,
});
}

if buf.len() < CHECKSUM_LEN {
return Err(DecodeError::InvalidData);
}

let (checksum_slice, buf) = buf.split_at(CHECKSUM_LEN);

if checksum_slice != checksum(buf) {
return Err(DecodeError::ChecksumFailed);
}

<Replica as Decode>::decode(buf)
.map(|(value, _rest)| value)
.map_err(|_| DecodeError::InvalidData)
}
}

impl Encode for EncodedReplica {
impl EncodedReplica<'static> {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
self.protocol_version.encode(buf);
buf.extend_from_slice(&*self.checksum);
(self.bytes.len() as u64).encode(buf);
buf.extend_from_slice(&*self.bytes);
pub(crate) fn from_replica(replica: &Replica) -> Self {
let mut bytes = Vec::new();
crate::PROTOCOL_VERSION.encode(&mut bytes);
let protocol_len = bytes.len();
let dummy_checksum = Checksum::default();
bytes.extend_from_slice(&dummy_checksum);
Encode::encode(replica, &mut bytes);
let replica_start = protocol_len + CHECKSUM_LEN;
let checksum = checksum(&bytes[replica_start..]);
bytes[protocol_len..protocol_len + CHECKSUM_LEN]
.copy_from_slice(&checksum);
Self { bytes: Bytes::Owned(bytes.into()) }
}
}

pub(crate) enum EncodedReplicaDecodeError {
Int(IntDecodeError),
Checksum { actual: usize },
Bytes { actual: usize, advertised: u64 },
impl fmt::Debug for EncodedReplica<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EncodedReplica").finish_non_exhaustive()
}
}

impl From<IntDecodeError> for EncodedReplicaDecodeError {
impl<'buf> From<&'buf [u8]> for EncodedReplica<'buf> {
#[inline]
fn from(err: IntDecodeError) -> Self {
Self::Int(err)
fn from(bytes: &'buf [u8]) -> Self {
Self { bytes: Bytes::Borrowed(bytes) }
}
}

impl core::fmt::Display for EncodedReplicaDecodeError {
impl From<Box<[u8]>> for EncodedReplica<'static> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let prefix = "Couldn't decode EncodedReplica";

match self {
EncodedReplicaDecodeError::Int(err) => {
write!(f, "{prefix}: {err}")
},

EncodedReplicaDecodeError::Checksum { actual } => {
write!(
f,
"{prefix}: need {CHECKSUM_LEN} bytes to decode checksum, \
but there are only {actual}",
)
},
fn from(bytes: Box<[u8]>) -> Self {
Self { bytes: Bytes::Owned(bytes) }
}
}

EncodedReplicaDecodeError::Bytes { actual, advertised } => {
write!(
f,
"{prefix}: {advertised} bytes were encoded, but there \
are only {actual}",
)
},
}
impl AsRef<[u8]> for EncodedReplica<'_> {
#[inline]
fn as_ref(&self) -> &[u8] {
&*self.bytes
}
}

impl Decode for EncodedReplica {
type Value = Self;
impl Deref for EncodedReplica<'_> {
type Target = [u8];

type Error = EncodedReplicaDecodeError;
#[inline]
fn deref(&self) -> &Self::Target {
&*self.bytes
}
}

impl PartialEq<Self> for EncodedReplica<'_> {
#[inline]
fn decode(buf: &[u8]) -> Result<(Self, &[u8]), Self::Error> {
let (protocol_version, buf) = ProtocolVersion::decode(buf)?;
fn eq(&self, other: &Self) -> bool {
*self.bytes == *other.bytes
}
}

if buf.len() < CHECKSUM_LEN {
return Err(EncodedReplicaDecodeError::Checksum {
actual: buf.len(),
});
}
impl Eq for EncodedReplica<'_> {}

let (checksum_slice, buf) = buf.split_at(CHECKSUM_LEN);
impl Encode for EncodedReplica<'_> {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) {
debug_assert!(buf.is_empty());
buf.extend_from_slice(&*self.bytes);
}
}

let mut checksum = [0; CHECKSUM_LEN];
impl Decode for EncodedReplica<'static> {
type Value = Self;
type Error = core::convert::Infallible;

checksum.copy_from_slice(checksum_slice);
#[inline]
fn decode(buf: &[u8]) -> Result<(Self, &[u8]), Self::Error> {
Ok((EncodedReplica::from_bytes(buf).to_static(), &[]))
}
}

let (num_bytes, buf) = u64::decode(buf)?;
impl Deref for Bytes<'_> {
type Target = [u8];

if (buf.len() as u64) < num_bytes {
return Err(EncodedReplicaDecodeError::Bytes {
actual: buf.len(),
advertised: num_bytes,
});
#[inline]
fn deref(&self) -> &Self::Target {
match self {
Bytes::Owned(bytes) => bytes,
Bytes::Borrowed(bytes) => bytes,
}

let (bytes, buf) = buf.split_at(num_bytes as usize);

let this = Self {
protocol_version,
checksum: Box::new(checksum),
bytes: bytes.into(),
};

Ok((this, buf))
}
}

Expand Down Expand Up @@ -208,19 +218,17 @@ pub enum DecodeError {
InvalidData,
}

impl core::fmt::Display for DecodeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DecodeError::ChecksumFailed => f.write_str("checksum failed"),

DecodeError::DifferentProtocol { encoded_on, decoding_on } => {
write!(
f,
"different protocol: encoded on {:?}, decoding on {:?}",
encoded_on, decoding_on
)
},

DecodeError::InvalidData => f.write_str("invalid data"),
}
}
Expand All @@ -230,17 +238,12 @@ impl std::error::Error for DecodeError {}

#[inline(always)]
pub(crate) fn checksum(bytes: &[u8]) -> Checksum {
Box::new(checksum_array(bytes))
}

#[inline(always)]
pub(crate) fn checksum_array(bytes: &[u8]) -> ChecksumArray {
let checksum = Sha256::digest(bytes);
*checksum.as_ref()
}

#[cfg(feature = "serde")]
mod serde {
crate::encode::impl_serialize!(super::EncodedReplica);
crate::encode::impl_deserialize!(super::EncodedReplica);
crate::encode::impl_serialize!(super::EncodedReplica<'_>);
crate::encode::impl_deserialize!(super::EncodedReplica<'static>);
}
4 changes: 1 addition & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ use backlog::Backlog;
pub use backlog::{BackloggedDeletions, BackloggedInsertions};
pub use deletion::Deletion;
#[cfg(feature = "encode")]
use encoded_replica::{checksum, checksum_array};
#[cfg(feature = "encode")]
pub use encoded_replica::{DecodeError, EncodedReplica};
use gtree::{Gtree, LeafIdx};
pub use insertion::Insertion;
Expand Down Expand Up @@ -273,4 +271,4 @@ pub type Length = usize;
///
/// See [`ProtocolVersion`] for more infos.
#[cfg(feature = "encode")]
const PROTOCOL_VERSION: ProtocolVersion = 2;
const PROTOCOL_VERSION: ProtocolVersion = 3;
Loading
Loading