diff --git a/Cargo.lock b/Cargo.lock index 4c02649..29d6a67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3220,6 +3220,7 @@ dependencies = [ "time", "tokio", "tokio-serde", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", @@ -4466,9 +4467,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 1c86857..04ae7c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ rustls = { version = "0.23", default-features = false, features = ["ring"], opti slab = "0.4.9" # iroh-quinn smallvec = "1.13.2" time = "0.3.36" # serde +tokio-stream = "0.1.17" [dev-dependencies] anyhow = "1.0.73" diff --git a/examples/errors.rs b/examples/errors.rs index a8e015e..5daa5e9 100644 --- a/examples/errors.rs +++ b/examples/errors.rs @@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> { let fs = Fs; let (server, client) = quic_rpc::transport::flume::channel(1); let client = RpcClient::::new(client); - let server = RpcServer::new(server); + let mut server = RpcServer::new(server); let handle = tokio::task::spawn(async move { for _ in 0..1 { let (req, chan) = server.accept().await?.read_first().await?; diff --git a/examples/store.rs b/examples/store.rs index b99edea..106ad40 100644 --- a/examples/store.rs +++ b/examples/store.rs @@ -166,7 +166,7 @@ async fn main() -> anyhow::Result<()> { async fn server_future>( server: RpcServer, ) -> result::Result<(), RpcServerError> { - let s = server; + let mut s = server; let store = Store; loop { let (req, chan) = s.accept().await?.read_first().await?; @@ -239,7 +239,7 @@ async fn _main_unsugared() -> anyhow::Result<()> { type Req = u64; type Res = String; } - let (server, client) = flume::channel::(1); + let (mut server, client) = flume::channel::(1); let to_string_service = tokio::spawn(async move { let (mut send, mut recv) = server.accept().await?; while let Some(item) = recv.next().await { diff --git a/src/server.rs b/src/server.rs index 47d1f70..a9bffc2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -64,15 +64,6 @@ pub struct RpcServer> { _p: PhantomData, } -impl Clone for RpcServer { - fn clone(&self) -> Self { - Self { - source: self.source.clone(), - _p: PhantomData, - } - } -} - impl> RpcServer { /// Create a new rpc server for a specific service for a [Service] given a compatible /// [Listener]. @@ -201,7 +192,7 @@ impl> Accepting { impl> RpcServer { /// Accepts a new channel from a client. The result is an [Accepting] object that /// can be used to read the first request. - pub async fn accept(&self) -> result::Result, RpcServerError> { + pub async fn accept(&mut self) -> result::Result, RpcServerError> { let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?; Ok(Accepting { send, @@ -220,7 +211,7 @@ impl> RpcServer { /// Each request will be handled in a separate task. /// /// It is the caller's responsibility to poll the returned future to drive the server. - pub async fn accept_loop(self, handler: Fun) + pub async fn accept_loop(mut self, handler: Fun) where S: Service, C: Listener, @@ -462,7 +453,7 @@ where F: FnMut(RpcChannel, S::Req, T) -> Fut + Send + 'static, Fut: Future>> + Send + 'static, { - let server: RpcServer = RpcServer::::new(conn); + let mut server: RpcServer = RpcServer::::new(conn); loop { let (req, chan) = server.accept().await?.read_first().await?; let target = target.clone(); diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index 8f4a886..7bd4abb 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -290,11 +290,8 @@ impl StreamTypes for BoxedStreamTypes /// A boxable listener pub trait BoxableListener: Debug + Send + Sync + 'static { - /// Clone the listener and box it - fn clone_box(&self) -> Box>; - /// Accept a channel from a remote client - fn accept_bi_boxed(&self) -> AcceptFuture; + fn accept_bi_boxed(&mut self) -> AcceptFuture; /// Get the local address fn local_addr(&self) -> &[super::LocalAddr]; @@ -311,12 +308,6 @@ impl BoxedListener { } } -impl Clone for BoxedListener { - fn clone(&self) -> Self { - Self(self.0.clone_box()) - } -} - impl StreamTypes for BoxedListener { type In = In; type Out = Out; @@ -333,7 +324,7 @@ impl ConnectionErrors for BoxedListener super::Listener for BoxedListener { fn accept( - &self, + &mut self, ) -> impl Future> + Send { self.0.accept_bi_boxed() @@ -378,11 +369,7 @@ impl BoxableConnector impl BoxableListener for super::quinn::QuinnListener { - fn clone_box(&self) -> Box> { - Box::new(self.clone()) - } - - fn accept_bi_boxed(&self) -> AcceptFuture { + fn accept_bi_boxed(&mut self) -> AcceptFuture { let f = async move { let (send, recv) = super::Listener::accept(self).await?; let send = send.sink_map_err(anyhow::Error::from); @@ -422,11 +409,7 @@ impl BoxableConnector impl BoxableListener for super::iroh::IrohListener { - fn clone_box(&self) -> Box> { - Box::new(self.clone()) - } - - fn accept_bi_boxed(&self) -> AcceptFuture { + fn accept_bi_boxed(&mut self) -> AcceptFuture { let f = async move { let (send, recv) = super::Listener::accept(self).await?; let send = send.sink_map_err(anyhow::Error::from); @@ -458,11 +441,7 @@ impl BoxableConnector impl BoxableListener for super::flume::FlumeListener { - fn clone_box(&self) -> Box> { - Box::new(self.clone()) - } - - fn accept_bi_boxed(&self) -> AcceptFuture { + fn accept_bi_boxed(&mut self) -> AcceptFuture { AcceptFuture::direct(super::Listener::accept(self)) } @@ -520,7 +499,7 @@ mod tests { use crate::transport::{Connector, Listener}; let (server, client) = crate::transport::flume::channel(1); - let server = super::BoxedListener::new(server); + let mut server = super::BoxedListener::new(server); let client = super::BoxedConnector::new(client); // spawn echo server tokio::spawn(async move { diff --git a/src/transport/combined.rs b/src/transport/combined.rs index 60e6843..737c13c 100644 --- a/src/transport/combined.rs +++ b/src/transport/combined.rs @@ -241,9 +241,9 @@ impl> StreamTypes for Combine } impl> Listener for CombinedListener { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { let a_fut = async { - if let Some(a) = &self.a { + if let Some(a) = &mut self.a { let (send, recv) = a.accept().await.map_err(AcceptError::A)?; Ok((SendSink::A(send), RecvStream::A(recv))) } else { @@ -251,7 +251,7 @@ impl> Listener for CombinedLi } }; let b_fut = async { - if let Some(b) = &self.b { + if let Some(b) = &mut self.b { let (send, recv) = b.accept().await.map_err(AcceptError::B)?; Ok((SendSink::B(send), RecvStream::B(recv))) } else { diff --git a/src/transport/flume.rs b/src/transport/flume.rs index 8db6d03..55fc09f 100644 --- a/src/transport/flume.rs +++ b/src/transport/flume.rs @@ -203,7 +203,7 @@ impl StreamTypes for FlumeListener { impl Listener for FlumeListener { #[allow(refining_impl_trait)] - fn accept(&self) -> AcceptFuture { + fn accept(&mut self) -> AcceptFuture { AcceptFuture { wrapped: self.stream.clone().into_recv_async(), _p: PhantomData, diff --git a/src/transport/hyper.rs b/src/transport/hyper.rs index 3ef4190..6f993f1 100644 --- a/src/transport/hyper.rs +++ b/src/transport/hyper.rs @@ -631,7 +631,7 @@ impl Listener for HyperListener { &self.local_addr } - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { let (recv, send) = self .channel .recv_async() diff --git a/src/transport/iroh.rs b/src/transport/iroh.rs index c084698..c6479ce 100644 --- a/src/transport/iroh.rs +++ b/src/transport/iroh.rs @@ -283,7 +283,7 @@ impl StreamTypes for IrohListener { } impl Listener for IrohListener { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { let (send, recv) = self .inner .receiver diff --git a/src/transport/mapped.rs b/src/transport/mapped.rs index 649ac1f..d774ee8 100644 --- a/src/transport/mapped.rs +++ b/src/transport/mapped.rs @@ -298,7 +298,7 @@ mod tests { // create a listener / connector pair. Type will be inferred let (s, c) = crate::transport::flume::channel(32); // wrap the server in a RpcServer, this is where the service type is specified - let server = RpcServer::::new(s.clone()); + let mut server = RpcServer::::new(s.clone()); // when using a boxed transport, we can omit the transport type and use the default let _server_boxed: RpcServer = RpcServer::::new(s.boxed()); // create a client in a RpcClient, this is where the service type is specified diff --git a/src/transport/misc/mod.rs b/src/transport/misc/mod.rs index 59bd19c..5d31e6b 100644 --- a/src/transport/misc/mod.rs +++ b/src/transport/misc/mod.rs @@ -42,7 +42,7 @@ impl StreamTypes for DummyListener { } impl Listener for DummyListener { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { futures_lite::future::pending().await } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index ed82fed..720aecb 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -41,12 +41,13 @@ pub mod mapped; pub mod misc; #[cfg(feature = "quinn-transport")] pub mod quinn; +pub mod tokio; #[cfg(any(feature = "quinn-transport", feature = "iroh-transport"))] mod util; /// Errors that can happen when creating and using a [`Connector`] or [`Listener`]. -pub trait ConnectionErrors: Debug + Clone + Send + Sync + 'static { +pub trait ConnectionErrors: Debug + Send + Sync + 'static { /// Error when sending a message via a channel type SendError: RpcError; /// Error when receiving a message via a channel @@ -78,7 +79,7 @@ pub trait StreamTypes: ConnectionErrors { /// A connection to a specific remote machine /// /// A connection can be used to open bidirectional typed channels using [`Connector::open`]. -pub trait Connector: StreamTypes { +pub trait Connector: StreamTypes + Clone { /// Open a channel to the remote che fn open( &self, @@ -110,7 +111,7 @@ pub trait Listener: StreamTypes { /// Accept a new typed bidirectional channel on any of the connections we /// have currently opened. fn accept( - &self, + &mut self, ) -> impl Future> + Send; /// The local addresses this endpoint is bound to. diff --git a/src/transport/quinn.rs b/src/transport/quinn.rs index 89a69e6..f4e1528 100644 --- a/src/transport/quinn.rs +++ b/src/transport/quinn.rs @@ -207,7 +207,7 @@ impl StreamTypes for QuinnListener { } impl Listener for QuinnListener { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { let (send, recv) = self .inner .receiver diff --git a/src/transport/tokio.rs b/src/transport/tokio.rs new file mode 100644 index 0000000..c1b6624 --- /dev/null +++ b/src/transport/tokio.rs @@ -0,0 +1,326 @@ +//! Memory transport implementation using [tokio::sync::mpsc::channel]. +use core::fmt; +use std::{error, fmt::Display, pin::Pin, result, task::Poll}; + +use futures_lite::{Future, Stream}; +use futures_sink::Sink; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::PollSender; + +use super::StreamTypes; +use crate::{ + transport::{ConnectionErrors, Connector, Listener, LocalAddr}, + RpcMessage, +}; + +/// Error when receiving from a channel +/// +/// This type has zero inhabitants, so it is always safe to unwrap a result with this error type. +#[derive(Debug)] +pub enum RecvError {} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +/// Sink for memory channels +pub struct SendSink(pub(crate) tokio_util::sync::PollSender); + +impl fmt::Debug for SendSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendSink").finish() + } +} + +impl Sink for SendSink { + type Error = self::SendError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_ready(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::new(&mut self.0) + .start_send(item) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_flush(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_close(cx) + .map_err(|_| SendError::ReceiverDropped) + } +} + +/// Stream for memory channels +pub struct RecvStream(pub(crate) tokio_stream::wrappers::ReceiverStream); + +impl fmt::Debug for RecvStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RecvStream").finish() + } +} + +impl Stream for RecvStream { + type Item = result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.0).poll_next(cx) { + Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl error::Error for RecvError {} + +/// A flume based listener. +/// +/// Created using [channel]. +pub struct MemListener { + #[allow(clippy::type_complexity)] + stream: tokio::sync::mpsc::Receiver<(SendSink, RecvStream)>, +} + +impl fmt::Debug for MemListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FlumeListener") + .field("stream", &self.stream) + .finish() + } +} + +impl ConnectionErrors for MemListener { + type SendError = self::SendError; + type RecvError = self::RecvError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; +} + +type Socket = (self::SendSink, self::RecvStream); + +/// Future returned by [FlumeConnection::open] +pub struct OpenFuture { + inner: PollSender>, + send: Option>, + res: Option>, +} + +impl fmt::Debug for OpenFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpenBiFuture").finish() + } +} + +impl OpenFuture { + fn new( + inner: PollSender>, + send: Socket, + res: Socket, + ) -> Self { + Self { + inner, + send: Some(send), + res: Some(res), + } + } +} + +impl Future for OpenFuture { + type Output = result::Result, self::OpenError>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match Pin::new(&mut self.inner).poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let Some(item) = self.send.take() else { + return Poll::Pending; + }; + let Ok(_) = self.inner.send_item(item) else { + return Poll::Ready(Err(self::OpenError::RemoteDropped)); + }; + self.res + .take() + .map(|x| Poll::Ready(Ok(x))) + .unwrap_or(Poll::Pending) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenError::RemoteDropped)), + Poll::Pending => Poll::Pending, + } + } +} + +impl StreamTypes for MemListener { + type In = In; + type Out = Out; + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl Listener for MemListener { + async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { + match self.stream.recv().await { + Some((send, recv)) => Ok((send, recv)), + None => Err(AcceptError::RemoteDropped), + } + } + + fn local_addr(&self) -> &[LocalAddr] { + &[LocalAddr::Mem] + } +} + +impl ConnectionErrors for MemConnector { + type SendError = self::SendError; + type RecvError = self::RecvError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; +} + +impl StreamTypes for MemConnector { + type In = In; + type Out = Out; + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl Connector for MemConnector { + #[allow(refining_impl_trait)] + fn open(&self) -> OpenFuture { + let (local_send, remote_recv) = tokio::sync::mpsc::channel::(128); + let (remote_send, local_recv) = tokio::sync::mpsc::channel::(128); + let remote_chan = ( + SendSink(PollSender::new(remote_send)), + RecvStream(ReceiverStream::new(remote_recv)), + ); + let local_chan = ( + SendSink(PollSender::new(local_send)), + RecvStream(ReceiverStream::new(local_recv)), + ); + let sender = PollSender::new(self.sink.clone()); + OpenFuture::new(sender, remote_chan, local_chan) + } +} + +/// A flume based connector. +/// +/// Created using [channel]. +pub struct MemConnector { + #[allow(clippy::type_complexity)] + sink: tokio::sync::mpsc::Sender<(SendSink, RecvStream)>, +} + +impl Clone for MemConnector { + fn clone(&self) -> Self { + Self { + sink: self.sink.clone(), + } + } +} + +impl fmt::Debug for MemConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MemConnector") + .field("sink", &self.sink) + .finish() + } +} + +/// AcceptError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum AcceptError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl fmt::Display for AcceptError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for AcceptError {} + +/// SendError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum SendError { + /// Receiver was dropped + ReceiverDropped, +} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for SendError {} + +/// OpenError for mem channels. +#[derive(Debug)] +pub enum OpenError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl Display for OpenError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for OpenError {} + +/// CreateChannelError for mem channels. +/// +/// You can always create a mem channel, so there is no possible error. +/// Nevertheless we need a type for it. +#[derive(Debug, Clone, Copy)] +pub enum CreateChannelError {} + +impl Display for CreateChannelError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for CreateChannelError {} + +/// Create a flume listener and a connected flume connector. +/// +/// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure +pub fn channel( + buffer: usize, +) -> (MemListener, MemConnector) { + let (sink, stream) = tokio::sync::mpsc::channel(buffer); + (MemListener { stream }, MemConnector { sink }) +} diff --git a/tests/flume.rs b/tests/flume.rs index 34fc900..cd00dd0 100644 --- a/tests/flume.rs +++ b/tests/flume.rs @@ -58,7 +58,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> { } let (server, client) = flume::channel(1); - let server = RpcServer::::new(server); + let mut server = RpcServer::::new(server); let server_handle: tokio::task::JoinHandle>> = tokio::task::spawn(async move { let service = ComputeService; diff --git a/tests/hyper.rs b/tests/hyper.rs index fa67144..66119aa 100644 --- a/tests/hyper.rs +++ b/tests/hyper.rs @@ -163,7 +163,7 @@ async fn hyper_channel_errors() -> anyhow::Result<()> { Receiver>>, ) { let channel = HyperListener::serve(addr).unwrap(); - let server = RpcServer::new(channel); + let mut server = RpcServer::new(channel); let (res_tx, res_rx) = flume::unbounded(); let handle = tokio::spawn(async move { loop { diff --git a/tests/math.rs b/tests/math.rs index b628c52..b476c78 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -197,7 +197,7 @@ impl ComputeService { count: usize, ) -> result::Result, RpcServerError> { tracing::info!(%count, "server running"); - let s = server; + let mut s = server; let mut received = 0; let service = ComputeService; while received < count { @@ -222,46 +222,6 @@ impl ComputeService { tracing::info!(%count, "server finished"); Ok(s) } - - pub async fn server_par>( - server: RpcServer, - parallelism: usize, - ) -> result::Result<(), RpcServerError> { - let s = server.clone(); - let s2 = s.clone(); - let service = ComputeService; - let request_stream = stream! { - loop { - yield s2.accept().await?.read_first().await; - } - }; - let process_stream = request_stream.map(move |r| { - let service = service.clone(); - async move { - let (req, chan) = r?; - use ComputeRequest::*; - #[rustfmt::skip] - match req { - Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await, - Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await, - Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await, - Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await, - SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, - MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, - }?; - Ok::<_, RpcServerError>(()) - } - }); - process_stream - .buffered_unordered(parallelism) - .for_each(|x| { - if let Err(e) = x { - eprintln!("error: {e:?}"); - } - }) - .await; - Ok(()) - } } pub async fn smoke_test>(client: C) -> anyhow::Result<()> { diff --git a/tests/slow_math.rs b/tests/slow_math.rs index 2060a9a..178974e 100644 --- a/tests/slow_math.rs +++ b/tests/slow_math.rs @@ -111,7 +111,7 @@ impl ComputeService { pub async fn server>( server: RpcServer, ) -> result::Result<(), RpcServerError> { - let s = server; + let mut s = server; let service = ComputeService; loop { let (req, chan) = s.accept().await?.read_first().await?; diff --git a/tests/tokio.rs b/tests/tokio.rs new file mode 100644 index 0000000..ce78f7c --- /dev/null +++ b/tests/tokio.rs @@ -0,0 +1,102 @@ +#![allow(non_local_definitions)] +mod math; +use math::*; +use quic_rpc::{ + server::{RpcChannel, RpcServerError}, + transport::tokio as tkio, + RpcClient, RpcServer, Service, +}; +use tokio_util::task::AbortOnDropHandle; + +#[tokio::test] +async fn tokio_channel_bench() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = tkio::channel(1); + + let server = RpcServer::::new(server); + let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server))); + let client = RpcClient::::new(client); + bench(client, 1000000).await?; + Ok(()) +} + +#[tokio::test] +async fn tokio_channel_mapped_bench() -> anyhow::Result<()> { + use derive_more::{From, TryInto}; + use serde::{Deserialize, Serialize}; + + tracing_subscriber::fmt::try_init().ok(); + + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterRequest { + Inner(InnerRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerRequest { + Compute(ComputeRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterResponse { + Inner(InnerResponse), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerResponse { + Compute(ComputeResponse), + } + #[derive(Debug, Clone)] + struct OuterService; + impl Service for OuterService { + type Req = OuterRequest; + type Res = OuterResponse; + } + #[derive(Debug, Clone)] + struct InnerService; + impl Service for InnerService { + type Req = InnerRequest; + type Res = InnerResponse; + } + let (server, client) = tkio::channel(1); + + let mut server = RpcServer::::new(server); + let server_handle: tokio::task::JoinHandle>> = + tokio::task::spawn(async move { + let service = ComputeService; + loop { + let (req, chan) = server.accept().await?.read_first().await?; + let service = service.clone(); + tokio::spawn(async move { + let req: OuterRequest = req; + match req { + OuterRequest::Inner(InnerRequest::Compute(req)) => { + let chan: RpcChannel = chan.map(); + let chan: RpcChannel = chan.map(); + ComputeService::handle_rpc_request(service, req, chan).await + } + } + }); + } + }); + + let client = RpcClient::::new(client); + let client: RpcClient = client.map(); + let client: RpcClient = client.map(); + bench(client, 1000000).await?; + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} + +/// simple happy path test for all 4 patterns +#[tokio::test] +async fn tokio_channel_smoke() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = tkio::channel(1); + + let server = RpcServer::::new(server); + let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server))); + smoke_test(client).await?; + Ok(()) +} diff --git a/tests/try.rs b/tests/try.rs index b11f633..2740c80 100644 --- a/tests/try.rs +++ b/tests/try.rs @@ -74,7 +74,7 @@ async fn try_server_streaming() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); let (server, client) = flume::channel(1); - let server = RpcServer::::new(server); + let mut server = RpcServer::::new(server); let server_handle = tokio::task::spawn(async move { loop { let (req, chan) = server.accept().await?.read_first().await?;