Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jw/changing how jwt is passed around #1130

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
49f044f
wip - passing userauthcontext instead of http headers
Feb 29, 2024
5fece60
merged main
Mar 1, 2024
946c9ea
revert comment out
shopifyski Mar 1, 2024
ecf2389
fixed temp change
shopifyski Mar 1, 2024
9521d72
Merge branch 'main' into jw/changing-how-jwt-is-passed-around
shopifyski Mar 1, 2024
83e91ff
fixed failing tests
shopifyski Mar 1, 2024
b8bb476
next iteration of unit test fixing
shopifyski Mar 4, 2024
e322b1a
fixed remaining unit tests
shopifyski Mar 4, 2024
6b517ea
reverted debug
shopifyski Mar 4, 2024
e3bd67c
removed accidentally added file
shopifyski Mar 4, 2024
8f6b127
Merge branch 'main' into jw/changing-how-jwt-is-passed-around
shopifyski Mar 4, 2024
f038f39
moved str to userauthcontext conversion to from trait
shopifyski Mar 5, 2024
9bc376b
remove vague comment
shopifyski Mar 5, 2024
37c130f
cargo fmt
shopifyski Mar 5, 2024
b350ce5
cleaned up mod reimportss
shopifyski Mar 5, 2024
d3ccabc
refactored cryptic matching in replica_proxy
shopifyski Mar 6, 2024
9b353cd
marked potentially duplicate code with // todo dupe #auth
shopifyski Mar 6, 2024
e30d5a3
refactored context to custome errors
shopifyski Mar 6, 2024
cadf427
added a factory to produce empty UserAuthContext
shopifyski Mar 6, 2024
5077466
added constructors for UserAuthContext
shopifyski Mar 6, 2024
071dcf3
switched from try_into to using constructors
shopifyski Mar 6, 2024
9f94e4e
cargo fmt
shopifyski Mar 6, 2024
284b36f
added tests for failing cases in parsers
shopifyski Mar 7, 2024
f716a28
cargo fmt
shopifyski Mar 8, 2024
843c9f7
added test for non-asci error
shopifyski Mar 8, 2024
4125383
fixed log message
shopifyski Mar 13, 2024
37dc162
merged main
shopifyski Mar 13, 2024
3105025
removed unnecessary error mapping
shopifyski Mar 14, 2024
bccfe77
turned context to result
shopifyski Mar 14, 2024
37dc3af
removing dummy tokens from tests
shopifyski Mar 14, 2024
57861f6
cargo fmt + cleanup
shopifyski Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions libsql-server/src/auth/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ pub enum AuthError {
JwtExpired,
#[error("The JWT is immature (not valid yet)")]
JwtImmature,
#[error("Auth string does not conform to '<scheme> <token>' form")]
AuthStringMalformed,
#[error("Expected authorization header but none given")]
AuthHeaderNotFound,
#[error("Non-ASCII auth header")]
AuthHeaderNonAscii,
#[error("Authentication failed")]
Other,
}
Expand All @@ -39,6 +45,9 @@ impl AuthError {
Self::JwtInvalid => "AUTH_JWT_INVALID",
Self::JwtExpired => "AUTH_JWT_EXPIRED",
Self::JwtImmature => "AUTH_JWT_IMMATURE",
Self::AuthStringMalformed => "AUTH_HEADER_MALFORMED",
Self::AuthHeaderNotFound => "AUTH_HEADER_NOT_FOUND",
Self::AuthHeaderNonAscii => "AUTH_HEADER_MALFORMED",
Self::Other => "AUTH_FAILED",
}
}
Expand Down
5 changes: 4 additions & 1 deletion libsql-server/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ impl Auth {
}
}

pub fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError> {
pub fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
self.user_strategy.authenticate(context)
}
}
49 changes: 45 additions & 4 deletions libsql-server/src/auth/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use anyhow::{bail, Context as _, Result};
use axum::http::HeaderValue;
use tonic::metadata::MetadataMap;

use super::UserAuthContext;

