diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index e42c15d9e5..c225256da2 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -1,7 +1,7 @@ use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; -use bytes::{Buf, Bytes}; +use bytes::{Buf, Bytes, BytesMut}; use crate::collation::{CharSet, Collation}; use crate::error::Error; @@ -126,9 +126,7 @@ impl MySqlStream { .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)); } - // receive the next packet from the database server - // may block (async) on more data from the server - pub(crate) async fn recv_packet(&mut self) -> Result, Error> { + async fn recv_packet_part(&mut self) -> Result { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet @@ -142,10 +140,33 @@ impl MySqlStream { let payload: Bytes = self.socket.read(packet_size).await?; // TODO: packet compression - // TODO: packet joining + + Ok(payload) + } + + // receive the next packet from the database server + // may block (async) on more data from the server + pub(crate) async fn recv_packet(&mut self) -> Result, Error> { + let payload = self.recv_packet_part().await?; + let payload = if payload.len() < 0xFF_FF_FF { + payload + } else { + let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2); + final_payload.extend_from_slice(&payload); + + drop(payload); // we don't need the allocation anymore + + let mut last_read = 0xFF_FF_FF; + while last_read == 0xFF_FF_FF { + let part = self.recv_packet_part().await?; + last_read = part.len(); + final_payload.extend_from_slice(&part); + } + final_payload.into() + }; if payload - .get(0) + .first() .ok_or(err_protocol!("Packet empty"))? .eq(&0xff) { diff --git a/sqlx-mysql/src/protocol/packet.rs b/sqlx-mysql/src/protocol/packet.rs index 2d40ba8e77..9d0d46c35a 100644 --- a/sqlx-mysql/src/protocol/packet.rs +++ b/sqlx-mysql/src/protocol/packet.rs @@ -1,3 +1,4 @@ +use std::cmp::min; use std::ops::{Deref, DerefMut}; use bytes::Bytes; @@ -19,6 +20,14 @@ where buf: &mut Vec, (capabilities, sequence_id): (Capabilities, &'stream mut u8), ) { + let mut next_header = |len: u32| { + let mut buf = len.to_le_bytes(); + buf[3] = *sequence_id; + *sequence_id = sequence_id.wrapping_add(1); + + buf + }; + // reserve space to write the prefixed length let offset = buf.len(); buf.extend(&[0_u8; 4]); @@ -31,13 +40,25 @@ where let len = buf.len() - offset - 4; let header = &mut buf[offset..]; - // FIXME: Support larger packets - assert!(len < 0xFF_FF_FF); + header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32)); - header[..4].copy_from_slice(&(len as u32).to_le_bytes()); - header[3] = *sequence_id; + // add more packets if we need to split the data + if len >= 0xFF_FF_FF { + let rest = buf.split_off(offset + 4 + 0xFF_FF_FF); + let mut chunks = rest.chunks_exact(0xFF_FF_FF); - *sequence_id = sequence_id.wrapping_add(1); + for chunk in chunks.by_ref() { + buf.reserve(chunk.len() + 4); + buf.extend(&next_header(chunk.len() as u32)); + buf.extend(chunk); + } + + // this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF + let remainder = chunks.remainder(); + buf.reserve(remainder.len() + 4); + buf.extend(&next_header(remainder.len() as u32)); + buf.extend(remainder); + } } } diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index ba57a01d8e..586cef2ed0 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -447,6 +447,39 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_can_handle_split_packets() -> anyhow::Result<()> { + // This will only take effect on new connections + new::() + .await? + .execute("SET GLOBAL max_allowed_packet = 4294967297") + .await?; + + let mut conn = new::().await?; + + conn.execute( + r#" +CREATE TEMPORARY TABLE large_table (data LONGBLOB); + "#, + ) + .await?; + + let data = vec![0x41; 0xFF_FF_FF * 2]; + + sqlx::query("INSERT INTO large_table (data) VALUES (?)") + .bind(&data) + .execute(&mut conn) + .await?; + + let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + Ok(()) +} + #[sqlx_macros::test] async fn test_shrink_buffers() -> anyhow::Result<()> { // We don't really have a good way to test that `.shrink_buffers()` functions as expected