From 03d13f8aa6038e2bb101244a72a9c549af9ad4ee Mon Sep 17 00:00:00 2001 From: Gregory Hill Date: Fri, 9 Oct 2020 16:12:45 +0100 Subject: [PATCH] conditional compilation for async roundtripper Signed-off-by: Gregory Hill --- Cargo.toml | 4 + src/client.rs | 262 +++++++++++++++++++++++++++++++++----------------- src/error.rs | 2 +- src/util.rs | 8 +- 4 files changed, 181 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c54572bf..d9e204a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,10 @@ documentation = "https://docs.rs/jsonrpc/" description = "Rust support for the JSON-RPC 2.0 protocol" keywords = [ "protocol", "json", "http", "jsonrpc" ] readme = "README.md" +edition = "2018" + +[features] +async = [] [lib] name = "jsonrpc" diff --git a/src/client.rs b/src/client.rs index cb784aab..5837d724 100644 --- a/src/client.rs +++ b/src/client.rs @@ -18,18 +18,18 @@ //! and parsing responses //! -use std::{error, io}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use std::{error, io}; -use serde; use base64; use http; +use serde; use serde_json; use super::{Request, Response}; -use util::HashableValue; -use error::Error; +use crate::error::Error; +use crate::util::HashableValue; /// An interface for an HTTP roundtripper that handles HTTP requests. pub trait HttpRoundTripper { @@ -38,11 +38,27 @@ pub trait HttpRoundTripper { /// The type for errors generated by the roundtripper. type Err: error::Error; - /// Make an HTTP request. In practice only POST request will be made. + /// Make a synchronous HTTP request. In practice only POST request will be made. + #[cfg(not(feature = "async"))] fn request( &self, - http::Request<&[u8]>, + _request: http::Request<&[u8]>, ) -> Result, Self::Err>; + + /// Make an asynchronous HTTP request. In practice only POST request will be made. + #[cfg(feature = "async")] + fn request<'life>( + &'life self, + _request: http::Request<&'life [u8]>, + ) -> std::pin::Pin< + Box< + dyn std::future::Future, Self::Err>> + + Send + + 'life, + >, + > + where + Self: Sync + 'life; } /// A handle to a remote JSONRPC server @@ -54,7 +70,21 @@ pub struct Client { nonce: Arc>, } -impl Client { +#[cfg(not(feature = "async"))] +macro_rules! maybe_async_fn { + ($($tokens:tt)*) => { + $($tokens)* + }; +} + +#[cfg(feature = "async")] +macro_rules! maybe_async_fn { + ($(#[$($meta:meta)*])* $vis:vis $ident:ident $($tokens:tt)*) => { + $(#[$($meta)*])* $vis async $ident $($tokens)* + }; +} + +impl Client { /// Creates a new client pub fn new( roundtripper: Rt, @@ -74,104 +104,139 @@ impl Client { } } - /// Make a request and deserialize the response - pub fn do_rpc serde::de::Deserialize<'a>>( - &self, - rpc_name: &str, - args: &[serde_json::value::Value], - ) -> Result { - let request = self.build_request(rpc_name, args); - let response = self.send_request(&request)?; + maybe_async_fn! { + /// Make a request and deserialize the response + pub fn do_rpc serde::de::Deserialize<'a>>( + &self, + rpc_name: &str, + args: &[serde_json::value::Value], + ) -> Result { + let request = self.build_request(rpc_name, args); + + #[cfg(not(feature = "async"))] + let response = self.send_request(&request)?; - Ok(response.into_result()?) + #[cfg(feature = "async")] + let response = self.send_request(&request).await?; + + Ok(response.into_result()?) + } } - /// The actual send logic used by both [send_request] and [send_batch]. - fn send_raw(&self, body: &B) -> Result - where - B: serde::ser::Serialize, - R: for<'de> serde::de::Deserialize<'de>, - { - // Build request - let request_raw = serde_json::to_vec(body)?; - - // Send request - let mut request_builder = http::Request::post(&self.url); - request_builder.header("Content-Type", "application/json-rpc"); - - // Set Authorization header - if let Some(ref user) = self.user { - let mut auth = user.clone(); - auth.push(':'); - if let Some(ref pass) = self.pass { - auth.push_str(&pass[..]); + maybe_async_fn! { + /// The actual send logic used by both [send_request] and [send_batch]. + fn send_raw(&self, body: &B) -> Result + where + B: serde::ser::Serialize, + R: for<'de> serde::de::Deserialize<'de>, + { + // Build request + let request_raw = serde_json::to_vec(body)?; + + // Send request + let mut request_builder = http::Request::post(&self.url); + request_builder.header("Content-Type", "application/json-rpc"); + + // Set Authorization header + if let Some(ref user) = self.user { + let mut auth = user.clone(); + auth.push(':'); + if let Some(ref pass) = self.pass { + auth.push_str(&pass[..]); + } + let value = format!("Basic {}", &base64::encode(auth.as_bytes())); + request_builder.header("Authorization", value); } - let value = format!("Basic {}", &base64::encode(auth.as_bytes())); - request_builder.header("Authorization", value); - } - // Errors only on invalid header or builder reuse. - let http_request = request_builder.body(&request_raw[..]).unwrap(); + // Errors only on invalid header or builder reuse. + let http_request = request_builder.body(&request_raw[..]).unwrap(); - let http_response = - self.roundtripper.request(http_request).map_err(|e| Error::Http(Box::new(e)))?; + #[cfg(not(feature = "async"))] + let http_response = self + .roundtripper + .request(http_request) + .map_err(|e| Error::Http(Box::new(e)))?; - // nb we ignore stream.status since we expect the body - // to contain information about any error - Ok(serde_json::from_reader(http_response.into_body())?) - } + #[cfg(feature = "async")] + let http_response = self + .roundtripper + .request(http_request).await + .map_err(|e| Error::Http(Box::new(e)))?; - /// Sends a request to a client - pub fn send_request(&self, request: &Request) -> Result { - let response: Response = self.send_raw(&request)?; - if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) { - return Err(Error::VersionMismatch); - } - if response.id != request.id { - return Err(Error::NonceMismatch); + + // nb we ignore stream.status since we expect the body + // to contain information about any error + Ok(serde_json::from_reader(http_response.into_body())?) } - Ok(response) } - /// Sends a batch of requests to the client. The return vector holds the response - /// for the request at the corresponding index. If no response was provided, it's [None]. - /// - /// Note that the requests need to have valid IDs, so it is advised to create the requests - /// with [build_request]. - pub fn send_batch(&self, requests: &[Request]) -> Result>, Error> { - if requests.len() < 1 { - return Err(Error::EmptyBatch); - } + maybe_async_fn! { + /// Sends a request to a client + pub fn send_request<'a, 'b>(&self, request: &Request<'a, 'b>) -> Result { + #[cfg(not(feature = "async"))] + let response: Response = self.send_raw(&request)?; - // If the request body is invalid JSON, the response is a single response object. - // We ignore this case since we are confident we are producing valid JSON. - let responses: Vec = self.send_raw(&requests)?; - if responses.len() > requests.len() { - return Err(Error::WrongBatchResponseSize); - } + #[cfg(feature = "async")] + let response: Response = self.send_raw(&request).await?; - // To prevent having to clone responses, we first copy all the IDs so we can reference - // them easily. IDs can only be of JSON type String or Number (or Null), so cloning - // should be inexpensive and require no allocations as Numbers are more common. - let ids: Vec = responses.iter().map(|r| r.id.clone()).collect(); - // First index responses by ID and catch duplicate IDs. - let mut resp_by_id = HashMap::new(); - for (id, resp) in ids.iter().zip(responses.into_iter()) { - if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) { - return Err(Error::BatchDuplicateResponseId(dup.id)); + if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) { + return Err(Error::VersionMismatch); } + if response.id != request.id { + return Err(Error::NonceMismatch); + } + Ok(response) } - // Match responses to the requests. - let results = - requests.into_iter().map(|r| resp_by_id.remove(&HashableValue(&r.id))).collect(); - - // Since we're also just producing the first duplicate ID, we can also just produce the - // first incorrect ID in case there are multiple. - if let Some(incorrect) = resp_by_id.into_iter().nth(0) { - return Err(Error::WrongBatchResponseId(incorrect.1.id)); - } + } + + maybe_async_fn! { + /// Sends a batch of requests to the client. The return vector holds the response + /// for the request at the corresponding index. If no response was provided, it's [None]. + /// + /// Note that the requests need to have valid IDs, so it is advised to create the requests + /// with [build_request]. + pub fn send_batch<'a, 'b>(&self, requests: &[Request<'a, 'b>]) -> Result>, Error> { + if requests.len() < 1 { + return Err(Error::EmptyBatch); + } + + // If the request body is invalid JSON, the response is a single response object. + // We ignore this case since we are confident we are producing valid JSON. + #[cfg(not(feature = "async"))] + let responses: Vec = self.send_raw(&requests)?; - Ok(results) + #[cfg(feature = "async")] + let responses: Vec = self.send_raw(&requests).await?; + + if responses.len() > requests.len() { + return Err(Error::WrongBatchResponseSize); + } + + // To prevent having to clone responses, we first copy all the IDs so we can reference + // them easily. IDs can only be of JSON type String or Number (or Null), so cloning + // should be inexpensive and require no allocations as Numbers are more common. + let ids: Vec = responses.iter().map(|r| r.id.clone()).collect(); + // First index responses by ID and catch duplicate IDs. + let mut resp_by_id = HashMap::new(); + for (id, resp) in ids.iter().zip(responses.into_iter()) { + if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) { + return Err(Error::BatchDuplicateResponseId(dup.id)); + } + } + // Match responses to the requests. + let results = requests + .into_iter() + .map(|r| resp_by_id.remove(&HashableValue(&r.id))) + .collect(); + + // Since we're also just producing the first duplicate ID, we can also just produce the + // first incorrect ID in case there are multiple. + if let Some(incorrect) = resp_by_id.into_iter().nth(0) { + return Err(Error::WrongBatchResponseId(incorrect.1.id)); + } + + Ok(results) + } } /// Builds a request @@ -206,12 +271,31 @@ mod tests { type ResponseBody = io::Empty; type Err = io::Error; + #[cfg(not(feature = "async"))] fn request( &self, _: http::Request<&[u8]>, ) -> Result, Self::Err> { Err(io::ErrorKind::Other.into()) } + + #[cfg(feature = "async")] + fn request<'life>( + &'life self, + request: http::Request<&[u8]>, + ) -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result, Self::Err>, + > + Send + + 'life, + >, + > + where + Self: Sync + 'life, + { + Box::pin(async { Err(io::ErrorKind::Other.into()) }) + } } #[test] diff --git a/src/error.rs b/src/error.rs index 5ff19119..8448c6c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,7 +21,7 @@ use std::{error, fmt}; use serde_json; -use Response; +use crate::Response; /// A library error #[derive(Debug)] diff --git a/src/util.rs b/src/util.rs index 482a069e..9f2aceb8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -44,18 +44,18 @@ impl<'a> Hash for HashableValue<'a> { } else { n.to_string().hash(state); } - }, + } Value::String(ref s) => { "string".hash(state); s.hash(state); - }, + } Value::Array(ref v) => { "array".hash(state); v.len().hash(state); for obj in v { HashableValue(obj).hash(state); } - }, + } Value::Object(ref m) => { "object".hash(state); m.len().hash(state); @@ -116,5 +116,3 @@ mod tests { assert!(coll.contains(&m)); } } - -