Skip to content

Commit

Permalink
refactor(auth): change the way multiple auth providers are defined
Browse files Browse the repository at this point in the history
  • Loading branch information
meskill committed Mar 28, 2024
1 parent 8f1d569 commit 231f9c9
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 90 deletions.
6 changes: 6 additions & 0 deletions src/auth/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ testuser2:$2y$10$wJ/mZDURcAOBIrswCAKFsO0Nk7BpHmWl/XuhF7lNm3gBAFH3ofsuu
testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4=
";

impl blueprint::BasicProvider {
pub fn test_value() -> Self {
Self { htpasswd: HTPASSWD_TEST.to_owned() }
}
}

pub fn create_basic_auth_request(username: &str, password: &str) -> RequestContext {
let mut req_context = RequestContext::default();

Expand Down
20 changes: 11 additions & 9 deletions src/auth/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::{Arc, Mutex};
use std::sync::{Arc, RwLock};

use anyhow::Result;

Expand All @@ -14,8 +14,7 @@ pub struct GlobalAuthContext {

#[derive(Default)]
pub struct AuthContext {
// TODO: can we do without mutex?
auth_result: Mutex<Option<Result<(), Error>>>,
auth_result: RwLock<Option<Result<(), Error>>>,
global_ctx: Arc<GlobalAuthContext>,
}

Expand All @@ -34,21 +33,20 @@ impl GlobalAuthContext {
}

impl GlobalAuthContext {
pub fn new(auth: Auth) -> Self {
let verifier = AuthVerifier::new(auth);
Self { verifier }
pub fn new(auth: Option<Auth>) -> Self {
Self { verifier: auth.map(AuthVerifier::from) }
}
}

impl AuthContext {
pub async fn validate(&self, request: &RequestContext) -> Result<(), Error> {
if let Some(result) = self.auth_result.lock().unwrap().as_ref() {
if let Some(result) = self.auth_result.read().unwrap().as_ref() {
return result.clone();
}

let result = self.global_ctx.validate(request).await;

self.auth_result.lock().unwrap().replace(result.clone());
self.auth_result.write().unwrap().replace(result.clone());

result
}
Expand All @@ -70,6 +68,7 @@ mod tests {
use crate::auth::basic::BasicVerifier;
use crate::auth::jwt::jwt_verify::tests::{create_jwt_auth_request, JWT_VALID_TOKEN_WITH_KID};
use crate::auth::jwt::jwt_verify::JwtVerifier;
use crate::auth::verify::Verifier;
use crate::blueprint;

#[tokio::test]
Expand All @@ -80,7 +79,10 @@ mod tests {
let jwt_provider = JwtVerifier::new(jwt_options);

let auth_context = GlobalAuthContext {
verifier: Some(AuthVerifier::Jwt(jwt_provider).or(AuthVerifier::Basic(basic_provider))),
verifier: Some(AuthVerifier::Any(vec![
AuthVerifier::Single(Verifier::Basic(basic_provider)),
AuthVerifier::Single(Verifier::Jwt(jwt_provider)),
])),
};

let validation = auth_context
Expand Down
12 changes: 1 addition & 11 deletions src/auth/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[derive(Debug, thiserror::Error, Clone, PartialEq, PartialOrd)]
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Error {
#[error("Haven't found auth parameters")]
Missing,
Expand All @@ -9,13 +9,3 @@ pub enum Error {
#[error("Auth validation failed")]
Invalid,
}

impl Error {
pub fn max(self, other: Self) -> Self {
if self < other {
other
} else {
self
}
}
}
175 changes: 122 additions & 53 deletions src/auth/verify.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::cmp::max;

use anyhow::Result;
use async_std::prelude::FutureExt;
use futures_util::future::join_all;

use super::basic::BasicVerifier;
use super::error::Error;
Expand All @@ -12,79 +14,146 @@ pub(crate) trait Verify {
async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error>;
}

#[allow(clippy::large_enum_variant)]
// The difference in size is indeed significant here
// but it's quite unlikely that someone will require to store several hundreds
// of providers or more to care much
pub enum AuthVerifier {
pub enum Verifier {
Basic(BasicVerifier),
Jwt(JwtVerifier),
Or(Box<AuthVerifier>, Box<AuthVerifier>),
And(Box<AuthVerifier>, Box<AuthVerifier>),
}

impl AuthVerifier {
pub fn new(provider: blueprint::Auth) -> Option<Self> {
pub enum AuthVerifier {
Single(Verifier),
Any(Vec<AuthVerifier>),
All(Vec<AuthVerifier>),
}

impl From<blueprint::AuthProvider> for Verifier {
fn from(provider: blueprint::AuthProvider) -> Self {
match provider {
blueprint::Auth::Basic(options) => {
Some(AuthVerifier::Basic(BasicVerifier::new(options)))
}
blueprint::Auth::Jwt(options) => Some(AuthVerifier::Jwt(JwtVerifier::new(options))),
blueprint::Auth::And(a, b) => {
let a = Self::new(*a);
let b = Self::new(*b);

match (a, b) {
(None, None) => None,
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(Some(a), Some(b)) => Some(AuthVerifier::And(Box::new(a), Box::new(b))),
}
blueprint::AuthProvider::Basic(options) => Verifier::Basic(BasicVerifier::new(options)),
blueprint::AuthProvider::Jwt(options) => Verifier::Jwt(JwtVerifier::new(options)),
}
}
}

impl From<blueprint::Auth> for AuthVerifier {
fn from(provider: blueprint::Auth) -> Self {
match provider {
blueprint::Auth::Single(provider) => AuthVerifier::Single(provider.into()),
blueprint::Auth::All(providers) => {
let verifiers = providers.into_iter().map(AuthVerifier::from).collect();

AuthVerifier::All(verifiers)
}
blueprint::Auth::Or(a, b) => {
let a = Self::new(*a);
let b = Self::new(*b);

match (a, b) {
(None, None) => None,
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(Some(a), Some(b)) => Some(AuthVerifier::Or(Box::new(a), Box::new(b))),
}
blueprint::Auth::Any(providers) => {
let verifiers = providers.into_iter().map(AuthVerifier::from).collect();

AuthVerifier::Any(verifiers)
}
blueprint::Auth::Empty => None,
}
}
}

#[cfg(test)]
pub fn or(self, other: AuthVerifier) -> Self {
AuthVerifier::Or(Box::new(self), Box::new(other))
#[async_trait::async_trait]
impl Verify for Verifier {
async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error> {
match self {
Verifier::Basic(basic) => basic.verify(req_ctx).await,
Verifier::Jwt(jwt) => jwt.verify(req_ctx).await,
}
}
}

#[async_trait::async_trait]
impl Verify for AuthVerifier {
async fn verify(&self, req_ctx: &RequestContext) -> Result<(), Error> {
match self {
AuthVerifier::Basic(basic) => basic.verify(req_ctx).await,
AuthVerifier::Jwt(jwt) => jwt.verify(req_ctx).await,
AuthVerifier::Or(left, right) => {
let left_result = left.verify(req_ctx).await;
if left_result.is_err() {
right.verify(req_ctx).await
} else {
Ok(())
AuthVerifier::Single(verifier) => verifier.verify(req_ctx).await,
AuthVerifier::All(verifiers) => {
for verifier in verifiers {
verifier.verify(req_ctx).await?
}

Ok(())
}
AuthVerifier::And(left, right) => {
let (a, b) = left.verify(req_ctx).join(right.verify(req_ctx)).await;
match (a, b) {
(Ok(_), Ok(_)) => Ok(()),
(Ok(_), Err(e)) => Err(e),
(Err(e), Ok(_)) => Err(e),
(Err(e1), Err(e2)) => Err(e1.max(e2)),
AuthVerifier::Any(verifiers) => {
let results =
join_all(verifiers.iter().map(|verifier| verifier.verify(req_ctx))).await;

let mut error = Error::Missing;

for result in results {
if let Err(err) = result {
error = max(error, err);
} else {
return Ok(());
}
}

Err(error)
}
}
}
}

#[cfg(test)]
mod tests {
use super::AuthVerifier;
use crate::auth::basic::tests::create_basic_auth_request;
use crate::auth::error::Error;
use crate::auth::jwt::jwt_verify::tests::{create_jwt_auth_request, JWT_VALID_TOKEN_WITH_KID};
use crate::auth::verify::Verify;
use crate::blueprint::{Auth, AuthProvider, BasicProvider, JwtProvider};

#[tokio::test]
async fn verify_all() {
let verifier = AuthVerifier::from(Auth::All(Vec::default()));
let req_ctx = create_basic_auth_request("testuser1", "wrong-password");

assert_eq!(verifier.verify(&req_ctx).await, Ok(()));

let verifier = AuthVerifier::from(Auth::All(vec![Auth::Single(AuthProvider::Basic(
BasicProvider::test_value(),
))]));

assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Invalid));

let req_ctx = create_basic_auth_request("testuser1", "password123");

assert_eq!(verifier.verify(&req_ctx).await, Ok(()));

let verifier = AuthVerifier::from(Auth::All(vec![
Auth::Single(AuthProvider::Basic(BasicProvider::test_value())),
Auth::Single(AuthProvider::Jwt(JwtProvider::test_value())),
]));

assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Missing));
}

#[tokio::test]
async fn verify_any() {
let verifier = AuthVerifier::from(Auth::Any(Vec::default()));
let req_ctx = create_basic_auth_request("testuser1", "wrong-password");

assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Missing));

let verifier = AuthVerifier::from(Auth::Any(vec![Auth::Single(AuthProvider::Basic(
BasicProvider::test_value(),
))]));

assert_eq!(verifier.verify(&req_ctx).await, Err(Error::Invalid));

let req_ctx = create_basic_auth_request("testuser1", "password123");

assert_eq!(verifier.verify(&req_ctx).await, Ok(()));

let verifier = AuthVerifier::from(Auth::Any(vec![
Auth::Single(AuthProvider::Basic(BasicProvider::test_value())),
Auth::Single(AuthProvider::Jwt(JwtProvider::test_value())),
]));

assert_eq!(verifier.verify(&req_ctx).await, Ok(()));

let req_ctx = create_jwt_auth_request(JWT_VALID_TOKEN_WITH_KID);

assert_eq!(verifier.verify(&req_ctx).await, Ok(()));
}
}
Loading

0 comments on commit 231f9c9

Please sign in to comment.