Skip to content

Commit

Permalink
fix TBinaryUnsafeInputProtocol get_bytes error
Browse files Browse the repository at this point in the history
  • Loading branch information
Millione committed Oct 15, 2024
1 parent 9982f34 commit 3140e17
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions pilota/src/thrift/shmipc/binary_unsafe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ impl<'a> TBinaryUnsafeInputProtocol<'a> {
pub fn index(&self) -> usize {
self.index
}

#[doc(hidden)]
fn advance(&mut self, len: usize) {
self.trans.advance(len);
self.buf.advance(len);
self.index -= len;
}
}

impl<'a> TLengthProtocol for TBinaryUnsafeInputProtocol<'a> {
Expand Down Expand Up @@ -615,6 +622,7 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
let name = self.read_faststr()?;

let sequence_number = self.read_i32()?;
self.advance(self.index);
Ok(TMessageIdentifier::new(name, message_type, sequence_number))
}

Expand Down Expand Up @@ -668,8 +676,9 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
#[inline]
fn read_bytes(&mut self) -> Result<Bytes, ThriftException> {
let len = self.read_i32()?;
let val = Bytes::copy_from_slice(&self.buf[self.index..self.index + len as usize]);
self.index += len as usize;
self.advance(self.index);
let val = Bytes::copy_from_slice(&self.trans.split_to(len as usize));
self.buf = unsafe { slice::from_raw_parts(self.trans.as_ptr(), self.trans.len()) };
Ok(val)
}

Expand All @@ -681,11 +690,10 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
) -> Result<Bytes, ThriftException> {
if ptr.is_none() {
len -= self.index;
self.trans.advance(self.index);
self.advance(self.index);
}
self.index = 0;
let val = Bytes::copy_from_slice(&self.buf[self.index..self.index + len]);
self.trans.advance(len);
let val = Bytes::copy_from_slice(&self.trans.split_to(len));
self.buf = unsafe { slice::from_raw_parts(self.trans.as_ptr(), self.trans.len()) };

Ok(val)
Expand Down Expand Up @@ -770,11 +778,10 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
fn read_faststr(&mut self) -> Result<FastStr, ThriftException> {
unsafe {
let len = self.read_i32().unwrap_unchecked() as usize;
let val = FastStr::new(str::from_utf8_unchecked(
self.buf.get_unchecked(self.index..self.index + len),
));
self.index += len;
Ok(val)
self.advance(self.index);
let bytes = Bytes::copy_from_slice(&self.trans.split_to(len));
self.buf = slice::from_raw_parts(self.trans.as_ptr(), self.trans.len());
Ok(FastStr::from_bytes_unchecked(bytes))
}
}

Expand Down Expand Up @@ -827,8 +834,9 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
#[inline]
fn read_bytes_vec(&mut self) -> Result<Vec<u8>, ThriftException> {
let len = self.read_i32()? as usize;
let val = self.buf[self.index..self.index + len].to_vec();
self.index += len;
self.advance(self.index);
let val = self.trans.split_to(len).into();
self.buf = unsafe { slice::from_raw_parts(self.trans.as_ptr(), self.trans.len()) };
Ok(val)
}

Expand All @@ -841,8 +849,7 @@ impl<'a> TInputProtocol for TBinaryUnsafeInputProtocol<'a> {
fn skip(&mut self, field_type: TType) -> Result<usize, ThriftException> {
debug_assert!(self.index >= FIELD_BEGIN_LEN);

self.trans.advance(self.index - FIELD_BEGIN_LEN);
self.index = FIELD_BEGIN_LEN;
self.advance(self.index - FIELD_BEGIN_LEN);
self.buf = unsafe { slice::from_raw_parts(self.trans.as_ptr(), self.trans.len()) };

self.skip_till_depth(field_type, crate::thrift::MAXIMUM_SKIP_DEPTH)
Expand Down

0 comments on commit 3140e17

Please sign in to comment.