From 231f9c948bb2aca1e819be823bce616326a7a7a2 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Thu, 28 Mar 2024 21:50:27 +0000 Subject: [PATCH] refactor(auth): change the way multiple auth providers are defined --- src/auth/basic.rs | 6 ++ src/auth/context.rs | 20 +++-- src/auth/error.rs | 12 +-- src/auth/verify.rs | 175 +++++++++++++++++++++++++----------- src/blueprint/auth.rs | 194 ++++++++++++++++++++++++++++++++++++---- src/blueprint/server.rs | 2 +- 6 files changed, 319 insertions(+), 90 deletions(-) diff --git a/src/auth/basic.rs b/src/auth/basic.rs index d2d201a08c..8685b72857 100644 --- a/src/auth/basic.rs +++ b/src/auth/basic.rs @@ -49,6 +49,12 @@ testuser2:$2y$10$wJ/mZDURcAOBIrswCAKFsO0Nk7BpHmWl/XuhF7lNm3gBAFH3ofsuu testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4= "; + impl blueprint::BasicProvider { + pub fn test_value() -> Self { + Self { htpasswd: HTPASSWD_TEST.to_owned() } + } + } + pub fn create_basic_auth_request(username: &str, password: &str) -> RequestContext { let mut req_context = RequestContext::default(); diff --git a/src/auth/context.rs b/src/auth/context.rs index 571007ecb1..dc9a5597cc 100644 --- a/src/auth/context.rs +++ b/src/auth/context.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use anyhow::Result; @@ -14,8 +14,7 @@ pub struct GlobalAuthContext { #[derive(Default)] pub struct AuthContext { - // TODO: can we do without mutex? - auth_result: Mutex>>, + auth_result: RwLock>>, global_ctx: Arc, } @@ -34,21 +33,20 @@ impl GlobalAuthContext { } impl GlobalAuthContext { - pub fn new(auth: Auth) -> Self { - let verifier = AuthVerifier::new(auth); - Self { verifier } + pub fn new(auth: Option) -> Self { + Self { verifier: auth.map(AuthVerifier::from) } } } impl AuthContext { pub async fn validate(&self, request: &RequestContext) -> Result<(), Error> { - if let Some(result) = self.auth_result.lock().unwrap().as_ref() { + if let Some(result) = self.auth_result.read().unwrap().as_ref() { return result.clone(); } let result = self.global_ctx.validate(request).await; - self.auth_result.lock().unwrap().replace(result.clone()); + self.auth_result.write().unwrap().replace(result.clone()); result } @@ -70,6 +68,7 @@ mod tests { use crate::auth::basic::BasicVerifier; use crate::auth::jwt::jwt_verify::tests::{create_jwt_auth_request, JWT_VALID_TOKEN_WITH_KID}; use crate::auth::jwt::jwt_verify::JwtVerifier; + use crate::auth::verify::Verifier; use crate::blueprint; #[tokio::test] @@ -80,7 +79,10 @@ mod tests { let jwt_provider = JwtVerifier::new(jwt_options); let auth_context = GlobalAuthContext { - verifier: Some(AuthVerifier::Jwt(jwt_provider).or(AuthVerifier::Basic(basic_provider))), + verifier: Some(AuthVerifier::Any(vec![ + AuthVerifier::Single(Verifier::Basic(basic_provider)), + AuthVerifier::Single(Verifier::Jwt(jwt_provider)), + ])), }; let validation = auth_context diff --git a/src/auth/error.rs b/src/auth/error.rs index 28eaabffe3..5d76975b80 100644 --- a/src/auth/error.rs +++ b/src/auth/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug, thiserror::Error, Clone, PartialEq, PartialOrd)] +#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum Error { #[error("Haven't found auth parameters")] Missing, @@ -9,13 +9,3 @@ pub enum Error { #[error("Auth validation failed")] Invalid, } - -impl Error { - pub fn max(self, other: Self) -> Self { - if self < other { - other - } else { - self - } - } -} diff --git a/src/auth/verify.rs b/src/auth/verify.rs index f2e1511413..69f6c3d6fa 100644 --- a/src/auth/verify.rs +++ b/src/auth/verify.rs @@ -1,5 +1,7 @@ +use std::cmp::max; + use anyhow::Result; -use async_std::prelude::FutureExt; +use futures_util::future::join_all; use super::basic::BasicVerifier; use super::error::Error; @@ -12,53 +14,51 @@ pub(crate) trait Verify { async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error>; } -#[allow(clippy::large_enum_variant)] -// The difference in size is indeed significant here -// but it's quite unlikely that someone will require to store several hundreds -// of providers or more to care much -pub enum AuthVerifier { +pub enum Verifier { Basic(BasicVerifier), Jwt(JwtVerifier), - Or(Box, Box), - And(Box, Box), } -impl AuthVerifier { - pub fn new(provider: blueprint::Auth) -> Option { +pub enum AuthVerifier { + Single(Verifier), + Any(Vec), + All(Vec), +} + +impl From for Verifier { + fn from(provider: blueprint::AuthProvider) -> Self { match provider { - blueprint::Auth::Basic(options) => { - Some(AuthVerifier::Basic(BasicVerifier::new(options))) - } - blueprint::Auth::Jwt(options) => Some(AuthVerifier::Jwt(JwtVerifier::new(options))), - blueprint::Auth::And(a, b) => { - let a = Self::new(*a); - let b = Self::new(*b); - - match (a, b) { - (None, None) => None, - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (Some(a), Some(b)) => Some(AuthVerifier::And(Box::new(a), Box::new(b))), - } + blueprint::AuthProvider::Basic(options) => Verifier::Basic(BasicVerifier::new(options)), + blueprint::AuthProvider::Jwt(options) => Verifier::Jwt(JwtVerifier::new(options)), + } + } +} + +impl From for AuthVerifier { + fn from(provider: blueprint::Auth) -> Self { + match provider { + blueprint::Auth::Single(provider) => AuthVerifier::Single(provider.into()), + blueprint::Auth::All(providers) => { + let verifiers = providers.into_iter().map(AuthVerifier::from).collect(); + + AuthVerifier::All(verifiers) } - blueprint::Auth::Or(a, b) => { - let a = Self::new(*a); - let b = Self::new(*b); - - match (a, b) { - (None, None) => None, - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (Some(a), Some(b)) => Some(AuthVerifier::Or(Box::new(a), Box::new(b))), - } + blueprint::Auth::Any(providers) => { + let verifiers = providers.into_iter().map(AuthVerifier::from).collect(); + + AuthVerifier::Any(verifiers) } - blueprint::Auth::Empty => None, } } +} - #[cfg(test)] - pub fn or(self, other: AuthVerifier) -> Self { - AuthVerifier::Or(Box::new(self), Box::new(other)) +#[async_trait::async_trait] +impl Verify for Verifier { + async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error> { + match self { + Verifier::Basic(basic) => basic.verify(req_ctx).await, + Verifier::Jwt(jwt) => jwt.verify(req_ctx).await, + } } } @@ -66,25 +66,94 @@ impl AuthVerifier { impl Verify for AuthVerifier { async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error> { match self { - AuthVerifier::Basic(basic) => basic.verify(req_ctx).await, - AuthVerifier::Jwt(jwt) => jwt.verify(req_ctx).await, - AuthVerifier::Or(left, right) => { - let left_result = left.verify(req_ctx).await; - if left_result.is_err() { - right.verify(req_ctx).await - } else { - Ok(()) + AuthVerifier::Single(verifier) => verifier.verify(req_ctx).await, + AuthVerifier::All(verifiers) => { + for verifier in verifiers { + verifier.verify(req_ctx).await? } + + Ok(()) } - AuthVerifier::And(left, right) => { - let (a, b) = left.verify(req_ctx).join(right.verify(req_ctx)).await; - match (a, b) { - (Ok(_), Ok(_)) => Ok(()), - (Ok(_), Err(e)) => Err(e), - (Err(e), Ok(_)) => Err(e), - (Err(e1), Err(e2)) => Err(e1.max(e2)), + AuthVerifier::Any(verifiers) => { + let results = + join_all(verifiers.iter().map(|verifier| verifier.verify(req_ctx))).await; + + let mut error = Error::Missing; + + for result in results { + if let Err(err) = result { + error = max(error, err); + } else { + return Ok(()); + } } + + Err(error) } } } } + +#[cfg(test)] +mod tests { + use super::AuthVerifier; + use crate::auth::basic::tests::create_basic_auth_request; + use crate::auth::error::Error; + use crate::auth::jwt::jwt_verify::tests::{create_jwt_auth_request, JWT_VALID_TOKEN_WITH_KID}; + use crate::auth::verify::Verify; + use crate::blueprint::{Auth, AuthProvider, BasicProvider, JwtProvider}; + + #[tokio::test] + async fn verify_all() { + let verifier = AuthVerifier::from(Auth::All(Vec::default())); + let req_ctx = create_basic_auth_request("testuser1", "wrong-password"); + + assert_eq!(verifier.verify(&req_ctx).await, Ok(())); + + let verifier = AuthVerifier::from(Auth::All(vec![Auth::Single(AuthProvider::Basic( + BasicProvider::test_value(), + ))])); + + assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Invalid)); + + let req_ctx = create_basic_auth_request("testuser1", "password123"); + + assert_eq!(verifier.verify(&req_ctx).await, Ok(())); + + let verifier = AuthVerifier::from(Auth::All(vec![ + Auth::Single(AuthProvider::Basic(BasicProvider::test_value())), + Auth::Single(AuthProvider::Jwt(JwtProvider::test_value())), + ])); + + assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Missing)); + } + + #[tokio::test] + async fn verify_any() { + let verifier = AuthVerifier::from(Auth::Any(Vec::default())); + let req_ctx = create_basic_auth_request("testuser1", "wrong-password"); + + assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Missing)); + + let verifier = AuthVerifier::from(Auth::Any(vec![Auth::Single(AuthProvider::Basic( + BasicProvider::test_value(), + ))])); + + assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Invalid)); + + let req_ctx = create_basic_auth_request("testuser1", "password123"); + + assert_eq!(verifier.verify(&req_ctx).await, Ok(())); + + let verifier = AuthVerifier::from(Auth::Any(vec![ + Auth::Single(AuthProvider::Basic(BasicProvider::test_value())), + Auth::Single(AuthProvider::Jwt(JwtProvider::test_value())), + ])); + + assert_eq!(verifier.verify(&req_ctx).await, Ok(())); + + let req_ctx = create_jwt_auth_request(JWT_VALID_TOKEN_WITH_KID); + + assert_eq!(verifier.verify(&req_ctx).await, Ok(())); + } +} diff --git a/src/blueprint/auth.rs b/src/blueprint/auth.rs index 6e43f4bf83..64244b7b4e 100644 --- a/src/blueprint/auth.rs +++ b/src/blueprint/auth.rs @@ -6,12 +6,12 @@ use jsonwebtoken::jwk::JwkSet; use crate::config::ConfigModule; use crate::valid::Valid; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct BasicProvider { pub htpasswd: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct JwtProvider { pub issuer: Option, pub audiences: HashSet, @@ -19,40 +19,202 @@ pub struct JwtProvider { pub jwks: JwkSet, } -#[derive(Clone, Debug, Default)] -pub enum Auth { +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AuthProvider { Basic(BasicProvider), Jwt(JwtProvider), - And(Box, Box), - Or(Box, Box), - #[default] - Empty, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Auth { + Single(AuthProvider), + All(Vec), + Any(Vec), } impl Auth { - pub fn make(config_module: &ConfigModule) -> Valid { - let mut auth = Auth::default(); + pub fn make(config_module: &ConfigModule) -> Valid, String> { + if !config_module.extensions.has_auth() { + return Valid::succeed(None); + } + + let mut auth = Auth::Any(Vec::new()); for htpasswd in config_module.extensions.htpasswd.iter() { - auth = auth.or(Auth::Basic(BasicProvider { + auth = auth.or(Auth::Single(AuthProvider::Basic(BasicProvider { htpasswd: htpasswd.content.clone(), - })) + }))); } for jwks in config_module.extensions.jwks.iter() { - auth = auth.or(Auth::Jwt(JwtProvider { + auth = auth.or(Auth::Single(AuthProvider::Jwt(JwtProvider { jwks: jwks.content.clone(), // TODO: read those options from link instead of using defaults issuer: Default::default(), audiences: Default::default(), optional_kid: Default::default(), - })) + }))); } - Valid::succeed(auth) + Valid::succeed(Some(auth)) + } + + pub fn and(self, other: Self) -> Self { + let v = match (self, other) { + (Auth::All(mut v1), Auth::All(mut v2)) => { + v1.append(&mut v2); + v1 + } + (Auth::All(mut v), other) | (other, Auth::All(mut v)) => { + v.push(other); + v + } + (this, other) => vec![this, other], + }; + + Auth::All(v) } pub fn or(self, other: Self) -> Self { - Auth::Or(Box::new(self), Box::new(other)) + let v = match (self, other) { + (Auth::Any(mut v1), Auth::Any(mut v2)) => { + v1.append(&mut v2); + v1 + } + (Auth::Any(mut v), other) | (other, Auth::Any(mut v)) => { + v.push(other); + v + } + (this, other) => vec![this, other], + }; + + Auth::Any(v) + } +} + +#[cfg(test)] +mod tests { + use super::{Auth, AuthProvider, BasicProvider, JwtProvider}; + + #[test] + fn test_and() { + let basic_provider_1 = AuthProvider::Basic(BasicProvider::test_value()); + let basic_provider_2 = AuthProvider::Basic(BasicProvider::test_value()); + let jwt_provider = AuthProvider::Jwt(JwtProvider::test_value()); + + assert_eq!( + Auth::Single(basic_provider_1.clone()).and(Auth::Single(basic_provider_2.clone())), + Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + ); + + assert_eq!( + Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + .and(Auth::Single(jwt_provider.clone())), + Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()), + Auth::Single(jwt_provider.clone()) + ]) + ); + + assert_eq!( + Auth::Single(jwt_provider.clone()).and(Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ])), + Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()), + Auth::Single(jwt_provider.clone()), + ]) + ); + + assert_eq!( + Auth::Any(vec![ + Auth::Single(jwt_provider.clone()), + Auth::Single(jwt_provider.clone()) + ]) + .and(Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ])), + Auth::All(vec![ + Auth::Any(vec![ + Auth::Single(jwt_provider.clone()), + Auth::Single(jwt_provider.clone()) + ]), + Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + ]) + ); + } + + #[test] + fn test_or() { + let basic_provider_1 = AuthProvider::Basic(BasicProvider { htpasswd: "1".into() }); + let basic_provider_2 = AuthProvider::Basic(BasicProvider { htpasswd: "2".into() }); + let jwt_provider = AuthProvider::Jwt(JwtProvider::test_value()); + + assert_eq!( + Auth::Single(basic_provider_1.clone()).or(Auth::Single(basic_provider_2.clone())), + Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + ); + + assert_eq!( + Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + .or(Auth::Single(jwt_provider.clone())), + Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()), + Auth::Single(jwt_provider.clone()) + ]) + ); + + assert_eq!( + Auth::Single(jwt_provider.clone()).or(Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ])), + Auth::Any(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()), + Auth::Single(jwt_provider.clone()), + ]) + ); + + assert_eq!( + Auth::All(vec![ + Auth::Single(jwt_provider.clone()), + Auth::Single(jwt_provider.clone()) + ]) + .or(Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ])), + Auth::Any(vec![ + Auth::All(vec![ + Auth::Single(jwt_provider.clone()), + Auth::Single(jwt_provider.clone()) + ]), + Auth::All(vec![ + Auth::Single(basic_provider_1.clone()), + Auth::Single(basic_provider_2.clone()) + ]) + ]) + ); } } diff --git a/src/blueprint/server.rs b/src/blueprint/server.rs index 0214f3edd8..b62f8e63eb 100644 --- a/src/blueprint/server.rs +++ b/src/blueprint/server.rs @@ -36,7 +36,7 @@ pub struct Server { pub script: Option