pub fn parse_http_basic_auth_arg(arg: &str) -> Result<Option<String>> {
if arg == "always" {
return Ok(Some("always".to_string()));
Expand Down Expand Up @@ -34,11 +36,12 @@ pub fn parse_jwt_key(data: &str) -> Result<jsonwebtoken::DecodingKey> {
}
}

pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Option<HeaderValue> {
pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result<UserAuthContext, AuthError> {
metadata
.get(GRPC_AUTH_HEADER)
.map(|v| v.to_bytes().expect("Auth should always be ASCII"))
.map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header"))
.ok_or(AuthError::AuthHeaderNotFound)
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
.and_then(|t| UserAuthContext::from_auth_str(t))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.and_then(|t| UserAuthContext::from_auth_str(t))
.and_then(UserAuthContext::from_auth_str)

btw if the input type is the same as the closure this works as well!

}

pub fn parse_http_auth_header<'a>(
Expand Down Expand Up @@ -71,7 +74,45 @@ mod tests {

use crate::auth::{parse_http_auth_header, AuthError};

use super::parse_http_basic_auth_arg;
use super::{parse_grpc_auth_header, parse_http_basic_auth_arg};

#[test]
fn parse_grpc_auth_header_returns_valid_context() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer 123".parse().unwrap());
let context = parse_grpc_auth_header(&map).unwrap();
assert_eq!(context.scheme().as_ref().unwrap(), "bearer");
assert_eq!(context.token().as_ref().unwrap(), "123");
}

#[test]
fn parse_grpc_auth_header_error_no_header() {
let map = tonic::metadata::MetadataMap::new();
let result = parse_grpc_auth_header(&map);
assert_eq!(
result.unwrap_err().to_string(),
"Expected authorization header but none given"
);
}

#[test]
fn parse_grpc_auth_header_error_non_ascii() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer I❤NY".parse().unwrap());
let result = parse_grpc_auth_header(&map);
assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header")
}

#[test]
fn parse_grpc_auth_header_error_malformed_auth_str() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer123".parse().unwrap());
let result = parse_grpc_auth_header(&map);
assert_eq!(
result.unwrap_err().to_string(),
"Auth string does not conform to '<scheme> <token>' form"
)
}

