Skip to content

Commit

Permalink
feat(mysql): support packet splitting (#2665)
Browse files Browse the repository at this point in the history
* Writing split packets

* Reading split packets

* Add tests for packet splitting

* Fix test for packet splitting
  • Loading branch information
tk2217 authored Oct 11, 2023
1 parent b138705 commit 5ebe296
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 11 deletions.
33 changes: 27 additions & 6 deletions sqlx-mysql/src/connection/stream.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::VecDeque;
use std::ops::{Deref, DerefMut};

use bytes::{Buf, Bytes};
use bytes::{Buf, Bytes, BytesMut};

use crate::collation::{CharSet, Collation};
use crate::error::Error;
Expand Down Expand Up @@ -126,9 +126,7 @@ impl<S: Socket> MySqlStream<S> {
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
}

// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet

Expand All @@ -142,10 +140,33 @@ impl<S: Socket> MySqlStream<S> {
let payload: Bytes = self.socket.read(packet_size).await?;

// TODO: packet compression
// TODO: packet joining

Ok(payload)
}

// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
let payload = self.recv_packet_part().await?;
let payload = if payload.len() < 0xFF_FF_FF {
payload
} else {
let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
final_payload.extend_from_slice(&payload);

drop(payload); // we don't need the allocation anymore

let mut last_read = 0xFF_FF_FF;
while last_read == 0xFF_FF_FF {
let part = self.recv_packet_part().await?;
last_read = part.len();
final_payload.extend_from_slice(&part);
}
final_payload.into()
};

if payload
.get(0)
.first()
.ok_or(err_protocol!("Packet empty"))?
.eq(&0xff)
{
Expand Down
31 changes: 26 additions & 5 deletions sqlx-mysql/src/protocol/packet.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::ops::{Deref, DerefMut};

use bytes::Bytes;
Expand All @@ -19,6 +20,14 @@ where
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
let mut next_header = |len: u32| {
let mut buf = len.to_le_bytes();
buf[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);

buf
};

// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
Expand All @@ -31,13 +40,25 @@ where
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];

// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));

header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
// add more packets if we need to split the data
if len >= 0xFF_FF_FF {
let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
let mut chunks = rest.chunks_exact(0xFF_FF_FF);

*sequence_id = sequence_id.wrapping_add(1);
for chunk in chunks.by_ref() {
buf.reserve(chunk.len() + 4);
buf.extend(&next_header(chunk.len() as u32));
buf.extend(chunk);
}

// this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
let remainder = chunks.remainder();
buf.reserve(remainder.len() + 4);
buf.extend(&next_header(remainder.len() as u32));
buf.extend(remainder);
}
}
}

Expand Down
33 changes: 33 additions & 0 deletions tests/mysql/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,39 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_can_handle_split_packets() -> anyhow::Result<()> {
// This will only take effect on new connections
new::<MySql>()
.await?
.execute("SET GLOBAL max_allowed_packet = 4294967297")
.await?;

let mut conn = new::<MySql>().await?;

conn.execute(
r#"
CREATE TEMPORARY TABLE large_table (data LONGBLOB);
"#,
)
.await?;

let data = vec![0x41; 0xFF_FF_FF * 2];

sqlx::query("INSERT INTO large_table (data) VALUES (?)")
.bind(&data)
.execute(&mut conn)
.await?;

let ret: Vec<u8> = sqlx::query_scalar("SELECT * FROM large_table")
.fetch_one(&mut conn)
.await?;

assert_eq!(ret, data);

Ok(())
}

#[sqlx_macros::test]
async fn test_shrink_buffers() -> anyhow::Result<()> {
// We don't really have a good way to test that `.shrink_buffers()` functions as expected
Expand Down

0 comments on commit 5ebe296

Please sign in to comment.