From d2c94e76aa21cc8b4714ab45a9e7771d05cda0eb Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Tue, 9 Jul 2024 10:23:15 +0800 Subject: [PATCH] Add support for X-Vault-Token request header and bug fixes. --- src/errors.rs | 6 ++++++ src/http/mod.rs | 40 +++++++++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 764eeba..a5d4483 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -240,6 +240,12 @@ pub enum RvError { source: ipnetwork::IpNetworkError, }, + #[error("Some actix_web http header error happened, {:?}", .source)] + ActixWebHttpHeaderError { + #[from] + source: actix_web::http::header::ToStrError, + }, + /// Database Errors Begin /// #[error("Database type is not support now. Please try postgressql or mysql again.")] diff --git a/src/http/mod.rs b/src/http/mod.rs index b8c23da..417e72d 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -11,7 +11,7 @@ use std::{ use actix_tls::accept::openssl::TlsStream; use actix_web::{ cookie::Cookie, dev::Extensions, http::StatusCode, rt::net::TcpStream, web, HttpRequest, HttpResponse, - ResponseError, + http::header, ResponseError, }; use openssl::x509::{X509Ref, X509VerifyResult, X509}; use serde::Serialize; @@ -23,6 +23,8 @@ pub mod logical; pub mod sys; pub const AUTH_COOKIE_NAME: &str = "token"; +pub const AUTH_HEADER_NAME: &str = "X-RustyVault-Token"; +pub const VAULT_AUTH_HEADER_NAME: &str = "X-Vault-Token"; #[derive(Debug, Clone)] pub struct TlsClientInfo { @@ -60,6 +62,11 @@ pub fn request_on_connect_handler(conn: &dyn Any, ext: &mut Extensions) { let socket = tls_stream.get_ref(); let mut cert_chain = None; + let peer_addr = socket.peer_addr(); + if peer_addr.is_err() { + return; + } + if let Some(cert_stack) = tls_stream.ssl().verified_chain() { let certs: Vec = cert_stack.iter().map(X509Ref::to_owned).collect(); cert_chain = Some(certs); @@ -67,7 +74,7 @@ pub fn request_on_connect_handler(conn: &dyn Any, ext: &mut Extensions) { ext.insert(Connection { bind: socket.local_addr().unwrap(), - peer: socket.peer_addr().unwrap(), + peer: peer_addr.unwrap(), ttl: socket.ttl().ok(), tls: Some(TlsClientInfo { client_cert_chain: cert_chain, @@ -75,9 +82,14 @@ pub fn request_on_connect_handler(conn: &dyn Any, ext: &mut Extensions) { }), }); } else if let Some(socket) = conn.downcast_ref::() { + let peer_addr = socket.peer_addr(); + if peer_addr.is_err() { + return; + } + ext.insert(Connection { bind: socket.local_addr().unwrap(), - peer: socket.peer_addr().unwrap(), + peer: peer_addr.unwrap(), ttl: socket.ttl().ok(), tls: None, }); @@ -108,10 +120,28 @@ impl ResponseError for RvError { } } +pub fn get_token_from_req(req: &HttpRequest) -> Result { + if let Some(token) = req.headers().get(AUTH_HEADER_NAME) { + return Ok(token.to_str()?.to_string()); + } else if let Some(vault_token) = req.headers().get(VAULT_AUTH_HEADER_NAME) { + return Ok(vault_token.to_str()?.to_string()); + } else if let Some(auth) = req.headers().get(header::AUTHORIZATION) { + if let Ok(auth_str) = auth.to_str(){ + if auth_str.starts_with("Bearer ") { + return Ok(auth_str.trim_start_matches("Bearer ").to_string()); + } + } + } else if let Some(cookie_token) = req.cookie(AUTH_COOKIE_NAME) { + return Ok(cookie_token.value().to_string()); + } + + Err(RvError::ErrResponse("missing client token".to_string())) +} + pub fn request_auth(req: &HttpRequest) -> Request { let mut r = Request::default(); - if let Some(token) = req.cookie(AUTH_COOKIE_NAME) { - r.client_token = token.value().to_string(); + if let Ok(token) = get_token_from_req(req) { + r.client_token = token; } r }