Skip to content

Commit

Permalink
Implement efficient block skipping in deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed Nov 13, 2024
1 parent a150f1a commit c85582c
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 31 deletions.
28 changes: 15 additions & 13 deletions serde_avro_fast/src/de/deserializer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
SchemaNode::String => read_length_delimited(self.state, StringVisitor(visitor)),
SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?),
block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?),
}),
SchemaNode::Map(elements_schema) => visitor.visit_map(MapMapAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?),
block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?),
}),
SchemaNode::Union(ref union) => Self {
schema_node: read_union_discriminant(self.state, union)?,
Expand Down Expand Up @@ -283,7 +283,7 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
match *self.schema_node {
SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?),
block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?),
}),
SchemaNode::Duration => visitor.visit_seq(DurationMapAndSeqAccess {
duration_buf: &self.state.read_const_size_buf::<12>()?,
Expand All @@ -300,7 +300,7 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
match *self.schema_node {
SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?),
block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?),
}),
SchemaNode::Duration if len == 3 => visitor.visit_seq(DurationMapAndSeqAccess {
duration_buf: &self.state.read_const_size_buf::<12>()?,
Expand Down Expand Up @@ -417,18 +417,20 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
where
V: Visitor<'de>,
{
// The main thing we can skip here for performance is utf8 decoding of strings.
// However we still need to drive the deserializer mostly normally to properly
// advance the reader.

// TODO skip more efficiently using blocks size hints
// https://stackoverflow.com/a/42247224/3799609

// Ideally this would also specialize if we have Seek on our generic reader but
// we don't have specialization
// We can skip here for performance:
// - utf8 decoding of strings
// - block reads when serialized data provides serialized block size in bytes

match *self.schema_node {
SchemaNode::String => read_length_delimited(self.state, BytesVisitor(visitor)),
SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, true, self.allowed_depth.dec()?),
}),
SchemaNode::Map(elements_schema) => visitor.visit_map(MapMapAccess {
elements_schema: elements_schema.as_ref(),
block_reader: BlockReader::new(self.state, true, self.allowed_depth.dec()?),
}),
_ => self.deserialize_any(visitor),
}
}
Expand Down
65 changes: 47 additions & 18 deletions serde_avro_fast/src/de/deserializer/types/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,65 @@ use super::*;

use std::num::NonZeroUsize;

