diff --git a/Cargo.toml b/Cargo.toml index ab43b6f..0cd1fba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ trust-dns = ["trust-dns-resolver"] [dependencies] async-trait = "0.1" -bincode = "1" +transmog = "0.1.0-dev.2" bytes = "1" ct-logs = "0.9" flume = "0.10" @@ -54,6 +54,8 @@ fabruic = { path = "", features = ["rcgen", "test"] } quinn-proto = { version = "0.8", default-features = false } tokio = { version = "1", features = ["macros"] } trust-dns-proto = "0.21.0-alpha.4" +transmog-bincode = { version = "0.1.0-dev.2" } +transmog-pot = { version = "0.1.0-dev.2" } [profile.release] codegen-units = 1 diff --git a/examples/basic.rs b/examples/basic.rs index c22dea8..30a5539 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,6 +1,7 @@ use anyhow::{Error, Result}; use fabruic::{Endpoint, KeyPair}; use futures_util::{future, StreamExt, TryFutureExt}; +use transmog_bincode::Bincode; const SERVER_NAME: &str = "test"; /// Some random port. @@ -38,7 +39,7 @@ async fn main() -> Result<()> { index, connecting.remote_address() ); - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; println!( "[client:{}] Successfully connected to {}", index, @@ -107,7 +108,7 @@ async fn main() -> Result<()> { // every new incoming connections is handled in it's own task connections.push( tokio::spawn(async move { - let mut connection = connecting.accept::<()>().await?; + let mut connection = connecting.accept::<(), _>(Bincode::default()).await?; println!("[server] New Connection: {}", connection.remote_address()); // start listening to new incoming streams diff --git a/src/error.rs b/src/error.rs index c3668b7..7379316 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,11 +8,10 @@ // TODO: error type is becoming too big, split it up use std::{ - fmt::{self, Debug, Formatter}, + fmt::{self, Debug, Display, Formatter}, io, }; -pub use bincode::ErrorKind; use quinn::ConnectionClose; pub use quinn::{ConnectError, ConnectionError, ReadError, WriteError}; use thiserror::Error; @@ -207,7 +206,9 @@ impl From for Connecting { reason, }) if reason.as_ref() == b"peer doesn't support any known protocol" && error_code.to_string() == "the cryptographic handshake failed: error 120" => - Self::ProtocolMismatch, + { + Self::ProtocolMismatch + } other => Self::Connection(other), } } @@ -241,6 +242,7 @@ pub enum Incoming { /// Error receiving a message from a [`Receiver`](crate::Receiver). #[derive(Debug, Error)] +#[allow(variant_size_differences)] pub enum Receiver { /// Failed to read from a [`Receiver`](crate::Receiver). #[error("Error reading from `Receiver`: {0}")] @@ -248,16 +250,17 @@ pub enum Receiver { /// Failed to [`Deserialize`](serde::Deserialize) a message from a /// [`Receiver`](crate::Receiver). #[error("Error deserializing a message from `Receiver`: {0}")] - Deserialize(#[from] ErrorKind), + Deserialize(Box), } /// Error sending a message to a [`Sender`](crate::Sender). #[derive(Debug, Error)] +#[allow(variant_size_differences)] pub enum Sender { /// Failed to [`Serialize`](serde::Serialize) a message for a /// [`Sender`](crate::Sender). #[error("Error serializing a message to `Sender`: {0}")] - Serialize(ErrorKind), + Serialize(Box), /// Failed to write to a [`Sender`](crate::Sender). #[error("Error writing to `Sender`: {0}")] Write(#[from] WriteError), @@ -266,8 +269,14 @@ pub enum Sender { Closed(#[from] AlreadyClosed), } -impl From> for Sender { - fn from(error: Box) -> Self { - Self::Serialize(*error) +impl Sender { + /// Returns a new instance after boxing `err`. + pub(crate) fn from_serialization(err: E) -> Self { + Self::Serialize(Box::new(err)) } } + +/// An error raised from serialization. +pub trait SerializationError: Display + Debug + Send + Sync + 'static {} + +impl SerializationError for T where T: Display + Debug + Send + Sync + 'static {} diff --git a/src/quic/connection/connecting.rs b/src/quic/connection/connecting.rs index dab07ba..c7c7797 100644 --- a/src/quic/connection/connecting.rs +++ b/src/quic/connection/connecting.rs @@ -4,9 +4,12 @@ use std::net::SocketAddr; use quinn::{crypto::rustls::HandshakeData, NewConnection}; -use serde::{de::DeserializeOwned, Serialize}; +use transmog::OwnedDeserializer; -use crate::{error, Connection}; +use crate::{ + error::{self, SerializationError}, + Connection, +}; /// Represent's an intermediate state to build a [`Connection`]. #[must_use = "`Connecting` does nothing unless accepted with `Connecting::accept`"] @@ -47,9 +50,12 @@ impl Connecting { /// /// # Errors /// [`error::Connecting`] if the [`Connection`] failed to be established. - pub async fn accept( - self, - ) -> Result, error::Connecting> { + pub async fn accept(self, format: F) -> Result, error::Connecting> + where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, + { self.0 .await .map( @@ -57,7 +63,7 @@ impl Connecting { connection, bi_streams, .. - }| Connection::new(connection, bi_streams), + }| Connection::new(connection, bi_streams, format), ) .map_err(error::Connecting::from) } diff --git a/src/quic/connection/incoming.rs b/src/quic/connection/incoming.rs index 5622935..b7ba66d 100644 --- a/src/quic/connection/incoming.rs +++ b/src/quic/connection/incoming.rs @@ -4,24 +4,30 @@ use std::fmt::{self, Debug, Formatter}; use futures_util::StreamExt; use quinn::{RecvStream, SendStream}; -use serde::{de::DeserializeOwned, Serialize}; +use transmog::{Format, OwnedDeserializer}; use super::ReceiverStream; -use crate::{error, Receiver, Sender}; +use crate::{ + error::{self, SerializationError}, + Receiver, Sender, +}; /// An intermediate state to define which type to accept in this stream. See /// [`accept_stream`](Self::accept). #[must_use = "`Incoming` does nothing unless accepted with `Incoming::accept`"] -pub struct Incoming { +pub struct Incoming> { /// [`SendStream`] to build [`Sender`]. sender: SendStream, /// [`RecvStream`] to build [`Receiver`]. - receiver: ReceiverStream, + receiver: ReceiverStream, /// Requested type. r#type: Option>, } -impl Debug for Incoming { +impl Debug for Incoming +where + F: OwnedDeserializer, +{ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Incoming") .field("sender", &self.sender) @@ -31,12 +37,16 @@ impl Debug for Incoming { } } -impl Incoming { +impl Incoming +where + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ /// Builds a new [`Incoming`] from raw [`quinn`] types. - pub(super) fn new(sender: SendStream, receiver: RecvStream) -> Self { + pub(super) fn new(sender: SendStream, receiver: RecvStream, format: F) -> Self { Self { sender, - receiver: ReceiverStream::new(receiver), + receiver: ReceiverStream::new(receiver, format), r#type: None, } } @@ -80,12 +90,57 @@ impl Incoming { /// - [`error::Incoming::Receiver`] if receiving the type information to the /// peer failed, see [`error::Receiver`] for more details /// - [`error::Incoming::Closed`] if the stream was closed - pub async fn accept< - S: DeserializeOwned + Serialize + Send + 'static, - R: DeserializeOwned + Serialize + Send + 'static, - >( + pub async fn accept( + self, + ) -> Result<(Sender, Receiver), error::Incoming> + where + F: OwnedDeserializer + Format<'static, S> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let format = self.receiver.format.clone(); + self.accept_with_format(format).await + } + + /// Accept the incoming stream with the given types. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving. + /// + /// # Errors + /// - [`error::Incoming::Receiver`] if receiving the type information to the + /// peer failed, see [`error::Receiver`] for more details + /// - [`error::Incoming::Closed`] if the stream was closed + pub async fn accept_raw( + self, + ) -> Result<(Sender, Receiver), error::Incoming> + where + F: OwnedDeserializer + Format<'static, S> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let format = self.receiver.format.clone(); + self.accept_with_format(format).await + } + + /// Accept the incoming stream with the given types. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving. + /// + /// # Errors + /// - [`error::Incoming::Receiver`] if receiving the type information to the + /// peer failed, see [`error::Receiver`] for more details + /// - [`error::Incoming::Closed`] if the stream was closed + pub async fn accept_with_format( mut self, - ) -> Result<(Sender, Receiver), error::Incoming> { + format: NewFormat, + ) -> Result<(Sender, Receiver), error::Incoming> + where + NewFormat: OwnedDeserializer + Format<'static, S> + Clone + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { match self.r#type { Some(Ok(_)) => (), Some(Err(error)) => return Err(error), @@ -100,8 +155,8 @@ impl Incoming { } } - let sender = Sender::new(self.sender); - let receiver = Receiver::new(self.receiver.transmute()); + let sender = Sender::new(self.sender, format.clone()); + let receiver = Receiver::new(self.receiver.transmute(format)); Ok((sender, receiver)) } diff --git a/src/quic/connection/mod.rs b/src/quic/connection/mod.rs index 9521afc..508a10a 100644 --- a/src/quic/connection/mod.rs +++ b/src/quic/connection/mod.rs @@ -34,16 +34,23 @@ use quinn::{crypto::rustls::HandshakeData, IncomingBiStreams, VarInt}; pub use receiver::Receiver; use receiver_stream::ReceiverStream; pub use sender::Sender; -use serde::{de::DeserializeOwned, Serialize}; use stream::Stream; +use transmog::{Format, OwnedDeserializer}; use super::Task; -use crate::{error, CertificateChain}; +use crate::{ + error::{self, SerializationError}, + CertificateChain, +}; /// Represents an open connection. Receives [`Incoming`] through [`Stream`]. #[pin_project] #[derive(Clone)] -pub struct Connection { +pub struct Connection +where + T: Send + 'static, + F: OwnedDeserializer, +{ /// Initiate new connections or close socket. connection: quinn::Connection, /// Receive incoming streams. @@ -51,11 +58,17 @@ pub struct Connection { RecvStream<'static, Result<(quinn::SendStream, quinn::RecvStream), error::Connection>>, /// [`Task`] handling new incoming streams. task: Task<()>, + /// Serialization foramt. + format: F, /// Type for type negotiation for new streams. types: PhantomData, } -impl Debug for Connection { +impl Debug for Connection +where + T: Send + 'static, + F: OwnedDeserializer, +{ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Connection") .field("connection", &self.connection) @@ -66,10 +79,19 @@ impl Debug for Connection { } } -impl Connection { +impl Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ /// Builds a new [`Connection`] from raw [`quinn`] types. #[allow(clippy::mut_mut)] // futures_util::select_biased internal usage - pub(super) fn new(connection: quinn::Connection, bi_streams: IncomingBiStreams) -> Self { + pub(super) fn new( + connection: quinn::Connection, + bi_streams: IncomingBiStreams, + format: F, + ) -> Self { // channels for passing down new `Incoming` `Connection`s let (sender, receiver) = flume::unbounded(); let receiver = receiver.into_stream(); @@ -96,6 +118,7 @@ impl Connection { connection, receiver, task, + format, types: PhantomData, } } @@ -110,19 +133,51 @@ impl Connection { /// - [`error::Stream::Open`] if opening a stream failed /// - [`error::Stream::Sender`] if sending the type information to the peer /// failed, see [`error::Sender`] for more details - pub async fn open_stream< - S: DeserializeOwned + Serialize + Send + 'static, - R: DeserializeOwned + Serialize + Send + 'static, - >( + pub async fn open_stream( + &self, + r#type: &T, + ) -> Result<(Sender, Receiver), error::Stream> + where + F: OwnedDeserializer + Format<'static, S> + Format<'static, T> + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { + let (sender, receiver) = self.connection.open_bi().await?; + + let sender = Sender::new(sender, self.format.clone()); + let receiver = Receiver::new(ReceiverStream::new(receiver, self.format.clone())); + + sender.send_any(r#type, &self.format)?; + + Ok((sender, receiver)) + } + + /// Open a stream on this [`Connection`], allowing to send data back and + /// forth. + /// + /// Use `S` and `R` to define which type this stream is sending and + /// receiving and `type` to send this information to the receiver. + /// + /// # Errors + /// - [`error::Stream::Open`] if opening a stream failed + /// - [`error::Stream::Sender`] if sending the type information to the peer + /// failed, see [`error::Sender`] for more details + pub async fn open_stream_with_format( &self, r#type: &T, - ) -> Result<(Sender, Receiver), error::Stream> { + format: NewFormat, + ) -> Result<(Sender, Receiver), error::Stream> + where + NewFormat: OwnedDeserializer + Format<'static, S> + Clone + 'static, + >::Error: SerializationError, + >::Error: SerializationError, + { let (sender, receiver) = self.connection.open_bi().await?; - let sender = Sender::new(sender); - let receiver = Receiver::new(ReceiverStream::new(receiver)); + let sender = Sender::new(sender, format.clone()); + let receiver = Receiver::new(ReceiverStream::new(receiver, format)); - sender.send_any(&r#type)?; + sender.send_any(r#type, &self.format)?; Ok((sender, receiver)) } @@ -175,8 +230,13 @@ impl Connection { } } -impl Stream for Connection { - type Item = Result, error::Connection>; +impl Stream for Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ + type Item = Result, error::Connection>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_terminated() { @@ -184,12 +244,17 @@ impl Stream for Connection } else { self.receiver .poll_next_unpin(cx) - .map_ok(|(sender, receiver)| Incoming::new(sender, receiver)) + .map_ok(|(sender, receiver)| Incoming::new(sender, receiver, self.format.clone())) } } } -impl FusedStream for Connection { +impl FusedStream for Connection +where + T: Send + 'static, + F: OwnedDeserializer + Clone, + F::Error: SerializationError, +{ fn is_terminated(&self) -> bool { self.receiver.is_terminated() } diff --git a/src/quic/connection/receiver.rs b/src/quic/connection/receiver.rs index 9cbccdd..f9d2fa4 100644 --- a/src/quic/connection/receiver.rs +++ b/src/quic/connection/receiver.rs @@ -7,10 +7,10 @@ use std::{ }; use futures_util::{stream::Stream, StreamExt}; -use serde::de::DeserializeOwned; +use transmog::OwnedDeserializer; use super::{ReceiverStream, Task}; -use crate::error; +use crate::error::{self, SerializationError}; /// Used to receive data from a stream. Will stop receiving message if /// deserialization failed. @@ -31,13 +31,17 @@ impl Debug for Receiver { } } -impl Receiver { +impl Receiver +where + T: Send, +{ /// Builds a new [`Receiver`] from a raw [`quinn`] type. Spawns a task that /// receives data from the stream. #[allow(clippy::mut_mut)] // futures_util::select_biased internal usage - pub(super) fn new(mut stream: ReceiverStream) -> Self + pub(super) fn new(mut stream: ReceiverStream) -> Self where - T: DeserializeOwned + Send, + F: OwnedDeserializer + 'static, + F::Error: SerializationError, { // receiver channels let (sender, receiver) = flume::unbounded(); diff --git a/src/quic/connection/receiver_stream.rs b/src/quic/connection/receiver_stream.rs index 6683647..bb9646e 100644 --- a/src/quic/connection/receiver_stream.rs +++ b/src/quic/connection/receiver_stream.rs @@ -8,7 +8,6 @@ use std::{ task::{Context, Poll}, }; -use bincode::ErrorKind; use bytes::{Buf, BufMut, BytesMut}; use futures_util::{ stream::{FusedStream, Stream}, @@ -16,13 +15,16 @@ use futures_util::{ }; use pin_project::pin_project; use quinn::{Chunk, ReadError, RecvStream, VarInt}; -use serde::de::DeserializeOwned; +use transmog::OwnedDeserializer; -use crate::error; +use crate::error::{self, SerializationError}; /// Wrapper around [`RecvStream`] providing framing and deserialization. #[pin_project] -pub(super) struct ReceiverStream { +pub(super) struct ReceiverStream +where + F: OwnedDeserializer, +{ /// Store length of the currently processing message. length: usize, /// Store incoming chunks. @@ -31,30 +33,41 @@ pub(super) struct ReceiverStream { stream: RecvStream, /// True if the stream is complete. complete: bool, + /// The deserialization format. + pub(super) format: F, /// Type to be [`Deserialize`](serde::Deserialize)d _type: PhantomData, } -impl ReceiverStream { +impl ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError + 'static, +{ /// Builds a new [`ReceiverStream`]. - pub(super) fn new(stream: RecvStream) -> Self { + pub(super) fn new(stream: RecvStream, format: F) -> Self { Self { length: 0, // 1480 bytes is a default MTU size configured by quinn-proto buffer: BytesMut::with_capacity(1480), stream, complete: false, + format, _type: PhantomData, } } /// Transmutes this [`ReceiverStream`] to a different message type. - pub(super) fn transmute(self) -> ReceiverStream { + pub(super) fn transmute(self, format: NewFormat) -> ReceiverStream + where + NewFormat: OwnedDeserializer, + { ReceiverStream { length: self.length, buffer: self.buffer, stream: self.stream, complete: self.complete, + format, _type: PhantomData, } } @@ -112,19 +125,20 @@ impl ReceiverStream { /// # Errors /// [`ErrorKind`] if `data` failed to be /// [`Deserialize`](serde::Deserialize)d. - fn deserialize(&mut self) -> Result, ErrorKind> { + #[allow(clippy::as_conversions, trivial_casts)] // False positive + fn deserialize(&mut self) -> Result, Box> { if let Some(length) = self.length() { if self.buffer.len() >= length { // split off the correct amount of data - let data = self.buffer.split_to(length).reader(); + let data = self.buffer.split_to(length); // reset the length self.length = 0; // deserialize message - // TODO: configure bincode, for example make it bounded - bincode::deserialize_from::<_, M>(data) + self.format + .deserialize_owned(&data) .map(Some) - .map_err(|error| *error) + .map_err(|err| Box::new(err) as Box) } else { Ok(None) } @@ -134,14 +148,18 @@ impl ReceiverStream { } } -impl Stream for ReceiverStream { +impl Stream for ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError, +{ type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use futures_util::ready; // did already have enough data to return a message without polling? - if let Some(message) = self.deserialize()? { + if let Some(message) = self.deserialize().map_err(error::Receiver::Deserialize)? { // send back the message return Poll::Ready(Some(Ok(message))); } @@ -154,7 +172,7 @@ impl Stream for ReceiverStream { loop { if ready!(self.poll(cx)?).is_some() { // The stream received some data, but we may not have a full packet. - if let Some(message) = self.deserialize()? { + if let Some(message) = self.deserialize().map_err(error::Receiver::Deserialize)? { break Poll::Ready(Some(Ok(message))); } } else { @@ -166,7 +184,11 @@ impl Stream for ReceiverStream { } } -impl FusedStream for ReceiverStream { +impl FusedStream for ReceiverStream +where + F: OwnedDeserializer, + F::Error: SerializationError, +{ fn is_terminated(&self) -> bool { self.complete } diff --git a/src/quic/connection/sender.rs b/src/quic/connection/sender.rs index 80e248d..d2eefe2 100644 --- a/src/quic/connection/sender.rs +++ b/src/quic/connection/sender.rs @@ -5,16 +5,18 @@ use std::{marker::PhantomData, mem::size_of}; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::StreamExt; use quinn::{SendStream, VarInt}; -use serde::Serialize; +use transmog::Format; use super::Task; -use crate::error; +use crate::error::{self, SerializationError}; /// Used to send data to a stream. #[derive(Clone, Debug)] -pub struct Sender { +pub struct Sender { /// Send [`Serialize`]d data to the sending task. sender: flume::Sender, + /// The serialization format. + format: F, /// Holds the type to [`Serialize`] too. _type: PhantomData, /// [`Task`] handle that does the sending into the stream. @@ -32,10 +34,14 @@ enum Message { Close, } -impl Sender { +impl Sender +where + F: Format<'static, T>, + F::Error: SerializationError, +{ /// Builds a new [`Sender`] from a raw [`quinn`] type. Spawns a task that /// sends data into the stream. - pub(super) fn new(mut stream_sender: SendStream) -> Self { + pub(super) fn new(mut stream_sender: SendStream, format: F) -> Self { // sender channels let (sender, receiver) = flume::unbounded(); @@ -68,6 +74,7 @@ impl Sender { Self { sender, + format, _type: PhantomData, task, } @@ -81,7 +88,7 @@ impl Sender { /// stream /// - [`error::Sender::Closed`] if the [`Sender`] is closed pub fn send(&self, data: &T) -> Result<(), error::Sender> { - self.send_any(data) + self.send_any(data, &self.format) } /// Send any `data` into the stream. This will fail on the receiving end if @@ -93,37 +100,67 @@ impl Sender { /// stream /// - [`error::Sender::Closed`] if the [`Sender`] is closed #[allow(clippy::panic_in_result_fn, clippy::unwrap_in_result)] - pub(super) fn send_any(&self, data: &A) -> Result<(), error::Sender> { + pub(super) fn send_any( + &self, + data: &A, + format: &AnyFormat, + ) -> Result<(), error::Sender> + where + AnyFormat: Format<'static, A>, + AnyFormat::Error: SerializationError, + { let mut bytes = BytesMut::new(); // get size - let len = bincode::serialized_size(&data)?; - // reserve an appropriate amount of space - bytes.reserve( - usize::try_from(len) - .expect("not a 64-bit system") - .checked_add(size_of::()) - .expect("data trying to be sent is too big"), - ); - // insert length first, this enables framing - bytes.put_u64_le(len); - - let mut bytes = bytes.writer(); - - // serialize `data` into `bytes` - bincode::serialize_into(&mut bytes, &data)?; - - // send data to task - let bytes = bytes.into_inner().freeze(); - - // make sure that our length is correct - debug_assert_eq!( - u64::try_from(bytes.len()).expect("not a 64-bit system"), - u64::try_from(size_of::()) - .expect("not a 64-bit system") - .checked_add(len) - .expect("message to long") - ); + let bytes = if let Some(len) = format + .serialized_size(data) + .map_err(error::Sender::from_serialization)? + { + // reserve an appropriate amount of space + bytes.reserve( + len.checked_add(size_of::()) + .expect("data trying to be sent is too big"), + ); + // insert length first, this enables framing + + let len = u64::try_from(len).expect("not a 64-bit system"); + bytes.put_u64_le(len); + + let mut bytes = bytes.writer(); + + // serialize `data` into `bytes` + format + .serialize_into(data, &mut bytes) + .map_err(error::Sender::from_serialization)?; + + let bytes = bytes.into_inner().freeze(); + // make sure that our length is correct + debug_assert_eq!( + u64::try_from(bytes.len()).expect("not a 64-bit system"), + u64::try_from(size_of::()) + .expect("not a 64-bit system") + .checked_add(len) + .expect("message to long") + ); + bytes + } else { + bytes.put_u64_le(0); + let mut bytes = bytes.writer(); + format + .serialize_into(data, &mut bytes) + .map_err(error::Sender::from_serialization)?; + let mut bytes = bytes.into_inner(); + let serialized_length = bytes + .len() + .checked_sub(size_of::()) + .expect("negative bytes written"); + let serialized_length = u64::try_from(serialized_length).expect("not a 64-bit system"); + bytes + .get_mut(0..8) + .expect("bytes never allocated") + .copy_from_slice(&serialized_length.to_le_bytes()); + bytes.freeze() + }; // if the sender task has been dropped, return it's error if self.sender.send(bytes).is_err() { diff --git a/src/quic/endpoint/builder/mod.rs b/src/quic/endpoint/builder/mod.rs index 87222b8..368e354 100644 --- a/src/quic/endpoint/builder/mod.rs +++ b/src/quic/endpoint/builder/mod.rs @@ -49,6 +49,7 @@ pub struct Builder { impl Default for Builder { fn default() -> Self { + // TODO configure for sane max allocations Self::new() } } @@ -505,11 +506,12 @@ impl Builder { false, ) { Ok(client) => client, - Err(error) => + Err(error) => { return Err(error::Builder { error: error.into(), builder: self, - }), + }) + } }; // build server only if we have a key-pair @@ -662,6 +664,7 @@ mod test { use futures_util::StreamExt; use quinn::ConnectionError; use quinn_proto::TransportError; + use transmog_bincode::Bincode; use trust_dns_proto::error::ProtoErrorKind; use trust_dns_resolver::error::ResolveErrorKind; @@ -713,7 +716,7 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // test receiving client on server @@ -721,7 +724,7 @@ mod test { .next() .await .expect("server dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -734,9 +737,10 @@ mod test { // build client let mut builder = Builder::new(); - Dangerous::set_root_certificates(&mut builder, [server_key_pair - .end_entity_certificate() - .clone()]); + Dangerous::set_root_certificates( + &mut builder, + [server_key_pair.end_entity_certificate().clone()], + ); builder.set_client_key_pair(Some(client_key_pair.clone())); let client = builder.build()?; @@ -752,7 +756,7 @@ mod test { server.local_address()?.port() )) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // test receiving client on server @@ -760,7 +764,7 @@ mod test { .next() .await .expect("server dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // validate client certificate @@ -819,7 +823,7 @@ mod test { ); // check protocol on `Connection` - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; assert_eq!( protocols[0], connection.protocol().expect("no protocol found") @@ -835,7 +839,7 @@ mod test { ); // check protocol on `Connection` - let connection = connecting.accept::<()>().await?; + let connection = connecting.accept::<(), _>(Bincode::default()).await?; assert_eq!( protocols[0], connection.protocol().expect("no protocol found") @@ -867,7 +871,7 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; // check result @@ -1003,7 +1007,7 @@ mod test { assert!(endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await .is_ok()); @@ -1030,7 +1034,7 @@ mod test { assert!(endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await .is_ok()); @@ -1057,7 +1061,7 @@ mod test { let result = endpoint .connect("https://cloudflare-quic.com:443") .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; // check result diff --git a/src/quic/endpoint/mod.rs b/src/quic/endpoint/mod.rs index c723e7b..ea3f13d 100644 --- a/src/quic/endpoint/mod.rs +++ b/src/quic/endpoint/mod.rs @@ -615,6 +615,7 @@ mod test { use futures_util::StreamExt; use quinn::{ConnectionClose, ConnectionError}; use quinn_proto::TransportErrorCode; + use transmog_bincode::Bincode; use super::*; use crate::KeyPair; @@ -638,13 +639,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -663,13 +664,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; Ok(()) @@ -712,13 +713,13 @@ mod test { let _connection = client .connect_pinned(&address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // closing the client/server will close all connection immediately @@ -730,7 +731,7 @@ mod test { client .connect_pinned(address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await, Err(error::Connecting::Connection( ConnectionError::LocallyClosed @@ -759,13 +760,13 @@ mod test { let client_connection = client .connect_pinned(&address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let mut server_connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; // refuse new incoming connections @@ -784,7 +785,7 @@ mod test { let result = client .connect_pinned(address, key_pair.end_entity_certificate(), None) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await; assert!(matches!( result, @@ -837,13 +838,13 @@ mod test { None, ) .await? - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; let _connection = server .next() .await .expect("client dropped") - .accept::<()>() + .accept::<(), _>(Bincode::default()) .await?; } diff --git a/src/x509/mod.rs b/src/x509/mod.rs index c173106..76cfe99 100644 --- a/src/x509/mod.rs +++ b/src/x509/mod.rs @@ -182,7 +182,7 @@ where #[test] fn serialize() -> anyhow::Result<()> { - use bincode::{config::DefaultOptions, Options, Serializer}; + use transmog_bincode::bincode::{self, config::DefaultOptions, Options, Serializer}; let key_pair = KeyPair::new_self_signed("test"); diff --git a/src/x509/private_key.rs b/src/x509/private_key.rs index 292db04..2884c44 100644 --- a/src/x509/private_key.rs +++ b/src/x509/private_key.rs @@ -129,7 +129,7 @@ impl Dangerous for PrivateKey { #[cfg(test)] mod test { use anyhow::Result; - use bincode::{config::DefaultOptions, Options, Serializer}; + use transmog_bincode::bincode::{self, config::DefaultOptions, Options, Serializer}; use super::*; use crate::KeyPair; diff --git a/tests/local-connections.rs b/tests/local-connections.rs new file mode 100644 index 0000000..d89bb60 --- /dev/null +++ b/tests/local-connections.rs @@ -0,0 +1,179 @@ +use anyhow::{Error, Result}; +use fabruic::{Endpoint, KeyPair}; +use futures_util::{future, StreamExt, TryFutureExt}; +use transmog::{Format, OwnedDeserializer}; +use transmog_bincode::Bincode; +use transmog_pot::Pot; + +const SERVER_NAME: &str = "test"; +const CLIENTS: usize = 100; + +async fn simulate_client_and_server(format: F) -> Result<()> +where + F: OwnedDeserializer<()> + OwnedDeserializer + Clone + 'static, + >::Error: Send + Sync + 'static, + >::Error: Send + Sync + 'static, +{ + // collect all tasks + let mut clients = Vec::with_capacity(CLIENTS); + + // generate a certificate pair + let key_pair = KeyPair::new_self_signed(SERVER_NAME); + + // build the server + // we want to do this outside to reserve the `SERVER_PORT`, otherwise spawned + // clients may take it + let mut server = Endpoint::new_server(0, key_pair.clone())?; + let address = format!("quic://{}", server.local_address()?); + + // start 100 clients + for index in 0..CLIENTS { + let address = address.clone(); + let certificate = key_pair.end_entity_certificate().clone(); + + let task_format = format.clone(); + clients.push( + tokio::spawn(async move { + // build a client + let client = Endpoint::new_client()?; + + let connecting = client.connect_pinned(address, &certificate, None).await?; + println!( + "[client:{}] Connecting to {}", + index, + connecting.remote_address() + ); + let connection = connecting.accept::<(), _>(task_format).await?; + println!( + "[client:{}] Successfully connected to {}", + index, + connection.remote_address() + ); + connection.close_incoming().await?; + + // initiate a stream + let (sender, mut receiver) = connection.open_stream::(&()).await?; + println!( + "[client:{}] Successfully opened stream to {}", + index, + connection.remote_address() + ); + + // send message + sender.send(&format!("hello from client {}", index))?; + + // start listening to new incoming messages + // in this example we know there is only 1 incoming message, so we will + // not wait for more + let message = receiver.next().await.expect("no message found")?; + println!( + "[client:{}] New message from {}: {}", + index, + connection.remote_address(), + message + ); + + // wait for stream to finish + sender.finish().await?; + receiver.finish().await?; + + // wait for client to finish cleanly + client.wait_idle().await; + println!( + "[client:{}] Successfully finished {}", + index, + client.local_address()? + ); + + Result::<_, Error>::Ok(()) + }) + .map_err(Error::from) + .and_then(future::ready), + ); + } + + // start the server + println!("[server] Listening on {}", server.local_address()?); + + // collect incoming connection tasks + let mut connections = Vec::with_capacity(CLIENTS); + + // start listening to new incoming connections + // in this example we know there is `CLIENTS` number of clients, so we will not + // wait for more + for _ in 0..CLIENTS { + let connecting = server.next().await.expect("connection failed"); + + println!( + "[server] New incoming Connection: {}", + connecting.remote_address() + ); + + let task_format = format.clone(); + // every new incoming connections is handled in it's own task + connections.push( + tokio::spawn(async move { + let mut connection = connecting.accept::<(), _>(task_format.clone()).await?; + println!("[server] New Connection: {}", connection.remote_address()); + + // start listening to new incoming streams + // in this example we know there is only 1 incoming stream, so we will not wait + // for more + let incoming = connection.next().await.expect("no stream found")?; + connection.close_incoming().await?; + println!( + "[server] New incoming stream from: {}", + connection.remote_address() + ); + + // accept stream + let (sender, mut receiver) = incoming.accept::().await?; + + // start listening to new incoming messages + // in this example we know there is only 1 incoming message, so we will not wait + // for more + let message = receiver.next().await.expect("no message found")?; + println!( + "[server] New message from {}: {}", + connection.remote_address(), + message + ); + + // respond + sender.send(&String::from("hello from server"))?; + + // wait for stream to finish + sender.finish().await?; + receiver.finish().await?; + + Result::<_, Error>::Ok(()) + }) + .map_err(Error::from) + .and_then(future::ready), + ); + } + + server.close_incoming().await?; + + // wait for all connections to finish + future::try_join_all(connections).await?; + + // wait for server to finish cleanly + server.wait_idle().await; + println!("[server] Successfully finished {}", server.local_address()?); + + future::try_join_all(clients).await?; + + Ok(()) +} + +#[tokio::test] +async fn format_without_serialized_size() -> Result<()> { + simulate_client_and_server(Pot::default()).await +} + +#[tokio::main] +#[cfg_attr(test, test)] +async fn format_with_serialized_size() -> Result<()> { + simulate_client_and_server(Bincode::default()).await +}