From a58c1b4ff581d2673741c07e7282147722c3c4e6 Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Fri, 29 Mar 2024 16:29:11 -0400 Subject: [PATCH 1/2] Allow APIs to bind different versions of the binary serialization format --- Cargo.lock | 37 +-- Cargo.toml | 2 +- examples/hello-world/main.rs | 5 +- examples/versions/main.rs | 5 +- src/api.rs | 429 ++++++++++++++++++----------- src/app.rs | 515 +++++++++++++++++++---------------- src/lib.rs | 4 +- src/metrics.rs | 5 +- src/middleware.rs | 175 ++++++++++++ src/request.rs | 2 +- src/route.rs | 115 ++++---- src/socket.rs | 2 +- src/status.rs | 2 +- 13 files changed, 801 insertions(+), 497 deletions(-) create mode 100644 src/middleware.rs diff --git a/Cargo.lock b/Cargo.lock index ccdd048d..7192d2b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1059,17 +1059,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0" -[[package]] -name = "displaydoc" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.53", -] - [[package]] name = "dlv-list" version = "0.5.2" @@ -3501,7 +3490,7 @@ dependencies = [ "tracing-log", "tracing-subscriber", "url", - "versioned-binary-serialization", + "vbs", ] [[package]] @@ -4048,6 +4037,18 @@ dependencies = [ "sval_serde", ] +[[package]] +name = "vbs" +version = "0.1.4" +source = "git+https://github.com/EspressoSystems/versioned-binary-serialization.git?tag=v0.1.4#50c93e6dd650484688077542c467a134a033893c" +dependencies = [ + "anyhow", + "bincode", + "derive_more", + "serde", + "serde_with", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -4060,18 +4061,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "versioned-binary-serialization" -version = "0.1.2" -source = "git+https://github.com/EspressoSystems/versioned-binary-serialization.git?tag=0.1.2#6874f91a3c8d64acc24fe0abe4ad93c35b75eb9d" -dependencies = [ - "anyhow", - "bincode", - "displaydoc", - "serde", - "serde_with", -] - [[package]] name = "waker-fn" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index f7a09ccb..e149c584 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ tracing-futures = "0.2" tracing-log = "0.2" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } url = "2.5.0" -versioned-binary-serialization = { git = "https://github.com/EspressoSystems/versioned-binary-serialization.git", tag = "0.1.2" } +vbs = { git = "https://github.com/EspressoSystems/versioned-binary-serialization.git", tag = "v0.1.4" } # Dependencies enabled by feature `testing` async-compatibility-layer = { git = "https://github.com/EspressoSystems/async-compatibility-layer.git", tag = "1.4.2", features = ["logging-utils"], optional = true } diff --git a/examples/hello-world/main.rs b/examples/hello-world/main.rs index 7772809f..501d66da 100644 --- a/examples/hello-world/main.rs +++ b/examples/hello-world/main.rs @@ -11,10 +11,9 @@ use snafu::Snafu; use std::io; use tide_disco::{Api, App, Error, RequestError, StatusCode}; use tracing::info; -use versioned_binary_serialization::version::StaticVersion; +use vbs::version::StaticVersion; type StaticVer01 = StaticVersion<0, 1>; -const STATIC_VER: StaticVer01 = StaticVersion {}; #[derive(Clone, Debug, Deserialize, Serialize, Snafu)] enum HelloError { @@ -79,7 +78,7 @@ async fn serve(port: u16) -> io::Result<()> { .unwrap(); app.register_module("hello", api).unwrap(); - app.serve(format!("0.0.0.0:{}", port), STATIC_VER).await + app.serve(format!("0.0.0.0:{}", port)).await } #[async_std::main] diff --git a/examples/versions/main.rs b/examples/versions/main.rs index 97b780a0..888e784e 100644 --- a/examples/versions/main.rs +++ b/examples/versions/main.rs @@ -7,10 +7,9 @@ use futures::FutureExt; use std::io; use tide_disco::{error::ServerError, Api, App}; -use versioned_binary_serialization::version::StaticVersion; +use vbs::version::StaticVersion; type StaticVer01 = StaticVersion<0, 1>; -const STATIC_VER: StaticVer01 = StaticVersion {}; async fn serve(port: u16) -> io::Result<()> { let mut app = App::<_, ServerError, StaticVer01>::with_state(()); @@ -32,7 +31,7 @@ async fn serve(port: u16) -> io::Result<()> { .unwrap() .register_module("api", v2) .unwrap(); - app.serve(format!("0.0.0.0:{}", port), STATIC_VER).await + app.serve(format!("0.0.0.0:{}", port)).await } #[async_std::main] diff --git a/src/api.rs b/src/api.rs index f947016a..cea71d08 100644 --- a/src/api.rs +++ b/src/api.rs @@ -8,6 +8,7 @@ use crate::{ healthcheck::{HealthCheck, HealthStatus}, method::{Method, ReadState, WriteState}, metrics::Metrics, + middleware::{error_handler, ErrorHandler}, request::RequestParams, route::{self, *}, socket, Html, @@ -15,21 +16,26 @@ use crate::{ use async_std::sync::Arc; use async_trait::async_trait; use derivative::Derivative; -use derive_more::From; -use futures::{future::BoxFuture, stream::BoxStream}; +use futures::{ + future::{BoxFuture, FutureExt}, + stream::BoxStream, +}; use maud::{html, PreEscaped}; use semver::Version; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use snafu::{OptionExt, ResultExt, Snafu}; -use std::borrow::Cow; -use std::collections::hash_map::{Entry, HashMap, IntoValues, Values}; -use std::fmt::Display; -use std::fs; -use std::ops::Index; -use std::path::{Path, PathBuf}; +use std::{ + borrow::Cow, + collections::hash_map::{Entry, HashMap, IntoValues, Values}, + fmt::Display, + fs, + marker::PhantomData, + ops::Index, + path::{Path, PathBuf}, +}; use tide::http::content::Accept; -use versioned_binary_serialization::version::StaticVersionType; +use vbs::version::StaticVersionType; /// An error encountered when parsing or constructing an [Api]. #[derive(Clone, Debug, Snafu, PartialEq, Eq)] @@ -253,64 +259,160 @@ mod meta_defaults { /// TOML file and registered as a module of an [App](crate::App). #[derive(Derivative)] #[derivative(Debug(bound = ""))] -pub struct Api { +pub struct Api { + inner: ApiInner, + _version: PhantomData, +} + +/// A version-erased description of an API. +/// +/// This type contains all the details of the API, with the version of the binary serialization +/// format type-erased and encapsulated into the route handlers. This type is used internally by +/// [`App`], to allow dynamic registration of different versions of APIs with different versions of +/// the binary format. +/// +/// It is exposed publicly and manipulated _only_ via [`Api`], which wraps this type with a static +/// format version type parameter, which provides compile-time enforcement of format version +/// consistency as the API is being constructed, until it is registered with an [`App`] and +/// type-erased. +#[derive(Derivative)] +#[derivative(Debug(bound = ""))] +pub(crate) struct ApiInner { meta: Arc, name: String, - routes: HashMap>, + routes: HashMap>, routes_by_path: HashMap>, #[derivative(Debug = "ignore")] - health_check: Option>, + health_check: HealthCheckHandler, api_version: Option, + /// Error handler encapsulating the serialization format version for errors. + /// + /// This field is optional so it can be bound late, potentially after a `map_err` changes the + /// error type. However, it will always be set after `Api::into_inner` is called. + #[derivative(Debug = "ignore")] + error_handler: Option>>, public: Option, short_description: String, long_description: String, } -impl<'a, State, Error, VER: StaticVersionType> IntoIterator for &'a Api { - type Item = &'a Route; - type IntoIter = Values<'a, String, Route>; +impl<'a, State, Error> IntoIterator for &'a ApiInner { + type Item = &'a Route; + type IntoIter = Values<'a, String, Route>; fn into_iter(self) -> Self::IntoIter { self.routes.values() } } -impl IntoIterator for Api { - type Item = Route; - type IntoIter = IntoValues>; +impl IntoIterator for ApiInner { + type Item = Route; + type IntoIter = IntoValues>; fn into_iter(self) -> Self::IntoIter { self.routes.into_values() } } -impl Index<&str> for Api { - type Output = Route; +impl Index<&str> for ApiInner { + type Output = Route; - fn index(&self, index: &str) -> &Route { + fn index(&self, index: &str) -> &Route { &self.routes[index] } } -/// Iterator for [routes_by_path](Api::routes_by_path). +/// Iterator for [routes_by_path](ApiInner::routes_by_path). /// /// This type iterates over all of the routes that have a given path. -/// [routes_by_path](Api::routes_by_path), in turn, returns an iterator over paths whose items +/// [routes_by_path](ApiInner::routes_by_path), in turn, returns an iterator over paths whose items /// contain a [RoutesWithPath] iterator. -pub struct RoutesWithPath<'a, State, Error, VER: StaticVersionType> { +pub(crate) struct RoutesWithPath<'a, State, Error> { routes: std::slice::Iter<'a, String>, - api: &'a Api, + api: &'a ApiInner, } -impl<'a, State, Error, VER: StaticVersionType> Iterator for RoutesWithPath<'a, State, Error, VER> { - type Item = &'a Route; +impl<'a, State, Error> Iterator for RoutesWithPath<'a, State, Error> { + type Item = &'a Route; fn next(&mut self) -> Option { Some(&self.api.routes[self.routes.next()?]) } } -impl Api { +impl ApiInner { + /// Iterate over groups of routes with the same path. + pub(crate) fn routes_by_path( + &self, + ) -> impl Iterator)> { + self.routes_by_path.iter().map(|(path, routes)| { + ( + path.as_str(), + RoutesWithPath { + routes: routes.iter(), + api: self, + }, + ) + }) + } + + /// Check the health status of a server with the given state. + pub(crate) async fn health(&self, req: RequestParams, state: &State) -> tide::Response { + (self.health_check)(req, state).await + } + + /// Get the version of this API. + pub(crate) fn version(&self) -> ApiVersion { + ApiVersion { + api_version: self.api_version.clone(), + spec_version: self.meta.format_version.clone(), + } + } + + pub(crate) fn public(&self) -> Option<&PathBuf> { + self.public.as_ref() + } + + pub(crate) fn set_name(&mut self, name: String) { + self.name = name; + } + + /// Compose an HTML page documenting all the routes in this API. + pub(crate) fn documentation(&self) -> Html { + html! { + (PreEscaped(self.meta.html_top + .replace("{{NAME}}", &self.name) + .replace("{{SHORT_DESCRIPTION}}", &self.short_description) + .replace("{{LONG_DESCRIPTION}}", &self.long_description) + .replace("{{VERSION}}", &match &self.api_version { + Some(version) => version.to_string(), + None => "(no version)".to_string(), + }) + .replace("{{FORMAT_VERSION}}", &self.meta.format_version.to_string()) + .replace("{{PUBLIC}}", &format!("/public/{}", self.name)))) + @for route in self.routes.values() { + (route.documentation()) + } + (PreEscaped(&self.meta.html_bottom)) + } + } + + /// The short description of this API from the specification. + pub(crate) fn short_description(&self) -> &str { + &self.short_description + } + + pub(crate) fn error_handler(&self) -> Arc> { + self.error_handler.clone().unwrap() + } +} + +impl Api +where + State: 'static, + Error: 'static, + VER: StaticVersionType + 'static, +{ /// Parse an API from a TOML specification. pub fn new(api: impl Into) -> Result { let mut api = api.into(); @@ -386,15 +488,19 @@ impl Api { }; Ok(Self { - name: meta.name.clone(), - meta, - routes, - routes_by_path, - health_check: None, - api_version: None, - public: None, - short_description, - long_description, + inner: ApiInner { + name: meta.name.clone(), + meta, + routes, + routes_by_path, + health_check: Box::new(Self::default_health_check), + api_version: None, + error_handler: None, + public: None, + short_description, + long_description, + }, + _version: Default::default(), }) } @@ -413,21 +519,6 @@ impl Api { })?) } - /// Iterate over groups of routes with the same path. - pub fn routes_by_path( - &self, - ) -> impl Iterator)> { - self.routes_by_path.iter().map(|(path, routes)| { - ( - path.as_str(), - RoutesWithPath { - routes: routes.iter(), - api: self, - }, - ) - }) - } - /// Set the API version. /// /// The version information will automatically be included in responses to `GET /version`. This @@ -440,13 +531,13 @@ impl Api { /// and may be different from the version of the Rust crate implementing the route handlers for /// the API. pub fn with_version(&mut self, version: Version) -> &mut Self { - self.api_version = Some(version); + self.inner.api_version = Some(version); self } /// Serve the contents of `dir` at the URL `/public/{{NAME}}`. pub fn with_public(&mut self, dir: PathBuf) -> &mut Self { - self.public = Some(dir); + self.inner.public = Some(dir); self } @@ -471,7 +562,7 @@ impl Api { /// ``` /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// type State = u64; /// type StaticVer01 = StaticVersion<0, 1>; @@ -497,7 +588,7 @@ impl Api { /// use async_std::sync::Mutex; /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// type State = Mutex; /// type StaticVer01 = StaticVersion<0, 1>; @@ -549,7 +640,11 @@ impl Api { State: 'static + Send + Sync, VER: 'static + Send + Sync, { - let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; + let route = self + .inner + .routes + .get_mut(name) + .ok_or(ApiError::UndefinedRoute)?; if route.has_handler() { return Err(ApiError::HandlerAlreadyRegistered); } @@ -565,7 +660,7 @@ impl Api { // `set_fn_handler` only fails if the route is not an HTTP route; since we have already // checked that it is, this cannot fail. route - .set_fn_handler(handler) + .set_fn_handler(handler, VER::instance()) .unwrap_or_else(|_| panic!("unexpected failure in set_fn_handler")); Ok(self) @@ -587,7 +682,11 @@ impl Api { VER: 'static + Send + Sync + StaticVersionType, { assert!(method.is_http() && !method.is_mutable()); - let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; + let route = self + .inner + .routes + .get_mut(name) + .ok_or(ApiError::UndefinedRoute)?; if route.method() != method { return Err(ApiError::IncorrectMethod { expected: method, @@ -600,7 +699,7 @@ impl Api { // `set_handler` only fails if the route is not an HTTP route; since we have already checked // that it is, this cannot fail. route - .set_handler(ReadHandler::from(handler)) + .set_handler(ReadHandler::<_, VER>::from(handler)) .unwrap_or_else(|_| panic!("unexpected failure in set_handler")); Ok(self) } @@ -632,7 +731,7 @@ impl Api { /// use async_std::sync::RwLock; /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::{Serializer, version::StaticVersion}; + /// # use vbs::{Serializer, version::StaticVersion}; /// /// type State = RwLock; /// type StaticVer01 = StaticVersion<0, 1>; @@ -686,7 +785,11 @@ impl Api { VER: 'static + Send + Sync, { assert!(method.is_http() && method.is_mutable()); - let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; + let route = self + .inner + .routes + .get_mut(name) + .ok_or(ApiError::UndefinedRoute)?; if route.method() != method { return Err(ApiError::IncorrectMethod { expected: method, @@ -700,7 +803,7 @@ impl Api { // `set_handler` only fails if the route is not an HTTP route; since we have already checked // that it is, this cannot fail. route - .set_handler(WriteHandler::from(handler)) + .set_handler(WriteHandler::<_, VER>::from(handler)) .unwrap_or_else(|_| panic!("unexpected failure in set_handler")); Ok(self) } @@ -733,7 +836,7 @@ impl Api { /// use async_std::sync::RwLock; /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// type State = RwLock; /// type StaticVer01 = StaticVersion<0, 1>; @@ -803,7 +906,7 @@ impl Api { /// use async_std::sync::RwLock; /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// type State = RwLock; /// type StaticVer01 = StaticVersion<0, 1>; @@ -872,7 +975,7 @@ impl Api { /// use async_std::sync::RwLock; /// use futures::FutureExt; /// # use tide_disco::Api; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// type State = RwLock>; /// type StaticVer01 = StaticVersion<0, 1>; @@ -944,7 +1047,7 @@ impl Api { /// ``` /// use futures::{FutureExt, SinkExt, StreamExt}; /// use tide_disco::{error::ServerError, socket::Connection, Api}; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) { /// api.socket("sum", |_req, mut conn: Connection>, _state| async move { @@ -1021,7 +1124,11 @@ impl Api { name: &str, handler: socket::Handler, ) -> Result<&mut Self, ApiError> { - let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; + let route = self + .inner + .routes + .get_mut(name) + .ok_or(ApiError::UndefinedRoute)?; if route.method() != Method::Socket { return Err(ApiError::IncorrectMethod { expected: Method::Socket, @@ -1074,7 +1181,7 @@ impl Api { /// # use futures::FutureExt; /// # use tide_disco::{api::{Api, ApiError}, error::ServerError}; /// # use std::borrow::Cow; - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// use prometheus::{Counter, Registry}; /// /// struct State { @@ -1120,7 +1227,11 @@ impl Api { Error: 'static, VER: 'static + Send + Sync, { - let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; + let route = self + .inner + .routes + .get_mut(name) + .ok_or(ApiError::UndefinedRoute)?; if route.method() != Method::Metrics { return Err(ApiError::IncorrectMethod { expected: Method::Metrics, @@ -1151,44 +1262,12 @@ impl Api { H: 'static + HealthCheck, VER: 'static + Send + Sync, { - self.health_check = Some(route::health_check_handler::<_, _, VER>(handler)); + self.inner.health_check = route::health_check_handler::<_, _, VER>(handler); self } - /// Check the health status of a server with the given state. - pub async fn health(&self, req: RequestParams, state: &State) -> tide::Response { - if let Some(handler) = &self.health_check { - handler(req, state).await - } else { - // If there is no healthcheck handler registered, just return [HealthStatus::Available] - // by default; after all, if this handler is getting hit at all, the service must be up. - route::health_check_response::<_, VER>( - &req.accept().unwrap_or_else(|_| { - // The healthcheck endpoint is not allowed to fail, so just use the default content - // type if we can't parse the Accept header. - let mut accept = Accept::new(); - accept.set_wildcard(true); - accept - }), - HealthStatus::Available, - ) - } - } - - /// Get the version of this API. - pub fn version(&self) -> ApiVersion { - ApiVersion { - api_version: self.api_version.clone(), - spec_version: self.meta.format_version.clone(), - } - } - - pub(crate) fn public(&self) -> Option<&PathBuf> { - self.public.as_ref() - } - /// Create a new [Api] which is just like this one, except has a transformed `Error` type. - pub fn map_err( + pub(crate) fn map_err( self, f: impl 'static + Clone + Send + Sync + Fn(Error) -> Error2, ) -> Api @@ -1196,52 +1275,55 @@ impl Api { Error: 'static + Send + Sync, Error2: 'static, State: 'static + Send + Sync, - VER: 'static + Send + Sync, { Api { - meta: self.meta, - name: self.name, - routes: self - .routes - .into_iter() - .map(|(name, route)| (name, route.map_err(f.clone()))) - .collect(), - routes_by_path: self.routes_by_path, - health_check: self.health_check, - api_version: self.api_version, - public: self.public, - short_description: self.short_description, - long_description: self.long_description, + inner: ApiInner { + meta: self.inner.meta, + name: self.inner.name, + routes: self + .inner + .routes + .into_iter() + .map(|(name, route)| (name, route.map_err(f.clone()))) + .collect(), + routes_by_path: self.inner.routes_by_path, + health_check: self.inner.health_check, + api_version: self.inner.api_version, + error_handler: None, + public: self.inner.public, + short_description: self.inner.short_description, + long_description: self.inner.long_description, + }, + _version: Default::default(), } } - pub(crate) fn set_name(&mut self, name: String) { - self.name = name; + pub(crate) fn into_inner(mut self) -> ApiInner + where + Error: crate::Error, + { + // This `into_inner` finalizes the error type for the API. At this point, ensure + // `error_handler` is set. + self.inner.error_handler = Some(error_handler::()); + self.inner } - /// Compose an HTML page documenting all the routes in this API. - pub fn documentation(&self) -> Html { - html! { - (PreEscaped(self.meta.html_top - .replace("{{NAME}}", &self.name) - .replace("{{SHORT_DESCRIPTION}}", &self.short_description) - .replace("{{LONG_DESCRIPTION}}", &self.long_description) - .replace("{{VERSION}}", &match &self.api_version { - Some(version) => version.to_string(), - None => "(no version)".to_string(), - }) - .replace("{{FORMAT_VERSION}}", &self.meta.format_version.to_string()) - .replace("{{PUBLIC}}", &format!("/public/{}", self.name)))) - @for route in self.routes.values() { - (route.documentation()) - } - (PreEscaped(&self.meta.html_bottom)) + fn default_health_check(req: RequestParams, _state: &State) -> BoxFuture { + async move { + // If there is no healthcheck handler registered, just return [HealthStatus::Available] + // by default; after all, if this handler is getting hit at all, the service must be up. + route::health_check_response::<_, VER>( + &req.accept().unwrap_or_else(|_| { + // The healthcheck endpoint is not allowed to fail, so just use the default + // content type if we can't parse the Accept header. + let mut accept = Accept::new(); + accept.set_wildcard(true); + accept + }), + HealthStatus::Available, + ) } - } - - /// The short description of this API from the specification. - pub fn short_description(&self) -> &str { - &self.short_description + .boxed() } } @@ -1253,13 +1335,22 @@ impl Api { // by reference, and probably partly due to my lack of creativity. In any case, writing out the // closure object and [Handler] implementation by hand seems to convince Rust that this code is // memory safe. -#[derive(From)] -struct ReadHandler { +struct ReadHandler { handler: F, + _version: PhantomData, +} + +impl From for ReadHandler { + fn from(f: F) -> Self { + Self { + handler: f, + _version: Default::default(), + } + } } #[async_trait] -impl Handler for ReadHandler +impl Handler for ReadHandler where F: 'static + Send @@ -1273,25 +1364,33 @@ where &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { let accept = req.accept()?; response_from_result( &accept, state.read(|state| (self.handler)(req, state)).await, - bind_version, + VER::instance(), ) } } // A manual closure that serves a similar purpose as [ReadHandler]. -#[derive(From)] -struct WriteHandler { +struct WriteHandler { handler: F, + _version: PhantomData, +} + +impl From for WriteHandler { + fn from(f: F) -> Self { + Self { + handler: f, + _version: Default::default(), + } + } } #[async_trait] -impl Handler for WriteHandler +impl Handler for WriteHandler where F: 'static + Send @@ -1305,13 +1404,12 @@ where &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { let accept = req.accept()?; response_from_result( &accept, state.write(|state| (self.handler)(req, state)).await, - bind_version, + VER::instance(), ) } } @@ -1338,7 +1436,7 @@ mod test { use prometheus::{Counter, Registry}; use std::borrow::Cow; use toml::toml; - use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; + use vbs::{version::StaticVersion, BinarySerializer, Serializer}; #[cfg(windows)] use async_tungstenite::tungstenite::Error as WsError; @@ -1347,7 +1445,6 @@ mod test { type StaticVer01 = StaticVersion<0, 1>; type SerializerV01 = Serializer>; - const VER_0_1: StaticVer01 = StaticVersion {}; async fn check_stream_closed(mut conn: WebSocketStream) where @@ -1392,7 +1489,9 @@ mod test { METHOD = "SOCKET" }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.socket( "echo", |_req, mut conn: Connection, _state| { @@ -1433,7 +1532,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); // Create a client that accepts JSON messages. let mut conn = test_ws_client_with_headers( @@ -1535,7 +1634,9 @@ mod test { METHOD = "SOCKET" }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.stream("nat", |_req, _state| iter(0..).map(Ok).boxed()) .unwrap() .stream("once", |_req, _state| once(async { Ok(0) }).boxed()) @@ -1553,7 +1654,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); // Consume the `nat` stream. let mut conn = test_ws_client(url.join("mod/nat").unwrap()).await; @@ -1601,12 +1702,14 @@ mod test { PATH = ["/dummy"] }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.with_health_check(|state| async move { *state }.boxed()); } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url).await; let res = client.get("/mod/healthcheck").send().await.unwrap(); @@ -1645,7 +1748,9 @@ mod test { METHOD = "METRICS" }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.metrics("metrics", |_req, state| { async move { state.counter.inc(); @@ -1657,7 +1762,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{port}").parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{port}"))); let client = Client::new(url).await; for i in 1..5 { diff --git a/src/app.rs b/src/app.rs index cc1e9f3a..cc32da37 100644 --- a/src/app.rs +++ b/src/app.rs @@ -5,12 +5,13 @@ // along with the tide-disco library. If not, see . use crate::{ - api::{Api, ApiError, ApiVersion}, + api::{Api, ApiError, ApiInner, ApiVersion}, healthcheck::{HealthCheck, HealthStatus}, http, method::Method, - request::{best_response_type, RequestParam, RequestParams}, - route::{self, health_check_response, respond_with, Handler, Route, RouteError}, + middleware::{request_params, AddErrorBody, MetricsMiddleware}, + request::RequestParams, + route::{health_check_response, respond_with, Handler, Route, RouteError}, socket::SocketError, Html, StatusCode, }; @@ -33,15 +34,16 @@ use std::{ env, fmt::Display, fs, io, + marker::PhantomData, ops::{Deref, DerefMut}, path::PathBuf, }; use tide::{ - http::{headers::HeaderValue, mime}, + http::headers::HeaderValue, security::{CorsMiddleware, Origin}, }; use tide_websockets::WebSocket; -use versioned_binary_serialization::version::StaticVersionType; +use vbs::version::StaticVersionType; pub use tide::listener::{Listener, ToListener}; @@ -51,12 +53,17 @@ pub use tide::listener::{Listener, ToListener}; /// constructing an [Api] for each module and calling [App::register_module]. Once all of the /// desired modules are registered, the app can be converted into an asynchronous server task using /// [App::serve]. +/// +/// Note that the [`App`] is bound to a binary serialization version `VER`. This format only applies +/// to application-level endpoints like `/version` and `/healthcheck`. The binary format version in +/// use by any given API module may differ, depending on the supported version of the API. #[derive(Debug)] -pub struct App { +pub struct App { // Map from base URL, major version to API. - apis: HashMap>>, - state: Arc, + pub(crate) apis: HashMap>>, + pub(crate) state: Arc, app_version: Option, + _version: PhantomData, } /// An error encountered while building an [App]. @@ -66,18 +73,14 @@ pub enum AppError { ModuleAlreadyExists, } -impl< - State: Send + Sync + 'static, - Error: 'static, - VER: Send + Sync + 'static + StaticVersionType, - > App -{ +impl App { /// Create a new [App] with a given state. pub fn with_state(state: State) -> Self { Self { apis: HashMap::new(), state: Arc::new(state), app_version: None, + _version: Default::default(), } } @@ -88,14 +91,16 @@ impl< /// handlers. When [`Module::register`] is called on the guard (or the guard is dropped), the /// module will be registered in this [`App`] as if by calling /// [`register_module`](Self::register_module). - pub fn module<'a, ModuleError>( + pub fn module<'a, ModuleError, ModuleVersion>( &'a mut self, base_url: &'a str, api: impl Into, - ) -> Result, AppError> + ) -> Result, AppError> where - Error: From, - ModuleError: 'static + Send + Sync, + Error: crate::Error + From, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { Ok(Module { app: self, @@ -136,16 +141,17 @@ impl< /// Note that non-breaking changes (e.g. new endpoints) can be deployed in place of an existing /// API without even incrementing the major version. The need for serving two versions of an API /// simultaneously only arises when you have breaking changes. - pub fn register_module( + pub fn register_module( &mut self, base_url: &str, - api: Api, + api: Api, ) -> Result<&mut Self, AppError> where - Error: From, - ModuleError: 'static + Send + Sync, + Error: crate::Error + From, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { - let mut api = api.map_err(Error::from); + let mut api = api.map_err(Error::from).into_inner(); api.set_name(base_url.to_string()); let major_version = match api.version().api_version { @@ -193,7 +199,7 @@ impl< /// is contained in the application crate, it should result in a reasonable version: /// /// ``` - /// # use versioned_binary_serialization::version::StaticVersion; + /// # use vbs::version::StaticVersion; /// # type StaticVer01 = StaticVersion<0, 1>; /// # fn ex(app: &mut tide_disco::App<(), (), StaticVer01>) { /// app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap()); @@ -287,22 +293,18 @@ lazy_static! { }; } -impl< - State: Send + Sync + 'static, - Error: 'static + crate::Error, - VER: 'static + Send + Sync + StaticVersionType, - > App +impl App +where + State: Send + Sync + 'static, + Error: 'static + crate::Error, + VER: StaticVersionType + Send + Sync + 'static, { /// Serve the [App] asynchronously. - pub async fn serve>>( - self, - listener: L, - bind_version: VER, - ) -> io::Result<()> { + pub async fn serve>>(self, listener: L) -> io::Result<()> { let state = Arc::new(self); let mut server = tide::Server::with_state(state.clone()); server.with(Self::version_middleware); - server.with(add_error_body::<_, Error, VER>); + server.with(AddErrorBody::::with_version::()); server.with( CorsMiddleware::new() .allow_methods("GET, POST".parse::().unwrap()) @@ -312,7 +314,7 @@ impl< ); for (name, versions) in &state.apis { - Self::register_api(&mut server, name.clone(), versions, bind_version)?; + Self::register_api(&mut server, name.clone(), versions)?; } // Register app-level automatic routes: `healthcheck` and `version`. @@ -330,7 +332,7 @@ impl< .at("version") .get(move |req: tide::Request>| async move { let accept = RequestParams::accept_from_headers(&req)?; - respond_with(&accept, req.state().version(), bind_version) + respond_with(&accept, req.state().version(), VER::instance()) .map_err(|err| Error::from_route_error::(err).into_tide_error()) }); @@ -372,11 +374,10 @@ impl< fn register_api( server: &mut tide::Server>, prefix: String, - versions: &BTreeMap>, - bind_version: VER, + versions: &BTreeMap>, ) -> io::Result<()> { for (version, api) in versions { - Self::register_api_version(server, &prefix, *version, api, bind_version)?; + Self::register_api_version(server, &prefix, *version, api)?; } Ok(()) } @@ -385,8 +386,7 @@ impl< server: &mut tide::Server>, prefix: &String, version: u64, - api: &Api, - bind_version: VER, + api: &ApiInner, ) -> io::Result<()> { // Clippy complains if the only non-trivial operation in an `unwrap_or_else` closure is // a deref, but for `lazy_static` types, deref is an effectful operation that (in this @@ -401,6 +401,7 @@ impl< // Register routes for this API. let mut api_endpoint = server.at(&format!("/v{version}/{prefix}")); + api_endpoint.with(AddErrorBody::new(api.error_handler())); for (path, routes) in api.routes_by_path() { let mut endpoint = api_endpoint.at(path); let routes = routes.collect::>(); @@ -423,26 +424,13 @@ impl< // all endpoints registered under this pattern, so that a request to this path // with the right headers will return metrics instead of going through the // normal method-based dispatching. - Self::register_metrics( - prefix.to_owned(), - version, - &mut endpoint, - metrics_route, - bind_version, - ); + Self::register_metrics(prefix.to_owned(), version, &mut endpoint, metrics_route); } // Register the HTTP routes. for route in routes { if let Method::Http(method) = route.method() { - Self::register_route( - prefix.to_owned(), - version, - &mut endpoint, - route, - method, - bind_version, - ); + Self::register_route(prefix.to_owned(), version, &mut endpoint, route, method); } } } @@ -505,7 +493,7 @@ impl< async move { let api = &req.state().apis[&prefix][&version]; let accept = RequestParams::accept_from_headers(&req)?; - respond_with(&accept, api.version(), bind_version).map_err(|err| { + respond_with(&accept, api.version(), VER::instance()).map_err(|err| { Error::from_route_error::(err).into_tide_error() }) } @@ -519,9 +507,8 @@ impl< api: String, version: u64, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, method: http::Method, - bind_version: VER, ) { let name = route.name(); endpoint.method(method, move |req: tide::Request>| { @@ -532,7 +519,7 @@ impl< let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route - .handle(req, state, bind_version) + .handle(req, state) .await .map_err(|err| match err { RouteError::AppSpecific(err) => err, @@ -547,20 +534,14 @@ impl< api: String, version: u64, endpoint: &mut tide::Route>, - route: &Route, - bind_version: VER, + route: &Route, ) { let name = route.name(); if route.has_handler() { // If there is a metrics handler, add middleware to the endpoint to intercept the // request and respond with metrics, rather than the usual HTTP dispatching, if the // appropriate headers are set. - endpoint.with(MetricsMiddleware::new( - name.clone(), - api.clone(), - version, - bind_version, - )); + endpoint.with(MetricsMiddleware::new(name.clone(), api.clone(), version)); } // Register a catch-all HTTP handler for the route, which serves the route documentation as @@ -579,7 +560,7 @@ impl< api: String, version: u64, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, ) { let name = route.name(); if route.has_handler() { @@ -627,7 +608,7 @@ impl< api: String, version: u64, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, ) { let name = route.name(); endpoint.all(move |req: tide::Request>| { @@ -748,86 +729,6 @@ impl< } } -struct MetricsMiddleware { - route: String, - api: String, - api_version: u64, - ver: VER, -} - -impl MetricsMiddleware { - fn new(route: String, api: String, api_version: u64, ver: VER) -> Self { - Self { - route, - api, - api_version, - ver, - } - } -} - -impl tide::Middleware>> for MetricsMiddleware -where - State: Send + Sync + 'static, - Error: 'static + crate::Error, - VER: Send + Sync + 'static + StaticVersionType, -{ - fn handle<'a, 'b, 't>( - &'a self, - req: tide::Request>>, - next: tide::Next<'b, Arc>>, - ) -> BoxFuture<'t, tide::Result> - where - 'a: 't, - 'b: 't, - Self: 't, - { - let route = self.route.clone(); - let api = self.api.clone(); - let version = self.api_version; - let bind_version = self.ver; - async move { - if req.method() != http::Method::Get { - // Metrics only apply to GET requests. For other requests, proceed with normal - // dispatching. - return Ok(next.run(req).await); - } - // Look at the `Accept` header. If the requested content type is plaintext, we consider - // it a metrics request. Other endpoints have typed responses yielding either JSON or - // binary. - let accept = RequestParams::accept_from_headers(&req)?; - let reponse_ty = - best_response_type(&accept, &[mime::PLAIN, mime::JSON, mime::BYTE_STREAM])?; - if reponse_ty != mime::PLAIN { - return Ok(next.run(req).await); - } - // This is a metrics request, abort the rest of the dispatching chain and run the - // metrics handler. - let route = &req.state().clone().apis[&api][&version][&route]; - let state = &*req.state().clone().state; - let req = request_params(req, route.params()).await?; - route - .handle(req, state, bind_version) - .await - .map_err(|err| match err { - RouteError::AppSpecific(err) => err, - _ => Error::from_route_error(err), - }) - .map_err(|err| err.into_tide_error()) - } - .boxed() - } -} - -async fn request_params( - req: tide::Request>>, - params: &[RequestParam], -) -> Result { - RequestParams::new(req, params) - .await - .map_err(|err| Error::from_request_error(err).into_tide_error()) -} - /// The health status of an application. #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] pub struct AppHealth { @@ -864,43 +765,6 @@ pub struct AppVersion { pub disco_version: Version, } -/// Server middleware which automatically populates the body of error responses. -/// -/// If the response contains an error, the error is encoded into the [Error](crate::Error) type -/// (either by downcasting if the server has generated an instance of [Error](crate::Error), or by -/// converting to a [String] using [Display] if the error can not be downcasted to -/// [Error](crate::Error)). The resulting [Error](crate::Error) is then serialized and used as the -/// body of the response. -/// -/// If the response does not contain an error, it is passed through unchanged. -fn add_error_body< - T: Clone + Send + Sync + 'static, - E: crate::Error, - VER: Send + Sync + 'static + StaticVersionType, ->( - req: tide::Request, - next: tide::Next, -) -> BoxFuture { - Box::pin(async { - let accept = RequestParams::accept_from_headers(&req)?; - let mut res = next.run(req).await; - if let Some(error) = res.take_error() { - let error = E::from_server_error(error); - tracing::info!("responding with error: {}", error); - // Try to add the error to the response body using a format accepted by the client. If - // we cannot do that (for example, if the client requested a format that is incompatible - // with a serialized error) just add the error as a string using plaintext. - let (body, content_type) = route::response_body::<_, E, VER>(&accept, &error) - .unwrap_or_else(|_| (error.to_string().into(), mime::PLAIN)); - res.set_body(body); - res.set_content_type(content_type); - Ok(res) - } else { - Ok(res) - } - }) -} - /// RAII guard to ensure a module is registered after it is configured. /// /// This type allows the owner to configure an [`Api`] module via the [`Deref`] and [`DerefMut`] @@ -912,66 +776,72 @@ fn add_error_body< /// Note that if anything goes wrong during module registration (for example, there is already an /// incompatible module registered with the same name), the drop implementation may panic. To handle /// errors without panicking, call [`register`](Self::register) explicitly. -pub struct Module<'a, State, Error, ModuleError, VER: StaticVersionType> +pub struct Module<'a, State, Error, VER, ModuleError, ModuleVersion> where - State: 'static + Send + Sync, - Error: 'static + From, - ModuleError: 'static + Send + Sync, - VER: 'static + Send + Sync, + State: Send + Sync + 'static, + Error: crate::Error + From + 'static, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { app: &'a mut App, base_url: &'a str, // This is only an [Option] so we can [take] out of it during [drop]. - api: Option>, + api: Option>, } -impl<'a, State, Error, ModuleError, VER: StaticVersionType> Deref - for Module<'a, State, Error, ModuleError, VER> +impl<'a, State, Error, VER, ModuleError, ModuleVersion> Deref + for Module<'a, State, Error, VER, ModuleError, ModuleVersion> where - State: 'static + Send + Sync, - Error: 'static + From, - ModuleError: 'static + Send + Sync, - VER: 'static + Send + Sync, + State: Send + Sync + 'static, + Error: crate::Error + From + 'static, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { - type Target = Api; + type Target = Api; fn deref(&self) -> &Self::Target { self.api.as_ref().unwrap() } } -impl<'a, State, Error, ModuleError, VER: StaticVersionType> DerefMut - for Module<'a, State, Error, ModuleError, VER> +impl<'a, State, Error, VER, ModuleError, ModuleVersion> DerefMut + for Module<'a, State, Error, VER, ModuleError, ModuleVersion> where - State: 'static + Send + Sync, - Error: 'static + From, - ModuleError: 'static + Send + Sync, - VER: 'static + Send + Sync, + State: Send + Sync + 'static, + Error: crate::Error + From + 'static, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { fn deref_mut(&mut self) -> &mut Self::Target { self.api.as_mut().unwrap() } } -impl<'a, State, Error, ModuleError, VER: StaticVersionType> Drop - for Module<'a, State, Error, ModuleError, VER> +impl<'a, State, Error, VER, ModuleError, ModuleVersion> Drop + for Module<'a, State, Error, VER, ModuleError, ModuleVersion> where - State: 'static + Send + Sync, - Error: 'static + From, - ModuleError: 'static + Send + Sync, - VER: 'static + Send + Sync, + State: Send + Sync + 'static, + Error: crate::Error + From + 'static, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { fn drop(&mut self) { self.register_impl().unwrap(); } } -impl<'a, State, Error, ModuleError, VER> Module<'a, State, Error, ModuleError, VER> +impl<'a, State, Error, VER, ModuleError, ModuleVersion> + Module<'a, State, Error, VER, ModuleError, ModuleVersion> where - State: 'static + Send + Sync, - Error: 'static + From, - ModuleError: 'static + Send + Sync, - VER: 'static + Send + Sync + StaticVersionType, + State: Send + Sync + 'static, + Error: crate::Error + From + 'static, + VER: StaticVersionType + 'static, + ModuleError: Send + Sync + 'static, + ModuleVersion: StaticVersionType + 'static, { /// Register this module with the linked app. pub fn register(mut self) -> Result<(), AppError> { @@ -995,7 +865,7 @@ where mod test { use super::*; use crate::{ - error::ServerError, + error::{Error, ServerError}, metrics::Metrics, socket::Connection, testing::{setup_test, test_ws_client, Client}, @@ -1005,13 +875,19 @@ mod test { use async_tungstenite::tungstenite::Message; use futures::{FutureExt, SinkExt, StreamExt}; use portpicker::pick_unused_port; - use std::borrow::Cow; + use serde::de::DeserializeOwned; + use std::{borrow::Cow, fmt::Debug}; use toml::toml; - use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; + use vbs::{version::StaticVersion, BinarySerializer, Serializer}; type StaticVer01 = StaticVersion<0, 1>; - type SerializerV01 = Serializer>; - const VER_0_1: StaticVer01 = StaticVersion {}; + type SerializerV01 = Serializer; + + type StaticVer02 = StaticVersion<0, 2>; + type SerializerV02 = Serializer; + + type StaticVer03 = StaticVersion<0, 3>; + type SerializerV03 = Serializer; #[derive(Clone, Copy, Debug)] struct FakeMetrics; @@ -1061,7 +937,9 @@ mod test { METHOD = "METRICS" }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.get("get_test", |_req, _state| { async move { Ok(Get.to_string()) }.boxed() }) @@ -1096,7 +974,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url.clone()).await; // Regular HTTP methods. @@ -1153,7 +1031,9 @@ mod test { ":b" = "Boolean" }; { - let mut api = app.module::("mod", api_toml).unwrap(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); api.get("test", |req, _state| { async move { if let Some(a) = req.opt_integer_param::<_, i32>("a")? { @@ -1168,7 +1048,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url.clone()).await; let res = client.get("mod/test/a/42").send().await.unwrap(); @@ -1217,7 +1097,9 @@ mod test { }; { - let mut v1 = app.module::("mod", v1_toml.clone()).unwrap(); + let mut v1 = app + .module::("mod", v1_toml.clone()) + .unwrap(); v1.with_version("1.0.0".parse().unwrap()) .get("deleted", |_req, _state| { async move { Ok("deleted v1") }.boxed() @@ -1234,12 +1116,16 @@ mod test { } { // Registering the same major version twice is an error. - let mut api = app.module::("mod", v1_toml).unwrap(); + let mut api = app + .module::("mod", v1_toml) + .unwrap(); api.with_version("1.1.1".parse().unwrap()); assert_eq!(api.register().unwrap_err(), AppError::ModuleAlreadyExists); } { - let mut v3 = app.module::("mod", v3_toml.clone()).unwrap(); + let mut v3 = app + .module::("mod", v3_toml.clone()) + .unwrap(); v3.with_version("3.0.0".parse().unwrap()) .get("added", |_req, _state| { async move { Ok("added v3") }.boxed() @@ -1253,7 +1139,7 @@ mod test { let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url.clone()).await; // First check that we can call all the expected methods. @@ -1464,7 +1350,7 @@ mod test { // Test discoverability documentation when a request is for an unknown API. let mut app = App::<_, ServerError, StaticVer01>::with_state(()); - app.module::( + app.module::( "the-correct-module", toml! { route = {} @@ -1475,7 +1361,7 @@ mod test { let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url.clone()).await; let expected_list_item = html! { @@ -1534,7 +1420,9 @@ mod test { PATH = ["/test"] }; { - let mut api = app.module::("mod", api_toml.clone()).unwrap(); + let mut api = app + .module::("mod", api_toml.clone()) + .unwrap(); api.post("test", |_req, state| { async move { *state += 1; @@ -1547,7 +1435,7 @@ mod test { let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); + spawn(app.serve(format!("0.0.0.0:{}", port))); let client = Client::new(url.clone()).await; for i in 1..3 { @@ -1566,4 +1454,161 @@ mod test { ); } } + + #[async_std::test] + async fn test_format_versions() { + // Register two modules with different binary format versions, each in turn different from + // the app-level version. Each module has two endpoints, one which always succeeds and one + // which always fails, so we can test error serialization. + let mut app = App::<_, ServerError, StaticVer01>::with_state(()); + let api_toml = toml! { + [meta] + FORMAT_VERSION = "0.1.0" + + [route.ok] + METHOD = "GET" + PATH = ["/ok"] + + [route.err] + METHOD = "GET" + PATH = ["/err"] + }; + + fn init_api(api: &mut Api<(), ServerError, VER>) { + api.get("ok", |_req, _state| async move { Ok("ok") }.boxed()) + .unwrap() + .get("err", |_req, _state| { + async move { + Err::(ServerError::catch_all( + StatusCode::InternalServerError, + "err".into(), + )) + } + .boxed() + }) + .unwrap(); + } + + { + let mut api = app + .module::("mod02", api_toml.clone()) + .unwrap(); + init_api(&mut api); + } + { + let mut api = app + .module::("mod03", api_toml.clone()) + .unwrap(); + init_api(&mut api); + } + + let port = pick_unused_port().unwrap(); + let url: Url = format!("http://localhost:{}", port).parse().unwrap(); + spawn(app.serve(format!("0.0.0.0:{}", port))); + let client = Client::new(url.clone()).await; + + async fn get( + client: &Client, + endpoint: &str, + expected_status: StatusCode, + ) -> anyhow::Result { + tracing::info!("GET {endpoint} ->"); + let res = client + .get(endpoint) + .header("Accept", "application/octet-stream") + .send() + .await + .unwrap(); + tracing::info!(?res, "<-"); + assert_eq!(res.status(), expected_status); + let bytes = res.bytes().await.unwrap(); + S::deserialize(&bytes) + } + + #[tracing::instrument(skip(client))] + async fn check_ok( + client: &Client, + endpoint: &str, + expected: impl Debug + DeserializeOwned + Eq, + ) { + tracing::info!("checking successful deserialization"); + assert_eq!( + expected, + get::(client, endpoint, StatusCode::Ok).await.unwrap() + ); + } + + check_ok::( + &client, + "healthcheck", + AppHealth { + status: HealthStatus::Available, + modules: [ + ("mod02".into(), [(0, StatusCode::Ok)].into()), + ("mod03".into(), [(0, StatusCode::Ok)].into()), + ] + .into(), + }, + ) + .await; + check_ok::( + &client, + "version", + AppVersion { + app_version: None, + disco_version: env!("CARGO_PKG_VERSION").parse().unwrap(), + modules: [ + ( + "mod02".into(), + vec![ApiVersion { + spec_version: "0.1.0".parse().unwrap(), + api_version: None, + }], + ), + ( + "mod03".into(), + vec![ApiVersion { + spec_version: "0.1.0".parse().unwrap(), + api_version: None, + }], + ), + ] + .into(), + }, + ) + .await; + check_ok::(&client, "mod02/ok", "ok".to_string()).await; + check_ok::(&client, "mod03/ok", "ok".to_string()).await; + + #[tracing::instrument(skip(client))] + async fn check_wrong_version( + client: &Client, + endpoint: &str, + ) { + tracing::info!("checking deserialization fails with wrong version"); + get::(client, endpoint, StatusCode::Ok) + .await + .unwrap_err(); + } + + check_wrong_version::(&client, "healthcheck").await; + check_wrong_version::(&client, "version").await; + check_wrong_version::(&client, "mod02/ok").await; + check_wrong_version::(&client, "mod03/ok").await; + + #[tracing::instrument(skip(client))] + async fn check_err(client: &Client, endpoint: &str) { + tracing::info!("checking error deserialization"); + tracing::info!("checking successful deserialization"); + assert_eq!( + get::(client, endpoint, StatusCode::InternalServerError) + .await + .unwrap(), + ServerError::catch_all(StatusCode::InternalServerError, "err".into()) + ); + } + + check_err::(&client, "mod02/err").await; + check_err::(&client, "mod03/err").await; + } } diff --git a/src/lib.rs b/src/lib.rs index 108de827..a83b6df5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -286,11 +286,13 @@ pub mod healthcheck; pub mod method; pub mod metrics; pub mod request; -pub mod route; pub mod socket; pub mod status; pub mod testing; +mod middleware; +mod route; + pub use api::Api; pub use app::App; pub use error::Error; diff --git a/src/metrics.rs b/src/metrics.rs index a130aed4..2e15b3cd 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -16,7 +16,6 @@ use derive_more::From; use futures::future::{BoxFuture, FutureExt}; use prometheus::{Encoder, TextEncoder}; use std::{borrow::Cow, error::Error, fmt::Debug}; -use versioned_binary_serialization::version::StaticVersionType; pub trait Metrics { type Error: Debug + Error; @@ -55,19 +54,17 @@ impl Metrics for prometheus::Registry { pub(crate) struct Handler(F); #[async_trait] -impl route::Handler for Handler +impl route::Handler for Handler where F: 'static + Send + Sync + Fn(RequestParams, &State::State) -> BoxFuture, Error>>, T: 'static + Clone + Metrics, State: 'static + Send + Sync + ReadState, Error: 'static, - VER: 'static + Send + Sync, { async fn handle( &self, req: RequestParams, state: &State, - _: VER, ) -> Result> { let exported = state .read(|state| { diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 00000000..948be979 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,175 @@ +use crate::{ + http::{self, content::Accept}, + mime, + request::best_response_type, + route::{self, Handler, RouteError}, + App, RequestParam, RequestParams, +}; +use async_std::sync::Arc; +use futures::future::{BoxFuture, FutureExt}; +use vbs::version::StaticVersionType; + +/// A function to add error information to a response body. +/// +/// This trait is object safe, so it can be used to dynamically dispatch to different strategies for +/// serializing the error depending on the format version being used. +pub(crate) trait ErrorHandler: + Fn(&Accept, &Error, &mut tide::Response) + Send + Sync +{ +} +impl ErrorHandler for F where + F: Fn(&Accept, &Error, &mut tide::Response) + Send + Sync +{ +} + +/// Type-erase a format-specific error handler. +pub(crate) fn error_handler() -> Arc> +where + Error: crate::Error, + VER: StaticVersionType, +{ + Arc::new(|accept, error, res| { + // Try to add the error to the response body using a format accepted by the client. If we + // cannot do that (for example, if the client requested a format that is incompatible with a + // serialized error) just add the error as a string using plaintext. + let (body, content_type) = route::response_body::<_, Error, VER>(accept, &error) + .unwrap_or_else(|_| (error.to_string().into(), mime::PLAIN)); + res.set_body(body); + res.set_content_type(content_type); + }) +} + +/// Server middleware which automatically populates the body of error responses. +/// +/// If the response contains an error, the error is encoded into the [Error](crate::Error) type +/// (either by downcasting if the server has generated an instance of [Error](crate::Error), or by +/// converting to a [String] using [Display] if the error can not be downcasted to +/// [Error](crate::Error)). The resulting [Error](crate::Error) is then serialized and used as the +/// body of the response. +/// +/// If the response does not contain an error, it is passed through unchanged. +pub(crate) struct AddErrorBody { + handler: Arc>, +} + +impl AddErrorBody { + pub(crate) fn new(handler: Arc>) -> Self { + Self { handler } + } + + pub(crate) fn with_version() -> Self + where + E: crate::Error, + VER: StaticVersionType, + { + Self::new(error_handler::()) + } +} + +impl tide::Middleware for AddErrorBody +where + T: Clone + Send + Sync + 'static, + E: crate::Error, +{ + fn handle<'a, 'b, 't>( + &'a self, + req: tide::Request, + next: tide::Next<'b, T>, + ) -> BoxFuture<'t, tide::Result> + where + 'a: 't, + 'b: 't, + Self: 't, + { + async { + let accept = RequestParams::accept_from_headers(&req)?; + let mut res = next.run(req).await; + if let Some(error) = res.take_error() { + let error = E::from_server_error(error); + tracing::info!("responding with error: {}", error); + (self.handler)(&accept, &error, &mut res); + Ok(res) + } else { + Ok(res) + } + } + .boxed() + } +} + +pub(crate) struct MetricsMiddleware { + route: String, + api: String, + api_version: u64, +} + +impl MetricsMiddleware { + pub(crate) fn new(route: String, api: String, api_version: u64) -> Self { + Self { + route, + api, + api_version, + } + } +} + +impl tide::Middleware>> for MetricsMiddleware +where + State: Send + Sync + 'static, + Error: crate::Error + Send + Sync + 'static, + VER: StaticVersionType + Send + Sync + 'static, +{ + fn handle<'a, 'b, 't>( + &'a self, + req: tide::Request>>, + next: tide::Next<'b, Arc>>, + ) -> BoxFuture<'t, tide::Result> + where + 'a: 't, + 'b: 't, + Self: 't, + { + let route = self.route.clone(); + let api = self.api.clone(); + let version = self.api_version; + async move { + if req.method() != http::Method::Get { + // Metrics only apply to GET requests. For other requests, proceed with normal + // dispatching. + return Ok(next.run(req).await); + } + // Look at the `Accept` header. If the requested content type is plaintext, we consider + // it a metrics request. Other endpoints have typed responses yielding either JSON or + // binary. + let accept = RequestParams::accept_from_headers(&req)?; + let reponse_ty = + best_response_type(&accept, &[mime::PLAIN, mime::JSON, mime::BYTE_STREAM])?; + if reponse_ty != mime::PLAIN { + return Ok(next.run(req).await); + } + // This is a metrics request, abort the rest of the dispatching chain and run the + // metrics handler. + let route = &req.state().clone().apis[&api][&version][&route]; + let state = &*req.state().clone().state; + let req = request_params(req, route.params()).await?; + route + .handle(req, state) + .await + .map_err(|err| match err { + RouteError::AppSpecific(err) => err, + _ => Error::from_route_error(err), + }) + .map_err(|err| err.into_tide_error()) + } + .boxed() + } +} + +pub(crate) async fn request_params( + req: tide::Request>>, + params: &[RequestParam], +) -> Result { + RequestParams::new(req, params) + .await + .map_err(|err| Error::from_request_error(err).into_tide_error()) +} diff --git a/src/request.rs b/src/request.rs index 65c2a9b6..10965c55 100644 --- a/src/request.rs +++ b/src/request.rs @@ -13,7 +13,7 @@ use std::fmt::Display; use strum_macros::EnumString; use tagged_base64::TaggedBase64; use tide::http::{self, content::Accept, mime::Mime, Headers}; -use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; +use vbs::{version::StaticVersionType, BinarySerializer, Serializer}; #[derive(Clone, Debug, Snafu, Deserialize, Serialize)] pub enum RequestError { diff --git a/src/route.rs b/src/route.rs index d63712e0..06f8a262 100644 --- a/src/route.rs +++ b/src/route.rs @@ -16,17 +16,18 @@ use crate::{ use async_std::sync::Arc; use async_trait::async_trait; use derivative::Derivative; -use derive_more::From; use futures::future::{BoxFuture, FutureExt}; use maud::{html, PreEscaped}; use serde::Serialize; use snafu::{OptionExt, Snafu}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::convert::Infallible; -use std::fmt::{self, Display, Formatter}; -use std::marker::PhantomData; -use std::str::FromStr; +use std::{ + borrow::Cow, + collections::HashMap, + convert::Infallible, + fmt::{self, Display, Formatter}, + marker::PhantomData, + str::FromStr, +}; use tide::{ http::{ self, @@ -36,7 +37,7 @@ use tide::{ Body, }; use tide_websockets::WebSocketConnection; -use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; +use vbs::{version::StaticVersionType, BinarySerializer, Serializer}; /// An error returned by a route handler. /// @@ -114,14 +115,11 @@ impl From for RouteError { /// return type of a handler function. The types which are preserved, `State` and `Error`, should be /// the same for all handlers in an API module. #[async_trait] -pub(crate) trait Handler: - 'static + Send + Sync -{ +pub(crate) trait Handler: 'static + Send + Sync { async fn handle( &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result>; } @@ -139,11 +137,16 @@ pub(crate) trait Handler: /// /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the handler /// function is required to return a [BoxFuture]. -#[derive(From)] -pub(crate) struct FnHandler(F); +pub(crate) struct FnHandler(F, PhantomData); + +impl From for FnHandler { + fn from(f: F) -> Self { + Self(f, Default::default()) + } +} #[async_trait] -impl Handler for FnHandler +impl Handler for FnHandler where F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxFuture<'_, Result>, T: Serialize, @@ -154,10 +157,9 @@ where &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { let accept = req.accept()?; - response_from_result(&accept, (self.0)(req, state).await, bind_version) + response_from_result(&accept, (self.0)(req, state).await, VER::instance()) } } @@ -171,43 +173,36 @@ pub(crate) fn response_from_result( } #[async_trait] -impl< - H: ?Sized + Handler, - State: 'static + Send + Sync, - Error, - VER: 'static + Send + Sync + StaticVersionType, - > Handler for Box +impl, State: 'static + Send + Sync, Error> Handler + for Box { async fn handle( &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { - (**self).handle(req, state, bind_version).await + (**self).handle(req, state).await } } -enum RouteImplementation { +enum RouteImplementation { Http { method: http::Method, - handler: Option>>, + handler: Option>>, }, Socket { handler: Option>, }, Metrics { - handler: Option>>, + handler: Option>>, }, } -impl - RouteImplementation -{ +impl RouteImplementation { fn map_err( self, f: impl 'static + Send + Sync + Fn(Error) -> Error2, - ) -> RouteImplementation + ) -> RouteImplementation where State: 'static + Send + Sync, Error: 'static + Send + Sync, @@ -217,10 +212,10 @@ impl Self::Http { method, handler } => RouteImplementation::Http { method, handler: handler.map(|h| { - let h: Box> = - Box::new( - MapErr::>, _, Error>::new(h, f), - ); + let h: Box> = + Box::new(MapErr::>, _, Error>::new( + h, f, + )); h }), }, @@ -229,10 +224,10 @@ impl }, Self::Metrics { handler } => RouteImplementation::Metrics { handler: handler.map(|h| { - let h: Box> = - Box::new( - MapErr::>, _, Error>::new(h, f), - ); + let h: Box> = + Box::new(MapErr::>, _, Error>::new( + h, f, + )); h }), }, @@ -248,14 +243,14 @@ impl /// simply returns information about the route. #[derive(Derivative)] #[derivative(Debug(bound = ""))] -pub struct Route { +pub(crate) struct Route { name: String, patterns: Vec, params: Vec, doc: String, meta: Arc, #[derivative(Debug = "ignore")] - handler: RouteImplementation, + handler: RouteImplementation, } #[derive(Clone, Copy, Debug, Snafu, PartialEq, Eq)] @@ -273,7 +268,7 @@ pub enum RouteParseError { RouteMustBeTable, } -impl Route { +impl Route { /// Parse a [Route] from a TOML specification. /// /// The specification must be a table containing at least the following keys: @@ -396,12 +391,11 @@ impl Route { pub fn map_err( self, f: impl 'static + Send + Sync + Fn(Error) -> Error2, - ) -> Route + ) -> Route where State: 'static + Send + Sync, Error: 'static + Send + Sync, Error2: 'static, - VER: 'static + Send + Sync, { Route { handler: self.handler.map_err(f), @@ -440,10 +434,10 @@ impl Route { } } -impl Route { +impl Route { pub(crate) fn set_handler( &mut self, - h: impl Handler, + h: impl Handler, ) -> Result<(), RouteError> { match &mut self.handler { RouteImplementation::Http { handler, .. } => { @@ -456,14 +450,18 @@ impl Route { } } - pub(crate) fn set_fn_handler(&mut self, handler: F) -> Result<(), RouteError> + pub(crate) fn set_fn_handler( + &mut self, + handler: F, + _: VER, + ) -> Result<(), RouteError> where F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync, - VER: 'static + Send + Sync, + VER: StaticVersionType + 'static, { - self.set_handler(FnHandler::from(handler)) + self.set_handler(FnHandler::::from(handler)) } pub(crate) fn set_socket_handler( @@ -490,7 +488,6 @@ impl Route { T: 'static + Clone + metrics::Metrics, State: 'static + Send + Sync + ReadState, Error: 'static, - VER: 'static + Send + Sync, { match &mut self.handler { RouteImplementation::Metrics { handler, .. } => { @@ -535,22 +532,20 @@ impl Route { } #[async_trait] -impl Handler for Route +impl Handler for Route where Error: 'static, State: 'static + Send + Sync, - VER: 'static + Send + Sync, { async fn handle( &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { match &self.handler { RouteImplementation::Http { handler, .. } | RouteImplementation::Metrics { handler, .. } => match handler { - Some(handler) => handler.handle(req, state, bind_version).await, + Some(handler) => handler.handle(req, state).await, None => self.default_handler(), }, RouteImplementation::Socket { .. } => Err(RouteError::IncorrectMethod { @@ -560,7 +555,7 @@ where } } -pub struct MapErr { +pub(crate) struct MapErr { handler: H, map: F, _phantom: PhantomData, @@ -577,23 +572,21 @@ impl MapErr { } #[async_trait] -impl Handler for MapErr +impl Handler for MapErr where - H: Handler, + H: Handler, F: 'static + Send + Sync + Fn(Error1) -> Error2, State: 'static + Send + Sync, Error1: 'static + Send + Sync, Error2: 'static, - VER: 'static + Send + Sync + StaticVersionType, { async fn handle( &self, req: RequestParams, state: &State, - bind_version: VER, ) -> Result> { self.handler - .handle(req, state, bind_version) + .handle(req, state) .await .map_err(|err| err.map_app_specific(&self.map)) } diff --git a/src/socket.rs b/src/socket.rs index e36540a4..2a75e174 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -29,7 +29,7 @@ use tide_websockets::{ tungstenite::protocol::frame::{coding::CloseCode, CloseFrame}, Message, WebSocketConnection, }; -use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; +use vbs::{version::StaticVersionType, BinarySerializer, Serializer}; /// An error returned by a socket handler. /// diff --git a/src/status.rs b/src/status.rs index 78387763..d9fbb73d 100644 --- a/src/status.rs +++ b/src/status.rs @@ -556,7 +556,7 @@ impl StatusCode { #[cfg(test)] mod test { use super::*; - use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; + use vbs::{version::StaticVersion, BinarySerializer, Serializer}; type SerializerV01 = Serializer>; #[test] From 395e64c3d97a98ab3cd382f1cb96ac223d1deede Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Fri, 29 Mar 2024 17:00:13 -0400 Subject: [PATCH 2/2] Fix doc tests --- src/lib.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a83b6df5..ecd590d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,7 @@ //! # fn main() -> Result<(), tide_disco::api::ApiError> { //! use tide_disco::Api; //! use tide_disco::error::ServerError; -//! use versioned_binary_serialization::version::StaticVersion; +//! use vbs::version::StaticVersion; //! //! type State = (); //! type Error = ServerError; @@ -74,7 +74,7 @@ //! //! ```no_run //! # use tide_disco::Api; -//! # use versioned_binary_serialization::version::StaticVersion; +//! # use vbs::version::StaticVersion; //! # type StaticVer01 = StaticVersion<0, 1>; //! # fn main() -> Result<(), tide_disco::api::ApiError> { //! # let spec: toml::Value = toml::from_str(std::str::from_utf8(&std::fs::read("/path/to/api.toml").unwrap()).unwrap()).unwrap(); @@ -92,7 +92,7 @@ //! an [App]: //! //! ```no_run -//! # use versioned_binary_serialization::version::StaticVersion; +//! # use vbs::version::StaticVersion; //! # type State = (); //! # type Error = tide_disco::error::ServerError; //! # type StaticVer01 = StaticVersion<0, 1>; @@ -100,14 +100,13 @@ //! # let spec: toml::Value = toml::from_str(std::str::from_utf8(&std::fs::read("/path/to/api.toml").unwrap()).unwrap()).unwrap(); //! # let api = tide_disco::Api::::new(spec).unwrap(); //! use tide_disco::App; -//! use versioned_binary_serialization::version::StaticVersion; +//! use vbs::version::StaticVersion; //! //! type StaticVer01 = StaticVersion<0, 1>; -//! const VER_0_1: StaticVer01 = StaticVersion {}; //! //! let mut app = App::::with_state(()); //! app.register_module("api", api); -//! app.serve("http://localhost:8080", VER_0_1).await; +//! app.serve("http://localhost:8080").await; //! # } //! ``` //! @@ -209,7 +208,7 @@ //! use async_std::sync::RwLock; //! use futures::FutureExt; //! use tide_disco::Api; -//! use versioned_binary_serialization::version::StaticVersion; +//! use vbs::version::StaticVersion; //! //! type State = RwLock; //! type Error = (); @@ -231,7 +230,7 @@ //! use async_std::sync::RwLock; //! use futures::FutureExt; //! use tide_disco::{Api, RequestParams}; -//! use versioned_binary_serialization::version::StaticVersion; +//! use vbs::version::StaticVersion; //! //! type State = RwLock; //! type Error = ();