Skip to content

Commit

Permalink
feat: token-header override (#360)
Browse files Browse the repository at this point in the history
--token-header <TOKEN_HEADER>
          Token header to use for both edge authorization and communication with the upstream server [env: TOKEN_HEADER=] [default: Authorization]

This is useful in complex deployment scenarios where proxies are using the Authorization header.
  • Loading branch information
gardleopard authored Dec 15, 2023
1 parent 85e1867 commit d28b049
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 3 deletions.
1 change: 1 addition & 0 deletions server/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ async fn build_edge(args: &EdgeArgs) -> EdgeResult<EdgeInfo> {
args.upstream_certificate_file.clone(),
Duration::seconds(args.upstream_request_timeout),
Duration::seconds(args.upstream_socket_timeout),
args.token_header.token_header.clone()
)
})
.map(|c| c.with_custom_client_headers(args.custom_client_headers.clone()))
Expand Down
25 changes: 25 additions & 0 deletions server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ pub struct EdgeArgs {
/// A URL pointing to a running Redis instance. Edge will use this instance to persist feature and token data and read this back after restart. Mutually exclusive with the --backup-folder option
#[clap(flatten)]
pub redis: Option<RedisArgs>,

/// Token header to use for both edge authorization and communication with the upstream server.
#[clap(long, env, global = true, default_value = "Authorization")]
pub token_header: TokenHeader,
}

pub fn string_to_header_tuple(s: &str) -> Result<(String, String), String> {
Expand Down Expand Up @@ -207,6 +211,23 @@ pub struct HealthCheckArgs {
pub ca_certificate_file: Option<PathBuf>,
}

#[derive(Args, Debug, Clone)]
pub struct TokenHeader {
/// Token header to use for edge authorization.
#[clap(long, env, global = true, default_value = "Authorization")]
pub token_header: String,

}

impl FromStr for TokenHeader {
type Err = clap::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let token_header = s.to_owned();
Ok(TokenHeader { token_header })
}
}

#[derive(Args, Debug, Clone)]
pub struct ReadyCheckArgs {
/// Where the instance you want to health check is running
Expand Down Expand Up @@ -259,6 +280,10 @@ pub struct CliArgs {
/// Which log format should Edge use
#[clap(short, long, env, global = true, value_enum, default_value_t = LogFormat::Plain)]
pub log_format: LogFormat,

/// token header to use for edge authorization.
#[clap(long, env, global = true, default_value = "Authorization")]
pub token_header: TokenHeader,
}

#[derive(Args, Debug, Clone)]
Expand Down
44 changes: 43 additions & 1 deletion server/src/client_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ mod tests {
use super::*;

use crate::auth::token_validator::TokenValidator;
use crate::cli::OfflineArgs;
use crate::cli::{OfflineArgs, TokenHeader};
use crate::http::unleash_client::UnleashClient;
use crate::middleware;
use crate::tests::{features_from_disk, upstream_server};
Expand Down Expand Up @@ -996,4 +996,46 @@ mod tests {
let res = test::call_service(&local_app, request).await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}

#[tokio::test]
async fn client_features_endpoint_works_with_overridden_token_header() {
let features_cache: Arc<DashMap<String, ClientFeatures>> = Arc::new(DashMap::default());
let token_cache: Arc<DashMap<String, EdgeToken>> = Arc::new(DashMap::default());
let token_header = TokenHeader::from_str("NeedsToBeTested").unwrap();
println!("token_header: {:?}", token_header);
let app = test::init_service(
App::new()
.app_data(Data::from(features_cache.clone()))
.app_data(Data::from(token_cache.clone()))
.app_data(Data::new(token_header.clone()))
.service(web::scope("/api/client").service(get_features)),
)
.await;
let client_features = cached_client_features();
let example_features = features_from_disk("../examples/features.json");
features_cache.insert("development".into(), client_features.clone());
features_cache.insert("production".into(), example_features.clone());
let mut production_token = EdgeToken::try_from(
"*:production.03fa5f506428fe80ed5640c351c7232e38940814d2923b08f5c05fa7".to_string(),
)
.unwrap();
production_token.token_type = Some(TokenType::Client);
production_token.status = TokenValidationStatus::Validated;
token_cache.insert(production_token.token.clone(), production_token.clone());

let request = test::TestRequest::get()
.uri("/api/client/features")
.insert_header(ContentType::json())
.insert_header(("NeedsToBeTested", production_token.token.clone()))
.to_request();
let res = test::call_service(&app, request).await;
assert_eq!(res.status(), StatusCode::OK);
let request = test::TestRequest::get()
.uri("/api/client/features")
.insert_header(ContentType::json())
.insert_header(("ShouldNotWork", production_token.token.clone()))
.to_request();
let res = test::call_service(&app, request).await;
assert_eq!(res.status(), StatusCode::FORBIDDEN);
}
}
8 changes: 8 additions & 0 deletions server/src/http/feature_refresher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -512,6 +513,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -547,6 +549,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -589,6 +592,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -640,6 +644,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -695,6 +700,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -735,6 +741,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down Expand Up @@ -770,6 +777,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let features_cache = Arc::new(DashMap::default());
let engines_cache = Arc::new(DashMap::default());
Expand Down
11 changes: 10 additions & 1 deletion server/src/http/unleash_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct UnleashClient {
backing_client: Client,
service_account_token: Option<String>,
custom_headers: HashMap<String, String>,
token_header: String,
}

fn load_pkcs12(id: &ClientIdentity) -> EdgeResult<Identity> {
Expand Down Expand Up @@ -176,6 +177,7 @@ impl UnleashClient {
upstream_certificate_file: Option<PathBuf>,
connect_timeout: Duration,
socket_timeout: Duration,
token_header: String,
) -> Self {
Self {
urls: UnleashUrls::from_base_url(server_url),
Expand All @@ -190,6 +192,7 @@ impl UnleashClient {
.unwrap(),
custom_headers: Default::default(),
service_account_token: Default::default(),
token_header,
}
}
pub fn from_url_with_service_account_token(
Expand All @@ -200,6 +203,7 @@ impl UnleashClient {
service_account_token: String,
connect_timeout: Duration,
socket_timeout: Duration,
token_header: String
) -> Self {
Self {
urls: UnleashUrls::from_base_url(server_url),
Expand All @@ -214,6 +218,7 @@ impl UnleashClient {
.unwrap(),
custom_headers: Default::default(),
service_account_token: Some(service_account_token),
token_header,
}
}

Expand All @@ -235,6 +240,7 @@ impl UnleashClient {
.unwrap(),
custom_headers: Default::default(),
service_account_token: Default::default(),
token_header: "Authorization".to_string(),
})
}

Expand All @@ -255,6 +261,7 @@ impl UnleashClient {
.unwrap(),
custom_headers: Default::default(),
service_account_token: Some(sa_token.into()),
token_header: "Authorization".to_string(),
})
}

Expand All @@ -275,6 +282,7 @@ impl UnleashClient {
.unwrap(),
custom_headers: Default::default(),
service_account_token: Default::default(),
token_header: "Authorization".to_string(),
})
}

