From d28b049d0287243d32a66da86fb98a26129fc39e Mon Sep 17 00:00:00 2001 From: Gard Rimestad Date: Fri, 15 Dec 2023 14:53:51 +0100 Subject: [PATCH] feat: token-header override (#360) --token-header Token header to use for both edge authorization and communication with the upstream server [env: TOKEN_HEADER=] [default: Authorization] This is useful in complex deployment scenarios where proxies are using the Authorization header. --- server/src/builder.rs | 1 + server/src/cli.rs | 25 +++++++++++ server/src/client_api.rs | 44 ++++++++++++++++++- server/src/http/feature_refresher.rs | 8 ++++ server/src/http/unleash_client.rs | 11 ++++- server/src/main.rs | 2 + .../client_token_from_frontend_token.rs | 3 ++ server/src/tokens.rs | 7 ++- 8 files changed, 98 insertions(+), 3 deletions(-) diff --git a/server/src/builder.rs b/server/src/builder.rs index f9f26b54..dd944c74 100644 --- a/server/src/builder.rs +++ b/server/src/builder.rs @@ -144,6 +144,7 @@ async fn build_edge(args: &EdgeArgs) -> EdgeResult { args.upstream_certificate_file.clone(), Duration::seconds(args.upstream_request_timeout), Duration::seconds(args.upstream_socket_timeout), + args.token_header.token_header.clone() ) }) .map(|c| c.with_custom_client_headers(args.custom_client_headers.clone())) diff --git a/server/src/cli.rs b/server/src/cli.rs index 34542a12..e20af4e5 100644 --- a/server/src/cli.rs +++ b/server/src/cli.rs @@ -165,6 +165,10 @@ pub struct EdgeArgs { /// A URL pointing to a running Redis instance. Edge will use this instance to persist feature and token data and read this back after restart. Mutually exclusive with the --backup-folder option #[clap(flatten)] pub redis: Option, + + /// Token header to use for both edge authorization and communication with the upstream server. + #[clap(long, env, global = true, default_value = "Authorization")] + pub token_header: TokenHeader, } pub fn string_to_header_tuple(s: &str) -> Result<(String, String), String> { @@ -207,6 +211,23 @@ pub struct HealthCheckArgs { pub ca_certificate_file: Option, } +#[derive(Args, Debug, Clone)] +pub struct TokenHeader { + /// Token header to use for edge authorization. + #[clap(long, env, global = true, default_value = "Authorization")] + pub token_header: String, + +} + +impl FromStr for TokenHeader { + type Err = clap::Error; + + fn from_str(s: &str) -> Result { + let token_header = s.to_owned(); + Ok(TokenHeader { token_header }) + } +} + #[derive(Args, Debug, Clone)] pub struct ReadyCheckArgs { /// Where the instance you want to health check is running @@ -259,6 +280,10 @@ pub struct CliArgs { /// Which log format should Edge use #[clap(short, long, env, global = true, value_enum, default_value_t = LogFormat::Plain)] pub log_format: LogFormat, + + /// token header to use for edge authorization. + #[clap(long, env, global = true, default_value = "Authorization")] + pub token_header: TokenHeader, } #[derive(Args, Debug, Clone)] diff --git a/server/src/client_api.rs b/server/src/client_api.rs index 7a4925cb..31cf9411 100644 --- a/server/src/client_api.rs +++ b/server/src/client_api.rs @@ -234,7 +234,7 @@ mod tests { use super::*; use crate::auth::token_validator::TokenValidator; - use crate::cli::OfflineArgs; + use crate::cli::{OfflineArgs, TokenHeader}; use crate::http::unleash_client::UnleashClient; use crate::middleware; use crate::tests::{features_from_disk, upstream_server}; @@ -996,4 +996,46 @@ mod tests { let res = test::call_service(&local_app, request).await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } + + #[tokio::test] + async fn client_features_endpoint_works_with_overridden_token_header() { + let features_cache: Arc> = Arc::new(DashMap::default()); + let token_cache: Arc> = Arc::new(DashMap::default()); + let token_header = TokenHeader::from_str("NeedsToBeTested").unwrap(); + println!("token_header: {:?}", token_header); + let app = test::init_service( + App::new() + .app_data(Data::from(features_cache.clone())) + .app_data(Data::from(token_cache.clone())) + .app_data(Data::new(token_header.clone())) + .service(web::scope("/api/client").service(get_features)), + ) + .await; + let client_features = cached_client_features(); + let example_features = features_from_disk("../examples/features.json"); + features_cache.insert("development".into(), client_features.clone()); + features_cache.insert("production".into(), example_features.clone()); + let mut production_token = EdgeToken::try_from( + "*:production.03fa5f506428fe80ed5640c351c7232e38940814d2923b08f5c05fa7".to_string(), + ) + .unwrap(); + production_token.token_type = Some(TokenType::Client); + production_token.status = TokenValidationStatus::Validated; + token_cache.insert(production_token.token.clone(), production_token.clone()); + + let request = test::TestRequest::get() + .uri("/api/client/features") + .insert_header(ContentType::json()) + .insert_header(("NeedsToBeTested", production_token.token.clone())) + .to_request(); + let res = test::call_service(&app, request).await; + assert_eq!(res.status(), StatusCode::OK); + let request = test::TestRequest::get() + .uri("/api/client/features") + .insert_header(ContentType::json()) + .insert_header(("ShouldNotWork", production_token.token.clone())) + .to_request(); + let res = test::call_service(&app, request).await; + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } } diff --git a/server/src/http/feature_refresher.rs b/server/src/http/feature_refresher.rs index 5ff3cb83..23b57247 100644 --- a/server/src/http/feature_refresher.rs +++ b/server/src/http/feature_refresher.rs @@ -481,6 +481,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -512,6 +513,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -547,6 +549,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -589,6 +592,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -640,6 +644,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -695,6 +700,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -735,6 +741,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); @@ -770,6 +777,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let features_cache = Arc::new(DashMap::default()); let engines_cache = Arc::new(DashMap::default()); diff --git a/server/src/http/unleash_client.rs b/server/src/http/unleash_client.rs index 0dd2dcf4..4121b72a 100644 --- a/server/src/http/unleash_client.rs +++ b/server/src/http/unleash_client.rs @@ -72,6 +72,7 @@ pub struct UnleashClient { backing_client: Client, service_account_token: Option, custom_headers: HashMap, + token_header: String, } fn load_pkcs12(id: &ClientIdentity) -> EdgeResult { @@ -176,6 +177,7 @@ impl UnleashClient { upstream_certificate_file: Option, connect_timeout: Duration, socket_timeout: Duration, + token_header: String, ) -> Self { Self { urls: UnleashUrls::from_base_url(server_url), @@ -190,6 +192,7 @@ impl UnleashClient { .unwrap(), custom_headers: Default::default(), service_account_token: Default::default(), + token_header, } } pub fn from_url_with_service_account_token( @@ -200,6 +203,7 @@ impl UnleashClient { service_account_token: String, connect_timeout: Duration, socket_timeout: Duration, + token_header: String ) -> Self { Self { urls: UnleashUrls::from_base_url(server_url), @@ -214,6 +218,7 @@ impl UnleashClient { .unwrap(), custom_headers: Default::default(), service_account_token: Some(service_account_token), + token_header, } } @@ -235,6 +240,7 @@ impl UnleashClient { .unwrap(), custom_headers: Default::default(), service_account_token: Default::default(), + token_header: "Authorization".to_string(), }) } @@ -255,6 +261,7 @@ impl UnleashClient { .unwrap(), custom_headers: Default::default(), service_account_token: Some(sa_token.into()), + token_header: "Authorization".to_string(), }) } @@ -275,6 +282,7 @@ impl UnleashClient { .unwrap(), custom_headers: Default::default(), service_account_token: Default::default(), + token_header: "Authorization".to_string(), }) } @@ -292,8 +300,9 @@ impl UnleashClient { fn header_map(&self, api_key: Option) -> HeaderMap { let mut header_map = HeaderMap::new(); + let token_header: HeaderName= HeaderName::from_str(self.token_header.as_str()).unwrap(); if let Some(key) = api_key { - header_map.insert(header::AUTHORIZATION, key.parse().unwrap()); + header_map.insert(token_header, key.parse().unwrap()); } for (header_name, header_value) in self.custom_headers.iter() { let key = HeaderName::from_str(header_name.as_str()).unwrap(); diff --git a/server/src/main.rs b/server/src/main.rs index 9a8dd27d..1d4d93c0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -45,6 +45,7 @@ async fn main() -> Result<(), anyhow::Error> { let schedule_args = args.clone(); let mode_arg = args.clone().mode; let http_args = args.clone().http; + let token_header = args.clone().token_header; let request_timeout = args.edge_request_timeout; let trust_proxy = args.clone().trust_proxy; let base_path = http_args.base_path.clone(); @@ -82,6 +83,7 @@ async fn main() -> Result<(), anyhow::Error> { .allow_any_method(); let mut app = App::new() .app_data(qs_config) + .app_data(web::Data::new(token_header.clone())) .app_data(web::Data::new(trust_proxy.clone())) .app_data(web::Data::new(mode_arg.clone())) .app_data(web::Data::new(connect_via.clone())) diff --git a/server/src/middleware/client_token_from_frontend_token.rs b/server/src/middleware/client_token_from_frontend_token.rs index 4b160be6..bbbfe6d5 100644 --- a/server/src/middleware/client_token_from_frontend_token.rs +++ b/server/src/middleware/client_token_from_frontend_token.rs @@ -147,6 +147,7 @@ mod tests { upstream_sa.token.to_string(), Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let arced_client = Arc::new(unleash_client); let local_features_cache: Arc> = @@ -211,6 +212,7 @@ mod tests { upstream_sa.to_string(), Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let arced_client = Arc::new(unleash_client); let local_features_cache: Arc> = @@ -273,6 +275,7 @@ mod tests { None, Duration::seconds(5), Duration::seconds(5), + "Authorization".to_string(), ); let local_features_cache: Arc> = Arc::new(DashMap::default()); diff --git a/server/src/tokens.rs b/server/src/tokens.rs index 2d87bc8c..8d71ef72 100644 --- a/server/src/tokens.rs +++ b/server/src/tokens.rs @@ -7,6 +7,7 @@ use std::collections::HashSet; use std::future::{ready, Ready}; use std::str::FromStr; +use crate::cli::TokenHeader; use crate::cli::EdgeMode; use crate::error::EdgeError; use crate::types::EdgeResult; @@ -111,7 +112,11 @@ impl FromRequest for EdgeToken { type Future = Ready>; fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { - let value = req.headers().get("Authorization"); + let token_header = match req.app_data::>() { + Some(data) => data.clone().into_inner().token_header.clone(), + None => "Authorization".to_string(), + }; + let value = req.headers().get(token_header); if let Some(data_mode) = req.app_data::>() { let mode = data_mode.clone().into_inner(); let key = match *mode {