From 60f1873a21f6da15e61814ac1b776912829c9604 Mon Sep 17 00:00:00 2001 From: Jonathan Johnson Date: Tue, 4 Jan 2022 20:16:14 -0800 Subject: [PATCH] Added Transmog compatability. This commit adds support for specifying the serialization format at the time of accepting a connection, which means that ALPN protocol negotiation can be used to control which serialization format is used on an incoming connection. The next feature is the ability to switch serialization formats on a per-stream basis, after the r#type negotation. --- Cargo.toml | 4 +- examples/basic.rs | 5 +- src/error.rs | 25 ++-- src/quic/connection/connecting.rs | 18 ++- src/quic/connection/incoming.rs | 85 +++++++++--- src/quic/connection/mod.rs | 101 +++++++++++--- src/quic/connection/receiver.rs | 14 +- src/quic/connection/receiver_stream.rs | 54 +++++--- src/quic/connection/sender.rs | 105 ++++++++++----- src/quic/endpoint/builder/mod.rs | 34 ++--- src/quic/endpoint/mod.rs | 25 ++-- src/x509/mod.rs | 2 +- src/x509/private_key.rs | 2 +- tests/local-connections.rs | 179 +++++++++++++++++++++++++ 14 files changed, 519 insertions(+), 134 deletions(-) create mode 100644 tests/local-connections.rs diff --git a/Cargo.toml b/Cargo.toml index ab43b6f..0cd1fba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ trust-dns = ["trust-dns-resolver"] [dependencies] async-trait = "0.1" -bincode = "1" +transmog = "0.1.0-dev.2" bytes = "1" ct-logs = "0.9" flume = "0.10" @@ -54,6 +54,8 @@ fabruic = { path = "", features = ["rcgen", "test"] } quinn-proto = { version = "0.8", default-features = false } tokio = { version = "1", features = ["macros"] } trust-dns-proto = "0.21.0-alpha.4" +transmog-bincode = { version = "0.1.0-dev.2" } +transmog-pot = { version = "0.1.0-dev.2" } [profile.release] codegen-units = 1 diff --git a/examples/basic.rs b/examples/basic.rs index c22dea8..30a5539 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,6 +1,7 @@ use anyhow::{Error, Result}; use fabruic::{Endpoint, KeyPair}; use futures_util::{future, StreamExt, TryFutureExt}; +use transmog_bincode::Bincode; const SERVER_NAME: &str = "test"; /// Some random port. @@ -38,7 +39,7 @@ async fn main() -> Result<()> { index, connecting.remote_address() ); - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; println!( "[client:{}] Successfully connected to {}", index, @@ -107,7 +108,7 @@ async fn main() -> Result<()> { // every new incoming connections is handled in it's own task connections.push( tokio::spawn(async move { - let mut connection = connecting.accept::<()>().await?; + let mut connection = connecting.accept::<(), _>(Bincode::default()).await?; println!("[server] New Connection: {}", connection.remote_address()); // start listening to new incoming streams diff --git a/src/error.rs b/src/error.rs index c3668b7..7379316 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,11 +8,10 @@ // TODO: error type is becoming too big, split it up use std::{ - fmt::{self, Debug, Formatter}, + fmt::{self, Debug, Display, Formatter}, io, }; -pub use bincode::ErrorKind; use quinn::ConnectionClose; pub use quinn::{ConnectError, ConnectionError, ReadError, WriteError}; use thiserror::Error; @@ -207,7 +206,9 @@ impl From for Connecting { reason, }) if reason.as_ref() == b"peer doesn't support any known protocol" && error_code.to_string() == "the cryptographic handshake failed: error 120" => - Self::ProtocolMismatch, + { + Self::ProtocolMismatch + } other => Self::Connection(other), } } @@ -241,6 +242,7 @@ pub enum Incoming { /// Error receiving a message from a [`Receiver`](crate::Receiver). #[derive(Debug, Error)] +#[allow(variant_size_differences)] pub enum Receiver { /// Failed to read from a [`Receiver`](crate::Receiver). #[error("Error reading from `Receiver`: {0}")] @@ -248,16 +250,17 @@ pub enum Receiver { /// Failed to [`Deserialize`](serde::Deserialize) a message from a /// [`Receiver`](crate::Receiver). #[error("Error deserializing a message from `Receiver`: {0}")] - Deserialize(#[from] ErrorKind), + Deserialize(Box), } /// Error sending a message to a [`Sender`](crate::Sender). #[derive(Debug, Error)] +#[allow(variant_size_differences)] pub enum Sender { /// Failed to [`Serialize`](serde::Serialize) a message for a /// [`Sender`](crate::Sender). #[error("Error serializing a message to `Sender`: {0}")] - Serialize(ErrorKind), + Serialize(Box), /// Failed to write to a [`Sender`](crate::Sender). #[error("Error writing to `Sender`: {0}")] Write(#[from] WriteError), @@ -266,8 +269,14 @@ pub enum Sender { Closed(#[from] AlreadyClosed), } -impl From> for Sender { - fn from(error: Box) -> Self { - Self::Serialize(*error) +impl Sender { + /// Returns a new instance after boxing `err`. + pub(crate) fn from_serialization(err: E) -> Self { + Self::Serialize(Box::new(err)) } } + +/// An error raised from serialization. +pub trait SerializationError: Display + Debug + Send + Sync + 'static {} + +impl SerializationError for T where T: Display + Debug + Send + Sync + 'static {} diff --git a/src/quic/connection/connecting.rs b/src/quic/connection/connecting.rs index dab07ba..c7c7797 100644 --- a/src/quic/connection/connecting.rs +++ b/src/quic/connection/connecting.rs @@ -4,9 +4,12 @@ use std::net::SocketAddr; use quinn::{crypto::rustls::HandshakeData, NewConnection}; -use serde::{de::DeserializeOwned, Serialize}; +use transmog::OwnedDeserializer; -use crate::{error, Connection}; +use crate::{ + error::{self, SerializationError}, + Connection, +}; /// Represent's an intermediate state to build a [`Connection`]. #[must_use = "`Connecting` does nothing unless accepted with `Connecting::accept`"] @@ -47,9 +50,12 @@ impl Connecting { /// /// # Errors /// [`error::Connecting`] if the [`Connection`] failed to be established. - pub async fn accept( - self, - ) -> Result, error::Connecting> { + pub async fn accept(self, format: F) -> Result, error::Connecting> + where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, + { self.0 .await .map( @@ -57,7 +63,7 @@ impl Connecting { connection, bi_streams, .. - }| Connection::new(connection, bi_streams), + }| Connection::new(connection, bi_streams, format), ) .map_err(error::Connecting::from) } diff --git a/src/quic/connection/incoming.rs b/src/quic/connection/incoming.rs index 5622935..b7ba66d 100644 --- a/src/quic/connection/incoming.rs +++ b/src/quic/connection/incoming.rs @@ -4,24 +4,30 @@ use std::fmt::{self, Debug, Formatter}; use futures_util::StreamExt; use quinn::{RecvStream, SendStream}; -use serde::{de::DeserializeOwned, Serialize}; +use transmog::{Format, OwnedDeserializer}; use super::ReceiverStream; -use crate::{error, Receiver, Sender}; +use crate::{ + error::{self, SerializationError}, + Receiver, Sender, +}; /// An intermediate state to define which type to accept in this stream. See /// [`accept_stream`](Self::accept). #[must_use = "`Incoming` does nothing unless accepted with `Incoming::accept`"] -pub struct Incoming { +pub struct Incoming> { /// [`SendStream`] to build [`Sender`]. sender: SendStream, /// [`RecvStream`] to build [`Receiver`]. - receiver: ReceiverStream, + receiver: ReceiverStream, /// Requested type. r#type: Option>, } -impl Debug for Incoming { +impl Debug for Incoming +where + F: OwnedDeserializer, +{ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Incoming") .field("sender", &self.sender) @@ -31,12 +37,16 @@ impl Debug for Incoming { } } -impl Incoming { +impl Incoming +where + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ /// Builds a new [`Incoming`] from raw [`quinn`] types. - pub(super) fn new(sender: SendStream, receiver: RecvStream) -> Self { + pub(super) fn new(sender: SendStream, receiver: RecvStream, format: F) -> Self { Self { sender, - receiver: ReceiverStream::new(receiver), + receiver: ReceiverStream::new(receiver, format), r#type: None, } } @@ -80,12 +90,57 @@ impl Incoming { /// - [`error::Incoming::Receiver`] if receiving the type information to the /// peer failed, see [`error::Receiver`] for more details /// - [`error::Incoming::Closed`] if the stream was closed - pub async fn accept< - S: DeserializeOwned + Serialize + Send + 'static, - R: DeserializeOwned + Serialize + Send + 'static, - >( + pub async fn accept( + self, + ) -> Result<(Sender, Receiver), error::Incoming> + where + F: OwnedDeserializer + Format<'static, S> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let format = self.receiver.format.clone(); + self.accept_with_format(format).await + } + + /// Accept the incoming stream with the given types. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving. + /// + /// # Errors + /// - [`error::Incoming::Receiver`] if receiving the type information to the + /// peer failed, see [`error::Receiver`] for more details + /// - [`error::Incoming::Closed`] if the stream was closed + pub async fn accept_raw( + self, + ) -> Result<(Sender, Receiver), error::Incoming> + where + F: OwnedDeserializer + Format<'static, S> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let format = self.receiver.format.clone(); + self.accept_with_format(format).await + } + + /// Accept the incoming stream with the given types. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving. + /// + /// # Errors + /// - [`error::Incoming::Receiver`] if receiving the type information to the + /// peer failed, see [`error::Receiver`] for more details + /// - [`error::Incoming::Closed`] if the stream was closed + pub async fn accept_with_format( mut self, - ) -> Result<(Sender, Receiver), error::Incoming> { + format: NewFormat, + ) -> Result<(Sender, Receiver), error::Incoming> + where + NewFormat: OwnedDeserializer + Format<'static, S> + Clone + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { match self.r#type { Some(Ok(_)) => (), Some(Err(error)) => return Err(error), @@ -100,8 +155,8 @@ impl Incoming { } } - let sender = Sender::new(self.sender); - let receiver = Receiver::new(self.receiver.transmute()); + let sender = Sender::new(self.sender, format.clone()); + let receiver = Receiver::new(self.receiver.transmute(format)); Ok((sender, receiver)) } diff --git a/src/quic/connection/mod.rs b/src/quic/connection/mod.rs index 9521afc..508a10a 100644 --- a/src/quic/connection/mod.rs +++ b/src/quic/connection/mod.rs @@ -34,16 +34,23 @@ use quinn::{crypto::rustls::HandshakeData, IncomingBiStreams, VarInt}; pub use receiver::Receiver; use receiver_stream::ReceiverStream; pub use sender::Sender; -use serde::{de::DeserializeOwned, Serialize}; use stream::Stream; +use transmog::{Format, OwnedDeserializer}; use super::Task; -use crate::{error, CertificateChain}; +use crate::{ + error::{self, SerializationError}, + CertificateChain, +}; /// Represents an open connection. Receives [`Incoming`] through [`Stream`]. #[pin_project] #[derive(Clone)] -pub struct Connection { +pub struct Connection +where + T: Send + 'static, + F: OwnedDeserializer, +{ /// Initiate new connections or close socket. connection: quinn::Connection, /// Receive incoming streams. @@ -51,11 +58,17 @@ pub struct Connection { RecvStream<'static, Result<(quinn::SendStream, quinn::RecvStream), error::Connection>>, /// [`Task`] handling new incoming streams. task: Task<()>, + /// Serialization foramt. + format: F, /// Type for type negotiation for new streams. types: PhantomData, } -impl Debug for Connection { +impl Debug for Connection +where + T: Send + 'static, + F: OwnedDeserializer, +{ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Connection") .field("connection", &self.connection) @@ -66,10 +79,19 @@ impl Debug for Connection { } } -impl Connection { +impl Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ /// Builds a new [`Connection`] from raw [`quinn`] types. #[allow(clippy::mut_mut)] // futures_util::select_biased internal usage - pub(super) fn new(connection: quinn::Connection, bi_streams: IncomingBiStreams) -> Self { + pub(super) fn new( + connection: quinn::Connection, + bi_streams: IncomingBiStreams, + format: F, + ) -> Self { // channels for passing down new `Incoming` `Connection`s let (sender, receiver) = flume::unbounded(); let receiver = receiver.into_stream(); @@ -96,6 +118,7 @@ impl Connection { connection, receiver, task, + format, types: PhantomData, } } @@ -110,19 +133,51 @@ impl Connection { /// - [`error::Stream::Open`] if opening a stream failed /// - [`error::Stream::Sender`] if sending the type information to the peer /// failed, see [`error::Sender`] for more details - pub async fn open_stream< - S: DeserializeOwned + Serialize + Send + 'static, - R: DeserializeOwned + Serialize + Send + 'static, - >( + pub async fn open_stream( + &self, + r#type: &T, + ) -> Result<(Sender, Receiver), error::Stream> + where + F: OwnedDeserializer + Format<'static, S> + Format<'static, T> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let (sender, receiver) = self.connection.open_bi().await?; + + let sender = Sender::new(sender, self.format.clone()); + let receiver = Receiver::new(ReceiverStream::new(receiver, self.format.clone())); + + sender.send_any(r#type, &self.format)?; + + Ok((sender, receiver)) + } + + /// Open a stream on this [`Connection`], allowing to send data back and + /// forth. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving and `type` to send this information to the receiver. + /// + /// # Errors + /// - [`error::Stream::Open`] if opening a stream failed + /// - [`error::Stream::Sender`] if sending the type information to the peer + /// failed, see [`error::Sender`] for more details + pub async fn open_stream_with_format( &self, r#type: &T, - ) -> Result<(Sender, Receiver), error::Stream> { + format: NewFormat, + ) -> Result<(Sender, Receiver), error::Stream> + where + NewFormat: OwnedDeserializer + Format<'static, S> + Clone + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { let (sender, receiver) = self.connection.open_bi().await?; - let sender = Sender::new(sender); - let receiver = Receiver::new(ReceiverStream::new(receiver)); + let sender = Sender::new(sender, format.clone()); + let receiver = Receiver::new(ReceiverStream::new(receiver, format)); - sender.send_any(&r#type)?; + sender.send_any(r#type, &self.format)?; Ok((sender, receiver)) } @@ -175,8 +230,13 @@ impl Connection { } } -impl Stream for Connection { - type Item = Result, error::Connection>; +impl Stream for Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ + type Item = Result, error::Connection>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_terminated() { @@ -184,12 +244,17 @@ impl Stream for Connection } else { self.receiver .poll_next_unpin(cx) - .map_ok(|(sender, receiver)| Incoming::new(sender, receiver)) + .map_ok(|(sender, receiver)| Incoming::new(sender, receiver, self.format.clone())) } } } -impl FusedStream for Connection { +impl FusedStream for Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ fn is_terminated(&self) -> bool { self.receiver.is_terminated() } diff --git a/src/quic/connection/receiver.rs b/src/quic/connection/receiver.rs index 9cbccdd..f9d2fa4 100644 --- a/src/quic/connection/receiver.rs +++ b/src/quic/connection/receiver.rs @@ -7,10 +7,10 @@ use std::{ }; use futures_util::{stream::Stream, StreamExt}; -use serde::de::DeserializeOwned; +use transmog::OwnedDeserializer; use super::{ReceiverStream, Task}; -use crate::error; +use crate::error::{self, SerializationError}; /// Used to receive data from a stream. Will stop receiving message if /// deserialization failed. @@ -31,13 +31,17 @@ impl Debug for Receiver { } } -impl Receiver { +impl Receiver +where + T: Send, +{ /// Builds a new [`Receiver`] from a raw [`quinn`] type. Spawns a task that /// receives data from the stream. #[allow(clippy::mut_mut)] // futures_util::select_biased internal usage - pub(super) fn new(mut stream: ReceiverStream) -> Self + pub(super) fn new(mut stream: ReceiverStream) -> Self where - T: DeserializeOwned + Send, + F: OwnedDeserializer + 'static, + F::Error: SerializationError, { // receiver channels let (sender, receiver) = flume::unbounded(); diff --git a/src/quic/connection/receiver_stream.rs b/src/quic/connection/receiver_stream.rs index 6683647..bb9646e 100644 --- a/src/quic/connection/receiver_stream.rs +++ b/src/quic/connection/receiver_stream.rs @@ -8,7 +8,6 @@ use std::{ task::{Context, Poll}, }; -use bincode::ErrorKind; use bytes::{Buf, BufMut, BytesMut}; use futures_util::{ stream::{FusedStream, Stream}, @@ -16,13 +15,16 @@ use futures_util::{ }; use pin_project::pin_project; use quinn::{Chunk, ReadError, RecvStream, VarInt}; -use serde::de::DeserializeOwned; +use transmog::OwnedDeserializer; -use crate::error; +use crate::error::{self, SerializationError}; /// Wrapper around [`RecvStream`] providing framing and deserialization. #[pin_project] -pub(super) struct ReceiverStream { +pub(super) struct ReceiverStream +where + F: OwnedDeserializer, +{ /// Store length of the currently processing message. length: usize, /// Store incoming chunks. @@ -31,30 +33,41 @@ pub(super) struct ReceiverStream { stream: RecvStream, /// True if the stream is complete. complete: bool, + /// The deserialization format. + pub(super) format: F, /// Type to be [`Deserialize`](serde::Deserialize)d _type: PhantomData, } -impl ReceiverStream { +impl ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError + 'static, +{ /// Builds a new [`ReceiverStream`]. - pub(super) fn new(stream: RecvStream) -> Self { + pub(super) fn new(stream: RecvStream, format: F) -> Self { Self { length: 0, // 1480 bytes is a default MTU size configured by quinn-proto buffer: BytesMut::with_capacity(1480), stream, complete: false, + format, _type: PhantomData, } } /// Transmutes this [`ReceiverStream`] to a different message type. - pub(super) fn transmute(self) -> ReceiverStream { + pub(super) fn transmute(self, format: NewFormat) -> ReceiverStream + where + NewFormat: OwnedDeserializer, + { ReceiverStream { length: self.length, buffer: self.buffer, stream: self.stream, complete: self.complete, + format, _type: PhantomData, } } @@ -112,19 +125,20 @@ impl ReceiverStream { /// # Errors /// [`ErrorKind`] if `data` failed to be /// [`Deserialize`](serde::Deserialize)d. - fn deserialize(&mut self) -> Result, ErrorKind> { + #[allow(clippy::as_conversions, trivial_casts)] // False positive + fn deserialize(&mut self) -> Result, Box> { if let Some(length) = self.length() { if self.buffer.len() >= length { // split off the correct amount of data - let data = self.buffer.split_to(length).reader(); + let data = self.buffer.split_to(length); // reset the length self.length = 0; // deserialize message - // TODO: configure bincode, for example make it bounded - bincode::deserialize_from::<_, M>(data) + self.format + .deserialize_owned(&data) .map(Some) - .map_err(|error| *error) + .map_err(|err| Box::new(err) as Box) } else { Ok(None) } @@ -134,14 +148,18 @@ impl ReceiverStream { } } -impl Stream for ReceiverStream { +impl Stream for ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError, +{ type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use futures_util::ready; // did already have enough data to return a message without polling? - if let Some(message) = self.deserialize()? { + if let Some(message) = self.deserialize().map_err(error::Receiver::Deserialize)? { // send back the message return Poll::Ready(Some(Ok(message))); } @@ -154,7 +172,7 @@ impl Stream for ReceiverStream { loop { if ready!(self.poll(cx)?).is_some() { // The stream received some data, but we may not have a full packet. - if let Some(message) = self.deserialize()? { + if let Some(message) = self.deserialize().map_err(error::Receiver::Deserialize)? { break Poll::Ready(Some(Ok(message))); } } else { @@ -166,7 +184,11 @@ impl Stream for ReceiverStream { } } -impl FusedStream for ReceiverStream { +impl FusedStream for ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError, +{ fn is_terminated(&self) -> bool { self.complete } diff --git a/src/quic/connection/sender.rs b/src/quic/connection/sender.rs index 80e248d..d2eefe2 100644 --- a/src/quic/connection/sender.rs +++ b/src/quic/connection/sender.rs @@ -5,16 +5,18 @@ use std::{marker::PhantomData, mem::size_of}; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::StreamExt; use quinn::{SendStream, VarInt}; -use serde::Serialize; +use transmog::Format; use super::Task; -use crate::error; +use crate::error::{self, SerializationError}; /// Used to send data to a stream. #[derive(Clone, Debug)] -pub struct Sender { +pub struct Sender { /// Send [`Serialize`]d data to the sending task. sender: flume::Sender, + /// The serialization format. + format: F, /// Holds the type to [`Serialize`] too. _type: PhantomData, /// [`Task`] handle that does the sending into the stream. @@ -32,10 +34,14 @@ enum Message { Close, } -impl Sender { +impl Sender +where + F: Format<'static, T>, + F::Error: SerializationError, +{ /// Builds a new [`Sender`] from a raw [`quinn`] type. Spawns a task that /// sends data into the stream. - pub(super) fn new(mut stream_sender: SendStream) -> Self { + pub(super) fn new(mut stream_sender: SendStream, format: F) -> Self { // sender channels let (sender, receiver) = flume::unbounded(); @@ -68,6 +74,7 @@ impl Sender { Self { sender, + format, _type: PhantomData, task, } @@ -81,7 +88,7 @@ impl Sender { /// stream /// - [`error::Sender::Closed`] if the [`Sender`] is closed pub fn send(&self, data: &T) -> Result<(), error::Sender> { - self.send_any(data) + self.send_any(data, &self.format) } /// Send any `data` into the stream. This will fail on the receiving end if @@ -93,37 +100,67 @@ impl Sender { /// stream /// - [`error::Sender::Closed`] if the [`Sender`] is closed #[allow(clippy::panic_in_result_fn, clippy::unwrap_in_result)] - pub(super) fn send_any(&self, data: &A) -> Result<(), error::Sender> { + pub(super) fn send_any( + &self, + data: &A, + format: &AnyFormat, + ) -> Result<(), error::Sender> + where + AnyFormat: Format<'static, A>, + AnyFormat::Error: SerializationError, + { let mut bytes = BytesMut::new(); // get size - let len = bincode::serialized_size(&data)?; - // reserve an appropriate amount of space - bytes.reserve( - usize::try_from(len) - .expect("not a 64-bit system") - .checked_add(size_of::()) - .expect("data trying to be sent is too big"), - ); - // insert length first, this enables framing - bytes.put_u64_le(len); - - let mut bytes = bytes.writer(); - - // serialize `data` into `bytes` - bincode::serialize_into(&mut bytes, &data)?; - - // send data to task - let bytes = bytes.into_inner().freeze(); - - // make sure that our length is correct - debug_assert_eq!( - u64::try_from(bytes.len()).expect("not a 64-bit system"), - u64::try_from(size_of::()) - .expect("not a 64-bit system") - .checked_add(len) - .expect("message to long") - ); + let bytes = if let Some(len) = format + .serialized_size(data) + .map_err(error::Sender::from_serialization)? + { + // reserve an appropriate amount of space + bytes.reserve( + len.checked_add(size_of::()) + .expect("data trying to be sent is too big"), + ); + // insert length first, this enables framing + + let len = u64::try_from(len).expect("not a 64-bit system"); + bytes.put_u64_le(len); + + let mut bytes = bytes.writer(); + + // serialize `data` into `bytes` + format + .serialize_into(data, &mut bytes) + .map_err(error::Sender::from_serialization)?; + + let bytes = bytes.into_inner().freeze(); + // make sure that our length is correct + debug_assert_eq!( + u64::try_from(bytes.len()).expect("not a 64-bit system"), + u64::try_from(size_of::()) + .expect("not a 64-bit system") + .checked_add(len) + .expect("message to long") + ); + bytes + } else { + bytes.put_u64_le(0); + let mut bytes = bytes.writer(); + format + .serialize_into(data, &mut bytes) + .map_err(error::Sender::from_serialization)?; + let mut bytes = bytes.into_inner(); + let serialized_length = bytes + .len() + .checked_sub(size_of::()) + .expect("negative bytes written"); + let serialized_length = u64::try_from(serialized_length).expect("not a 64-bit system"); + bytes + .get_mut(0..8) + .expect("bytes never allocated") + .copy_from_slice(&serialized_length.to_le_bytes()); + bytes.freeze() + }; // if the sender task has been dropped, return it's error if self.sender.send(bytes).is_err() { diff --git a/src/quic/endpoint/builder/mod.rs b/src/quic/endpoint/builder/mod.rs index 87222b8..368e354 100644 --- a/src/quic/endpoint/builder/mod.rs +++ b/src/quic/endpoint/builder/mod.rs @@ -49,6 +49,7 @@ pub struct Builder { impl Default for Builder { fn default() -> Self { + // TODO configure for sane max allocations Self::new() } } @@ -505,11 +506,12 @@ impl Builder { false, ) { Ok(client) => client, - Err(error) => + Err(error) => { return Err(error::Builder { error: error.into(), builder: self, - }), + }) + } }; // build server only if we have a key-pair @@ -662,6 +664,7 @@ mod test { use futures_util::StreamExt; use quinn::ConnectionError; use quinn_proto::TransportError; + use transmog_bincode::Bincode; use trust_dns_proto::error::ProtoErrorKind; use trust_dns_resolver::error::ResolveErrorKind; @@ -713,7 +716,7 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // test receiving client on server @@ -721,7 +724,7 @@ mod test { .next() .await .expect("server dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -734,9 +737,10 @@ mod test { // build client let mut builder = Builder::new(); - Dangerous::set_root_certificates(&mut builder, [server_key_pair - .end_entity_certificate() - .clone()]); + Dangerous::set_root_certificates( + &mut builder, + [server_key_pair.end_entity_certificate().clone()], + ); builder.set_client_key_pair(Some(client_key_pair.clone())); let client = builder.build()?; @@ -752,7 +756,7 @@ mod test { server.local_address()?.port() )) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // test receiving client on server @@ -760,7 +764,7 @@ mod test { .next() .await .expect("server dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // validate client certificate @@ -819,7 +823,7 @@ mod test { ); // check protocol on `Connection` - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; assert_eq!( protocols[0], connection.protocol().expect("no protocol found") @@ -835,7 +839,7 @@ mod test { ); // check protocol on `Connection` - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; assert_eq!( protocols[0], connection.protocol().expect("no protocol found") @@ -867,7 +871,7 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; // check result @@ -1003,7 +1007,7 @@ mod test { assert!(endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await .is_ok()); @@ -1030,7 +1034,7 @@ mod test { assert!(endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await .is_ok()); @@ -1057,7 +1061,7 @@ mod test { let result = endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; // check result diff --git a/src/quic/endpoint/mod.rs b/src/quic/endpoint/mod.rs index c723e7b..ea3f13d 100644 --- a/src/quic/endpoint/mod.rs +++ b/src/quic/endpoint/mod.rs @@ -615,6 +615,7 @@ mod test { use futures_util::StreamExt; use quinn::{ConnectionClose, ConnectionError}; use quinn_proto::TransportErrorCode; + use transmog_bincode::Bincode; use super::*; use crate::KeyPair; @@ -638,13 +639,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -663,13 +664,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -712,13 +713,13 @@ mod test { let _connection = client .connect_pinned(&address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // closing the client/server will close all connection immediately @@ -730,7 +731,7 @@ mod test { client .connect_pinned(address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await, Err(error::Connecting::Connection( ConnectionError::LocallyClosed @@ -759,13 +760,13 @@ mod test { let client_connection = client .connect_pinned(&address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let mut server_connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // refuse new incoming connections @@ -784,7 +785,7 @@ mod test { let result = client .connect_pinned(address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; assert!(matches!( result, @@ -837,13 +838,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; } diff --git a/src/x509/mod.rs b/src/x509/mod.rs index c173106..76cfe99 100644 --- a/src/x509/mod.rs +++ b/src/x509/mod.rs @@ -182,7 +182,7 @@ where #[test] fn serialize() -> anyhow::Result<()> { - use bincode::{config::DefaultOptions, Options, Serializer}; + use transmog_bincode::bincode::{self, config::DefaultOptions, Options, Serializer}; let key_pair = KeyPair::new_self_signed("test"); diff --git a/src/x509/private_key.rs b/src/x509/private_key.rs index 292db04..2884c44 100644 --- a/src/x509/private_key.rs +++ b/src/x509/private_key.rs @@ -129,7 +129,7 @@ impl Dangerous for PrivateKey { #[cfg(test)] mod test { use anyhow::Result; - use bincode::{config::DefaultOptions, Options, Serializer}; + use transmog_bincode::bincode::{self, config::DefaultOptions, Options, Serializer}; use super::*; use crate::KeyPair; diff --git a/tests/local-connections.rs b/tests/local-connections.rs new file mode 100644 index 0000000..d89bb60 --- /dev/null +++ b/tests/local-connections.rs @@ -0,0 +1,179 @@ +use anyhow::{Error, Result}; +use fabruic::{Endpoint, KeyPair}; +use futures_util::{future, StreamExt, TryFutureExt}; +use transmog::{Format, OwnedDeserializer}; +use transmog_bincode::Bincode; +use transmog_pot::Pot; + +const SERVER_NAME: &str = "test"; +const CLIENTS: usize = 100; + +async fn simulate_client_and_server(format: F) -> Result<()> +where + F: OwnedDeserializer<()> + OwnedDeserializer + Clone + 'static, + >::Error: Send + Sync + 'static, + >::Error: Send + Sync + 'static, +{ + // collect all tasks + let mut clients = Vec::with_capacity(CLIENTS); + + // generate a certificate pair + let key_pair = KeyPair::new_self_signed(SERVER_NAME); + + // build the server + // we want to do this outside to reserve the `SERVER_PORT`, otherwise spawned + // clients may take it + let mut server = Endpoint::new_server(0, key_pair.clone())?; + let address = format!("quic://{}", server.local_address()?); + + // start 100 clients + for index in 0..CLIENTS { + let address = address.clone(); + let certificate = key_pair.end_entity_certificate().clone(); + + let task_format = format.clone(); + clients.push( + tokio::spawn(async move { + // build a client + let client = Endpoint::new_client()?; + + let connecting = client.connect_pinned(address, &certificate, None).await?; + println!( + "[client:{}] Connecting to {}", + index, + connecting.remote_address() + ); + let connection = connecting.accept::<(), _>(task_format).await?; + println!( + "[client:{}] Successfully connected to {}", + index, + connection.remote_address() + ); + connection.close_incoming().await?; + + // initiate a stream + let (sender, mut receiver) = connection.open_stream::(&()).await?; + println!( + "[client:{}] Successfully opened stream to {}", + index, + connection.remote_address() + ); + + // send message + sender.send(&format!("hello from client {}", index))?; + + // start listening to new incoming messages + // in this example we know there is only 1 incoming message, so we will + // not wait for more + let message = receiver.next().await.expect("no message found")?; + println!( + "[client:{}] New message from {}: {}", + index, + connection.remote_address(), + message + ); + + // wait for stream to finish + sender.finish().await?; + receiver.finish().await?; + + // wait for client to finish cleanly + client.wait_idle().await; + println!( + "[client:{}] Successfully finished {}", + index, + client.local_address()? + ); + + Result::<_, Error>::Ok(()) + }) + .map_err(Error::from) + .and_then(future::ready), + ); + } + + // start the server + println!("[server] Listening on {}", server.local_address()?); + + // collect incoming connection tasks + let mut connections = Vec::with_capacity(CLIENTS); + + // start listening to new incoming connections + // in this example we know there is `CLIENTS` number of clients, so we will not + // wait for more + for _ in 0..CLIENTS { + let connecting = server.next().await.expect("connection failed"); + + println!( + "[server] New incoming Connection: {}", + connecting.remote_address() + ); + + let task_format = format.clone(); + // every new incoming connections is handled in it's own task + connections.push( + tokio::spawn(async move { + let mut connection = connecting.accept::<(), _>(task_format.clone()).await?; + println!("[server] New Connection: {}", connection.remote_address()); + + // start listening to new incoming streams + // in this example we know there is only 1 incoming stream, so we will not wait + // for more + let incoming = connection.next().await.expect("no stream found")?; + connection.close_incoming().await?; + println!( + "[server] New incoming stream from: {}", + connection.remote_address() + ); + + // accept stream + let (sender, mut receiver) = incoming.accept::().await?; + + // start listening to new incoming messages + // in this example we know there is only 1 incoming message, so we will not wait + // for more + let message = receiver.next().await.expect("no message found")?; + println!( + "[server] New message from {}: {}", + connection.remote_address(), + message + ); + + // respond + sender.send(&String::from("hello from server"))?; + + // wait for stream to finish + sender.finish().await?; + receiver.finish().await?; + + Result::<_, Error>::Ok(()) + }) + .map_err(Error::from) + .and_then(future::ready), + ); + } + + server.close_incoming().await?; + + // wait for all connections to finish + future::try_join_all(connections).await?; + + // wait for server to finish cleanly + server.wait_idle().await; + println!("[server] Successfully finished {}", server.local_address()?); + + future::try_join_all(clients).await?; + + Ok(()) +} + +#[tokio::test] +async fn format_without_serialized_size() -> Result<()> { + simulate_client_and_server(Pot::default()).await +} + +#[tokio::main] +#[cfg_attr(test, test)] +async fn format_with_serialized_size() -> Result<()> { + simulate_client_and_server(Bincode::default()).await +}