diff --git a/src/cli/command/server.rs b/src/cli/command/server.rs index a7774b22..60e641e8 100644 --- a/src/cli/command/server.rs +++ b/src/cli/command/server.rs @@ -1,12 +1,19 @@ use std::{ default::Default, env, fs, + fs::File, + io::Read, path::Path, sync::{Arc, RwLock}, }; use actix_web::{middleware, web, App, HttpResponse, HttpServer}; +use anyhow::format_err; use clap::ArgMatches; +use openssl::{ + ssl::{SslAcceptor, SslFiletype, SslMethod, SslOptions, SslVerifyMode, SslVersion}, + x509::X509, +}; use sysexits::ExitCode; use crate::{ @@ -37,7 +44,7 @@ pub fn main(config_path: &str) -> Result<(), RvError> { let (_, storage) = config.storage.iter().next().unwrap(); let (_, listener) = config.listener.iter().next().unwrap(); - let addr = listener.address.clone(); + let listener = listener.clone(); let mut work_dir = WORK_DIR_PATH_DEFAULT.to_string(); if !config.work_dir.is_empty() { @@ -122,9 +129,56 @@ pub fn main(config_path: &str) -> Result<(), RvError> { }) .on_connect(http::request_on_connect_handler); - log::info!("start listen, addr: {}", addr); + log::info!( + "start listen, addr: {}, tls_disable: {}, tls_disable_client_certs: {}", + listener.address, + listener.tls_disable, + listener.tls_disable_client_certs + ); + + if listener.tls_disable { + http_server = http_server.bind(listener.address)?; + } else { + let cert_file: &Path = Path::new(&listener.tls_cert_file); + let key_file: &Path = Path::new(&listener.tls_key_file); + + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; + builder + .set_private_key_file(key_file, SslFiletype::PEM) + .map_err(|err| format_err!("unable to read proxy key {} - {}", key_file.display(), err))?; + builder + .set_certificate_chain_file(cert_file) + .map_err(|err| format_err!("unable to read proxy cert {} - {}", cert_file.display(), err))?; + builder.check_private_key()?; + + builder.set_min_proto_version(Some(listener.tls_min_version))?; + builder.set_max_proto_version(Some(listener.tls_max_version))?; + + log::info!("tls_cipher_suites: {}", listener.tls_cipher_suites); + builder.set_cipher_list(&listener.tls_cipher_suites)?; + + if listener.tls_max_version == SslVersion::TLS1_3 { + builder.clear_options(SslOptions::NO_TLSV1_3); + builder.set_ciphersuites("TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256")?; + } + + if listener.tls_require_and_verify_client_cert { + builder.set_verify_callback(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT, move |p, _x| { + return p; + }); + + if listener.tls_client_ca_file.len() > 0 { + let mut client_ca_file = File::open(&listener.tls_client_ca_file)?; + let mut client_ca_file_bytes = Vec::new(); + client_ca_file.read_to_end(&mut client_ca_file_bytes)?; + let client_ca_x509 = X509::from_pem(&client_ca_file_bytes)?; - http_server = http_server.bind(addr)?; + builder.add_client_ca(client_ca_x509.as_ref())?; + } + } + + http_server = http_server.bind_openssl(listener.address, builder)?; + } log::info!("rusty_vault server starts, waiting for request..."); diff --git a/src/cli/config.rs b/src/cli/config.rs index 1e0cd979..54f15ed8 100644 --- a/src/cli/config.rs +++ b/src/cli/config.rs @@ -1,6 +1,10 @@ -use std::{collections::HashMap, fs, path::Path}; +use std::{collections::HashMap, fmt, fs, path::Path}; -use serde::{Deserialize, Deserializer, Serialize}; +use openssl::ssl::SslVersion; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use serde_json::Value; use crate::errors::RvError; @@ -34,8 +38,32 @@ pub struct Listener { #[serde(default)] pub ltype: String, pub address: String, - #[serde(default, deserialize_with = "parse_bool_string")] + #[serde(default = "default_bool_true", deserialize_with = "parse_bool_string")] pub tls_disable: bool, + #[serde(default)] + pub tls_cert_file: String, + #[serde(default)] + pub tls_key_file: String, + #[serde(default)] + pub tls_client_ca_file: String, + #[serde(default, deserialize_with = "parse_bool_string")] + pub tls_disable_client_certs: bool, + #[serde(default, deserialize_with = "parse_bool_string")] + pub tls_require_and_verify_client_cert: bool, + #[serde( + default = "default_tls_min_version", + serialize_with = "serialize_tls_version", + deserialize_with = "deserialize_tls_version" + )] + pub tls_min_version: SslVersion, + #[serde( + default = "default_tls_max_version", + serialize_with = "serialize_tls_version", + deserialize_with = "deserialize_tls_version" + )] + pub tls_max_version: SslVersion, + #[serde(default = "default_tls_cipher_suites")] + pub tls_cipher_suites: String, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -46,6 +74,10 @@ pub struct Storage { pub config: HashMap, } +fn default_bool_true() -> bool { + true +} + fn parse_bool_string<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -62,6 +94,61 @@ where } } +fn default_tls_min_version() -> SslVersion { + SslVersion::TLS1_2 +} + +fn default_tls_max_version() -> SslVersion { + SslVersion::TLS1_3 +} + +fn default_tls_cipher_suites() -> String { + "HIGH:!PSK:!SRP:!3DES".to_string() +} + +fn serialize_tls_version(version: &SslVersion, serializer: S) -> Result +where + S: Serializer, +{ + match *version { + SslVersion::TLS1 => serializer.serialize_str("tls10"), + SslVersion::TLS1_1 => serializer.serialize_str("tls11"), + SslVersion::TLS1_2 => serializer.serialize_str("tls12"), + SslVersion::TLS1_3 => serializer.serialize_str("tls13"), + _ => unreachable!("unexpected SSL/TLS version: {:?}", version), + } +} + +fn deserialize_tls_version<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + struct TlsVersionVisitor; + + impl<'de> Visitor<'de> for TlsVersionVisitor { + type Value = SslVersion; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string representing an SSL version") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "tls10" => Ok(SslVersion::TLS1), + "tls11" => Ok(SslVersion::TLS1_1), + "tls12" => Ok(SslVersion::TLS1_2), + "tls13" => Ok(SslVersion::TLS1_3), + _ => Err(E::custom(format!("unexpected SSL/TLS version: {}", value))), + } + } + } + + deserializer.deserialize_str(TlsVersionVisitor) +} + fn validate_storage<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, @@ -81,15 +168,29 @@ fn validate_listener<'de, D>(deserializer: D) -> Result, { - let listener: HashMap = Deserialize::deserialize(deserializer)?; + let listeners: HashMap = Deserialize::deserialize(deserializer)?; - for key in listener.keys() { + for (key, listener) in &listeners { if key != "tcp" { return Err(serde::de::Error::custom("Invalid listener key")); } + + if !listener.tls_disable && (listener.tls_cert_file.len() == 0 || listener.tls_key_file.len() == 0) { + return Err(serde::de::Error::custom( + "when tls_disable is false, tls_cert_file and tls_key_file must be configured", + )); + } + + if !listener.tls_disable { + if listener.tls_require_and_verify_client_cert && listener.tls_disable_client_certs { + return Err(serde::de::Error::custom( + "'tls_disable_client_certs' and 'tls_require_and_verify_client_cert' are mutually exclusive", + )); + } + } } - Ok(listener) + Ok(listeners) } impl Config { @@ -169,10 +270,12 @@ fn load_config_file(path: &str) -> Result { if path.ends_with(".hcl") { let mut config: Config = hcl::from_reader(file)?; set_config_type_field(&mut config)?; + check_config(&config)?; Ok(config) } else if path.ends_with(".json") { let mut config: Config = serde_json::from_reader(file)?; set_config_type_field(&mut config)?; + check_config(&config)?; Ok(config) } else { return Err(RvError::ErrConfigPathInvalid); @@ -185,6 +288,18 @@ fn set_config_type_field(config: &mut Config) -> Result<(), RvError> { Ok(()) } +fn check_config(config: &Config) -> Result<(), RvError> { + if config.storage.len() != 1 { + return Err(RvError::ErrConfigStorageNotFound); + } + + if config.listener.len() != 1 { + return Err(RvError::ErrConfigListenerNotFound); + } + + Ok(()) +} + #[cfg(test)] mod test { use std::{env, fs, io::prelude::*}; @@ -214,7 +329,7 @@ mod test { let file_path = dir.join("config.hcl"); let path = file_path.to_str().unwrap_or("config.hcl"); - let hcl_config = r#" + let hcl_config_str = r#" storage "file" { path = "./vault/data" } @@ -229,14 +344,14 @@ mod test { pid_file = "/tmp/rusty_vault.pid" "#; - assert!(write_file(path, hcl_config).is_ok()); + assert!(write_file(path, hcl_config_str).is_ok()); let config = load_config(path); assert!(config.is_ok()); let hcl_config = config.unwrap(); println!("hcl config: {:?}", hcl_config); - let json_config = r#"{ + let json_config_str = r#"{ "storage": { "file": { "path": "./vault/data" @@ -255,7 +370,7 @@ mod test { let file_path = dir.join("config.json"); let path = file_path.to_str().unwrap_or("config.json"); - assert!(write_file(path, json_config).is_ok()); + assert!(write_file(path, json_config_str).is_ok()); let config = load_config(path); assert!(config.is_ok()); @@ -270,6 +385,28 @@ mod test { assert!(json_config_value.is_ok()); let json_config_value: Value = json_config_value.unwrap(); assert_eq!(hcl_config_value, json_config_value); + + assert_eq!(json_config.listener.len(), 1); + assert_eq!(json_config.storage.len(), 1); + assert_eq!(json_config.api_addr.as_str(), "http://127.0.0.1:8200"); + assert_eq!(json_config.log_format.as_str(), "{date} {req.path}"); + assert_eq!(json_config.log_level.as_str(), "debug"); + assert_eq!(json_config.pid_file.as_str(), "/tmp/rusty_vault.pid"); + assert_eq!(json_config.work_dir.as_str(), ""); + assert_eq!(json_config.daemon, false); + assert_eq!(json_config.daemon_user.as_str(), ""); + assert_eq!(json_config.daemon_group.as_str(), ""); + + let (_, listener) = json_config.listener.iter().next().unwrap(); + assert!(listener.tls_disable); + assert_eq!(listener.ltype.as_str(), "tcp"); + assert_eq!(listener.address.as_str(), "127.0.0.1:8200"); + + let (_, storage) = json_config.storage.iter().next().unwrap(); + assert_eq!(storage.stype.as_str(), "file"); + assert_eq!(storage.config.len(), 1); + let (_, path) = storage.config.iter().next().unwrap(); + assert_eq!(path.as_str(), Some("./vault/data")); } #[test] @@ -283,14 +420,14 @@ mod test { let file_path = dir.join("config1.hcl"); let path = file_path.to_str().unwrap_or("config1.hcl"); - let hcl_config = r#" + let hcl_config_str = r#" storage "file" { path = "./vault/data" } listener "tcp" { address = "127.0.0.1:8200" - tls_disable = "false" + tls_disable = "true" } api_addr = "http://127.0.0.1:8200" @@ -299,12 +436,12 @@ mod test { pid_file = "/tmp/rusty_vault.pid" "#; - assert!(write_file(path, hcl_config).is_ok()); + assert!(write_file(path, hcl_config_str).is_ok()); let file_path = dir.join("config2.hcl"); let path = file_path.to_str().unwrap_or("config2.hcl"); - let hcl_config = r#" + let hcl_config_str = r#" storage "file" { address = "127.0.0.1:8899" } @@ -317,11 +454,84 @@ mod test { log_level = "info" "#; - assert!(write_file(path, hcl_config).is_ok()); + assert!(write_file(path, hcl_config_str).is_ok()); let config = load_config(dir.to_str().unwrap()); + println!("config: {:?}", config); + assert!(config.is_ok()); + let hcl_config = config.unwrap(); + println!("hcl config: {:?}", hcl_config); + + let (_, listener) = hcl_config.listener.iter().next().unwrap(); + assert!(listener.tls_disable); + } + + #[test] + fn test_load_config_tls() { + let dir = env::temp_dir().join("rusty_vault_tls_config_test"); + assert!(fs::create_dir(&dir).is_ok()); + defer! ( + assert!(fs::remove_dir_all(&dir).is_ok()); + ); + + let file_path = dir.join("config.hcl"); + let path = file_path.to_str().unwrap_or("config.hcl"); + + let hcl_config_str = r#" + storage "file" { + path = "./vault/data" + } + + listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = false + tls_cert_file = "./cert/test.crt" + tls_key_file = "./cert/test.key" + tls_client_ca_file = "./cert/ca.pem" + tls_min_version = "tls12" + tls_max_version = "tls13" + } + + api_addr = "http://127.0.0.1:8200" + log_level = "debug" + log_format = "{date} {req.path}" + pid_file = "/tmp/rusty_vault.pid" + "#; + + assert!(write_file(path, hcl_config_str).is_ok()); + + let config = load_config(path); assert!(config.is_ok()); let hcl_config = config.unwrap(); println!("hcl config: {:?}", hcl_config); + + assert_eq!(hcl_config.listener.len(), 1); + assert_eq!(hcl_config.storage.len(), 1); + assert_eq!(hcl_config.api_addr.as_str(), "http://127.0.0.1:8200"); + assert_eq!(hcl_config.log_format.as_str(), "{date} {req.path}"); + assert_eq!(hcl_config.log_level.as_str(), "debug"); + assert_eq!(hcl_config.pid_file.as_str(), "/tmp/rusty_vault.pid"); + assert_eq!(hcl_config.work_dir.as_str(), ""); + assert_eq!(hcl_config.daemon, false); + assert_eq!(hcl_config.daemon_user.as_str(), ""); + assert_eq!(hcl_config.daemon_group.as_str(), ""); + + let (_, listener) = hcl_config.listener.iter().next().unwrap(); + assert_eq!(listener.ltype.as_str(), "tcp"); + assert_eq!(listener.address.as_str(), "127.0.0.1:8200"); + assert_eq!(listener.tls_disable, false); + assert_eq!(listener.tls_cert_file.as_str(), "./cert/test.crt"); + assert_eq!(listener.tls_key_file.as_str(), "./cert/test.key"); + assert_eq!(listener.tls_client_ca_file.as_str(), "./cert/ca.pem"); + assert_eq!(listener.tls_disable_client_certs, false); + assert_eq!(listener.tls_require_and_verify_client_cert, false); + assert_eq!(listener.tls_min_version, SslVersion::TLS1_2); + assert_eq!(listener.tls_max_version, SslVersion::TLS1_3); + + let (_, storage) = hcl_config.storage.iter().next().unwrap(); + assert_eq!(storage.stype.as_str(), "file"); + assert_eq!(storage.config.len(), 1); + let (_, path) = storage.config.iter().next().unwrap(); + assert_eq!(path.as_str(), Some("./vault/data")); } } diff --git a/src/errors.rs b/src/errors.rs index e9f29a9c..446e0ffc 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -141,6 +141,8 @@ pub enum RvError { ErrPkiRoleNotFound, #[error("PKI internal error.")] ErrPkiInternal, + #[error("Credentail is invalid.")] + ErrCredentailInvalid, #[error("Some IO error happened, {:?}", .source)] IO { #[from] @@ -308,6 +310,7 @@ impl PartialEq for RvError { | (RvError::ErrPkiCertNotFound, RvError::ErrPkiCertNotFound) | (RvError::ErrPkiRoleNotFound, RvError::ErrPkiRoleNotFound) | (RvError::ErrPkiInternal, RvError::ErrPkiInternal) + | (RvError::ErrCredentailInvalid, RvError::ErrCredentailInvalid) | (RvError::ErrUnknown, RvError::ErrUnknown) => true, _ => false, }