#[test]
fn parse_http_auth_header_returns_auth_header_param_when_valid() {
Expand Down
9 changes: 5 additions & 4 deletions libsql-server/src/auth/user_auth_strategies/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use crate::auth::{AuthError, Authenticated};
pub struct Disabled {}

impl UserAuthStrategy for Disabled {
fn authenticate(&self, _context: UserAuthContext) -> Result<Authenticated, AuthError> {
fn authenticate(
&self,
_context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
tracing::trace!("executing disabled auth");
Ok(Authenticated::FullAccess)
}
Expand All @@ -23,9 +26,7 @@ mod tests {
#[test]
fn authenticates() {
let strategy = Disabled::new();
let context = UserAuthContext {
user_credential: None,
};
let context = Ok(UserAuthContext::empty());

assert!(matches!(
strategy.authenticate(context).unwrap(),
Expand Down
32 changes: 14 additions & 18 deletions libsql-server/src/auth/user_auth_strategies/http_basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::auth::{parse_http_auth_header, AuthError, Authenticated};
use crate::auth::{AuthError, Authenticated};

use super::{UserAuthContext, UserAuthStrategy};

Expand All @@ -7,17 +7,22 @@ pub struct HttpBasic {
}

impl UserAuthStrategy for HttpBasic {
fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError> {
fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
tracing::trace!("executing http basic auth");

let param = parse_http_auth_header("basic", &context.user_credential)?;

// NOTE: this naive comparison may leak information about the `expected_value`
// using a timing attack
let actual_value = param.trim_end_matches('=');
let expected_value = self.credential.trim_end_matches('=');

if actual_value == expected_value {
let creds_match = match context?.token {
Some(s) => s.contains(expected_value),
None => expected_value.is_empty(),
};

if creds_match {
return Ok(Authenticated::FullAccess);
}

Expand All @@ -33,8 +38,6 @@ impl HttpBasic {

#[cfg(test)]
mod tests {
use axum::http::HeaderValue;

use super::*;

const CREDENTIAL: &str = "d29qdGVrOnRoZWJlYXI=";
Expand All @@ -45,9 +48,7 @@ mod tests {

#[test]
fn authenticates_with_valid_credential() {
let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Basic {CREDENTIAL}")).ok(),
};
let context = Ok(UserAuthContext::basic(CREDENTIAL));

assert!(matches!(
strategy().authenticate(context).unwrap(),
Expand All @@ -58,10 +59,7 @@ mod tests {
#[test]
fn authenticates_with_valid_trimmed_credential() {
let credential = CREDENTIAL.trim_end_matches('=');

let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Basic {credential}")).ok(),
};
let context = Ok(UserAuthContext::basic(credential));

assert!(matches!(
strategy().authenticate(context).unwrap(),
Expand All @@ -71,9 +69,7 @@ mod tests {

#[test]
fn errors_when_credentials_do_not_match() {
let context = UserAuthContext {
user_credential: HeaderValue::from_str("Basic abc").ok(),
};
let context = Ok(UserAuthContext::basic("abc"));

assert_eq!(
strategy().authenticate(context).unwrap_err(),
Expand Down
49 changes: 26 additions & 23 deletions libsql-server/src/auth/user_auth_strategies/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use chrono::{DateTime, Utc};

use crate::{
auth::{
authenticated::LegacyAuth, parse_http_auth_header, AuthError, Authenticated, Authorized,
Permission,
},
auth::{authenticated::LegacyAuth, AuthError, Authenticated, Authorized, Permission},
namespace::NamespaceName,
};

Expand All @@ -15,10 +12,27 @@ pub struct Jwt {
}

impl UserAuthStrategy for Jwt {
fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError> {
fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
tracing::trace!("executing jwt auth");
let param = parse_http_auth_header("bearer", &context.user_credential)?;
validate_jwt(&self.key, param)

let ctx = context?;

let UserAuthContext {
scheme: Some(scheme),
token: Some(token),
} = ctx
else {
return Err(AuthError::HttpAuthHeaderInvalid);
};

if !scheme.eq_ignore_ascii_case("bearer") {
return Err(AuthError::HttpAuthHeaderUnsupportedScheme);
}

return validate_jwt(&self.key, &token);
}
}

Expand Down Expand Up @@ -104,7 +118,6 @@ fn validate_jwt(
mod tests {
use std::time::Duration;

use axum::http::HeaderValue;
use jsonwebtoken::{DecodingKey, EncodingKey};
use ring::signature::{Ed25519KeyPair, KeyPair};
use serde::Serialize;
Expand Down Expand Up @@ -142,9 +155,7 @@ mod tests {
};
let token = encode(&token, &enc);

let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(),
};
let context = Ok(UserAuthContext::bearer(token.as_str()));

assert!(matches!(
strategy(dec).authenticate(context).unwrap(),
Expand All @@ -166,9 +177,7 @@ mod tests {
};
let token = encode(&token, &enc);

let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(),
};
let context = Ok(UserAuthContext::bearer(token.as_str()));

let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else {
panic!()
Expand All @@ -181,9 +190,7 @@ mod tests {
#[test]
fn errors_when_jwt_token_invalid() {
let (_enc, dec) = key_pair();
let context = UserAuthContext {
user_credential: HeaderValue::from_str("Bearer abc").ok(),
};
let context = Ok(UserAuthContext::bearer("abc"));

assert_eq!(
strategy(dec).authenticate(context).unwrap_err(),
Expand All @@ -203,9 +210,7 @@ mod tests {

let token = encode(&token, &enc);

let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(),
};
let context = Ok(UserAuthContext::bearer(token.as_str()));

assert_eq!(
strategy(dec).authenticate(context).unwrap_err(),
Expand All @@ -227,9 +232,7 @@ mod tests {

let token = encode(&token, &enc);

let context = UserAuthContext {
user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(),
};
let context = Ok(UserAuthContext::bearer(token.as_str()));

let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else {
panic!()
Expand Down
68 changes: 62 additions & 6 deletions libsql-server/src/auth/user_auth_strategies/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,73 @@ pub mod disabled;
pub mod http_basic;
pub mod jwt;

use axum::http::HeaderValue;
pub use disabled::*;
pub use http_basic::*;
pub use jwt::*;
pub use disabled::Disabled;
pub use http_basic::HttpBasic;
pub use jwt::Jwt;

use super::{AuthError, Authenticated};

#[derive(Debug)]
pub struct UserAuthContext {
pub user_credential: Option<HeaderValue>,
scheme: Option<String>,
token: Option<String>,
}

impl UserAuthContext {
pub fn scheme(&self) -> &Option<String> {
&self.scheme
}

pub fn token(&self) -> &Option<String> {
&self.token
}

pub fn empty() -> UserAuthContext {
UserAuthContext {
scheme: None,
token: None,
}
}

pub fn basic(creds: &str) -> UserAuthContext {
UserAuthContext {
scheme: Some("Basic".into()),
token: Some(creds.into()),
}
}

pub fn bearer(token: &str) -> UserAuthContext {
UserAuthContext {
scheme: Some("Bearer".into()),
token: Some(token.into()),
}
}

pub fn bearer_opt(token: Option<String>) -> UserAuthContext {
UserAuthContext {
scheme: Some("Bearer".into()),
token: token,
shopifyski marked this conversation as resolved.
Show resolved Hide resolved
}
}

pub fn new(scheme: &str, token: &str) -> UserAuthContext {
UserAuthContext {
scheme: Some(scheme.into()),
token: Some(token.into()),
}
}

pub fn from_auth_str(auth_string: &str) -> Result<Self, AuthError> {
let (scheme, token) = auth_string
.split_once(' ')
shopifyski marked this conversation as resolved.
Show resolved Hide resolved
.ok_or(AuthError::AuthStringMalformed)?;
Ok(UserAuthContext::new(scheme, token))
}
}

pub trait UserAuthStrategy: Sync + Send {
fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError>;
fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError>;
}
Loading
Loading