diff --git a/Cargo.toml b/Cargo.toml index 4281e98..ed038d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mauth-client" -version = "0.4.0" +version = "0.5.0" authors = ["Mason Gup "] edition = "2021" documentation = "https://docs.rs/mauth-client/" @@ -14,6 +14,9 @@ categories = ["authentication", "web-programming"] [dependencies] reqwest = { version = "0.12", features = ["json"] } +reqwest-middleware = "0.4" +reqwest-tracing = { version = "0.5.5", optional = true } +async-trait = ">= 0.1.83" url = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -25,7 +28,7 @@ tokio = { version = "1", features = ["fs"] } tower = { version = "0.4", optional = true } axum = { version = ">= 0.7.2", optional = true } futures-core = { version = "0.3", optional = true } -http = { version = "1", optional = true } +http = "1" bytes = { version = "1", optional = true } thiserror = "1" mauth-core = "0.5" @@ -34,4 +37,6 @@ mauth-core = "0.5" tokio = { version = "1", features = ["rt-multi-thread", "macros"] } [features] -axum-service = ["tower", "futures-core", "axum", "http", "bytes"] +axum-service = ["tower", "futures-core", "axum", "bytes"] +tracing-otel-26 = ["reqwest-tracing/opentelemetry_0_26"] +tracing-otel-27 = ["reqwest-tracing/opentelemetry_0_27"] diff --git a/README.md b/README.md index f3530ba..a800925 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # mauth-client -## mauth-client - This crate allows users of the Reqwest crate for making HTTP requests to sign those requests with the MAuth protocol, and verify the responses. Usage example: @@ -9,9 +7,10 @@ the MAuth protocol, and verify the responses. Usage example: release any code to Production or deploy in a Client-accessible environment without getting approval for the full stack used through the Architecture and Security groups. -```rust +```no_run use mauth_client::MAuthInfo; use reqwest::Client; +# async fn send_request() { let mauth_info = MAuthInfo::from_default_file().unwrap(); let client = Client::new(); let mut req = client.get("https://www.example.com/").build().unwrap(); @@ -20,9 +19,9 @@ match client.execute(req).await { Err(err) => println!("Got error {}", err), Ok(response) => println!("Got validated response with body {}", response.text().await.unwrap()), } +# } ``` - The above code will read your mauth configuration from a file in `~/.mauth_config.yml` which format is: ```yaml common: &common @@ -32,8 +31,32 @@ common: &common private_key_file: ``` +The `MAuthInfo` struct also functions as a outgoing middleware using the +[`reqwest-middleware`](https://crates.io/crates/reqwest-middleware) crate for a simpler API and easier +integration with other outgoing middleware: + +```no_run +use mauth_client::MAuthInfo; +use reqwest::Client; +use reqwest_middleware::ClientBuilder; +# async fn send_request() { +let mauth_info = MAuthInfo::from_default_file().unwrap(); +let client = ClientBuilder::new(Client::new()).with(mauth_info).build(); +match client.get("https://www.example.com/").send().await { + Err(err) => println!("Got error {}", err), + Ok(response) => println!("Got validated response with body {}", response.text().await.unwrap()), +} +# } +``` + The optional `axum-service` feature provides for a Tower Layer and Service that will authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a validated app_uuid from the request via the ValidatedRequestDetails struct. -License: MIT +There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with +the `axum-service` feature to ensure that any outgoing requests for credentials that take +place in the context of an incoming web request also include the proper OpenTelemetry span +information in any requests to MAudit services. Note that it is critical to use the same +version of OpenTelemetry crates as the rest of the project - if you do not, there will be 2 +or more instances of the OpenTelemetry global information, and requests may not be traced +through properly. diff --git a/src/axum_service.rs b/src/axum_service.rs index f5fa971..04cf6ae 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -2,13 +2,9 @@ use axum::extract::Request; use futures_core::future::BoxFuture; -use mauth_core::verifier::Verifier; -use std::collections::HashMap; use std::error::Error; -use std::sync::{Arc, RwLock}; use std::task::{Context, Poll}; use tower::{Layer, Service}; -use uuid::Uuid; use crate::{ config::{ConfigFileSection, ConfigReadError}, @@ -56,11 +52,7 @@ impl Clone for MAuthValidationService { fn clone(&self) -> Self { MAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer - mauth_info: MAuthInfo::from_config_section( - &self.config_info, - Some(self.mauth_info.remote_key_store.clone()), - ) - .unwrap(), + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), service: self.service.clone(), } @@ -72,7 +64,6 @@ impl Clone for MAuthValidationService { #[derive(Clone)] pub struct MAuthValidationLayer { config_info: ConfigFileSection, - remote_key_store: Arc>>, } impl Layer for MAuthValidationLayer { @@ -81,11 +72,7 @@ impl Layer for MAuthValidationLayer { fn layer(&self, service: S) -> Self::Service { MAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer - mauth_info: MAuthInfo::from_config_section( - &self.config_info, - Some(self.remote_key_store.clone()), - ) - .unwrap(), + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), service, } @@ -97,24 +84,16 @@ impl MAuthValidationLayer { /// found in the default location. pub fn from_default_file() -> Result { let config_info = MAuthInfo::config_section_from_default_file()?; - let remote_key_store = Arc::new(RwLock::new(HashMap::new())); // Generate a MAuthInfo and then drop it to validate that it works, // making it safe to use `unwrap` in the service constructor. - MAuthInfo::from_config_section(&config_info, Some(remote_key_store.clone()))?; - Ok(MAuthValidationLayer { - config_info, - remote_key_store, - }) + MAuthInfo::from_config_section(&config_info)?; + Ok(MAuthValidationLayer { config_info }) } /// Construct a MAuthValidationLayer based on the configuration options in a manually /// created or parsed ConfigFileSection. pub fn from_config_section(config_info: ConfigFileSection) -> Result { - let remote_key_store = Arc::new(RwLock::new(HashMap::new())); - MAuthInfo::from_config_section(&config_info, Some(remote_key_store.clone()))?; - Ok(MAuthValidationLayer { - config_info, - remote_key_store, - }) + MAuthInfo::from_config_section(&config_info)?; + Ok(MAuthValidationLayer { config_info }) } } diff --git a/src/config.rs b/src/config.rs index 29c2cf4..4b47830 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,10 +1,10 @@ -use crate::MAuthInfo; -use mauth_core::{signer::Signer, verifier::Verifier}; +use crate::{MAuthInfo, CLIENT}; +use mauth_core::signer::Signer; +use reqwest::Client; use reqwest::Url; +use reqwest_middleware::ClientBuilder; use serde::Deserialize; -use std::collections::HashMap; use std::io; -use std::sync::{Arc, RwLock}; use thiserror::Error; use uuid::Uuid; @@ -15,7 +15,7 @@ impl MAuthInfo { /// present in the current user's home directory. Returns an enum error type that includes the /// error types of all crates used. pub fn from_default_file() -> Result { - Self::from_config_section(&Self::config_section_from_default_file()?, None) + Self::from_config_section(&Self::config_section_from_default_file()?) } pub(crate) fn config_section_from_default_file() -> Result { @@ -35,10 +35,7 @@ impl MAuthInfo { /// Construct the MAuthInfo struct based on a passed-in ConfigFileSection instance. The /// optional input_keystore is present to support internal cloning and need not be provided /// if being used outside of the crate. - pub fn from_config_section( - section: &ConfigFileSection, - input_keystore: Option>>>, - ) -> Result { + pub fn from_config_section(section: &ConfigFileSection) -> Result { let full_uri: Url = format!( "{}/mauth/{}/security_tokens/", §ion.mauth_baseurl, §ion.mauth_api_version @@ -55,15 +52,22 @@ impl MAuthInfo { return Err(ConfigReadError::NoPrivateKey); } - Ok(MAuthInfo { + let mauth_info = MAuthInfo { app_id: Uuid::parse_str(§ion.app_uuid)?, mauth_uri_base: full_uri, - remote_key_store: input_keystore - .unwrap_or_else(|| Arc::new(RwLock::new(HashMap::new()))), sign_with_v1_also: !section.v2_only_sign_requests.unwrap_or(false), allow_v1_auth: !section.v2_only_authenticate.unwrap_or(false), signer: Signer::new(section.app_uuid.clone(), pk_data.unwrap())?, - }) + }; + + CLIENT.get_or_init(|| { + let builder = ClientBuilder::new(Client::new()).with(mauth_info.clone()); + #[cfg(any(feature = "tracing-otel-26", feature = "tracing-otel-27"))] + let builder = builder.with(reqwest_tracing::TracingMiddleware::default()); + builder.build() + }); + + Ok(mauth_info) } } @@ -145,7 +149,7 @@ mod test { v2_only_sign_requests: None, v2_only_authenticate: None, }; - let load_result = MAuthInfo::from_config_section(&bad_config, None); + let load_result = MAuthInfo::from_config_section(&bad_config); assert!(matches!(load_result, Err(ConfigReadError::InvalidUri(_)))); } @@ -160,7 +164,7 @@ mod test { v2_only_sign_requests: None, v2_only_authenticate: None, }; - let load_result = MAuthInfo::from_config_section(&bad_config, None); + let load_result = MAuthInfo::from_config_section(&bad_config); assert!(matches!( load_result, Err(ConfigReadError::FileReadError(_)) @@ -180,7 +184,7 @@ mod test { v2_only_sign_requests: None, v2_only_authenticate: None, }; - let load_result = MAuthInfo::from_config_section(&bad_config, None); + let load_result = MAuthInfo::from_config_section(&bad_config); fs::remove_file(&filename).await.unwrap(); assert!(matches!( load_result, @@ -201,7 +205,7 @@ mod test { v2_only_sign_requests: None, v2_only_authenticate: None, }; - let load_result = MAuthInfo::from_config_section(&bad_config, None); + let load_result = MAuthInfo::from_config_section(&bad_config); fs::remove_file(&filename).await.unwrap(); assert!(matches!( load_result, diff --git a/src/lib.rs b/src/lib.rs index c30e5bc..7fab385 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,47 +1,12 @@ #![forbid(unsafe_code)] -//! # mauth-client -//! -//! This crate allows users of the Reqwest crate for making HTTP requests to sign those requests with -//! the MAuth protocol, and verify the responses. Usage example: -//! -//! **Note**: This crate and Rust support within Medidata is considered experimental. Do not -//! release any code to Production or deploy in a Client-accessible environment without getting -//! approval for the full stack used through the Architecture and Security groups. -//! -//! ```no_run -//! use mauth_client::MAuthInfo; -//! use reqwest::Client; -//! # async fn make_signed_request() { -//! let mauth_info = MAuthInfo::from_default_file().unwrap(); -//! let client = Client::new(); -//! let mut req = client.get("https://www.example.com/").build().unwrap(); -//! mauth_info.sign_request(&mut req); -//! match client.execute(req).await { -//! Err(err) => println!("Got error {}", err), -//! Ok(response) => println!("Got validated response with body {}", response.text().await.unwrap()), -//! } -//! # } -//! ``` -//! -//! -//! The above code will read your mauth configuration from a file in `~/.mauth_config.yml` which format is: -//! ```yaml -//! common: &common -//! mauth_baseurl: https:// -//! mauth_api_version: v1 -//! app_uuid: -//! private_key_file: -//! ``` -//! -//! The optional `axum-service` feature provides for a Tower Layer and Service that will -//! authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a -//! validated app_uuid from the request via the ValidatedRequestDetails struct. +#![doc = include_str!("../README.md")] +use ::reqwest_middleware::ClientWithMiddleware; use mauth_core::signer::Signer; use mauth_core::verifier::Verifier; use reqwest::Url; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::{LazyLock, OnceLock, RwLock}; use uuid::Uuid; /// This is the primary struct of this class. It contains all of the information @@ -49,16 +14,20 @@ use uuid::Uuid; /// /// Note that it contains a cache of response keys for verifying response signatures. This cache /// makes the struct non-Sync. -#[allow(dead_code)] +#[derive(Clone)] pub struct MAuthInfo { app_id: Uuid, sign_with_v1_also: bool, signer: Signer, - remote_key_store: Arc>>, mauth_uri_base: Url, allow_v1_auth: bool, } +static CLIENT: OnceLock = OnceLock::new(); + +static PUBKEY_CACHE: LazyLock>> = + LazyLock::new(|| RwLock::new(HashMap::new())); + /// Tower Service and Layer to allow Tower-integrated servers to validate incoming request #[cfg(feature = "axum-service")] pub mod axum_service; @@ -66,6 +35,7 @@ pub mod axum_service; pub mod config; #[cfg(test)] mod protocol_test_suite; +mod reqwest_middleware; /// Implementation of code to sign outgoing requests pub mod sign_outgoing; /// Implementation of code to validate incoming requests diff --git a/src/protocol_test_suite.rs b/src/protocol_test_suite.rs index 751ac07..31fa05d 100644 --- a/src/protocol_test_suite.rs +++ b/src/protocol_test_suite.rs @@ -31,7 +31,7 @@ async fn setup_mauth_info() -> (MAuthInfo, u64) { v2_only_authenticate: None, }; ( - MAuthInfo::from_config_section(&mock_config_section, None).unwrap(), + MAuthInfo::from_config_section(&mock_config_section).unwrap(), sign_config.request_time, ) } diff --git a/src/reqwest_middleware.rs b/src/reqwest_middleware.rs new file mode 100644 index 0000000..71f6c0d --- /dev/null +++ b/src/reqwest_middleware.rs @@ -0,0 +1,25 @@ +use http::Extensions; +use reqwest::{Request, Response}; +use reqwest_middleware::{Middleware, Next, Result}; + +use crate::{sign_outgoing::SigningError, MAuthInfo}; + +#[async_trait::async_trait] +impl Middleware for MAuthInfo { + #[must_use] + async fn handle( + &self, + mut req: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> Result { + self.sign_request(&mut req)?; + next.run(req, extensions).await + } +} + +impl From for reqwest_middleware::Error { + fn from(value: SigningError) -> Self { + reqwest_middleware::Error::Middleware(value.into()) + } +} diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index 9061974..ac02f5d 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -1,7 +1,6 @@ -use crate::MAuthInfo; +use crate::{MAuthInfo, CLIENT, PUBKEY_CACHE}; use chrono::prelude::*; use mauth_core::verifier::Verifier; -use reqwest::{Client, Method, Request}; use thiserror::Error; use uuid::Uuid; @@ -184,17 +183,13 @@ impl MAuthInfo { async fn get_app_pub_key(&self, app_uuid: &Uuid) -> Option { { - let key_store = self.remote_key_store.read().unwrap(); + let key_store = PUBKEY_CACHE.read().unwrap(); if let Some(pub_key) = key_store.get(app_uuid) { return Some(pub_key.clone()); } } - let client = Client::new(); let uri = self.mauth_uri_base.join(&format!("{}", &app_uuid)).unwrap(); - let mut req = Request::new(Method::GET, uri); - // This can only error with invalid UTF8 format, which is impossible here - self.sign_request_v2(&mut req).unwrap(); - let mauth_response = client.execute(req).await; + let mauth_response = CLIENT.get().unwrap().get(uri).send().await; match mauth_response { Err(_) => None, Ok(response) => { @@ -205,7 +200,7 @@ impl MAuthInfo { .map(|st| st.to_owned()) { if let Ok(verifier) = Verifier::new(*app_uuid, pub_key_str) { - let mut key_store = self.remote_key_store.write().unwrap(); + let mut key_store = PUBKEY_CACHE.write().unwrap(); key_store.insert(*app_uuid, verifier.clone()); Some(verifier) } else {