Expand All @@ -292,8 +300,9 @@ impl UnleashClient {

fn header_map(&self, api_key: Option<String>) -> HeaderMap {
let mut header_map = HeaderMap::new();
let token_header: HeaderName= HeaderName::from_str(self.token_header.as_str()).unwrap();
if let Some(key) = api_key {
header_map.insert(header::AUTHORIZATION, key.parse().unwrap());
header_map.insert(token_header, key.parse().unwrap());
}
for (header_name, header_value) in self.custom_headers.iter() {
let key = HeaderName::from_str(header_name.as_str()).unwrap();
Expand Down
2 changes: 2 additions & 0 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async fn main() -> Result<(), anyhow::Error> {
let schedule_args = args.clone();
let mode_arg = args.clone().mode;
let http_args = args.clone().http;
let token_header = args.clone().token_header;
let request_timeout = args.edge_request_timeout;
let trust_proxy = args.clone().trust_proxy;
let base_path = http_args.base_path.clone();
Expand Down Expand Up @@ -82,6 +83,7 @@ async fn main() -> Result<(), anyhow::Error> {
.allow_any_method();
let mut app = App::new()
.app_data(qs_config)
.app_data(web::Data::new(token_header.clone()))
.app_data(web::Data::new(trust_proxy.clone()))
.app_data(web::Data::new(mode_arg.clone()))
.app_data(web::Data::new(connect_via.clone()))
Expand Down
3 changes: 3 additions & 0 deletions server/src/middleware/client_token_from_frontend_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ mod tests {
upstream_sa.token.to_string(),
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let arced_client = Arc::new(unleash_client);
let local_features_cache: Arc<DashMap<String, ClientFeatures>> =
Expand Down Expand Up @@ -211,6 +212,7 @@ mod tests {
upstream_sa.to_string(),
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let arced_client = Arc::new(unleash_client);
let local_features_cache: Arc<DashMap<String, ClientFeatures>> =
Expand Down Expand Up @@ -273,6 +275,7 @@ mod tests {
None,
Duration::seconds(5),
Duration::seconds(5),
"Authorization".to_string(),
);
let local_features_cache: Arc<DashMap<String, ClientFeatures>> =
Arc::new(DashMap::default());
Expand Down
7 changes: 6 additions & 1 deletion server/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::collections::HashSet;
use std::future::{ready, Ready};
use std::str::FromStr;

use crate::cli::TokenHeader;
use crate::cli::EdgeMode;
use crate::error::EdgeError;
use crate::types::EdgeResult;
Expand Down Expand Up @@ -111,7 +112,11 @@ impl FromRequest for EdgeToken {
type Future = Ready<EdgeResult<Self>>;

fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
let value = req.headers().get("Authorization");
let token_header = match req.app_data::<Data<TokenHeader>>() {
Some(data) => data.clone().into_inner().token_header.clone(),
None => "Authorization".to_string(),
};
let value = req.headers().get(token_header);
if let Some(data_mode) = req.app_data::<Data<EdgeMode>>() {
let mode = data_mode.clone().into_inner();
let key = match *mode {
Expand Down

0 comments on commit d28b049

Please sign in to comment.