From 38e64d28600a6c62030ea1a2f8849771d6e94dbc Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Mon, 23 Sep 2024 13:31:23 -0700 Subject: [PATCH] Add TryFromBytes::try_read_from_{prefix,suffix} (#1738) --- src/lib.rs | 261 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 234 insertions(+), 27 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f64e36b72c..0880954447 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1506,12 +1506,12 @@ pub unsafe trait TryFromBytes { /// // These are more bytes than are needed to encode a `Packet`. /// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5, 6][..]; /// - /// let (packet, excess) = Packet::try_ref_from_prefix(bytes).unwrap(); + /// let (packet, suffix) = Packet::try_ref_from_prefix(bytes).unwrap(); /// /// assert_eq!(packet.mug_size, 240); /// assert_eq!(packet.temperature, 77); /// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]); - /// assert_eq!(excess, &[6u8][..]); + /// assert_eq!(suffix, &[6u8][..]); /// /// // These bytes are not valid instance of `Packet`. /// let bytes = &[0x10, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5, 6][..]; @@ -1592,12 +1592,12 @@ pub unsafe trait TryFromBytes { /// // These are more bytes than are needed to encode a `Packet`. /// let bytes = &[0, 0xC0, 0xC0, 240, 77, 2, 3, 4, 5, 6, 7][..]; /// - /// let (excess, packet) = Packet::try_ref_from_suffix(bytes).unwrap(); + /// let (prefix, packet) = Packet::try_ref_from_suffix(bytes).unwrap(); /// /// assert_eq!(packet.mug_size, 240); /// assert_eq!(packet.temperature, 77); /// assert_eq!(packet.marshmallows, [[2, 3], [4, 5], [6, 7]]); - /// assert_eq!(excess, &[0u8][..]); + /// assert_eq!(prefix, &[0u8][..]); /// /// // These bytes are not valid instance of `Packet`. /// let bytes = &[0, 1, 2, 3, 4, 5, 6, 77, 240, 0xC0, 0x10][..]; @@ -1783,15 +1783,15 @@ pub unsafe trait TryFromBytes { /// // These are more bytes than are needed to encode a `Packet`. /// let bytes = &mut [0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5, 6][..]; /// - /// let (packet, excess) = Packet::try_mut_from_prefix(bytes).unwrap(); + /// let (packet, suffix) = Packet::try_mut_from_prefix(bytes).unwrap(); /// /// assert_eq!(packet.mug_size, 240); /// assert_eq!(packet.temperature, 77); /// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]); - /// assert_eq!(excess, &[6u8][..]); + /// assert_eq!(suffix, &[6u8][..]); /// /// packet.temperature = 111; - /// excess[0] = 222; + /// suffix[0] = 222; /// /// assert_eq!(bytes, [0xC0, 0xC0, 240, 111, 0, 1, 2, 3, 4, 5, 222]); /// @@ -1878,14 +1878,14 @@ pub unsafe trait TryFromBytes { /// // These are more bytes than are needed to encode a `Packet`. /// let bytes = &mut [0, 0xC0, 0xC0, 240, 77, 2, 3, 4, 5, 6, 7][..]; /// - /// let (excess, packet) = Packet::try_mut_from_suffix(bytes).unwrap(); + /// let (prefix, packet) = Packet::try_mut_from_suffix(bytes).unwrap(); /// /// assert_eq!(packet.mug_size, 240); /// assert_eq!(packet.temperature, 77); /// assert_eq!(packet.marshmallows, [[2, 3], [4, 5], [6, 7]]); - /// assert_eq!(excess, &[0u8][..]); + /// assert_eq!(prefix, &[0u8][..]); /// - /// excess[0] = 111; + /// prefix[0] = 111; /// packet.temperature = 222; /// /// assert_eq!(bytes, [111, 0xC0, 0xC0, 240, 222, 2, 3, 4, 5, 6, 7]); @@ -1918,16 +1918,16 @@ pub unsafe trait TryFromBytes { /// # use zerocopy_derive::*; /// /// // The only valid value of this type is the byte `0xC0` - /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[derive(TryFromBytes)] /// #[repr(u8)] /// enum C0 { xC0 = 0xC0 } /// /// // The only valid value of this type is the bytes `0xC0C0`. - /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[derive(TryFromBytes)] /// #[repr(C)] /// struct C0C0(C0, C0); /// - /// #[derive(TryFromBytes, KnownLayout, Immutable)] + /// #[derive(TryFromBytes)] /// #[repr(C)] /// struct Packet { /// magic_number: C0C0, @@ -1984,6 +1984,173 @@ pub unsafe trait TryFromBytes { // SAFETY: We just validated that `candidate` contains a valid `Self`. Ok(unsafe { candidate.assume_init() }) } + + /// Attempts to read a `Self` from the prefix of the given `source`. + /// + /// This attempts to read a `Self` from the first `size_of::()` bytes + /// of `source`, returning that `Self` and any remaining bytes. If + /// `source.len() < size_of::()` or the bytes are not a valid instance + /// of `Self`, it returns `Err`. + /// + /// # Examples + /// + /// ``` + /// use zerocopy::TryFromBytes; + /// # use zerocopy_derive::*; + /// + /// // The only valid value of this type is the byte `0xC0` + /// #[derive(TryFromBytes)] + /// #[repr(u8)] + /// enum C0 { xC0 = 0xC0 } + /// + /// // The only valid value of this type is the bytes `0xC0C0`. + /// #[derive(TryFromBytes)] + /// #[repr(C)] + /// struct C0C0(C0, C0); + /// + /// #[derive(TryFromBytes)] + /// #[repr(C)] + /// struct Packet { + /// magic_number: C0C0, + /// mug_size: u8, + /// temperature: u8, + /// } + /// + /// // These are more bytes than are needed to encode a `Packet`. + /// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5, 6][..]; + /// + /// let (packet, suffix) = Packet::try_read_from_prefix(bytes).unwrap(); + /// + /// assert_eq!(packet.mug_size, 240); + /// assert_eq!(packet.temperature, 77); + /// assert_eq!(suffix, &[0u8, 1, 2, 3, 4, 5, 6][..]); + /// + /// // These bytes are not valid instance of `Packet`. + /// let bytes = &[0x10, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5, 6][..]; + /// assert!(Packet::try_read_from_prefix(bytes).is_err()); + /// ``` + #[must_use = "has no side effects"] + #[inline] + fn try_read_from_prefix(source: &[u8]) -> Result<(Self, &[u8]), TryReadError<&[u8], Self>> + where + Self: Sized, + { + // Note that we have to call `is_bit_valid` on an exclusive-aliased + // pointer since we don't require `Self: Immutable`. That's why we do `let + // mut` and `Ptr::from_mut` here. See the doc comment on `is_bit_valid` + // and the implementation of `TryFromBytes` for `UnsafeCell` for more + // details. + let (mut candidate, suffix) = match MaybeUninit::::read_from_prefix(source) { + Ok(candidate) => candidate, + Err(e) => { + return Err(TryReadError::Size(e.with_dst())); + } + }; + let c_ptr = Ptr::from_mut(&mut candidate); + let c_ptr = c_ptr.transparent_wrapper_into_inner(); + // SAFETY: `c_ptr` has no uninitialized sub-ranges because it derived + // from `candidate`, which in turn derives from `source: &[u8]`. + let c_ptr = unsafe { c_ptr.assume_validity::() }; + + // This call may panic. If that happens, it doesn't cause any soundness + // issues, as we have not generated any invalid state which we need to + // fix before returning. + // + // Note that one panic or post-monomorphization error condition is + // calling `try_into_valid` (and thus `is_bit_valid`) with a shared + // pointer when `Self: !Immutable`. Since `Self: Immutable`, this panic + // condition will not happen. + if !Self::is_bit_valid(c_ptr.forget_aligned()) { + return Err(ValidityError::new(source).into()); + } + + // SAFETY: We just validated that `candidate` contains a valid `Self`. + Ok((unsafe { candidate.assume_init() }, suffix)) + } + + /// Attempts to read a `Self` from the suffix of the given `source`. + /// + /// This attempts to read a `Self` from the last `size_of::()` bytes + /// of `source`, returning that `Self` and any preceding bytes. If + /// `source.len() < size_of::()` or the bytes are not a valid instance + /// of `Self`, it returns `Err`. + /// + /// # Examples + /// + /// ``` + /// # #![allow(non_camel_case_types)] // For C0::xC0 + /// use zerocopy::TryFromBytes; + /// # use zerocopy_derive::*; + /// + /// // The only valid value of this type is the byte `0xC0` + /// #[derive(TryFromBytes)] + /// #[repr(u8)] + /// enum C0 { xC0 = 0xC0 } + /// + /// // The only valid value of this type is the bytes `0xC0C0`. + /// #[derive(TryFromBytes)] + /// #[repr(C)] + /// struct C0C0(C0, C0); + /// + /// #[derive(TryFromBytes)] + /// #[repr(C)] + /// struct Packet { + /// magic_number: C0C0, + /// mug_size: u8, + /// temperature: u8, + /// } + /// + /// // These are more bytes than are needed to encode a `Packet`. + /// let bytes = &[0, 1, 2, 3, 4, 5, 0xC0, 0xC0, 240, 77][..]; + /// + /// let (prefix, packet) = Packet::try_read_from_suffix(bytes).unwrap(); + /// + /// assert_eq!(packet.mug_size, 240); + /// assert_eq!(packet.temperature, 77); + /// assert_eq!(prefix, &[0u8, 1, 2, 3, 4, 5][..]); + /// + /// // These bytes are not valid instance of `Packet`. + /// let bytes = &[0, 1, 2, 3, 4, 5, 0x10, 0xC0, 240, 77][..]; + /// assert!(Packet::try_read_from_suffix(bytes).is_err()); + /// ``` + #[must_use = "has no side effects"] + #[inline] + fn try_read_from_suffix(source: &[u8]) -> Result<(&[u8], Self), TryReadError<&[u8], Self>> + where + Self: Sized, + { + // Note that we have to call `is_bit_valid` on an exclusive-aliased + // pointer since we don't require `Self: Immutable`. That's why we do `let + // mut` and `Ptr::from_mut` here. See the doc comment on `is_bit_valid` + // and the implementation of `TryFromBytes` for `UnsafeCell` for more + // details. + let (prefix, mut candidate) = match MaybeUninit::::read_from_suffix(source) { + Ok(candidate) => candidate, + Err(e) => { + return Err(TryReadError::Size(e.with_dst())); + } + }; + let c_ptr = Ptr::from_mut(&mut candidate); + let c_ptr = c_ptr.transparent_wrapper_into_inner(); + // SAFETY: `c_ptr` has no uninitialized sub-ranges because it derived + // from `candidate`, which in turn derives from `source: &[u8]`. + let c_ptr = unsafe { c_ptr.assume_validity::() }; + + // This call may panic. If that happens, it doesn't cause any soundness + // issues, as we have not generated any invalid state which we need to + // fix before returning. + // + // Note that one panic or post-monomorphization error condition is + // calling `try_into_valid` (and thus `is_bit_valid`) with a shared + // pointer when `Self: !Immutable`. Since `Self: Immutable`, this panic + // condition will not happen. + if !Self::is_bit_valid(c_ptr.forget_aligned()) { + return Err(ValidityError::new(source).into()); + } + + // SAFETY: We just validated that `candidate` contains a valid `Self`. + Ok((prefix, unsafe { candidate.assume_init() })) + } } #[inline(always)] @@ -2858,14 +3025,14 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode a `Packet`. /// let bytes = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14][..]; /// - /// let (packet, excess) = Packet::ref_from_prefix(bytes).unwrap(); + /// let (packet, suffix) = Packet::ref_from_prefix(bytes).unwrap(); /// /// assert_eq!(packet.header.src_port, [0, 1]); /// assert_eq!(packet.header.dst_port, [2, 3]); /// assert_eq!(packet.header.length, [4, 5]); /// assert_eq!(packet.header.checksum, [6, 7]); /// assert_eq!(packet.body, [[8, 9], [10, 11], [12, 13]]); - /// assert_eq!(excess, &[14u8][..]); + /// assert_eq!(suffix, &[14u8][..]); /// ``` #[must_use = "has no side effects"] #[inline] @@ -3299,14 +3466,14 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode two `Pixel`s. /// let bytes = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9][..]; /// - /// let (pixels, rest) = <[Pixel]>::ref_from_prefix_with_elems(bytes, 2).unwrap(); + /// let (pixels, suffix) = <[Pixel]>::ref_from_prefix_with_elems(bytes, 2).unwrap(); /// /// assert_eq!(pixels, &[ /// Pixel { r: 0, g: 1, b: 2, a: 3 }, /// Pixel { r: 4, g: 5, b: 6, a: 7 }, /// ]); /// - /// assert_eq!(rest, &[8, 9]); + /// assert_eq!(suffix, &[8, 9]); /// ``` /// /// Since an explicit `count` is provided, this method supports types with @@ -3374,9 +3541,9 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode two `Pixel`s. /// let bytes = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9][..]; /// - /// let (rest, pixels) = <[Pixel]>::ref_from_suffix_with_elems(bytes, 2).unwrap(); + /// let (prefix, pixels) = <[Pixel]>::ref_from_suffix_with_elems(bytes, 2).unwrap(); /// - /// assert_eq!(rest, &[0, 1]); + /// assert_eq!(prefix, &[0, 1]); /// /// assert_eq!(pixels, &[ /// Pixel { r: 2, g: 3, b: 4, a: 5 }, @@ -3530,17 +3697,17 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode two `Pixel`s. /// let bytes = &mut [0, 1, 2, 3, 4, 5, 6, 7, 8, 9][..]; /// - /// let (pixels, rest) = <[Pixel]>::mut_from_prefix_with_elems(bytes, 2).unwrap(); + /// let (pixels, suffix) = <[Pixel]>::mut_from_prefix_with_elems(bytes, 2).unwrap(); /// /// assert_eq!(pixels, &[ /// Pixel { r: 0, g: 1, b: 2, a: 3 }, /// Pixel { r: 4, g: 5, b: 6, a: 7 }, /// ]); /// - /// assert_eq!(rest, &[8, 9]); + /// assert_eq!(suffix, &[8, 9]); /// /// pixels[1] = Pixel { r: 0, g: 0, b: 0, a: 0 }; - /// rest.fill(1); + /// suffix.fill(1); /// /// assert_eq!(bytes, [0, 1, 2, 3, 0, 0, 0, 0, 1, 1]); /// ``` @@ -3610,16 +3777,16 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode two `Pixel`s. /// let bytes = &mut [0, 1, 2, 3, 4, 5, 6, 7, 8, 9][..]; /// - /// let (rest, pixels) = <[Pixel]>::mut_from_suffix_with_elems(bytes, 2).unwrap(); + /// let (prefix, pixels) = <[Pixel]>::mut_from_suffix_with_elems(bytes, 2).unwrap(); /// - /// assert_eq!(rest, &[0, 1]); + /// assert_eq!(prefix, &[0, 1]); /// /// assert_eq!(pixels, &[ /// Pixel { r: 2, g: 3, b: 4, a: 5 }, /// Pixel { r: 6, g: 7, b: 8, a: 9 }, /// ]); /// - /// rest.fill(9); + /// prefix.fill(9); /// pixels[1] = Pixel { r: 0, g: 0, b: 0, a: 0 }; /// /// assert_eq!(bytes, [9, 9, 2, 3, 4, 5, 0, 0, 0, 0]); @@ -3730,13 +3897,13 @@ pub unsafe trait FromBytes: FromZeros { /// // These are more bytes than are needed to encode a `PacketHeader`. /// let bytes = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9][..]; /// - /// let (header, suffix) = PacketHeader::read_from_prefix(bytes).unwrap(); + /// let (header, body) = PacketHeader::read_from_prefix(bytes).unwrap(); /// /// assert_eq!(header.src_port, [0, 1]); /// assert_eq!(header.dst_port, [2, 3]); /// assert_eq!(header.length, [4, 5]); /// assert_eq!(header.checksum, [6, 7]); - /// assert_eq!(suffix, [8, 9]); + /// assert_eq!(body, [8, 9]); /// ``` #[must_use = "has no side effects"] #[inline] @@ -5245,11 +5412,25 @@ mod tests { assert_eq!(::try_read_from_bytes(&[0]), Ok(false)); assert_eq!(::try_read_from_bytes(&[1]), Ok(true)); + assert_eq!(::try_read_from_prefix(&[0, 2]), Ok((false, &[2][..]))); + assert_eq!(::try_read_from_prefix(&[1, 2]), Ok((true, &[2][..]))); + + assert_eq!(::try_read_from_suffix(&[2, 0]), Ok((&[2][..], false))); + assert_eq!(::try_read_from_suffix(&[2, 1]), Ok((&[2][..], true))); + // If we don't pass enough bytes, it fails. assert!(matches!( ::try_read_from_bytes(&[]), Err(TryReadError::Size(_)) )); + assert!(matches!( + ::try_read_from_prefix(&[]), + Err(TryReadError::Size(_)) + )); + assert!(matches!( + ::try_read_from_suffix(&[]), + Err(TryReadError::Size(_)) + )); // If we pass too many bytes, it fails. assert!(matches!( @@ -5262,6 +5443,14 @@ mod tests { ::try_read_from_bytes(&[2]), Err(TryReadError::Validity(_)) )); + assert!(matches!( + ::try_read_from_prefix(&[2, 0]), + Err(TryReadError::Validity(_)) + )); + assert!(matches!( + ::try_read_from_suffix(&[0, 2]), + Err(TryReadError::Validity(_)) + )); // Reading from a misaligned buffer should still succeed. Since `AU64`'s // alignment is 8, and since we read from two adjacent addresses one @@ -5270,6 +5459,24 @@ mod tests { let bytes: [u8; 9] = [0, 0, 0, 0, 0, 0, 0, 0, 0]; assert_eq!(::try_read_from_bytes(&bytes[..8]), Ok(AU64(0))); assert_eq!(::try_read_from_bytes(&bytes[1..9]), Ok(AU64(0))); + + assert_eq!( + ::try_read_from_prefix(&bytes[..8]), + Ok((AU64(0), &[][..])) + ); + assert_eq!( + ::try_read_from_prefix(&bytes[1..9]), + Ok((AU64(0), &[][..])) + ); + + assert_eq!( + ::try_read_from_suffix(&bytes[..8]), + Ok((&[][..], AU64(0))) + ); + assert_eq!( + ::try_read_from_suffix(&bytes[1..9]), + Ok((&[][..], AU64(0))) + ); } #[test]