fn read_block_len<'de, R>(state: &mut DeserializerState<R>) -> Result<Option<NonZeroUsize>, DeError>
fn read_block_len<'de, R>(
state: &mut DeserializerState<R>,
ignored: bool,
) -> Result<Option<NonZeroUsize>, DeError>
where
R: ReadSlice<'de>,
{
let len: i64 = state.read_varint()?;
let res;
if len < 0 {
// res = -len, properly handling i64::MIN
res = u64::from_ne_bytes(len.to_ne_bytes()).wrapping_neg();
// Drop the number of bytes in the block to properly advance the reader
// Since we don't use that value, decode as u64 instead of i64 (skip zigzag
// decoding) TODO enable fast skipping when encountering
// `deserialize_ignored_any`
let _: u64 = state.read_varint()?;
} else {
res = len as u64;
loop {
let len: i64 = state.read_varint()?;
let res;
if len < 0 {
if ignored {
// We have block length hint in the data, and we are ignoring the data, so we
// can skip the block efficiently
let block_len_in_bytes: i64 = state.read_varint()?;
let block_len_in_bytes: u64 = block_len_in_bytes.try_into().map_err(|e| {
DeError::custom(format_args!("Invalid block length in stream: {e}"))
})?;
state.skip_bytes(block_len_in_bytes)?;
continue; // Also discard next blocks if any
} else {
// res = -len, properly handling i64::MIN
res = u64::from_ne_bytes(len.to_ne_bytes()).wrapping_neg();
// Drop the number of bytes in the block to properly advance the reader
// Since we don't use that value, decode as u64 instead of i64 (skip zigzag
// decoding)
let _: u64 = state.read_varint()?;
}
} else {
res = len as u64;
}
break res
.try_into()
.map_err(|e| DeError::custom(format_args!("Invalid array length in stream: {e}")))
.map(NonZeroUsize::new);
}
res.try_into()
.map_err(|e| DeError::custom(format_args!("Invalid array length in stream: {e}")))
.map(NonZeroUsize::new)
}

pub(in super::super) struct BlockReader<'r, 's, R> {
current_block_len: usize,
n_read: usize,
reader: &'r mut DeserializerState<'s, R>,
allowed_depth: AllowedDepth,
/// Represents whether we were hinted deserialize_ignored_any. If yes, we
/// can use the block length to skip the block.
ignored: bool,
}
impl<'r, 's, R> BlockReader<'r, 's, R> {
pub(in super::super) fn new(
reader: &'r mut DeserializerState<'s, R>,
hinted_ignored: bool,
allowed_depth: AllowedDepth,
) -> Self {
Self {
reader,
current_block_len: 0,
n_read: 0,
allowed_depth,
ignored: hinted_ignored,
}
}
fn has_more<'de>(&mut self) -> Result<bool, DeError>
Expand All @@ -48,7 +69,7 @@ impl<'r, 's, R> BlockReader<'r, 's, R> {
{
self.current_block_len = match self.current_block_len.checked_sub(1) {
None => {
let new_len = read_block_len(self.reader)?;
let new_len = read_block_len(self.reader, self.ignored)?;
match new_len {
None => return Ok(false),
Some(new_len) => {
Expand Down Expand Up @@ -137,9 +158,17 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for StringDeserializer<'_, '_, R>
read_length_delimited(self.reader, StringVisitor(visitor))
}

fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
// Skip utf-8 validation
read_length_delimited(self.reader, BytesVisitor(visitor))
}

serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
tuple_struct map struct enum identifier
}
}
28 changes: 28 additions & 0 deletions serde_avro_fast/src/de/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ pub trait Read: std::io::Read + Sized + private::Sealed {
self.read_exact(&mut buf).map_err(DeError::io)?;
Ok(buf)
}
/// Skip `n_bytes` bytes from the underlying buffer
fn skip_bytes(&mut self, n_bytes: u64) -> Result<(), DeError> {
let written = std::io::copy(
&mut <&mut Self as std::io::Read>::take(self, n_bytes),
&mut std::io::sink(),
)
.map_err(DeError::io)?;
if written == n_bytes {
Ok(())
} else {
Err(DeError::custom(format_args!(
"Expected to skip {} bytes, but only skipped {}",
n_bytes, written
)))
}
}
}

/// Abstracts reading from slices (propagating lifetime) or any other `impl
Expand Down Expand Up @@ -74,6 +90,18 @@ impl<'de> Read for SliceRead<'de> {
}
}
}
fn skip_bytes(&mut self, n_bytes: u64) -> Result<(), DeError> {
let n_bytes: usize = n_bytes
.try_into()
.map_err(|_| DeError::custom("Invalid number of bytes to skip"))?;
match self.slice.get(n_bytes..) {
Some(rest) => {
self.slice = rest;
Ok(())
}
None => Err(DeError::unexpected_eof()),
}
}
}
impl<'de> ReadSlice<'de> for SliceRead<'de> {
fn read_slice<V>(&mut self, n: usize, visitor: V) -> Result<V::Value, DeError>
Expand Down
47 changes: 47 additions & 0 deletions serde_avro_fast/tests/deserialize_ignored.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use serde_avro_fast::Schema;

const SCHEMA: &str = r#"
{
"fields": [
{
"type": {"type": "array", "items": "int"},
"name": "a"
},
{
"type": {"type": "array", "items": "int"},
"name": "b"
},
{
"type": {"type": "array", "items": "int"},
"name": "cd"
}
],
"type": "record",
"name": "test_skip"
}
"#;

#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct TestSkip {
a: Vec<i32>,
cd: Vec<i32>,
}

#[test]
fn skip_block() {
let schema: Schema = SCHEMA.parse().unwrap();
let input: &[u8] = &[1, 2, 20, 0, 1, 2, 30, 1, 4, 31, 32, 0, 4, 40, 50, 0, 0xFF];
let expected = TestSkip {
a: vec![10],
cd: vec![20, 25],
};

let deserialized: TestSkip = serde_avro_fast::from_datum_slice(input, &schema).unwrap();
assert_eq!(deserialized, expected);

let mut reader = &input[..];
let deserialized: TestSkip = serde_avro_fast::from_datum_reader(&mut reader, &schema).unwrap();
assert_eq!(deserialized, expected);
// Also make sure that the reader stopped at the end of the block
assert_eq!(reader, &[0xFF]);
}

0 comments on commit c85582c

Please sign in to comment.