diff --git a/ciborium/src/de/mod.rs b/ciborium/src/de/mod.rs index 53e04cb..799bf8c 100644 --- a/ciborium/src/de/mod.rs +++ b/ciborium/src/de/mod.rs @@ -333,6 +333,12 @@ where } } + // Longer strings require alloaction; delegate to `deserialize_string` + item @ Header::Text(_) => { + self.decoder.push(item); + self.deserialize_string(visitor) + } + header => Err(header.expected("str")), }; } @@ -371,6 +377,12 @@ where visitor.visit_bytes(&self.scratch[..len]) } + // Longer byte sequences require alloaction; delegate to `deserialize_byte_buf` + item @ Header::Bytes(_) => { + self.decoder.push(item); + self.deserialize_byte_buf(visitor) + } + Header::Array(len) => self.recurse(|me| { let access = Access(me, len); visitor.visit_seq(access) diff --git a/ciborium/tests/codec.rs b/ciborium/tests/codec.rs index bf8dfdf..90fe516 100644 --- a/ciborium/tests/codec.rs +++ b/ciborium/tests/codec.rs @@ -5,11 +5,13 @@ extern crate std; use std::collections::{BTreeMap, HashMap}; use std::convert::TryFrom; use std::fmt::Debug; +use std::io::Cursor; use ciborium::value::Value; use ciborium::{cbor, de::from_reader, ser::into_writer}; use rstest::rstest; +use serde::de::Visitor; use serde::{de::DeserializeOwned, Deserialize, Serialize}; macro_rules! val { @@ -416,3 +418,112 @@ fn byte_vec_serde_bytes_compatibility(input: Vec) { let bytes: Vec = from_reader(&buf[..]).unwrap(); assert_eq!(input, bytes); } + +// Regression test for #32 where strings and bytes longer than 4096 bytes previously failed to +// roundtrip if `deserialize_str` and `deserialize_bytes` (and not their owned equivalents) are used +// in the deserializers. + +#[derive(Clone, Debug, PartialEq, Eq)] +struct LongString { + s: String, +} + +impl Serialize for LongString { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.s.as_str()) + } +} + +impl<'de> Deserialize<'de> for LongString { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(LongStringVisitor) + } +} + +struct LongStringVisitor; + +impl<'de> Visitor<'de> for LongStringVisitor { + type Value = LongString; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(LongString { s: v.to_owned() }) + } +} + +#[test] +fn long_string_roundtrips() { + let s = String::from_utf8(vec![b'A'; 5000]).unwrap(); + let long_string = LongString { s }; + + let mut buf = vec![]; + into_writer(&long_string, Cursor::new(&mut buf)).unwrap(); + let long_string_de = from_reader(Cursor::new(&buf)).unwrap(); + + assert_eq!(long_string, long_string_de); +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct LongBytes { + v: Vec, +} + +impl Serialize for LongBytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(self.v.as_slice()) + } +} + +impl<'de> Deserialize<'de> for LongBytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_bytes(LongBytesVisitor) + } +} + +struct LongBytesVisitor; + +impl<'de> Visitor<'de> for LongBytesVisitor { + type Value = LongBytes; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "bytes") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(LongBytes { v: v.to_vec() }) + } +} + +#[test] +fn long_bytes_roundtrips() { + let long_bytes = LongBytes { + v: vec![b'A'; 5000], + }; + + let mut buf = vec![]; + into_writer(&long_bytes, Cursor::new(&mut buf)).unwrap(); + let long_bytes_de = from_reader(Cursor::new(&buf)).unwrap(); + + assert_eq!(long_bytes, long_bytes_de); +}