diff --git a/Cargo.lock b/Cargo.lock index 72e5daf..d7eef6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,6 +713,18 @@ dependencies = [ "getrandom", ] +[[package]] +name = "nested_enum_utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f256ef99e7ac37428ef98c89bef9d84b590172de4bbfbe81b68a4cd3abadb32" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -870,6 +882,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -881,7 +902,7 @@ dependencies = [ [[package]] name = "quic-rpc" -version = "0.14.0" +version = "0.15.0" dependencies = [ "anyhow", "async-stream", @@ -898,12 +919,14 @@ dependencies = [ "hex", "hyper", "iroh-quinn", + "nested_enum_utils", "pin-project", "proc-macro2", "rcgen", "serde", "slab", "tempfile", + "testresult", "thousands", "tokio", "tokio-serde", @@ -914,7 +937,7 @@ dependencies = [ [[package]] name = "quic-rpc-derive" -version = "0.14.0" +version = "0.15.0" dependencies = [ "derive_more", "proc-macro2", @@ -1382,6 +1405,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "testresult" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614b328ff036a4ef882c61570f72918f7e9c5bee1da33f8e7f91e01daee7e56c" + [[package]] name = "thiserror" version = "1.0.63" diff --git a/Cargo.toml b/Cargo.toml index 2454b4e..6471559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "quic-rpc" -version = "0.14.0" +version = "0.15.0" edition = "2021" authors = ["RĂ¼diger Klaehn ", "n0 team"] keywords = ["api", "protocol", "network", "rpc"] @@ -49,12 +49,13 @@ tracing-subscriber = "0.3.16" tempfile = "3.5.0" proc-macro2 = "1.0.66" futures-buffered = "0.2.4" +testresult = "0.4.1" +nested_enum_utils = "0.1.0" [features] hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"] quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] flume-transport = ["dep:flume"] -combined-transport = [] macros = [] default = ["flume-transport"] diff --git a/examples/errors.rs b/examples/errors.rs index ffedb1e..fabf864 100644 --- a/examples/errors.rs +++ b/examples/errors.rs @@ -55,8 +55,8 @@ impl Fs { #[tokio::main] async fn main() -> anyhow::Result<()> { let fs = Fs; - let (server, client) = quic_rpc::transport::flume::service_connection::(1); - let client = RpcClient::new(client); + let (server, client) = quic_rpc::transport::flume::channel(1); + let client = RpcClient::::new(client); let server = RpcServer::new(server); let handle = tokio::task::spawn(async move { for _ in 0..1 { diff --git a/examples/macro.rs b/examples/macro.rs index e626816..88e76d2 100644 --- a/examples/macro.rs +++ b/examples/macro.rs @@ -105,7 +105,7 @@ create_store_dispatch!(Store, dispatch_store_request); #[tokio::main] async fn main() -> anyhow::Result<()> { - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let server_handle = tokio::task::spawn(async move { let target = Store; run_server_loop(StoreService, server, target, dispatch_store_request).await diff --git a/examples/modularize.rs b/examples/modularize.rs index 1b19a2d..e574c1b 100644 --- a/examples/modularize.rs +++ b/examples/modularize.rs @@ -10,7 +10,7 @@ use anyhow::Result; use futures_lite::StreamExt; use futures_util::SinkExt; -use quic_rpc::{transport::flume, RpcClient, RpcServer, ServiceConnection, ServiceEndpoint}; +use quic_rpc::{client::BoxedConnector, transport::flume, Listener, RpcClient, RpcServer}; use tracing::warn; use app::AppService; @@ -19,19 +19,19 @@ use app::AppService; async fn main() -> Result<()> { // Spawn an inmemory connection. // Could use quic equally (all code in this example is generic over the transport) - let (server_conn, client_conn) = flume::service_connection::(1); + let (server_conn, client_conn) = flume::channel(1); // spawn the server let handler = app::Handler::default(); tokio::task::spawn(run_server(server_conn, handler)); // run a client demo - client_demo(client_conn).await?; + client_demo(BoxedConnector::::new(client_conn)).await?; Ok(()) } -async fn run_server>(server_conn: C, handler: app::Handler) { +async fn run_server>(server_conn: C, handler: app::Handler) { let server = RpcServer::::new(server_conn); loop { let Ok(accepting) = server.accept().await else { @@ -50,8 +50,8 @@ async fn run_server>(server_conn: C, handler: app } } } -pub async fn client_demo>(conn: C) -> Result<()> { - let rpc_client = RpcClient::new(conn); +pub async fn client_demo(conn: BoxedConnector) -> Result<()> { + let rpc_client = RpcClient::::new(conn); let client = app::Client::new(rpc_client.clone()); // call a method from the top-level app client @@ -99,15 +99,12 @@ mod app { //! //! It could also easily compose services from other crates or internal modules. + use super::iroh; use anyhow::Result; use derive_more::{From, TryInto}; - use quic_rpc::{ - message::RpcMsg, server::RpcChannel, RpcClient, Service, ServiceConnection, ServiceEndpoint, - }; + use quic_rpc::{message::RpcMsg, server::RpcChannel, Listener, RpcClient, Service}; use serde::{Deserialize, Serialize}; - use super::iroh; - #[derive(Debug, Serialize, Deserialize, From, TryInto)] pub enum Request { Iroh(iroh::Request), @@ -153,13 +150,17 @@ mod app { } impl Handler { - pub async fn handle_rpc_request>( + pub async fn handle_rpc_request>( self, req: Request, - chan: RpcChannel, + chan: RpcChannel, ) -> Result<()> { match req { - Request::Iroh(req) => self.iroh.handle_rpc_request(req, chan.map()).await?, + Request::Iroh(req) => { + self.iroh + .handle_rpc_request(req, chan.map().boxed()) + .await? + } Request::AppVersion(req) => chan.rpc(req, self, Self::on_version).await?, }; Ok(()) @@ -171,20 +172,16 @@ mod app { } #[derive(Debug, Clone)] - pub struct Client> { - pub iroh: iroh::Client, - client: RpcClient, + pub struct Client { + pub iroh: iroh::Client, + client: RpcClient, } - impl Client - where - S: Service, - C: ServiceConnection, - { - pub fn new(client: RpcClient) -> Self { + impl Client { + pub fn new(client: RpcClient) -> Self { Self { - iroh: iroh::Client::new(client.clone().map()), - client, + client: client.clone(), + iroh: iroh::Client::new(client.map().boxed()), } } @@ -202,7 +199,7 @@ mod iroh { use anyhow::Result; use derive_more::{From, TryInto}; - use quic_rpc::{server::RpcChannel, RpcClient, Service, ServiceConnection, ServiceEndpoint}; + use quic_rpc::{server::RpcChannel, RpcClient, Service}; use serde::{Deserialize, Serialize}; use super::{calc, clock}; @@ -233,38 +230,38 @@ mod iroh { } impl Handler { - pub async fn handle_rpc_request( + pub async fn handle_rpc_request( self, req: Request, - chan: RpcChannel, - ) -> Result<()> - where - S: Service, - E: ServiceEndpoint, - { + chan: RpcChannel, + ) -> Result<()> { match req { - Request::Calc(req) => self.calc.handle_rpc_request(req, chan.map()).await?, - Request::Clock(req) => self.clock.handle_rpc_request(req, chan.map()).await?, + Request::Calc(req) => { + self.calc + .handle_rpc_request(req, chan.map().boxed()) + .await? + } + Request::Clock(req) => { + self.clock + .handle_rpc_request(req, chan.map().boxed()) + .await? + } } Ok(()) } } #[derive(Debug, Clone)] - pub struct Client { - pub calc: calc::Client, - pub clock: clock::Client, + pub struct Client { + pub calc: calc::Client, + pub clock: clock::Client, } - impl Client - where - S: Service, - C: ServiceConnection, - { - pub fn new(client: RpcClient) -> Self { + impl Client { + pub fn new(client: RpcClient) -> Self { Self { - calc: calc::Client::new(client.clone().map()), - clock: clock::Client::new(client.clone().map()), + calc: calc::Client::new(client.clone().map().boxed()), + clock: clock::Client::new(client.clone().map().boxed()), } } } @@ -280,7 +277,7 @@ mod calc { use quic_rpc::{ message::{ClientStreaming, ClientStreamingMsg, Msg, RpcMsg}, server::RpcChannel, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + RpcClient, Service, }; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -337,15 +334,11 @@ mod calc { pub struct Handler; impl Handler { - pub async fn handle_rpc_request( + pub async fn handle_rpc_request( self, req: Request, - chan: RpcChannel, - ) -> Result<()> - where - S: Service, - E: ServiceEndpoint, - { + chan: RpcChannel, + ) -> Result<()> { match req { Request::Add(req) => chan.rpc(req, self, Self::on_add).await?, Request::Sum(req) => chan.client_streaming(req, self, Self::on_sum).await?, @@ -373,16 +366,12 @@ mod calc { } #[derive(Debug, Clone)] - pub struct Client { - client: RpcClient, + pub struct Client { + client: RpcClient, } - impl Client - where - C: ServiceConnection, - S: Service, - { - pub fn new(client: RpcClient) -> Self { + impl Client { + pub fn new(client: RpcClient) -> Self { Self { client } } pub async fn add(&self, a: i64, b: i64) -> anyhow::Result { @@ -403,7 +392,7 @@ mod clock { use quic_rpc::{ message::{Msg, ServerStreaming, ServerStreamingMsg}, server::RpcChannel, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + RpcClient, Service, }; use serde::{Deserialize, Serialize}; use std::{ @@ -475,15 +464,11 @@ mod clock { h } - pub async fn handle_rpc_request( + pub async fn handle_rpc_request( self, req: Request, - chan: RpcChannel, - ) -> Result<()> - where - S: Service, - E: ServiceEndpoint, - { + chan: RpcChannel, + ) -> Result<()> { match req { Request::Tick(req) => chan.server_streaming(req, self, Self::on_tick).await?, } @@ -517,16 +502,12 @@ mod clock { } #[derive(Debug, Clone)] - pub struct Client { - client: RpcClient, + pub struct Client { + client: RpcClient, } - impl Client - where - C: ServiceConnection, - S: Service, - { - pub fn new(client: RpcClient) -> Self { + impl Client { + pub fn new(client: RpcClient) -> Self { Self { client } } pub async fn tick(&self) -> Result>> { diff --git a/examples/split/client/src/main.rs b/examples/split/client/src/main.rs index 1d4c5c9..f68406d 100644 --- a/examples/split/client/src/main.rs +++ b/examples/split/client/src/main.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use anyhow::Result; use futures::sink::SinkExt; use futures::stream::StreamExt; -use quic_rpc::transport::quinn::QuinnConnection; +use quic_rpc::transport::quinn::QuinnConnector; use quic_rpc::RpcClient; use quinn::crypto::rustls::QuicClientConfig; use quinn::{ClientConfig, Endpoint}; @@ -19,7 +19,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); let server_addr: SocketAddr = "127.0.0.1:12345".parse()?; let endpoint = make_insecure_client_endpoint("0.0.0.0:0".parse()?)?; - let client = QuinnConnection::new(endpoint, server_addr, "localhost".to_string()); + let client = QuinnConnector::new(endpoint, server_addr, "localhost".to_string()); let client = RpcClient::new(client); // let mut client = ComputeClient(client); diff --git a/examples/split/server/src/main.rs b/examples/split/server/src/main.rs index db46ec7..ec28ad0 100644 --- a/examples/split/server/src/main.rs +++ b/examples/split/server/src/main.rs @@ -1,7 +1,7 @@ use async_stream::stream; use futures::stream::{Stream, StreamExt}; use quic_rpc::server::run_server_loop; -use quic_rpc::transport::quinn::QuinnServerEndpoint; +use quic_rpc::transport::quinn::QuinnListener; use quinn::{Endpoint, ServerConfig}; use std::net::SocketAddr; use std::sync::Arc; @@ -62,7 +62,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); let server_addr: SocketAddr = "127.0.0.1:12345".parse()?; let (server, _server_certs) = make_server_endpoint(server_addr)?; - let channel = QuinnServerEndpoint::new(server)?; + let channel = QuinnListener::new(server)?; let target = Compute; run_server_loop( ComputeService, diff --git a/examples/store.rs b/examples/store.rs index 8f8e0f8..6339218 100644 --- a/examples/store.rs +++ b/examples/store.rs @@ -5,7 +5,7 @@ use futures_lite::{Stream, StreamExt}; use futures_util::SinkExt; use quic_rpc::{ server::RpcServerError, - transport::{flume, Connection, ServerEndpoint}, + transport::{flume, Connector}, *, }; use serde::{Deserialize, Serialize}; @@ -162,7 +162,7 @@ impl Store { #[tokio::main] async fn main() -> anyhow::Result<()> { - async fn server_future>( + async fn server_future>( server: RpcServer, ) -> result::Result<(), RpcServerError> { let s = server; @@ -184,7 +184,7 @@ async fn main() -> anyhow::Result<()> { } } - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let client = RpcClient::::new(client); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(server_future(server)); @@ -231,13 +231,14 @@ async fn main() -> anyhow::Result<()> { } async fn _main_unsugared() -> anyhow::Result<()> { + use transport::Listener; #[derive(Clone, Debug)] struct Service; impl crate::Service for Service { type Req = u64; type Res = String; } - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel::(1); let to_string_service = tokio::spawn(async move { let (mut send, mut recv) = server.accept().await?; while let Some(item) = recv.next().await { diff --git a/quic-rpc-derive/Cargo.toml b/quic-rpc-derive/Cargo.toml index 98dea91..131985e 100644 --- a/quic-rpc-derive/Cargo.toml +++ b/quic-rpc-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "quic-rpc-derive" -version = "0.14.0" +version = "0.15.0" edition = "2021" authors = ["RĂ¼diger Klaehn "] keywords = ["api", "protocol", "network", "rpc", "macro"] @@ -16,7 +16,7 @@ proc-macro = true syn = { version = "1.0", features = ["full"] } quote = "1.0" proc-macro2 = "1.0" -quic-rpc = { version = "0.14", path = ".." } +quic-rpc = { version = "0.15", path = ".." } [dev-dependencies] derive_more = "1.0.0-beta.6" diff --git a/src/client.rs b/src/client.rs index bda3c3f..a88da1f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,8 +2,8 @@ //! //! The main entry point is [RpcClient]. use crate::{ - map::{ChainedMapper, MapService, Mapper}, - Service, ServiceConnection, + transport::{boxed::BoxableConnector, mapped::MappedConnector, StreamTypes}, + Connector, Service, }; use futures_lite::Stream; use futures_sink::Sink; @@ -13,38 +13,38 @@ use std::{ fmt::Debug, marker::PhantomData, pin::Pin, - sync::Arc, task::{Context, Poll}, }; /// Type alias for a boxed connection to a specific service -pub type BoxedServiceConnection = - crate::transport::boxed::Connection<::Res, ::Req>; +/// +/// This is a convenience type alias for a boxed connection to a specific service. +pub type BoxedConnector = + crate::transport::boxed::BoxedConnector<::Res, ::Req>; /// Sync version of `future::stream::BoxStream`. pub type BoxStreamSync<'a, T> = Pin + Send + Sync + 'a>>; /// A client for a specific service /// -/// This is a wrapper around a [ServiceConnection] that serves as the entry point +/// This is a wrapper around a [`Connector`] that serves as the entry point /// for the client DSL. /// /// Type parameters: /// /// `S` is the service type that determines what interactions this client supports. -/// `SC` is the service type that is compatible with the connection. -/// `C` is the substream source. +/// `C` is the connector that determines the transport. #[derive(Debug)] -pub struct RpcClient, SC = S> { +pub struct RpcClient> { pub(crate) source: C, - pub(crate) map: Arc>, + pub(crate) _p: PhantomData, } -impl Clone for RpcClient { +impl Clone for RpcClient { fn clone(&self) -> Self { Self { source: self.source.clone(), - map: Arc::clone(&self.map), + _p: PhantomData, } } } @@ -53,23 +53,25 @@ impl Clone for RpcClient { /// that support it, [crate::message::ClientStreaming] and [crate::message::BidiStreaming]. #[pin_project] #[derive(Debug)] -pub struct UpdateSink( - #[pin] pub C::SendSink, - pub PhantomData, - pub Arc>, -) +pub struct UpdateSink(#[pin] pub C::SendSink, PhantomData) where - SC: Service, - S: Service, - C: ServiceConnection, - T: Into; + C: StreamTypes; -impl Sink for UpdateSink +impl UpdateSink where - SC: Service, - S: Service, - C: ServiceConnection, - T: Into, + C: StreamTypes, + T: Into, +{ + /// Create a new update sink + pub fn new(sink: C::SendSink) -> Self { + Self(sink, PhantomData) + } +} + +impl Sink for UpdateSink +where + C: StreamTypes, + T: Into, { type Error = C::SendError; @@ -78,7 +80,7 @@ where } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - let req = self.2.req_into_outer(item.into()); + let req = item.into(); self.project().0.start_send(req) } @@ -91,34 +93,29 @@ where } } -impl RpcClient +impl RpcClient where S: Service, - C: ServiceConnection, + C: Connector, { /// Create a new rpc client for a specific [Service] given a compatible - /// [ServiceConnection]. + /// [Connector]. /// /// This is where a generic typed connection is converted into a client for a specific service. /// - /// When creating a new client, the outer service type `S` and the inner - /// service type `SC` that is compatible with the underlying connection will - /// be identical. - /// /// You can get a client for a nested service by calling [map](RpcClient::map). pub fn new(source: C) -> Self { Self { source, - map: Arc::new(Mapper::new()), + _p: PhantomData, } } } -impl RpcClient +impl RpcClient where S: Service, - SC: Service, - C: ServiceConnection, + C: Connector, { /// Get the underlying connection pub fn into_inner(self) -> C { @@ -134,25 +131,28 @@ where /// Where SNext is the new service to map to and S is the current inner service. /// /// This method can be chained infintely. - pub fn map(self) -> RpcClient + pub fn map(self) -> RpcClient> where SNext: Service, - SNext::Req: Into + TryFrom, - SNext::Res: Into + TryFrom, + S::Req: From, + SNext::Res: TryFrom, { - let map = ChainedMapper::new(self.map); - RpcClient { - source: self.source, - map: Arc::new(map), - } + RpcClient::new(self.source.map::()) + } + + /// box + pub fn boxed(self) -> RpcClient> + where + C: BoxableConnector, + { + RpcClient::new(self.source.boxed()) } } -impl AsRef for RpcClient +impl AsRef for RpcClient where S: Service, - SC: Service, - C: ServiceConnection, + C: Connector, { fn as_ref(&self) -> &C { &self.source diff --git a/src/lib.rs b/src/lib.rs index 71d09f9..89ebc29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ //! } //! //! // create a transport channel, here a memory channel for testing -//! let (server, client) = quic_rpc::transport::flume::service_connection::(1); +//! let (server, client) = quic_rpc::transport::flume::channel(1); //! //! // client side //! // create the rpc client given the channel and the service type @@ -93,7 +93,6 @@ #![deny(rustdoc::broken_intra_doc_links)] use serde::{de::DeserializeOwned, Serialize}; use std::fmt::{Debug, Display}; -use transport::{Connection, ServerEndpoint}; pub mod client; pub mod message; pub mod server; @@ -102,7 +101,6 @@ pub use client::RpcClient; pub use server::RpcServer; #[cfg(feature = "macros")] mod macros; -mod map; pub mod pattern; @@ -127,9 +125,13 @@ impl RpcMessage for T where /// /// We don't require them to implement [std::error::Error] so we can use /// anyhow::Error as an error type. -pub trait RpcError: Debug + Display + Send + Sync + Unpin + 'static {} +/// +/// Instead we require them to implement `Into`, which is available +/// both for any type that implements [std::error::Error] and anyhow itself. +pub trait RpcError: Debug + Display + Into + Send + Sync + Unpin + 'static {} -impl RpcError for T where T: Debug + Display + Send + Sync + Unpin + 'static {} +impl RpcError for T where T: Debug + Display + Into + Send + Sync + Unpin + 'static +{} /// A service /// @@ -157,21 +159,20 @@ pub trait Service: Send + Sync + Debug + Clone + 'static { type Res: RpcMessage; } -/// A connection to a specific service on a specific remote machine -/// -/// This is just a trait alias for a [Connection] with the right types. +/// A connector to a specific service /// -/// This can be used to create a [RpcClient] that can be used to send requests. -pub trait ServiceConnection: Connection {} +/// This is just a trait alias for a [`transport::Connector`] with the right types. It is used +/// to make it easier to specify the bounds of a connector that matches a specific +/// service. +pub trait Connector: transport::Connector {} -impl, S: Service> ServiceConnection for T {} +impl, S: Service> Connector for T {} -/// A server endpoint for a specific service -/// -/// This is just a trait alias for a [ServerEndpoint] with the right types. +/// A listener for a specific service /// -/// This can be used to create a [RpcServer] that can be used to handle -/// requests. -pub trait ServiceEndpoint: ServerEndpoint {} +/// This is just a trait alias for a [`transport::Listener`] with the right types. It is used +/// to make it easier to specify the bounds of a listener that matches a specific +/// service. +pub trait Listener: transport::Listener {} -impl, S: Service> ServiceEndpoint for T {} +impl, S: Service> Listener for T {} diff --git a/src/macros.rs b/src/macros.rs index 208dae0..03cf475 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -198,8 +198,8 @@ macro_rules! __derive_create_dispatch { #[macro_export] macro_rules! $create_dispatch { ($target:ident, $handler:ident) => { - pub async fn $handler>( - mut chan: $crate::server::RpcChannel<$service, C, $service>, + pub async fn $handler>( + mut chan: $crate::server::RpcChannel<$service, C>, msg: <$service as $crate::Service>::Req, target: $target, ) -> Result<(), $crate::server::RpcServerError> { @@ -435,9 +435,9 @@ macro_rules! __derive_create_client{ macro_rules! $create_client { ($struct:ident) => { #[derive(::std::clone::Clone, ::std::fmt::Debug)] - pub struct $struct>(pub $crate::client::RpcClient<$service, C>); + pub struct $struct>(pub $crate::client::RpcClient<$service, C>); - impl> $struct { + impl> $struct { $( $crate::__rpc_method!($m_pattern, $service, $m_name, $m_input, $m_output, $m_update); )* diff --git a/src/map.rs b/src/map.rs deleted file mode 100644 index 597ca93..0000000 --- a/src/map.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::{marker::PhantomData, sync::Arc}; - -use crate::Service; - -/// Convert requests and responses between an outer and an inner service. -/// -/// An "outer" service has request and response enums which wrap the requests and responses of an -/// "inner" service. This trait is implemented on the [`Mapper`] and [`ChainedMapper`] structs -/// to convert the requests and responses between the outer and inner services. -pub trait MapService: - std::fmt::Debug + Send + Sync + 'static -{ - /// Convert an inner request into the outer request. - fn req_into_outer(&self, req: SInner::Req) -> SOuter::Req; - - /// Convert an inner response into the outer response. - fn res_into_outer(&self, res: SInner::Res) -> SOuter::Res; - - /// Try to convert the outer request into the inner request. - /// - /// Returns an error if the request is not of the variant of the inner service. - fn req_try_into_inner(&self, req: SOuter::Req) -> Result; - - /// Try to convert the outer response into the inner request. - /// - /// Returns an error if the response is not of the variant of the inner service. - fn res_try_into_inner(&self, res: SOuter::Res) -> Result; -} - -/// Zero-sized struct to map between two services. -#[derive(Debug)] -pub struct Mapper(PhantomData, PhantomData); - -impl Mapper -where - SOuter: Service, - SInner: Service, - SInner::Req: Into + TryFrom, - SInner::Res: Into + TryFrom, -{ - /// Create a new mapper between `SOuter` and `SInner` services. - /// - /// This method is availalbe if the required bounds to convert between the outer and inner - /// request and response enums are met: - /// `SInner::Req: Into + TryFrom` - /// `SInner::Res: Into + TryFrom` - pub fn new() -> Self { - Self(PhantomData, PhantomData) - } -} - -impl MapService for Mapper -where - SOuter: Service, - SInner: Service, - SInner::Req: Into + TryFrom, - SInner::Res: Into + TryFrom, -{ - fn req_into_outer(&self, req: SInner::Req) -> SOuter::Req { - req.into() - } - - fn res_into_outer(&self, res: SInner::Res) -> SOuter::Res { - res.into() - } - - fn req_try_into_inner(&self, req: SOuter::Req) -> Result { - req.try_into().map_err(|_| ()) - } - - fn res_try_into_inner(&self, res: SOuter::Res) -> Result { - res.try_into().map_err(|_| ()) - } -} - -/// Map between an outer and an inner service with any number of intermediate services. -/// -/// This uses an `Arc` to contain an unlimited chain of [`Mapper`]s. -#[derive(Debug)] -pub struct ChainedMapper -where - SOuter: Service, - SMid: Service, - SInner: Service, - SInner::Req: Into + TryFrom, -{ - map1: Arc>, - map2: Mapper, -} - -impl ChainedMapper -where - SOuter: Service, - SMid: Service, - SInner: Service, - SInner::Req: Into + TryFrom, - SInner::Res: Into + TryFrom, -{ - /// Create a new [`ChainedMapper`] by appending a service `SInner` to the existing `dyn - /// MapService`. - /// - /// Usage example: - /// ```ignore - /// // S1 is a Service and impls the Into and TryFrom traits to map to S2 - /// // S2 is a Service and impls the Into and TryFrom traits to map to S3 - /// // S3 is also a Service - /// - /// let mapper: Mapper = Mapper::new(); - /// let mapper: Arc> = Arc::new(mapper); - /// let chained_mapper: ChainedMapper = ChainedMapper::new(mapper); - /// ``` - pub fn new(map1: Arc>) -> Self { - Self { - map1, - map2: Mapper::new(), - } - } -} - -impl MapService for ChainedMapper -where - SOuter: Service, - SMid: Service, - SInner: Service, - SInner::Req: Into + TryFrom, - SInner::Res: Into + TryFrom, -{ - fn req_into_outer(&self, req: SInner::Req) -> SOuter::Req { - let req = self.map2.req_into_outer(req); - self.map1.req_into_outer(req) - } - fn res_into_outer(&self, res: SInner::Res) -> SOuter::Res { - let res = self.map2.res_into_outer(res); - self.map1.res_into_outer(res) - } - fn req_try_into_inner(&self, req: SOuter::Req) -> Result { - let req = self.map1.req_try_into_inner(req)?; - self.map2.req_try_into_inner(req) - } - - fn res_try_into_inner(&self, res: SOuter::Res) -> Result { - let res = self.map1.res_try_into_inner(res)?; - self.map2.res_try_into_inner(res) - } -} diff --git a/src/pattern/bidi_streaming.rs b/src/pattern/bidi_streaming.rs index d7ec853..53e6275 100644 --- a/src/pattern/bidi_streaming.rs +++ b/src/pattern/bidi_streaming.rs @@ -7,16 +7,14 @@ use crate::{ client::{BoxStreamSync, UpdateSink}, message::{InteractionPattern, Msg}, server::{race2, RpcChannel, RpcServerError, UpdateStream}, - transport::ConnectionErrors, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + transport::{ConnectionErrors, Connector, StreamTypes}, + RpcClient, Service, }; use std::{ error, fmt::{self, Debug}, - marker::PhantomData, result, - sync::Arc, }; /// Bidirectional streaming interaction pattern @@ -75,11 +73,10 @@ impl fmt::Display for ItemError { impl error::Error for ItemError {} -impl RpcClient +impl RpcClient where - SC: Service, - C: ServiceConnection, S: Service, + C: Connector, { /// Bidi call to the server, request opens a stream, response is a stream pub async fn bidi( @@ -87,7 +84,7 @@ where msg: M, ) -> result::Result< ( - UpdateSink, + UpdateSink, BoxStreamSync<'static, result::Result>>, ), Error, @@ -95,28 +92,21 @@ where where M: BidiStreamingMsg, { - let msg = self.map.req_into_outer(msg.into()); + let msg = msg.into(); let (mut send, recv) = self.source.open().await.map_err(Error::Open)?; send.send(msg).await.map_err(Error::::Send)?; - let send = UpdateSink(send, PhantomData, Arc::clone(&self.map)); - let map = Arc::clone(&self.map); + let send = UpdateSink::new(send); let recv = Box::pin(recv.map(move |x| match x { - Ok(x) => { - let x = map - .res_try_into_inner(x) - .map_err(|_| ItemError::DowncastError)?; - M::Response::try_from(x).map_err(|_| ItemError::DowncastError) - } + Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError), Err(e) => Err(ItemError::RecvError(e)), })); Ok((send, recv)) } } -impl RpcChannel +impl RpcChannel where - SC: Service, - C: ServiceEndpoint, + C: StreamTypes, S: Service, { /// handle the message M using the given function on the target object @@ -130,20 +120,20 @@ where ) -> result::Result<(), RpcServerError> where M: BidiStreamingMsg, - F: FnOnce(T, M, UpdateStream) -> Str + Send + 'static, + F: FnOnce(T, M, UpdateStream) -> Str + Send + 'static, Str: Stream + Send + 'static, T: Send + 'static, { let Self { mut send, recv, .. } = self; // downcast the updates - let (updates, read_error) = UpdateStream::new(recv, Arc::clone(&self.map)); + let (updates, read_error) = UpdateStream::new(recv); // get the response let responses = f(target, req, updates); race2(read_error.map(Err), async move { tokio::pin!(responses); while let Some(response) = responses.next().await { // turn into a S::Res so we can send it - let response = self.map.res_into_outer(response.into()); + let response = response.into(); // send it and return the error if any send.send(response) .await diff --git a/src/pattern/client_streaming.rs b/src/pattern/client_streaming.rs index 074023c..729b8c7 100644 --- a/src/pattern/client_streaming.rs +++ b/src/pattern/client_streaming.rs @@ -7,16 +7,14 @@ use crate::{ client::UpdateSink, message::{InteractionPattern, Msg}, server::{race2, RpcChannel, RpcServerError, UpdateStream}, - transport::ConnectionErrors, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + transport::{ConnectionErrors, StreamTypes}, + Connector, RpcClient, Service, }; use std::{ error, fmt::{self, Debug}, - marker::PhantomData, result, - sync::Arc, }; /// Client streaming interaction pattern @@ -77,11 +75,10 @@ impl fmt::Display for ItemError { impl error::Error for ItemError {} -impl RpcClient +impl RpcClient where S: Service, - SC: Service, - C: ServiceConnection, + C: Connector, { /// Call to the server that allows the client to stream, single response pub async fn client_streaming( @@ -89,7 +86,7 @@ where msg: M, ) -> result::Result< ( - UpdateSink, + UpdateSink, Boxed>>, ), Error, @@ -97,21 +94,15 @@ where where M: ClientStreamingMsg, { - let msg = self.map.req_into_outer(msg.into()); + let msg = msg.into(); let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?; send.send(msg).map_err(Error::Send).await?; - let send = UpdateSink::(send, PhantomData, Arc::clone(&self.map)); - let map = Arc::clone(&self.map); + let send = UpdateSink::::new(send); let recv = async move { let item = recv.next().await.ok_or(ItemError::EarlyClose)?; match item { - Ok(x) => { - let x = map - .res_try_into_inner(x) - .map_err(|_| ItemError::DowncastError)?; - M::Response::try_from(x).map_err(|_| ItemError::DowncastError) - } + Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError), Err(e) => Err(ItemError::RecvError(e)), } } @@ -120,11 +111,10 @@ where } } -impl RpcChannel +impl RpcChannel where S: Service, - SC: Service, - C: ServiceEndpoint, + C: StreamTypes, { /// handle the message M using the given function on the target object /// @@ -137,17 +127,17 @@ where ) -> result::Result<(), RpcServerError> where M: ClientStreamingMsg, - F: FnOnce(T, M, UpdateStream) -> Fut + Send + 'static, + F: FnOnce(T, M, UpdateStream) -> Fut + Send + 'static, Fut: Future + Send + 'static, T: Send + 'static, { let Self { mut send, recv, .. } = self; - let (updates, read_error) = UpdateStream::new(recv, Arc::clone(&self.map)); + let (updates, read_error) = UpdateStream::new(recv); race2(read_error.map(Err), async move { // get the response let res = f(target, req, updates).await; // turn into a S::Res so we can send it - let res = self.map.res_into_outer(res.into()); + let res = res.into(); // send it and return the error if any send.send(res).await.map_err(RpcServerError::SendError) }) diff --git a/src/pattern/rpc.rs b/src/pattern/rpc.rs index 8782c25..9337113 100644 --- a/src/pattern/rpc.rs +++ b/src/pattern/rpc.rs @@ -6,8 +6,8 @@ use futures_util::{FutureExt, SinkExt}; use crate::{ message::{InteractionPattern, Msg}, server::{race2, RpcChannel, RpcServerError}, - transport::ConnectionErrors, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + transport::{ConnectionErrors, StreamTypes}, + Connector, RpcClient, Service, }; use std::{ @@ -62,18 +62,17 @@ impl fmt::Display for Error { impl error::Error for Error {} -impl RpcClient +impl RpcClient where S: Service, - SC: Service, - C: ServiceConnection, + C: Connector, { /// RPC call to the server, single request, single response pub async fn rpc(&self, msg: M) -> result::Result> where M: RpcMsg, { - let msg = self.map.req_into_outer(msg.into()); + let msg = msg.into(); let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?; send.send(msg).await.map_err(Error::::Send)?; let res = recv @@ -83,19 +82,14 @@ where .map_err(Error::::RecvError)?; // keep send alive until we have the answer drop(send); - let res = self - .map - .res_try_into_inner(res) - .map_err(|_| Error::DowncastError)?; M::Response::try_from(res).map_err(|_| Error::DowncastError) } } -impl RpcChannel +impl RpcChannel where S: Service, - SC: Service, - C: ServiceEndpoint, + C: StreamTypes, { /// handle the message of type `M` using the given function on the target object /// @@ -124,7 +118,7 @@ where // get the response let res = f(target, req).await; // turn into a S::Res so we can send it - let res = self.map.res_into_outer(res.into()); + let res = res.into(); // send it and return the error if any send.send(res).await.map_err(RpcServerError::SendError) }) diff --git a/src/pattern/server_streaming.rs b/src/pattern/server_streaming.rs index 679d033..d1889e4 100644 --- a/src/pattern/server_streaming.rs +++ b/src/pattern/server_streaming.rs @@ -7,15 +7,14 @@ use crate::{ client::{BoxStreamSync, DeferDrop}, message::{InteractionPattern, Msg}, server::{race2, RpcChannel, RpcServerError}, - transport::ConnectionErrors, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + transport::{ConnectionErrors, Connector, StreamTypes}, + RpcClient, Service, }; use std::{ error, fmt::{self, Debug}, result, - sync::Arc, }; /// Server streaming interaction pattern @@ -42,13 +41,13 @@ pub enum Error { Send(C::SendError), } -impl fmt::Display for Error { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for Error {} +impl error::Error for Error {} /// Client error when handling responses from a server streaming request #[derive(Debug)] @@ -67,10 +66,9 @@ impl fmt::Display for ItemError { impl error::Error for ItemError {} -impl RpcClient +impl RpcClient where - SC: Service, - C: ServiceConnection, + C: crate::Connector, S: Service, { /// Bidi call to the server, request opens a stream, response is a stream @@ -81,17 +79,11 @@ where where M: ServerStreamingMsg, { - let msg = self.map.req_into_outer(msg.into()); + let msg = msg.into(); let (mut send, recv) = self.source.open().await.map_err(Error::Open)?; send.send(msg).map_err(Error::::Send).await?; - let map = Arc::clone(&self.map); let recv = recv.map(move |x| match x { - Ok(x) => { - let x = map - .res_try_into_inner(x) - .map_err(|_| ItemError::DowncastError)?; - M::Response::try_from(x).map_err(|_| ItemError::DowncastError) - } + Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError), Err(e) => Err(ItemError::RecvError(e)), }); // keep send alive so the request on the server side does not get cancelled @@ -100,11 +92,10 @@ where } } -impl RpcChannel +impl RpcChannel where S: Service, - SC: Service, - C: ServiceEndpoint, + C: StreamTypes, { /// handle the message M using the given function on the target object /// @@ -135,7 +126,7 @@ where tokio::pin!(responses); while let Some(response) = responses.next().await { // turn into a S::Res so we can send it - let response = self.map.res_into_outer(response.into()); + let response = response.into(); // send it and return the error if any send.send(response) .await diff --git a/src/pattern/try_server_streaming.rs b/src/pattern/try_server_streaming.rs index 3b2950c..46d705d 100644 --- a/src/pattern/try_server_streaming.rs +++ b/src/pattern/try_server_streaming.rs @@ -8,15 +8,14 @@ use crate::{ client::{BoxStreamSync, DeferDrop}, message::{InteractionPattern, Msg}, server::{race2, RpcChannel, RpcServerError}, - transport::ConnectionErrors, - RpcClient, Service, ServiceConnection, ServiceEndpoint, + transport::{self, ConnectionErrors, StreamTypes}, + Connector, RpcClient, Service, }; use std::{ error, fmt::{self, Debug}, result, - sync::Arc, }; /// A guard message to indicate that the stream has been created. @@ -54,7 +53,7 @@ where /// care about the exact nature of the error, but if you want to handle /// application errors differently, you can match on this enum. #[derive(Debug)] -pub enum Error { +pub enum Error { /// Unable to open a substream at all Open(C::OpenError), /// Unable to send the request to the server @@ -69,13 +68,13 @@ pub enum Error { Application(E), } -impl fmt::Display for Error { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for Error {} +impl error::Error for Error {} /// Client error when handling responses from a server streaming request. /// @@ -98,10 +97,9 @@ impl fmt::Display for ItemError { impl error::Error for ItemError {} -impl RpcChannel +impl RpcChannel where - SC: Service, - C: ServiceEndpoint, + C: StreamTypes, S: Service, { /// handle the message M using the given function on the target object @@ -138,7 +136,7 @@ where let responses = match f(target, req).await { Ok(responses) => { // turn into a S::Res so we can send it - let response = self.map.res_into_outer(Ok(StreamCreated).into()); + let response = Ok(StreamCreated).into(); // send it and return the error if any send.send(response) .await @@ -147,7 +145,7 @@ where } Err(cause) => { // turn into a S::Res so we can send it - let response = self.map.res_into_outer(Err(cause).into()); + let response = Err(cause).into(); // send it and return the error if any send.send(response) .await @@ -158,7 +156,7 @@ where tokio::pin!(responses); while let Some(response) = responses.next().await { // turn into a S::Res so we can send it - let response = self.map.res_into_outer(response.into()); + let response = response.into(); // send it and return the error if any send.send(response) .await @@ -170,10 +168,9 @@ where } } -impl RpcClient +impl RpcClient where - SC: Service, - C: ServiceConnection, + C: Connector, S: Service, { /// Bidi call to the server, request opens a stream, response is a stream @@ -189,23 +186,18 @@ where Result: Into + TryFrom, Result: Into + TryFrom, { - let msg = self.map.req_into_outer(msg.into()); + let msg = msg.into(); let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?; send.send(msg).map_err(Error::Send).await?; - let map = Arc::clone(&self.map); let Some(initial) = recv.next().await else { return Err(Error::EarlyClose); }; let initial = initial.map_err(Error::Recv)?; // initial response - let initial = map - .res_try_into_inner(initial) - .map_err(|_| Error::Downcast)?; let initial = >::try_from(initial) .map_err(|_| Error::Downcast)?; let _ = initial.map_err(Error::Application)?; let recv = recv.map(move |x| { let x = x.map_err(ItemError::Recv)?; - let x = map.res_try_into_inner(x).map_err(|_| ItemError::Downcast)?; let x = >::try_from(x) .map_err(|_| ItemError::Downcast)?; let x = x.map_err(ItemError::Application)?; diff --git a/src/server.rs b/src/server.rs index 0185584..1d35425 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,11 +2,16 @@ //! //! The main entry point is [RpcServer] use crate::{ - map::{ChainedMapper, MapService, Mapper}, - transport::ConnectionErrors, - Service, ServiceEndpoint, + transport::{ + self, + boxed::BoxableListener, + mapped::{ErrorOrMapError, MappedRecvStream, MappedSendSink, MappedStreamTypes}, + ConnectionErrors, StreamTypes, + }, + Listener, RpcMessage, Service, }; use futures_lite::{Future, Stream, StreamExt}; +use futures_util::{SinkExt, TryStreamExt}; use pin_project::pin_project; use std::{ error, @@ -14,53 +19,77 @@ use std::{ marker::PhantomData, pin::Pin, result, - sync::Arc, task::{self, Poll}, }; use tokio::sync::oneshot; +/// Stream types on the server side +/// +/// On the server side, we receive requests and send responses. +/// On the client side, we send requests and receive responses. +pub trait ChannelTypes: transport::StreamTypes {} + +impl, S: Service> ChannelTypes for T {} + +/// Type alias for when you want to require a boxed channel +pub type BoxedChannelTypes = crate::transport::boxed::BoxedStreamTypes< + ::Req, + ::Res, +>; + /// Type alias for a service endpoint -pub type BoxedServiceEndpoint = - crate::transport::boxed::ServerEndpoint<::Req, ::Res>; +pub type BoxedListener = + crate::transport::boxed::BoxedListener<::Req, ::Res>; /// A server for a specific service. /// -/// This is a wrapper around a [ServiceEndpoint] that serves as the entry point for the server DSL. +/// This is a wrapper around a [`Listener`] that serves as the entry point for the server DSL. /// /// Type parameters: /// /// `S` is the service type. /// `C` is the channel type. #[derive(Debug)] -pub struct RpcServer> { +pub struct RpcServer> { /// The channel on which new requests arrive. /// /// Each new request is a receiver and channel pair on which messages for this request /// are received and responses sent. source: C, - p: PhantomData, + _p: PhantomData, } impl Clone for RpcServer { fn clone(&self) -> Self { Self { source: self.source.clone(), - p: PhantomData, + _p: PhantomData, } } } -impl> RpcServer { +impl> RpcServer { /// Create a new rpc server for a specific service for a [Service] given a compatible - /// [ServiceEndpoint]. + /// [Listener]. /// /// This is where a generic typed endpoint is converted into a server for a specific service. pub fn new(source: C) -> Self { Self { source, - p: PhantomData, + _p: PhantomData, } } + + /// Box the transport for the service. + /// + /// The boxed transport is the default for the `C` type parameter, so by boxing we can avoid + /// having to specify the type parameter. + pub fn boxed(self) -> RpcServer> + where + C: BoxableListener, + { + RpcServer::new(self.source.boxed()) + } } /// A channel for requests and responses for a specific service. @@ -74,40 +103,43 @@ impl> RpcServer { /// Type parameters: /// /// `S` is the service type. -/// `SC` is the service type that is compatible with the connection. /// `C` is the service endpoint from which the channel was created. #[derive(Debug)] -pub struct RpcChannel = BoxedServiceEndpoint, SC: Service = S> -{ +pub struct RpcChannel = BoxedChannelTypes> { /// Sink to send responses to the client. pub send: C::SendSink, /// Stream to receive requests from the client. pub recv: C::RecvStream, - /// Mapper to map between S and S2 - pub map: Arc>, + + pub(crate) _p: PhantomData, } -impl RpcChannel +impl RpcChannel where S: Service, - C: ServiceEndpoint, + C: StreamTypes, { /// Create a new RPC channel. pub fn new(send: C::SendSink, recv: C::RecvStream) -> Self { Self { send, recv, - map: Arc::new(Mapper::new()), + _p: PhantomData, } } -} -impl RpcChannel -where - S: Service, - SC: Service, - C: ServiceEndpoint, -{ + /// Convert this channel into a boxed channel. + pub fn boxed(self) -> RpcChannel> + where + C::SendError: Into + Send + Sync + 'static, + C::RecvError: Into + Send + Sync + 'static, + { + let send = + transport::boxed::SendSink::boxed(Box::new(self.send.sink_map_err(|e| e.into()))); + let recv = transport::boxed::RecvStream::boxed(Box::new(self.recv.map_err(|e| e.into()))); + RpcChannel::new(send, recv) + } + /// Map this channel's service into an inner service. /// /// This method is available if the required bounds are upheld: @@ -117,28 +149,27 @@ where /// Where SNext is the new service to map to and S is the current inner service. /// /// This method can be chained infintely. - pub fn map(self) -> RpcChannel + pub fn map(self) -> RpcChannel> where SNext: Service, - SNext::Req: Into + TryFrom, - SNext::Res: Into + TryFrom, + SNext::Req: TryFrom, + S::Res: From, { - let map = ChainedMapper::new(self.map); - RpcChannel { - send: self.send, - recv: self.recv, - map: Arc::new(map), - } + RpcChannel::new( + MappedSendSink::new(self.send), + MappedRecvStream::new(self.recv), + ) } } /// The result of accepting a new connection. -pub struct Accepting> { +pub struct Accepting> { send: C::SendSink, recv: C::RecvStream, + _p: PhantomData, } -impl> Accepting { +impl> Accepting { /// Read the first message from the client. /// /// The return value is a tuple of `(request, channel)`. Here `request` is the @@ -148,10 +179,8 @@ impl> Accepting { /// /// Often sink and stream will wrap an an underlying byte stream. In this case you can /// call into_inner() on them to get it back to perform byte level reads and writes. - pub async fn read_first( - self, - ) -> result::Result<(S::Req, RpcChannel), RpcServerError> { - let Accepting { send, mut recv } = self; + pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel), RpcServerError> { + let Accepting { send, mut recv, .. } = self; // get the first message from the client. This will tell us what it wants to do. let request: S::Req = recv .next() @@ -160,16 +189,20 @@ impl> Accepting { .ok_or(RpcServerError::EarlyClose)? // recv error .map_err(RpcServerError::RecvError)?; - Ok((request, RpcChannel::new(send, recv))) + Ok((request, RpcChannel::::new(send, recv))) } } -impl> RpcServer { +impl> RpcServer { /// Accepts a new channel from a client. The result is an [Accepting] object that /// can be used to read the first request. pub async fn accept(&self) -> result::Result, RpcServerError> { let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?; - Ok(Accepting { send, recv }) + Ok(Accepting { + send, + recv, + _p: PhantomData, + }) } /// Get the underlying service endpoint @@ -178,7 +211,7 @@ impl> RpcServer { } } -impl> AsRef for RpcServer { +impl> AsRef for RpcServer { fn as_ref(&self) -> &C { &self.source } @@ -190,40 +223,30 @@ impl> AsRef for RpcServer { /// cause a termination of the RPC call. #[pin_project] #[derive(Debug)] -pub struct UpdateStream( +pub struct UpdateStream( #[pin] C::RecvStream, Option>>, PhantomData, - Arc>, ) where - SC: Service, - S: Service, - C: ServiceEndpoint; + C: StreamTypes; -impl UpdateStream +impl UpdateStream where - SC: Service, - S: Service, - C: ServiceEndpoint, - T: TryFrom, + C: StreamTypes, + T: TryFrom, { - pub(crate) fn new( - recv: C::RecvStream, - map: Arc>, - ) -> (Self, UnwrapToPending>) { + pub(crate) fn new(recv: C::RecvStream) -> (Self, UnwrapToPending>) { let (error_send, error_recv) = oneshot::channel(); let error_recv = UnwrapToPending(error_recv); - (Self(recv, Some(error_send), PhantomData, map), error_recv) + (Self(recv, Some(error_send), PhantomData), error_recv) } } -impl Stream for UpdateStream +impl Stream for UpdateStream where - SC: Service, - S: Service, - C: ServiceEndpoint, - T: TryFrom, + C: StreamTypes, + T: TryFrom, { type Item = T; @@ -232,8 +255,7 @@ where match Pin::new(&mut this.0).poll_next(cx) { Poll::Ready(Some(msg)) => match msg { Ok(msg) => { - let msg = this.3.req_try_into_inner(msg); - let msg = msg.and_then(|msg| T::try_from(msg).map_err(|_cause| ())); + let msg = T::try_from(msg).map_err(|_cause| ()); match msg { Ok(msg) => Poll::Ready(Some(msg)), Err(_cause) => { @@ -262,7 +284,7 @@ where /// Server error. All server DSL methods return a `Result` with this error type. pub enum RpcServerError { /// Unable to open a new channel - Accept(C::OpenError), + Accept(C::AcceptError), /// Recv side for a channel was closed before getting the first message EarlyClose, /// Got an unexpected first message, e.g. an update message @@ -275,6 +297,45 @@ pub enum RpcServerError { UnexpectedUpdateMessage, } +impl + RpcServerError> +{ + /// For a mapped connection, map the error back to the original error type + pub fn map_back(self) -> RpcServerError { + match self { + RpcServerError::EarlyClose => RpcServerError::EarlyClose, + RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage, + RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage, + RpcServerError::SendError(x) => RpcServerError::SendError(x), + RpcServerError::Accept(x) => RpcServerError::Accept(x), + RpcServerError::RecvError(ErrorOrMapError::Inner(x)) => RpcServerError::RecvError(x), + RpcServerError::RecvError(ErrorOrMapError::Conversion) => { + RpcServerError::UnexpectedUpdateMessage + } + } + } +} + +impl RpcServerError { + /// Convert into a different error type provided the send, recv and accept errors can be converted + pub fn errors_into(self) -> RpcServerError + where + T: ConnectionErrors, + C::SendError: Into, + C::RecvError: Into, + C::AcceptError: Into, + { + match self { + RpcServerError::EarlyClose => RpcServerError::EarlyClose, + RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage, + RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage, + RpcServerError::SendError(x) => RpcServerError::SendError(x.into()), + RpcServerError::Accept(x) => RpcServerError::Accept(x.into()), + RpcServerError::RecvError(x) => RpcServerError::RecvError(x.into()), + } + } +} + impl fmt::Debug for RpcServerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -329,9 +390,9 @@ pub async fn run_server_loop( ) -> Result<(), RpcServerError> where S: Service, - C: ServiceEndpoint, + C: Listener, T: Clone + Send + 'static, - F: FnMut(RpcChannel, S::Req, T) -> Fut + Send + 'static, + F: FnMut(RpcChannel, S::Req, T) -> Fut + Send + 'static, Fut: Future>> + Send + 'static, { let server: RpcServer = RpcServer::::new(conn); diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index ced3829..3304bd1 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -8,18 +8,17 @@ use std::{ use futures_lite::FutureExt; use futures_sink::Sink; -#[cfg(feature = "quinn-transport")] -use futures_util::TryStreamExt; -use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt}; +use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project::pin_project; use std::future::Future; use crate::RpcMessage; -use super::{ConnectionCommon, ConnectionErrors}; +use super::{ConnectionErrors, StreamTypes}; type BoxedFuture<'a, T> = Pin + Send + Sync + 'a>>; enum SendSinkInner { + #[cfg(feature = "flume-transport")] Direct(::flume::r#async::SendSink<'static, T>), Boxed(Pin + Send + Sync + 'static>>), } @@ -39,6 +38,7 @@ impl SendSink { } /// Create a new send sink from a direct flume send sink + #[cfg(feature = "flume-transport")] pub(crate) fn direct(sink: ::flume::r#async::SendSink<'static, T>) -> Self { Self(SendSinkInner::Direct(sink)) } @@ -52,6 +52,7 @@ impl Sink for SendSink { cx: &mut std::task::Context<'_>, ) -> Poll> { match self.project().0 { + #[cfg(feature = "flume-transport")] SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), SendSinkInner::Boxed(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), } @@ -59,6 +60,7 @@ impl Sink for SendSink { fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { match self.project().0 { + #[cfg(feature = "flume-transport")] SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), SendSinkInner::Boxed(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), } @@ -69,6 +71,7 @@ impl Sink for SendSink { cx: &mut Context<'_>, ) -> Poll> { match self.project().0 { + #[cfg(feature = "flume-transport")] SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), SendSinkInner::Boxed(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), } @@ -79,6 +82,7 @@ impl Sink for SendSink { cx: &mut Context<'_>, ) -> Poll> { match self.project().0 { + #[cfg(feature = "flume-transport")] SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), SendSinkInner::Boxed(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), } @@ -86,6 +90,7 @@ impl Sink for SendSink { } enum RecvStreamInner { + #[cfg(feature = "flume-transport")] Direct(::flume::r#async::RecvStream<'static, T>), Boxed(Pin> + Send + Sync + 'static>>), } @@ -106,6 +111,7 @@ impl RecvStream { } /// Create a new receive stream from a direct flume receive stream + #[cfg(feature = "flume-transport")] pub(crate) fn direct(stream: ::flume::r#async::RecvStream<'static, T>) -> Self { Self(RecvStreamInner::Direct(stream)) } @@ -116,6 +122,7 @@ impl Stream for RecvStream { fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project().0 { + #[cfg(feature = "flume-transport")] RecvStreamInner::Direct(stream) => match stream.poll_next_unpin(cx) { Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))), Poll::Ready(None) => Poll::Ready(None), @@ -128,7 +135,8 @@ impl Stream for RecvStream { enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future (todo) - Direct(super::flume::OpenBiFuture), + #[cfg(feature = "flume-transport")] + Direct(super::flume::OpenFuture), /// A boxed future Boxed(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -138,13 +146,14 @@ enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> { pub struct OpenFuture<'a, In: RpcMessage, Out: RpcMessage>(OpenFutureInner<'a, In, Out>); impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> { - fn direct(f: super::flume::OpenBiFuture) -> Self { + #[cfg(feature = "flume-transport")] + fn direct(f: super::flume::OpenFuture) -> Self { Self(OpenFutureInner::Direct(f)) } /// Create a new boxed future pub fn boxed( - f: impl Future, RecvStream)>> + Send + Sync + 'a, + f: impl Future, RecvStream)>> + Send + 'a, ) -> Self { Self(OpenFutureInner::Boxed(Box::pin(f))) } @@ -155,6 +164,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { match self.project().0 { + #[cfg(feature = "flume-transport")] OpenFutureInner::Direct(f) => f .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) @@ -166,7 +176,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future - Direct(super::flume::AcceptBiFuture), + #[cfg(feature = "flume-transport")] + Direct(super::flume::AcceptFuture), /// A boxed future Boxed(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -176,7 +187,8 @@ enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> { pub struct AcceptFuture<'a, In: RpcMessage, Out: RpcMessage>(AcceptFutureInner<'a, In, Out>); impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> { - fn direct(f: super::flume::AcceptBiFuture) -> Self { + #[cfg(feature = "flume-transport")] + fn direct(f: super::flume::AcceptFuture) -> Self { Self(AcceptFutureInner::Direct(f)) } @@ -193,6 +205,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { match self.project().0 { + #[cfg(feature = "flume-transport")] AcceptFutureInner::Direct(f) => f .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) @@ -202,57 +215,84 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> { } } -/// A boxable connection -pub trait BoxableConnection: - Debug + Send + Sync + 'static -{ +/// A boxable connector +pub trait BoxableConnector: Debug + Send + Sync + 'static { /// Clone the connection and box it - fn clone_box(&self) -> Box>; + fn clone_box(&self) -> Box>; /// Open a channel to the remote che fn open_boxed(&self) -> OpenFuture; } -/// A boxed connection +/// A boxed connector #[derive(Debug)] -pub struct Connection(Box>); +pub struct BoxedConnector(Box>); -impl Connection { - /// Wrap a boxable server endpoint into a box, transforming all the types to concrete types - pub fn new(x: impl BoxableConnection) -> Self { +impl BoxedConnector { + /// Wrap a boxable connector into a box, transforming all the types to concrete types + pub fn new(x: impl BoxableConnector) -> Self { Self(Box::new(x)) } } -impl Clone for Connection { +impl Clone for BoxedConnector { fn clone(&self) -> Self { Self(self.0.clone_box()) } } -impl ConnectionCommon for Connection { +impl StreamTypes for BoxedConnector { + type In = In; + type Out = Out; type RecvStream = RecvStream; type SendSink = SendSink; } -impl ConnectionErrors for Connection { - type OpenError = anyhow::Error; +impl ConnectionErrors for BoxedConnector { type SendError = anyhow::Error; type RecvError = anyhow::Error; + type OpenError = anyhow::Error; + type AcceptError = anyhow::Error; } -impl super::Connection for Connection { +impl super::Connector for BoxedConnector { async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { self.0.open_boxed().await } } -/// A boxable server endpoint -pub trait BoxableServerEndpoint: - Debug + Send + Sync + 'static -{ - /// Clone the server endpoint and box it - fn clone_box(&self) -> Box>; +/// Stream types for boxed streams +#[derive(Debug)] +pub struct BoxedStreamTypes { + _p: std::marker::PhantomData<(In, Out)>, +} + +impl Clone for BoxedStreamTypes { + fn clone(&self) -> Self { + Self { + _p: std::marker::PhantomData, + } + } +} + +impl ConnectionErrors for BoxedStreamTypes { + type SendError = anyhow::Error; + type RecvError = anyhow::Error; + type OpenError = anyhow::Error; + type AcceptError = anyhow::Error; +} + +impl StreamTypes for BoxedStreamTypes { + type In = In; + type Out = Out; + type RecvStream = RecvStream; + type SendSink = SendSink; +} + +/// A boxable listener +pub trait BoxableListener: Debug + Send + Sync + 'static { + /// Clone the listener and box it + fn clone_box(&self) -> Box>; /// Accept a channel from a remote client fn accept_bi_boxed(&self) -> AcceptFuture; @@ -261,38 +301,41 @@ pub trait BoxableServerEndpoint: fn local_addr(&self) -> &[super::LocalAddr]; } -/// A boxed server endpoint +/// A boxed listener #[derive(Debug)] -pub struct ServerEndpoint(Box>); +pub struct BoxedListener(Box>); -impl ServerEndpoint { - /// Wrap a boxable server endpoint into a box, transforming all the types to concrete types - pub fn new(x: impl BoxableServerEndpoint) -> Self { +impl BoxedListener { + /// Wrap a boxable listener into a box, transforming all the types to concrete types + pub fn new(x: impl BoxableListener) -> Self { Self(Box::new(x)) } } -impl Clone for ServerEndpoint { +impl Clone for BoxedListener { fn clone(&self) -> Self { Self(self.0.clone_box()) } } -impl ConnectionCommon for ServerEndpoint { +impl StreamTypes for BoxedListener { + type In = In; + type Out = Out; type RecvStream = RecvStream; type SendSink = SendSink; } -impl ConnectionErrors for ServerEndpoint { - type OpenError = anyhow::Error; +impl ConnectionErrors for BoxedListener { type SendError = anyhow::Error; type RecvError = anyhow::Error; + type OpenError = anyhow::Error; + type AcceptError = anyhow::Error; } -impl super::ServerEndpoint for ServerEndpoint { +impl super::Listener for BoxedListener { fn accept( &self, - ) -> impl Future> + Send + ) -> impl Future> + Send { self.0.accept_bi_boxed() } @@ -301,18 +344,27 @@ impl super::ServerEndpoint for ServerE self.0.local_addr() } } +impl BoxableConnector for BoxedConnector { + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn open_boxed(&self) -> OpenFuture { + OpenFuture::boxed(crate::transport::Connector::open(self)) + } +} #[cfg(feature = "quinn-transport")] -impl BoxableConnection - for super::quinn::QuinnConnection +impl BoxableConnector + for super::quinn::QuinnConnector { - fn clone_box(&self) -> Box> { + fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn open_boxed(&self) -> OpenFuture { let f = Box::pin(async move { - let (send, recv) = super::Connection::open(self).await?; + let (send, recv) = super::Connector::open(self).await?; // map the error types to anyhow let send = send.sink_map_err(anyhow::Error::from); let recv = recv.map_err(anyhow::Error::from); @@ -324,16 +376,16 @@ impl BoxableConnection } #[cfg(feature = "quinn-transport")] -impl BoxableServerEndpoint - for super::quinn::QuinnServerEndpoint +impl BoxableListener + for super::quinn::QuinnListener { - fn clone_box(&self) -> Box> { + fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn accept_bi_boxed(&self) -> AcceptFuture { let f = async move { - let (send, recv) = super::ServerEndpoint::accept(self).await?; + let (send, recv) = super::Listener::accept(self).await?; let send = send.sink_map_err(anyhow::Error::from); let recv = recv.map_err(anyhow::Error::from); anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv))) @@ -342,37 +394,65 @@ impl BoxableServerEndpoint } fn local_addr(&self) -> &[super::LocalAddr] { - super::ServerEndpoint::local_addr(self) + super::Listener::local_addr(self) } } #[cfg(feature = "flume-transport")] -impl BoxableConnection - for super::flume::FlumeConnection +impl BoxableConnector + for super::flume::FlumeConnector { - fn clone_box(&self) -> Box> { + fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn open_boxed(&self) -> OpenFuture { - OpenFuture::direct(super::Connection::open(self)) + OpenFuture::direct(super::Connector::open(self)) } } #[cfg(feature = "flume-transport")] -impl BoxableServerEndpoint - for super::flume::FlumeServerEndpoint +impl BoxableListener + for super::flume::FlumeListener { - fn clone_box(&self) -> Box> { + fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn accept_bi_boxed(&self) -> AcceptFuture { - AcceptFuture::direct(super::ServerEndpoint::accept(self)) + AcceptFuture::direct(super::Listener::accept(self)) } fn local_addr(&self) -> &[super::LocalAddr] { - super::ServerEndpoint::local_addr(self) + super::Listener::local_addr(self) + } +} + +impl BoxableConnector for super::mapped::MappedConnector +where + In: RpcMessage, + Out: RpcMessage, + C: super::Connector, + C::Out: From, + In: TryFrom, + C::SendError: Into, + C::RecvError: Into, + C::OpenError: Into, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn open_boxed(&self) -> OpenFuture { + let f = Box::pin(async move { + let (send, recv) = super::Connector::open(self).await.map_err(|e| e.into())?; + // map the error types to anyhow + let send = send.sink_map_err(|e| e.into()); + let recv = recv.map_err(|e| e.into()); + // return the boxed streams + anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv))) + }); + OpenFuture::boxed(f) } } @@ -394,11 +474,11 @@ mod tests { use futures_lite::StreamExt; use futures_util::SinkExt; - use crate::transport::{Connection, ServerEndpoint}; + use crate::transport::{Connector, Listener}; - let (server, client) = crate::transport::flume::service_connection::(1); - let server = super::ServerEndpoint::new(server); - let client = super::Connection::new(client); + let (server, client) = crate::transport::flume::channel(1); + let server = super::BoxedListener::new(server); + let client = super::BoxedConnector::new(client); // spawn echo server tokio::spawn(async move { while let Ok((mut send, mut recv)) = server.accept().await { diff --git a/src/transport/combined.rs b/src/transport/combined.rs index 8c27805..1829a3f 100644 --- a/src/transport/combined.rs +++ b/src/transport/combined.rs @@ -1,83 +1,51 @@ //! Transport that combines two other transports -use super::{Connection, ConnectionCommon, ConnectionErrors, LocalAddr, ServerEndpoint}; -use crate::RpcMessage; +use super::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes}; use futures_lite::Stream; use futures_sink::Sink; use pin_project::pin_project; use std::{ error, fmt, fmt::Debug, - marker::PhantomData, pin::Pin, task::{Context, Poll}, }; /// A connection that combines two other connections -pub struct CombinedConnection { +#[derive(Debug, Clone)] +pub struct CombinedConnector { /// First connection pub a: Option, /// Second connection pub b: Option, - /// Phantom data so we can have `S` as type parameters - _p: PhantomData<(In, Out)>, } -impl, B: Connection, In, Out> CombinedConnection { +impl> CombinedConnector { /// Create a combined connection from two other connections /// /// It will always use the first connection that is not `None`. pub fn new(a: Option, b: Option) -> Self { - Self { - a, - b, - _p: PhantomData, - } - } -} -impl Clone - for CombinedConnection -{ - fn clone(&self) -> Self { - Self { - a: self.a.clone(), - b: self.b.clone(), - _p: PhantomData, - } - } -} - -impl Debug - for CombinedConnection -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CombinedConnection") - .field("a", &self.a) - .field("b", &self.b) - .finish() + Self { a, b } } } /// An endpoint that combines two other endpoints -pub struct CombinedServerEndpoint { +#[derive(Debug, Clone)] +pub struct CombinedListener { /// First endpoint pub a: Option, /// Second endpoint pub b: Option, /// Local addresses from all endpoints local_addr: Vec, - /// Phantom data so we can have `S` as type parameters - _p: PhantomData<(In, Out)>, } -impl, B: ServerEndpoint, In: RpcMessage, Out: RpcMessage> - CombinedServerEndpoint -{ - /// Create a combined server endpoint from two other server endpoints +impl> CombinedListener { + /// Create a combined listener from two other listeners /// /// When listening for incoming connections with - /// [crate::ServerEndpoint::accept], all configured channels will be listened on, + /// [`Listener::accept`], all configured channels will be listened on, /// and the first to receive a connection will be used. If no channels are configured, - /// accept_bi will not throw an error but wait forever. + /// accept will not throw an error but just wait forever. pub fn new(a: Option, b: Option) -> Self { let mut local_addr = Vec::with_capacity(2); if let Some(a) = &a { @@ -86,12 +54,7 @@ impl, B: ServerEndpoint, In: RpcMessage, Out if let Some(b) = &b { local_addr.extend(b.local_addr().iter().cloned()) }; - Self { - a, - b, - local_addr, - _p: PhantomData, - } + Self { a, b, local_addr } } /// Get back the inner endpoints @@ -100,51 +63,16 @@ impl, B: ServerEndpoint, In: RpcMessage, Out } } -impl Clone - for CombinedServerEndpoint -{ - fn clone(&self) -> Self { - Self { - a: self.a.clone(), - b: self.b.clone(), - local_addr: self.local_addr.clone(), - _p: PhantomData, - } - } -} - -impl Debug - for CombinedServerEndpoint -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CombinedServerEndpoint") - .field("a", &self.a) - .field("b", &self.b) - .finish() - } -} - /// Send sink for combined channels #[pin_project(project = SendSinkProj)] -pub enum SendSink< - A: ConnectionCommon, - B: ConnectionCommon, - In: RpcMessage, - Out: RpcMessage, -> { +pub enum SendSink { /// A variant A(#[pin] A::SendSink), /// B variant B(#[pin] B::SendSink), } -impl< - A: ConnectionCommon, - B: ConnectionCommon, - In: RpcMessage, - Out: RpcMessage, - > Sink for SendSink -{ +impl> Sink for SendSink { type Error = self::SendError; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -154,7 +82,7 @@ impl< } } - fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: A::Out) -> Result<(), Self::Error> { match self.project() { SendSinkProj::A(sink) => sink.start_send(item).map_err(Self::Error::A), SendSinkProj::B(sink) => sink.start_send(item).map_err(Self::Error::B), @@ -178,26 +106,15 @@ impl< /// RecvStream for combined channels #[pin_project(project = ResStreamProj)] -pub enum RecvStream< - A: ConnectionCommon, - B: ConnectionCommon, - In: RpcMessage, - Out: RpcMessage, -> { +pub enum RecvStream { /// A variant A(#[pin] A::RecvStream), /// B variant B(#[pin] B::RecvStream), } -impl< - A: ConnectionCommon, - B: ConnectionCommon, - In: RpcMessage, - Out: RpcMessage, - > Stream for RecvStream -{ - type Item = Result>; +impl> Stream for RecvStream { + type Item = Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -241,9 +158,9 @@ impl fmt::Display for RecvError impl error::Error for RecvError {} -/// OpenBiError for combined channels +/// OpenError for combined channels #[derive(Debug)] -pub enum OpenBiError { +pub enum OpenError { /// A variant A(A::OpenError), /// B variant @@ -252,86 +169,80 @@ pub enum OpenBiError { NoChannel, } -impl fmt::Display for OpenBiError { +impl fmt::Display for OpenError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for OpenBiError {} +impl error::Error for OpenError {} -/// AcceptBiError for combined channels +/// AcceptError for combined channels #[derive(Debug)] -pub enum AcceptBiError { +pub enum AcceptError { /// A variant - A(A::OpenError), + A(A::AcceptError), /// B variant - B(B::OpenError), + B(B::AcceptError), } -impl fmt::Display for AcceptBiError { +impl fmt::Display for AcceptError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for AcceptBiError {} +impl error::Error for AcceptError {} -impl ConnectionErrors - for CombinedConnection -{ +impl ConnectionErrors for CombinedConnector { type SendError = self::SendError; type RecvError = self::RecvError; - type OpenError = self::OpenBiError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; } -impl, B: Connection, In: RpcMessage, Out: RpcMessage> - ConnectionCommon for CombinedConnection -{ - type RecvStream = self::RecvStream; - type SendSink = self::SendSink; +impl> StreamTypes for CombinedConnector { + type In = A::In; + type Out = A::Out; + type RecvStream = self::RecvStream; + type SendSink = self::SendSink; } -impl, B: Connection, In: RpcMessage, Out: RpcMessage> - Connection for CombinedConnection -{ +impl> Connector for CombinedConnector { async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { let this = self.clone(); // try a first, then b if let Some(a) = this.a { - let (send, recv) = a.open().await.map_err(OpenBiError::A)?; + let (send, recv) = a.open().await.map_err(OpenError::A)?; Ok((SendSink::A(send), RecvStream::A(recv))) } else if let Some(b) = this.b { - let (send, recv) = b.open().await.map_err(OpenBiError::B)?; + let (send, recv) = b.open().await.map_err(OpenError::B)?; Ok((SendSink::B(send), RecvStream::B(recv))) } else { - Err(OpenBiError::NoChannel) + Err(OpenError::NoChannel) } } } -impl ConnectionErrors - for CombinedServerEndpoint -{ +impl ConnectionErrors for CombinedListener { type SendError = self::SendError; type RecvError = self::RecvError; - type OpenError = self::AcceptBiError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; } -impl, B: ServerEndpoint, In: RpcMessage, Out: RpcMessage> - ConnectionCommon for CombinedServerEndpoint -{ - type RecvStream = self::RecvStream; - type SendSink = self::SendSink; +impl> StreamTypes for CombinedListener { + type In = A::In; + type Out = A::Out; + type RecvStream = self::RecvStream; + type SendSink = self::SendSink; } -impl, B: ServerEndpoint, In: RpcMessage, Out: RpcMessage> - ServerEndpoint for CombinedServerEndpoint -{ - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { +impl> Listener for CombinedListener { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { let a_fut = async { if let Some(a) = &self.a { - let (send, recv) = a.accept().await.map_err(AcceptBiError::A)?; + let (send, recv) = a.accept().await.map_err(AcceptError::A)?; Ok((SendSink::A(send), RecvStream::A(recv))) } else { std::future::pending().await @@ -339,7 +250,7 @@ impl, B: ServerEndpoint, In: RpcMessage, Out }; let b_fut = async { if let Some(b) = &self.b { - let (send, recv) = b.accept().await.map_err(AcceptBiError::B)?; + let (send, recv) = b.accept().await.map_err(AcceptError::B)?; Ok((SendSink::B(send), RecvStream::B(recv))) } else { std::future::pending().await @@ -360,24 +271,20 @@ impl, B: ServerEndpoint, In: RpcMessage, Out } #[cfg(test)] +#[cfg(feature = "flume-transport")] mod tests { - use crate::{ - transport::{ - combined::{self, OpenBiError}, - flume, - }, - Connection, + use crate::transport::{ + combined::{self, OpenError}, + flume, Connector, }; #[tokio::test] async fn open_empty_channel() { - let channel = combined::CombinedConnection::< - flume::FlumeConnection<(), ()>, - flume::FlumeConnection<(), ()>, - (), - (), + let channel = combined::CombinedConnector::< + flume::FlumeConnector<(), ()>, + flume::FlumeConnector<(), ()>, >::new(None, None); let res = channel.open().await; - assert!(matches!(res, Err(OpenBiError::NoChannel))); + assert!(matches!(res, Err(OpenError::NoChannel))); } } diff --git a/src/transport/flume.rs b/src/transport/flume.rs index f12e0f6..2671953 100644 --- a/src/transport/flume.rs +++ b/src/transport/flume.rs @@ -5,13 +5,13 @@ use futures_lite::{Future, Stream}; use futures_sink::Sink; use crate::{ - transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint}, + transport::{ConnectionErrors, Connector, Listener, LocalAddr}, RpcMessage, }; use core::fmt; use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, task::Poll}; -use super::ConnectionCommon; +use super::StreamTypes; /// Error when receiving from a channel /// @@ -97,15 +97,15 @@ impl Stream for RecvStream { impl error::Error for RecvError {} -/// A flume based server endpoint. +/// A flume based listener. /// -/// Created using [connection]. -pub struct FlumeServerEndpoint { +/// Created using [channel]. +pub struct FlumeListener { #[allow(clippy::type_complexity)] stream: flume::Receiver<(SendSink, RecvStream)>, } -impl Clone for FlumeServerEndpoint { +impl Clone for FlumeListener { fn clone(&self) -> Self { Self { stream: self.stream.clone(), @@ -113,37 +113,36 @@ impl Clone for FlumeServerEndpoint { } } -impl fmt::Debug for FlumeServerEndpoint { +impl fmt::Debug for FlumeListener { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FlumeServerEndpoint") + f.debug_struct("FlumeListener") .field("stream", &self.stream) .finish() } } -impl ConnectionErrors for FlumeServerEndpoint { +impl ConnectionErrors for FlumeListener { type SendError = self::SendError; - type RecvError = self::RecvError; - - type OpenError = self::AcceptBiError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; } type Socket = (self::SendSink, self::RecvStream); -/// Future returned by [FlumeConnection::open] -pub struct OpenBiFuture { +/// Future returned by [FlumeConnector::open] +pub struct OpenFuture { inner: flume::r#async::SendFut<'static, Socket>, res: Option>, } -impl fmt::Debug for OpenBiFuture { +impl fmt::Debug for OpenFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenBiFuture").finish() + f.debug_struct("OpenFuture").finish() } } -impl OpenBiFuture { +impl OpenFuture { fn new(inner: flume::r#async::SendFut<'static, Socket>, res: Socket) -> Self { Self { inner, @@ -152,8 +151,8 @@ impl OpenBiFuture { } } -impl Future for OpenBiFuture { - type Output = result::Result, self::OpenBiError>; +impl Future for OpenFuture { + type Output = result::Result, self::OpenError>; fn poll( mut self: Pin<&mut Self>, @@ -165,45 +164,47 @@ impl Future for OpenBiFuture { .take() .map(|x| Poll::Ready(Ok(x))) .unwrap_or(Poll::Pending), - Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)), + Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenError::RemoteDropped)), Poll::Pending => Poll::Pending, } } } -/// Future returned by [FlumeServerEndpoint::accept] -pub struct AcceptBiFuture { +/// Future returned by [FlumeListener::accept] +pub struct AcceptFuture { wrapped: flume::r#async::RecvFut<'static, (SendSink, RecvStream)>, _p: PhantomData<(In, Out)>, } -impl fmt::Debug for AcceptBiFuture { +impl fmt::Debug for AcceptFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AcceptBiFuture").finish() + f.debug_struct("AcceptFuture").finish() } } -impl Future for AcceptBiFuture { - type Output = result::Result<(SendSink, RecvStream), AcceptBiError>; +impl Future for AcceptFuture { + type Output = result::Result<(SendSink, RecvStream), AcceptError>; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { match Pin::new(&mut self.wrapped).poll(cx) { Poll::Ready(Ok((send, recv))) => Poll::Ready(Ok((send, recv))), - Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptBiError::RemoteDropped)), + Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptError::RemoteDropped)), Poll::Pending => Poll::Pending, } } } -impl ConnectionCommon for FlumeServerEndpoint { +impl StreamTypes for FlumeListener { + type In = In; + type Out = Out; type SendSink = SendSink; type RecvStream = RecvStream; } -impl ServerEndpoint for FlumeServerEndpoint { +impl Listener for FlumeListener { #[allow(refining_impl_trait)] - fn accept(&self) -> AcceptBiFuture { - AcceptBiFuture { + fn accept(&self) -> AcceptFuture { + AcceptFuture { wrapped: self.stream.clone().into_recv_async(), _p: PhantomData, } @@ -214,22 +215,23 @@ impl ServerEndpoint for FlumeServerEnd } } -impl ConnectionErrors for FlumeConnection { +impl ConnectionErrors for FlumeConnector { type SendError = self::SendError; - type RecvError = self::RecvError; - - type OpenError = self::OpenBiError; + type OpenError = self::OpenError; + type AcceptError = self::AcceptError; } -impl ConnectionCommon for FlumeConnection { +impl StreamTypes for FlumeConnector { + type In = In; + type Out = Out; type SendSink = SendSink; type RecvStream = RecvStream; } -impl Connection for FlumeConnection { +impl Connector for FlumeConnector { #[allow(refining_impl_trait)] - fn open(&self) -> OpenBiFuture { + fn open(&self) -> OpenFuture { let (local_send, remote_recv) = flume::bounded::(128); let (remote_send, local_recv) = flume::bounded::(128); let remote_chan = ( @@ -240,19 +242,19 @@ impl Connection for FlumeConnection { +/// Created using [channel]. +pub struct FlumeConnector { #[allow(clippy::type_complexity)] sink: flume::Sender<(SendSink, RecvStream)>, } -impl Clone for FlumeConnection { +impl Clone for FlumeConnector { fn clone(&self) -> Self { Self { sink: self.sink.clone(), @@ -260,7 +262,7 @@ impl Clone for FlumeConnection { } } -impl fmt::Debug for FlumeConnection { +impl fmt::Debug for FlumeConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FlumeClientChannel") .field("sink", &self.sink) @@ -268,22 +270,22 @@ impl fmt::Debug for FlumeConnection { } } -/// AcceptBiError for mem channels. +/// AcceptError for mem channels. /// /// There is not much that can go wrong with mem channels. #[derive(Debug)] -pub enum AcceptBiError { +pub enum AcceptError { /// The remote side of the channel was dropped RemoteDropped, } -impl fmt::Display for AcceptBiError { +impl fmt::Display for AcceptError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for AcceptBiError {} +impl error::Error for AcceptError {} /// SendError for mem channels. /// @@ -302,20 +304,20 @@ impl Display for SendError { impl std::error::Error for SendError {} -/// OpenBiError for mem channels. +/// OpenError for mem channels. #[derive(Debug)] -pub enum OpenBiError { +pub enum OpenError { /// The remote side of the channel was dropped RemoteDropped, } -impl Display for OpenBiError { +impl Display for OpenError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl std::error::Error for OpenBiError {} +impl std::error::Error for OpenError {} /// CreateChannelError for mem channels. /// @@ -332,23 +334,12 @@ impl Display for CreateChannelError { impl std::error::Error for CreateChannelError {} -/// Create a flume server endpoint and a connected flume client channel. +/// Create a flume listener and a connected flume connector. /// /// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure -pub fn connection( +pub fn channel( buffer: usize, -) -> (FlumeServerEndpoint, FlumeConnection) { +) -> (FlumeListener, FlumeConnector) { let (sink, stream) = flume::bounded(buffer); - (FlumeServerEndpoint { stream }, FlumeConnection { sink }) -} - -/// Create a flume server endpoint and a connected flume client channel for a specific service. -#[allow(clippy::type_complexity)] -pub fn service_connection( - buffer: usize, -) -> ( - FlumeServerEndpoint, - FlumeConnection, -) { - connection(buffer) + (FlumeListener { stream }, FlumeConnector { sink }) } diff --git a/src/transport/hyper.rs b/src/transport/hyper.rs index 42be44b..131c576 100644 --- a/src/transport/hyper.rs +++ b/src/transport/hyper.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, task::Poll, }; -use crate::transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint}; +use crate::transport::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes}; use crate::RpcMessage; use bytes::Bytes; use flume::{Receiver, Sender}; @@ -22,8 +22,6 @@ use tokio::sync::mpsc; use tokio::task::JoinHandle; use tracing::{debug, event, trace, Level}; -use super::ConnectionCommon; - struct HyperConnectionInner { client: Box, config: Arc, @@ -31,12 +29,12 @@ struct HyperConnectionInner { } /// Hyper based connection to a server -pub struct HyperConnection { +pub struct HyperConnector { inner: Arc, _p: PhantomData<(In, Out)>, } -impl Clone for HyperConnection { +impl Clone for HyperConnector { fn clone(&self) -> Self { Self { inner: self.inner.clone(), @@ -56,7 +54,7 @@ impl Requester for Client { } } -impl HyperConnection { +impl HyperConnector { /// create a client given an uri and the default configuration pub fn new(uri: Uri) -> Self { Self::with_config(uri, ChannelConfig::default()) @@ -93,7 +91,7 @@ impl HyperConnection { } } -impl fmt::Debug for HyperConnection { +impl fmt::Debug for HyperConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ClientChannel") .field("uri", &self.inner.uri) @@ -164,7 +162,7 @@ impl Default for ChannelConfig { } } -/// A server endpoint using a hyper server +/// A listener using a hyper server /// /// Each request made by the any client connection this channel will yield a `(recv, send)` /// pair which allows receiving the request and sending the response. Both these are @@ -173,7 +171,7 @@ impl Default for ChannelConfig { /// Creating this spawns a tokio task which runs the server, once dropped this task is shut /// down: no new connections will be accepted and existing channels will stop. #[derive(Debug)] -pub struct HyperServerEndpoint { +pub struct HyperListener { /// The channel. channel: Receiver>, /// The configuration. @@ -192,7 +190,7 @@ pub struct HyperServerEndpoint { _p: PhantomData<(In, Out)>, } -impl HyperServerEndpoint { +impl HyperListener { /// Creates a server listening on the [`SocketAddr`], with the default configuration. pub fn serve(addr: &SocketAddr) -> hyper::Result { Self::serve_with_config(addr, Default::default()) @@ -365,7 +363,7 @@ fn spawn_recv_forwarder( // This does not want or need RpcMessage to be clone but still want to clone the // ServerChannel and it's containing channels itself. The derive macro can't cope with this // so this needs to be written by hand. -impl Clone for HyperServerEndpoint { +impl Clone for HyperListener { fn clone(&self) -> Self { Self { channel: self.channel.clone(), @@ -536,9 +534,9 @@ impl fmt::Display for RecvError { impl error::Error for RecvError {} -/// OpenBiError for hyper channels. +/// OpenError for hyper channels. #[derive(Debug)] -pub enum OpenBiError { +pub enum OpenError { /// Hyper http error HyperHttp(hyper::http::Error), /// Generic hyper error @@ -547,59 +545,62 @@ pub enum OpenBiError { RemoteDropped, } -impl fmt::Display for OpenBiError { +impl fmt::Display for OpenError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl std::error::Error for OpenBiError {} +impl std::error::Error for OpenError {} -/// AcceptBiError for hyper channels. +/// AcceptError for hyper channels. /// /// There is not much that can go wrong with hyper channels. #[derive(Debug)] -pub enum AcceptBiError { +pub enum AcceptError { /// Hyper error Hyper(hyper::http::Error), /// The remote side of the channel was dropped RemoteDropped, } -impl fmt::Display for AcceptBiError { +impl fmt::Display for AcceptError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self, f) } } -impl error::Error for AcceptBiError {} +impl error::Error for AcceptError {} -impl ConnectionErrors for HyperConnection { +impl ConnectionErrors for HyperConnector { type SendError = self::SendError; type RecvError = self::RecvError; - type OpenError = OpenBiError; + type OpenError = OpenError; + + type AcceptError = AcceptError; } -impl ConnectionCommon for HyperConnection { +impl StreamTypes for HyperConnector { + type In = In; + type Out = Out; type RecvStream = self::RecvStream; - type SendSink = self::SendSink; } -impl Connection for HyperConnection { +impl Connector for HyperConnector { async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { let (out_tx, out_rx) = flume::bounded::>(32); let req: Request = Request::post(&self.inner.uri) .body(Body::wrap_stream(out_rx.into_stream())) - .map_err(OpenBiError::HyperHttp)?; + .map_err(OpenError::HyperHttp)?; let res = self .inner .client .request(req) .await - .map_err(OpenBiError::Hyper)?; + .map_err(OpenError::Hyper)?; let (in_tx, in_rx) = flume::bounded::>(32); spawn_recv_forwarder(res.into_body(), in_tx); @@ -609,30 +610,31 @@ impl Connection for HyperConnection ConnectionErrors for HyperServerEndpoint { +impl ConnectionErrors for HyperListener { type SendError = self::SendError; - type RecvError = self::RecvError; - - type OpenError = AcceptBiError; + type OpenError = AcceptError; + type AcceptError = AcceptError; } -impl ConnectionCommon for HyperServerEndpoint { +impl StreamTypes for HyperListener { + type In = In; + type Out = Out; type RecvStream = self::RecvStream; type SendSink = self::SendSink; } -impl ServerEndpoint for HyperServerEndpoint { +impl Listener for HyperListener { fn local_addr(&self) -> &[LocalAddr] { &self.local_addr } - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { let (recv, send) = self .channel .recv_async() .await - .map_err(|_| AcceptBiError::RemoteDropped)?; + .map_err(|_| AcceptError::RemoteDropped)?; Ok(( SendSink::new(send, self.config.clone()), RecvStream::new(recv), diff --git a/src/transport/mapped.rs b/src/transport/mapped.rs new file mode 100644 index 0000000..fbdee41 --- /dev/null +++ b/src/transport/mapped.rs @@ -0,0 +1,326 @@ +//! Transport with mapped input and output types. +use std::{ + fmt::{Debug, Display}, + marker::PhantomData, + task::{Context, Poll}, +}; + +use futures_lite::{Stream, StreamExt}; +use futures_util::SinkExt; +use pin_project::pin_project; + +use crate::{RpcError, RpcMessage}; + +use super::{ConnectionErrors, Connector, StreamTypes}; + +/// A connection that maps input and output types +#[derive(Debug)] +pub struct MappedConnector { + inner: C, + _p: std::marker::PhantomData<(In, Out)>, +} + +impl MappedConnector +where + C: Connector, + In: TryFrom, + C::Out: From, +{ + /// Create a new mapped connection + pub fn new(inner: C) -> Self { + Self { + inner, + _p: std::marker::PhantomData, + } + } +} + +impl Clone for MappedConnector +where + C: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _p: std::marker::PhantomData, + } + } +} + +impl ConnectionErrors for MappedConnector +where + In: RpcMessage, + Out: RpcMessage, + C: ConnectionErrors, +{ + type RecvError = ErrorOrMapError; + type SendError = C::SendError; + type OpenError = C::OpenError; + type AcceptError = C::AcceptError; +} + +impl StreamTypes for MappedConnector +where + C: StreamTypes, + In: RpcMessage, + Out: RpcMessage, + In: TryFrom, + C::Out: From, +{ + type In = In; + type Out = Out; + type RecvStream = MappedRecvStream; + type SendSink = MappedSendSink; +} + +impl Connector for MappedConnector +where + C: Connector, + In: RpcMessage, + Out: RpcMessage, + In: TryFrom, + C::Out: From, +{ + fn open( + &self, + ) -> impl std::future::Future> + + Send { + let inner = self.inner.open(); + async move { + let (send, recv) = inner.await?; + Ok((MappedSendSink::new(send), MappedRecvStream::new(recv))) + } + } +} + +/// A combinator that maps a stream of incoming messages to a different type +#[pin_project] +pub struct MappedRecvStream { + inner: S, + _p: std::marker::PhantomData, +} + +impl MappedRecvStream { + /// Create a new mapped receive stream + pub fn new(inner: S) -> Self { + Self { + inner, + _p: std::marker::PhantomData, + } + } +} + +/// Error mapping an incoming message to the inner type +#[derive(Debug)] +pub enum ErrorOrMapError { + /// Error from the inner stream + Inner(E), + /// Conversion error + Conversion, +} + +impl std::error::Error for ErrorOrMapError {} + +impl Display for ErrorOrMapError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorOrMapError::Inner(e) => write!(f, "Inner error: {}", e), + ErrorOrMapError::Conversion => write!(f, "Conversion error"), + } + } +} + +impl Stream for MappedRecvStream +where + S: Stream> + Unpin, + In: TryFrom, + E: RpcError, +{ + type Item = Result>; + + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { + let item = item.try_into().map_err(|_| ErrorOrMapError::Conversion); + Poll::Ready(Some(item)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ErrorOrMapError::Inner(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// A sink that maps outgoing messages to a different type +/// +/// The conversion to the underlying message type always succeeds, so this +/// is relatively simple. +#[pin_project] +pub struct MappedSendSink { + inner: S, + _p: std::marker::PhantomData<(Out, OutS)>, +} + +impl MappedSendSink { + /// Create a new mapped send sink + pub fn new(inner: S) -> Self { + Self { + inner, + _p: std::marker::PhantomData, + } + } +} + +impl futures_sink::Sink for MappedSendSink +where + S: futures_sink::Sink + Unpin, + Out: Into, +{ + type Error = S::Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + self.project().inner.poll_ready_unpin(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + self.project().inner.start_send_unpin(item.into()) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + self.project().inner.poll_flush_unpin(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + self.project().inner.poll_close_unpin(cx) + } +} + +/// Connection types for a mapped connection +pub struct MappedStreamTypes(PhantomData<(In, Out, C)>); + +impl Debug for MappedStreamTypes { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MappedConnectionTypes").finish() + } +} + +impl Clone for MappedStreamTypes { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl ConnectionErrors for MappedStreamTypes +where + In: RpcMessage, + Out: RpcMessage, + C: ConnectionErrors, +{ + type RecvError = ErrorOrMapError; + type SendError = C::SendError; + type OpenError = C::OpenError; + type AcceptError = C::AcceptError; +} + +impl StreamTypes for MappedStreamTypes +where + C: StreamTypes, + In: RpcMessage, + Out: RpcMessage, + In: TryFrom, + C::Out: From, +{ + type In = In; + type Out = Out; + type RecvStream = MappedRecvStream; + type SendSink = MappedSendSink; +} + +#[cfg(test)] +#[cfg(feature = "flume-transport")] +mod tests { + + use crate::{ + server::{BoxedChannelTypes, RpcChannel}, + transport::Listener, + RpcClient, RpcServer, + }; + use serde::{Deserialize, Serialize}; + use testresult::TestResult; + + use super::*; + + #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)] + enum Request { + A(u64), + B(String), + } + + #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)] + enum Response { + A(u64), + B(String), + } + + #[derive(Debug, Clone)] + struct FullService; + + impl crate::Service for FullService { + type Req = Request; + type Res = Response; + } + + #[derive(Debug, Clone)] + struct SubService; + + impl crate::Service for SubService { + type Req = String; + type Res = String; + } + + #[tokio::test] + #[ignore] + async fn smoke() -> TestResult<()> { + async fn handle_sub_request( + _req: String, + _chan: RpcChannel>, + ) -> anyhow::Result<()> { + Ok(()) + } + // create a listener / connector pair. Type will be inferred + let (s, c) = crate::transport::flume::channel(32); + // wrap the server in a RpcServer, this is where the service type is specified + let server = RpcServer::::new(s.clone()); + // when using a boxed transport, we can omit the transport type and use the default + let _server_boxed: RpcServer = RpcServer::::new(s.boxed()); + // create a client in a RpcClient, this is where the service type is specified + let client = RpcClient::::new(c); + // when using a boxed transport, we can omit the transport type and use the default + let _boxed_client = client.clone().boxed(); + // map the client to a sub-service + let _sub_client: RpcClient = client.clone().map::(); + // when using a boxed transport, we can omit the transport type and use the default + let _sub_client_boxed: RpcClient = client.clone().map::().boxed(); + // we can not map the service to a sub-service, since we need the first message to determine which sub-service to use + while let Ok(accepting) = server.accept().await { + let (msg, chan) = accepting.read_first().await?; + match msg { + Request::A(_x) => todo!(), + Request::B(x) => { + // but we can map the channel to the sub-service, once we know which one to use + handle_sub_request(x, chan.map::().boxed()).await? + } + } + } + Ok(()) + } +} diff --git a/src/transport/misc/mod.rs b/src/transport/misc/mod.rs index 84d6de0..756d6c3 100644 --- a/src/transport/misc/mod.rs +++ b/src/transport/misc/mod.rs @@ -1,36 +1,48 @@ //! Miscellaneous transport utilities - use futures_lite::stream; use futures_sink::Sink; use crate::{ - transport::{ConnectionErrors, ServerEndpoint}, + transport::{ConnectionErrors, Listener}, RpcMessage, }; use std::convert::Infallible; -use super::ConnectionCommon; +use super::StreamTypes; -/// A dummy server endpoint that does nothing +/// A dummy listener that does nothing /// /// This can be useful as a default if you want to configure -/// an optional server endpoint. -#[derive(Debug, Clone, Default)] -pub struct DummyServerEndpoint; +/// an optional listener. +#[derive(Debug, Default)] +pub struct DummyListener { + _p: std::marker::PhantomData<(In, Out)>, +} -impl ConnectionErrors for DummyServerEndpoint { - type OpenError = Infallible; +impl Clone for DummyListener { + fn clone(&self) -> Self { + Self { + _p: std::marker::PhantomData, + } + } +} + +impl ConnectionErrors for DummyListener { type RecvError = Infallible; type SendError = Infallible; + type OpenError = Infallible; + type AcceptError = Infallible; } -impl ConnectionCommon for DummyServerEndpoint { +impl StreamTypes for DummyListener { + type In = In; + type Out = Out; type RecvStream = stream::Pending>; type SendSink = Box + Unpin + Send + Sync>; } -impl ServerEndpoint for DummyServerEndpoint { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { +impl Listener for DummyListener { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> { futures_lite::future::pending().await } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 79ae4b0..95c17c8 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,4 +1,4 @@ -//! Transports for quic-rpc +//! Built in transports for quic-rpc //! //! There are two sides to a transport, a server side where connections are //! accepted and a client side where connections are initiated. @@ -8,87 +8,123 @@ //! //! In the transport module, the message types are referred to as `In` and `Out`. //! -//! A [`Connection`] can be used to *open* bidirectional typed channels using -//! [`Connection::open`]. A [`ServerEndpoint`] can be used to *accept* bidirectional +//! A [`Connector`] can be used to *open* bidirectional typed channels using +//! [`Connector::open`]. A [`Listener`] can be used to *accept* bidirectional //! typed channels from any of the currently opened connections to clients, using -//! [`ServerEndpoint::accept`]. +//! [`Listener::accept`]. //! //! In both cases, the result is a tuple of a send side and a receive side. These -//! types are defined by implementing the [`ConnectionCommon`] trait. +//! types are defined by implementing the [`StreamTypes`] trait. //! //! Errors for both sides are defined by implementing the [`ConnectionErrors`] trait. +use boxed::{BoxableConnector, BoxableListener, BoxedConnector, BoxedListener}; use futures_lite::{Future, Stream}; use futures_sink::Sink; +use mapped::MappedConnector; -use crate::RpcError; +use crate::{RpcError, RpcMessage}; use std::{ fmt::{self, Debug, Display}, net::SocketAddr, }; -#[cfg(feature = "flume-transport")] + pub mod boxed; -#[cfg(feature = "combined-transport")] pub mod combined; #[cfg(feature = "flume-transport")] pub mod flume; #[cfg(feature = "hyper-transport")] pub mod hyper; +pub mod mapped; +pub mod misc; #[cfg(feature = "quinn-transport")] pub mod quinn; -pub mod misc; - #[cfg(any(feature = "quinn-transport", feature = "hyper-transport"))] mod util; -/// Errors that can happen when creating and using a [`Connection`] or [`ServerEndpoint`]. +/// Errors that can happen when creating and using a [`Connector`] or [`Listener`]. pub trait ConnectionErrors: Debug + Clone + Send + Sync + 'static { - /// Error when opening or accepting a channel - type OpenError: RpcError; /// Error when sending a message via a channel type SendError: RpcError; /// Error when receiving a message via a channel type RecvError: RpcError; + /// Error when opening a channel + type OpenError: RpcError; + /// Error when accepting a channel + type AcceptError: RpcError; } -/// Types that are common to both [`Connection`] and [`ServerEndpoint`]. +/// Types that are common to both [`Connector`] and [`Listener`]. /// /// Having this as a separate trait is useful when writing generic code that works with both. -pub trait ConnectionCommon: ConnectionErrors { +pub trait StreamTypes: ConnectionErrors { + /// The type of messages that can be received on the channel + type In: RpcMessage; + /// The type of messages that can be sent on the channel + type Out: RpcMessage; /// Receive side of a bidirectional typed channel - type RecvStream: Stream> + Send + Sync + Unpin + 'static; + type RecvStream: Stream> + + Send + + Sync + + Unpin + + 'static; /// Send side of a bidirectional typed channel - type SendSink: Sink + Send + Sync + Unpin + 'static; + type SendSink: Sink + Send + Sync + Unpin + 'static; } /// A connection to a specific remote machine /// -/// A connection can be used to open bidirectional typed channels using [`Connection::open`]. -pub trait Connection: ConnectionCommon { +/// A connection can be used to open bidirectional typed channels using [`Connector::open`]. +pub trait Connector: StreamTypes { /// Open a channel to the remote che fn open( &self, ) -> impl Future> + Send; + + /// Map the input and output types of this connection + fn map(self) -> MappedConnector + where + In1: TryFrom, + Self::Out: From, + { + MappedConnector::new(self) + } + + /// Box the connection + fn boxed(self) -> BoxedConnector + where + Self: BoxableConnector + Sized + 'static, + { + self::BoxedConnector::new(self) + } } -/// A server endpoint that listens for connections +/// A listener that listens for connections /// -/// A server endpoint can be used to accept bidirectional typed channels from any of the -/// currently opened connections to clients, using [`ServerEndpoint::accept`]. -pub trait ServerEndpoint: ConnectionCommon { +/// A listener can be used to accept bidirectional typed channels from any of the +/// currently opened connections to clients, using [`Listener::accept`]. +pub trait Listener: StreamTypes { /// Accept a new typed bidirectional channel on any of the connections we /// have currently opened. fn accept( &self, - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// The local addresses this endpoint is bound to. fn local_addr(&self) -> &[LocalAddr]; + + /// Box the listener + fn boxed(self) -> BoxedListener + where + Self: BoxableListener + Sized + 'static, + { + BoxedListener::new(self) + } } -/// The kinds of local addresses a [ServerEndpoint] can be bound to. +/// The kinds of local addresses a [Listener] can be bound to. /// -/// Returned by [ServerEndpoint::local_addr]. +/// Returned by [Listener::local_addr]. /// /// [`Display`]: fmt::Display #[derive(Debug, Clone)] diff --git a/src/transport/quinn.rs b/src/transport/quinn.rs index 577551b..db5c3bb 100644 --- a/src/transport/quinn.rs +++ b/src/transport/quinn.rs @@ -1,6 +1,6 @@ //! QUIC transport implementation based on [quinn](https://crates.io/crates/quinn) use crate::{ - transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint}, + transport::{ConnectionErrors, Connector, Listener, LocalAddr}, RpcMessage, }; use futures_lite::{Future, Stream, StreamExt}; @@ -18,28 +18,28 @@ use tracing::{debug_span, Instrument}; use super::{ util::{FramedBincodeRead, FramedBincodeWrite}, - ConnectionCommon, + StreamTypes, }; const MAX_FRAME_LENGTH: usize = 1024 * 1024 * 16; #[derive(Debug)] -struct ServerEndpointInner { +struct ListenerInner { endpoint: Option, task: Option>, local_addr: [LocalAddr; 1], receiver: flume::Receiver, } -impl Drop for ServerEndpointInner { +impl Drop for ListenerInner { fn drop(&mut self) { - tracing::debug!("Dropping server endpoint"); + tracing::debug!("Dropping listener"); if let Some(endpoint) = self.endpoint.take() { - endpoint.close(0u32.into(), b"server endpoint dropped"); + endpoint.close(0u32.into(), b"Listener dropped"); if let Ok(handle) = tokio::runtime::Handle::try_current() { // spawn a task to wait for the endpoint to notify peers that it is closing - let span = debug_span!("closing server endpoint"); + let span = debug_span!("closing listener"); handle.spawn( async move { endpoint.wait_idle().await; @@ -54,14 +54,14 @@ impl Drop for ServerEndpointInner { } } -/// A server endpoint using a quinn connection +/// A listener using a quinn connection #[derive(Debug)] -pub struct QuinnServerEndpoint { - inner: Arc, - _phantom: PhantomData<(In, Out)>, +pub struct QuinnListener { + inner: Arc, + _p: PhantomData<(In, Out)>, } -impl QuinnServerEndpoint { +impl QuinnListener { /// handles RPC requests from a connection /// /// to cleanly shutdown the handler, drop the receiver side of the sender. @@ -122,13 +122,13 @@ impl QuinnServerEndpoint { let (sender, receiver) = flume::bounded(16); let task = tokio::spawn(Self::endpoint_handler(endpoint.clone(), sender)); Ok(Self { - inner: Arc::new(ServerEndpointInner { + inner: Arc::new(ListenerInner { endpoint: Some(endpoint), task: Some(task), local_addr: [LocalAddr::Socket(local_addr)], receiver, }), - _phantom: PhantomData, + _p: PhantomData, }) } @@ -148,13 +148,13 @@ impl QuinnServerEndpoint { } }); Self { - inner: Arc::new(ServerEndpointInner { + inner: Arc::new(ListenerInner { endpoint: None, task: Some(task), local_addr: [LocalAddr::Socket(local_addr)], receiver, }), - _phantom: PhantomData, + _p: PhantomData, } } @@ -167,41 +167,42 @@ impl QuinnServerEndpoint { local_addr: SocketAddr, ) -> Self { Self { - inner: Arc::new(ServerEndpointInner { + inner: Arc::new(ListenerInner { endpoint: None, task: None, local_addr: [LocalAddr::Socket(local_addr)], receiver, }), - _phantom: PhantomData, + _p: PhantomData, } } } -impl Clone for QuinnServerEndpoint { +impl Clone for QuinnListener { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - _phantom: PhantomData, + _p: PhantomData, } } } -impl ConnectionErrors for QuinnServerEndpoint { +impl ConnectionErrors for QuinnListener { type SendError = io::Error; - type RecvError = io::Error; - type OpenError = quinn::ConnectionError; + type AcceptError = quinn::ConnectionError; } -impl ConnectionCommon for QuinnServerEndpoint { +impl StreamTypes for QuinnListener { + type In = In; + type Out = Out; type SendSink = self::SendSink; type RecvStream = self::RecvStream; } -impl ServerEndpoint for QuinnServerEndpoint { - async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { +impl Listener for QuinnListener { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> { let (send, recv) = self .inner .receiver @@ -254,12 +255,12 @@ impl Drop for ClientConnectionInner { } /// A connection using a quinn connection -pub struct QuinnConnection { +pub struct QuinnConnector { inner: Arc, - _phantom: PhantomData<(In, Out)>, + _p: PhantomData<(In, Out)>, } -impl QuinnConnection { +impl QuinnConnector { async fn single_connection_handler_inner( connection: quinn::Connection, requests: flume::Receiver>>, @@ -441,7 +442,7 @@ impl QuinnConnection { task: Some(task), sender, }), - _phantom: PhantomData, + _p: PhantomData, } } @@ -460,7 +461,7 @@ impl QuinnConnection { task: Some(task), sender, }), - _phantom: PhantomData, + _p: PhantomData, } } } @@ -601,7 +602,7 @@ impl<'a, T> Stream for Receiver<'a, T> { } } -impl fmt::Debug for QuinnConnection { +impl fmt::Debug for QuinnConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ClientChannel") .field("inner", &self.inner) @@ -609,29 +610,30 @@ impl fmt::Debug for QuinnConnection { } } -impl Clone for QuinnConnection { +impl Clone for QuinnConnector { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - _phantom: PhantomData, + _p: PhantomData, } } } -impl ConnectionErrors for QuinnConnection { +impl ConnectionErrors for QuinnConnector { type SendError = io::Error; - type RecvError = io::Error; - type OpenError = quinn::ConnectionError; + type AcceptError = quinn::ConnectionError; } -impl ConnectionCommon for QuinnConnection { +impl StreamTypes for QuinnConnector { + type In = In; + type Out = Out; type SendSink = self::SendSink; type RecvStream = self::RecvStream; } -impl Connection for QuinnConnection { +impl Connector for QuinnConnector { async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> { let (sender, receiver) = oneshot::channel(); self.inner @@ -743,10 +745,10 @@ impl Stream for RecvStream { } /// Error for open. Currently just a quinn::ConnectionError -pub type OpenBiError = quinn::ConnectionError; +pub type OpenError = quinn::ConnectionError; /// Error for accept. Currently just a quinn::ConnectionError -pub type AcceptBiError = quinn::ConnectionError; +pub type AcceptError = quinn::ConnectionError; /// CreateChannelError for quinn channels. #[derive(Debug, Clone)] diff --git a/tests/flume.rs b/tests/flume.rs index 54f211e..d0b14c8 100644 --- a/tests/flume.rs +++ b/tests/flume.rs @@ -11,7 +11,7 @@ use quic_rpc::{ #[tokio::test] async fn flume_channel_bench() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(ComputeService::server(server)); @@ -60,7 +60,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> { type Req = InnerRequest; type Res = InnerResponse; } - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let server = RpcServer::::new(server); let server_handle: tokio::task::JoinHandle>> = @@ -73,8 +73,8 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> { let req: OuterRequest = req; match req { OuterRequest::Inner(InnerRequest::Compute(req)) => { - let chan: RpcChannel = chan.map(); - let chan: RpcChannel = chan.map(); + let chan: RpcChannel = chan.map(); + let chan: RpcChannel = chan.map(); ComputeService::handle_rpc_request(service, req, chan).await } } @@ -83,8 +83,8 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> { }); let client = RpcClient::::new(client); - let client: RpcClient = client.map(); - let client: RpcClient = client.map(); + let client: RpcClient = client.map(); + let client: RpcClient = client.map(); bench(client, 1000000).await?; // dropping the client will cause the server to terminate match server_handle.await? { @@ -98,7 +98,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> { #[tokio::test] async fn flume_channel_smoke() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(ComputeService::server(server)); diff --git a/tests/hyper.rs b/tests/hyper.rs index d085e97..0e5766d 100644 --- a/tests/hyper.rs +++ b/tests/hyper.rs @@ -7,7 +7,7 @@ use flume::Receiver; use quic_rpc::{ declare_rpc, server::RpcServerError, - transport::hyper::{self, HyperConnection, HyperServerEndpoint, RecvError}, + transport::hyper::{self, HyperConnector, HyperListener, RecvError}, RpcClient, RpcServer, Service, }; use serde::{Deserialize, Serialize}; @@ -18,7 +18,7 @@ use math::*; mod util; fn run_server(addr: &SocketAddr) -> JoinHandle> { - let channel = HyperServerEndpoint::serve(addr).unwrap(); + let channel = HyperListener::serve(addr).unwrap(); let server = RpcServer::new(channel); tokio::spawn(async move { loop { @@ -38,7 +38,7 @@ enum TestResponse { NoDeser(NoDeser), } -type SC = HyperServerEndpoint; +type SC = HyperListener; /// request that can be too big #[derive(Debug, Serialize, Deserialize)] @@ -134,7 +134,7 @@ async fn hyper_channel_bench() -> anyhow::Result<()> { let addr: SocketAddr = "127.0.0.1:3000".parse()?; let uri: Uri = "http://127.0.0.1:3000".parse()?; let server_handle = run_server(&addr); - let client = HyperConnection::new(uri); + let client = HyperConnector::new(uri); let client = RpcClient::new(client); bench(client, 50000).await?; println!("terminating server"); @@ -148,7 +148,7 @@ async fn hyper_channel_smoke() -> anyhow::Result<()> { let addr: SocketAddr = "127.0.0.1:3001".parse()?; let uri: Uri = "http://127.0.0.1:3001".parse()?; let server_handle = run_server(&addr); - let client = HyperConnection::new(uri); + let client = HyperConnector::new(uri); smoke_test(client).await?; server_handle.abort(); let _ = server_handle.await; @@ -171,7 +171,7 @@ async fn hyper_channel_errors() -> anyhow::Result<()> { JoinHandle>, Receiver>>, ) { - let channel = HyperServerEndpoint::serve(addr).unwrap(); + let channel = HyperListener::serve(addr).unwrap(); let server = RpcServer::new(channel); let (res_tx, res_rx) = flume::unbounded(); let handle = tokio::spawn(async move { @@ -214,7 +214,7 @@ async fn hyper_channel_errors() -> anyhow::Result<()> { let addr: SocketAddr = "127.0.0.1:3002".parse()?; let uri: Uri = "http://127.0.0.1:3002".parse()?; let (server_handle, server_results) = run_test_server(&addr); - let client = HyperConnection::new(uri); + let client = HyperConnector::new(uri); let client = RpcClient::new(client); macro_rules! assert_matches { diff --git a/tests/math.rs b/tests/math.rs index 94794f3..af224fe 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -15,7 +15,8 @@ use quic_rpc::{ ServerStreaming, ServerStreamingMsg, }, server::{RpcChannel, RpcServerError}, - RpcClient, RpcServer, Service, ServiceConnection, ServiceEndpoint, + transport::StreamTypes, + Connector, Listener, RpcClient, RpcServer, Service, }; use serde::{Deserialize, Serialize}; use std::{ @@ -160,7 +161,7 @@ impl ComputeService { } } - pub async fn server>( + pub async fn server>( server: RpcServer, ) -> result::Result<(), RpcServerError> { let s = server; @@ -172,14 +173,13 @@ impl ComputeService { } } - pub async fn handle_rpc_request( + pub async fn handle_rpc_request( service: ComputeService, req: ComputeRequest, - chan: RpcChannel, + chan: RpcChannel, ) -> Result<(), RpcServerError> where - S: Service, - E: ServiceEndpoint, + E: StreamTypes, { use ComputeRequest::*; #[rustfmt::skip] @@ -195,7 +195,7 @@ impl ComputeService { } /// Runs the service until `count` requests have been received. - pub async fn server_bounded>( + pub async fn server_bounded>( server: RpcServer, count: usize, ) -> result::Result, RpcServerError> { @@ -226,7 +226,7 @@ impl ComputeService { Ok(s) } - pub async fn server_par>( + pub async fn server_par>( server: RpcServer, parallelism: usize, ) -> result::Result<(), RpcServerError> { @@ -267,7 +267,7 @@ impl ComputeService { } } -pub async fn smoke_test>(client: C) -> anyhow::Result<()> { +pub async fn smoke_test>(client: C) -> anyhow::Result<()> { let client = RpcClient::::new(client); // a rpc call tracing::debug!("calling rpc S(1234)"); @@ -316,11 +316,10 @@ fn clear_line() { print!("\r{}\r", " ".repeat(80)); } -pub async fn bench(client: RpcClient, n: u64) -> anyhow::Result<()> +pub async fn bench(client: RpcClient, n: u64) -> anyhow::Result<()> where C::SendError: std::error::Error, - S: Service, - C: ServiceConnection, + C: Connector, { // individual RPCs { diff --git a/tests/quinn.rs b/tests/quinn.rs index b28c6da..be445b7 100644 --- a/tests/quinn.rs +++ b/tests/quinn.rs @@ -114,7 +114,7 @@ pub fn make_endpoints(port: u16) -> anyhow::Result { fn run_server(server: quinn::Endpoint) -> JoinHandle> { tokio::task::spawn(async move { - let connection = transport::quinn::QuinnServerEndpoint::new(server)?; + let connection = transport::quinn::QuinnListener::new(server)?; let server = RpcServer::new(connection); ComputeService::server(server).await?; anyhow::Ok(()) @@ -133,7 +133,7 @@ async fn quinn_channel_bench() -> anyhow::Result<()> { tracing::debug!("Starting server"); let server_handle = run_server(server); tracing::debug!("Starting client"); - let client = transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + let client = transport::quinn::QuinnConnector::new(client, server_addr, "localhost".into()); let client = RpcClient::new(client); tracing::debug!("Starting benchmark"); bench(client, 50000).await?; @@ -151,7 +151,7 @@ async fn quinn_channel_smoke() -> anyhow::Result<()> { } = make_endpoints(12346)?; let server_handle = run_server(server); let client_connection = - transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + transport::quinn::QuinnConnector::new(client, server_addr, "localhost".into()); smoke_test(client_connection).await?; server_handle.abort(); Ok(()) @@ -172,7 +172,7 @@ async fn server_away_and_back() -> anyhow::Result<()> { // create the RPC client let client = make_client_endpoint("0.0.0.0:0".parse()?, &[&server_cert])?; let client_connection = - transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + transport::quinn::QuinnConnector::new(client, server_addr, "localhost".into()); let client = RpcClient::new(client_connection); // send a request. No server available so it should fail @@ -180,7 +180,7 @@ async fn server_away_and_back() -> anyhow::Result<()> { // create the RPC Server let server = Endpoint::server(server_config.clone(), server_addr)?; - let connection = transport::quinn::QuinnServerEndpoint::new(server)?; + let connection = transport::quinn::QuinnListener::new(server)?; let server = RpcServer::new(connection); let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1)); @@ -195,7 +195,7 @@ async fn server_away_and_back() -> anyhow::Result<()> { // make the server run again let server = Endpoint::server(server_config, server_addr)?; - let connection = transport::quinn::QuinnServerEndpoint::new(server)?; + let connection = transport::quinn::QuinnListener::new(server)?; let server = RpcServer::new(connection); let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5)); diff --git a/tests/slow_math.rs b/tests/slow_math.rs index c453c68..858614a 100644 --- a/tests/slow_math.rs +++ b/tests/slow_math.rs @@ -15,7 +15,7 @@ use quic_rpc::{ ServerStreaming, ServerStreamingMsg, }, server::RpcServerError, - RpcServer, Service, ServiceEndpoint, + Listener, RpcServer, Service, }; #[derive(Debug, Clone)] @@ -107,7 +107,7 @@ impl ComputeService { } } - pub async fn server>( + pub async fn server>( server: RpcServer, ) -> result::Result<(), RpcServerError> { let s = server; diff --git a/tests/try.rs b/tests/try.rs index 72ae87d..b11f633 100644 --- a/tests/try.rs +++ b/tests/try.rs @@ -72,7 +72,7 @@ impl Handler { #[tokio::test] async fn try_server_streaming() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); - let (server, client) = flume::service_connection::(1); + let (server, client) = flume::channel(1); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(async move { diff --git a/tests/util.rs b/tests/util.rs index 428aa76..cd946e4 100644 --- a/tests/util.rs +++ b/tests/util.rs @@ -1,8 +1,8 @@ use anyhow::Context; -use quic_rpc::{server::RpcServerError, transport::Connection, RpcMessage}; +use quic_rpc::{server::RpcServerError, transport::Connector}; #[allow(unused)] -pub async fn check_termination_anyhow>( +pub async fn check_termination_anyhow( server_handle: tokio::task::JoinHandle>, ) -> anyhow::Result<()> { // dropping the client will cause the server to terminate