From 44bd1361eb4b1cdccaec83bcbb3cd160038b578c Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Fri, 5 Jul 2024 21:59:33 +0800 Subject: [PATCH] cargo fmt --- src/cli/command/server.rs | 8 +- src/core.rs | 8 +- src/http/logical.rs | 2 +- src/lib.rs | 4 +- src/logical/backend.rs | 16 +- src/logical/connection.rs | 5 +- src/logical/field.rs | 64 ++-- src/logical/mod.rs | 6 +- src/logical/path.rs | 12 +- src/logical/request.rs | 9 +- src/logical/response.rs | 19 +- src/modules/auth/token_store.rs | 3 +- src/modules/credential/userpass/path_login.rs | 3 +- src/modules/credential/userpass/path_users.rs | 3 +- src/modules/kv/mod.rs | 3 +- src/modules/pki/path_config_ca.rs | 3 +- src/modules/pki/path_config_crl.rs | 3 +- src/modules/pki/path_fetch.rs | 3 +- src/modules/pki/path_issue.rs | 3 +- src/modules/pki/path_keys.rs | 23 +- src/modules/pki/path_revoke.rs | 3 +- src/modules/pki/path_roles.rs | 78 +++-- src/modules/pki/path_root.rs | 3 +- src/modules/pki/util.rs | 12 +- src/modules/system/mod.rs | 3 +- src/shamir.rs | 5 +- src/storage/barrier.rs | 3 +- src/storage/barrier_aes_gcm.rs | 11 +- src/storage/mod.rs | 7 +- src/storage/mysql/mod.rs | 16 +- src/storage/mysql/mysql_backend.rs | 20 +- src/storage/physical/file.rs | 6 +- src/storage/physical/mock.rs | 6 +- src/utils/cidr.rs | 15 +- src/utils/crypto.rs | 6 + src/utils/ip_sock_addr.rs | 25 +- src/utils/key.rs | 42 +-- src/utils/locks.rs | 166 ++++++++++ src/utils/ocsp.rs | 30 ++ src/utils/policy.rs | 96 ++++++ src/utils/salt.rs | 37 +-- src/utils/sock_addr.rs | 30 +- src/utils/string.rs | 69 +++++ src/utils/token_util.rs | 283 ++++++++++++++++++ src/utils/unix_sock_addr.rs | 20 +- 45 files changed, 916 insertions(+), 276 deletions(-) create mode 100644 src/utils/crypto.rs create mode 100644 src/utils/locks.rs create mode 100644 src/utils/ocsp.rs create mode 100644 src/utils/policy.rs create mode 100644 src/utils/string.rs create mode 100644 src/utils/token_util.rs diff --git a/src/cli/command/server.rs b/src/cli/command/server.rs index d3020d1..bd52781 100644 --- a/src/cli/command/server.rs +++ b/src/cli/command/server.rs @@ -17,12 +17,8 @@ use openssl::{ use sysexits::ExitCode; use crate::{ - cli::config, - core::Core, - errors::RvError, - http, - storage, - EXIT_CODE_INSUFFICIENT_PARAMS, EXIT_CODE_LOAD_CONFIG_FAILURE, EXIT_CODE_OK, + cli::config, core::Core, errors::RvError, http, storage, EXIT_CODE_INSUFFICIENT_PARAMS, + EXIT_CODE_LOAD_CONFIG_FAILURE, EXIT_CODE_OK, }; pub const WORK_DIR_PATH_DEFAULT: &str = "/tmp/rusty_vault"; diff --git a/src/core.rs b/src/core.rs index 2a4df2b..b63933a 100644 --- a/src/core.rs +++ b/src/core.rs @@ -24,13 +24,7 @@ use crate::{ handler::Handler, logical::{Backend, Request, Response}, module_manager::ModuleManager, - modules::{ - auth::AuthModule, - credential::{ - userpass::UserPassModule, - }, - pki::PkiModule, - }, + modules::{auth::AuthModule, credential::userpass::UserPassModule, pki::PkiModule}, mount::MountTable, router::Router, shamir::{ShamirSecret, SHAMIR_OVERHEAD}, diff --git a/src/http/logical.rs b/src/http/logical.rs index e005974..bd7fc52 100644 --- a/src/http/logical.rs +++ b/src/http/logical.rs @@ -18,7 +18,7 @@ use crate::{ core::Core, errors::RvError, http::{request_auth, response_error, response_json_ok, response_ok, Connection}, - logical::{Operation, Connection as ReqConnection, Response}, + logical::{Connection as ReqConnection, Operation, Response}, }; #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index c7723dc..6a89c88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,11 +35,11 @@ pub mod module_manager; pub mod modules; pub mod mount; pub mod router; +#[cfg(feature = "storage_mysql")] +pub mod schema; pub mod shamir; pub mod storage; pub mod utils; -#[cfg(feature = "storage_mysql")] -pub mod schema; /// Exit ok pub const EXIT_CODE_OK: sysexits::ExitCode = sysexits::ExitCode::Ok; diff --git a/src/logical/backend.rs b/src/logical/backend.rs index e8809b5..72e67de 100644 --- a/src/logical/backend.rs +++ b/src/logical/backend.rs @@ -3,10 +3,8 @@ use std::{collections::HashMap, sync::Arc}; use regex::Regex; use serde_json::{Map, Value}; -use super::{path::Path, request::Request, response::Response, secret::Secret, FieldType, Backend, Operation}; -use crate::{ - context::Context, errors::RvError -}; +use super::{path::Path, request::Request, response::Response, secret::Secret, Backend, FieldType, Operation}; +use crate::{context::Context, errors::RvError}; type BackendOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result, RvError> + Send + Sync; @@ -262,9 +260,8 @@ mod test { use super::*; use crate::{ - logical::{Field, field::FieldTrait, FieldType, PathOperation}, - new_fields, new_fields_internal, new_path, new_path_internal, new_secret, new_secret_internal, - storage, + logical::{field::FieldTrait, Field, FieldType, PathOperation}, + new_fields, new_fields_internal, new_path, new_path_internal, new_secret, new_secret_internal, storage, }; struct MyTest; @@ -440,7 +437,10 @@ mod test { "mytype": 1, "mypath": "/pp", "mypassword": "123qwe", - }).as_object().unwrap().clone(); + }) + .as_object() + .unwrap() + .clone(); req.body = Some(body); req.storage = Some(Arc::new(barrier)); assert!(logical_backend.handle_request(&mut req).is_ok()); diff --git a/src/logical/connection.rs b/src/logical/connection.rs index 0941373..6779aeb 100644 --- a/src/logical/connection.rs +++ b/src/logical/connection.rs @@ -7,9 +7,6 @@ pub struct Connection { impl Default for Connection { fn default() -> Self { - Self { - peer_addr: String::new(), - peer_tls_cert: None, - } + Self { peer_addr: String::new(), peer_tls_cert: None } } } diff --git a/src/logical/field.rs b/src/logical/field.rs index 8fefeb8..2800b6c 100644 --- a/src/logical/field.rs +++ b/src/logical/field.rs @@ -1,10 +1,10 @@ -use std::{fmt, time::Duration, collections::HashMap}; +use std::{collections::HashMap, fmt, time::Duration}; use enum_map::Enum; +use humantime::parse_duration; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strum::{Display, EnumString}; -use humantime::parse_duration; use crate::errors::RvError; @@ -149,7 +149,7 @@ impl FieldTrait for Value { fn as_duration(&self) -> Option { if let Some(secs) = self.as_u64() { - return Some(Duration::from_secs(secs)) + return Some(Duration::from_secs(secs)); } if let Some(secs_str) = self.as_str() { @@ -225,12 +225,7 @@ impl FieldTrait for Value { impl Field { pub fn new() -> Self { - Self { - required: false, - field_type: FieldType::Str, - default: json!(null), - description: String::new(), - } + Self { required: false, field_type: FieldType::Str, default: json!(null), description: String::new() } } pub fn check_data_type(&self, data: &Value) -> bool { @@ -250,22 +245,22 @@ impl Field { match &self.field_type { FieldType::SecretStr | FieldType::Str => { return Ok(json!("")); - }, + } FieldType::Int => { return Ok(json!(0)); - }, + } FieldType::Bool => { return Ok(json!(false)); - }, + } FieldType::Array => { return Ok(json!([])); - }, + } FieldType::Map => { return Ok(serde_json::from_str("{}")?); - }, + } FieldType::DurationSecond => { return Ok(json!(0)); - }, + } FieldType::CommaStringSlice => { return Ok(json!([])); } @@ -279,21 +274,21 @@ impl Field { } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::Int => { if self.default.is_i64() { return Ok(self.default.clone()); } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::Bool => { if self.default.is_boolean() { return Ok(self.default.clone()); } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::Array => { if self.default.is_array() { return Ok(self.default.clone()); @@ -306,7 +301,7 @@ impl Field { } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::Map => { if self.default.is_object() { return Ok(self.default.clone()); @@ -319,14 +314,14 @@ impl Field { } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::DurationSecond => { if self.default.is_duration() { return Ok(self.default.clone()); } return Err(RvError::ErrRustDowncastFailed); - }, + } FieldType::CommaStringSlice => { if self.default.is_comma_string_slice() { return Ok(self.default.clone()); @@ -443,23 +438,38 @@ mod test { val = json!("aa, bb, cc ,dd"); assert!(val.is_comma_string_slice()); - assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string()])); + assert_eq!( + val.as_comma_string_slice(), + Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string()]) + ); - val = json!(["aaa", " bbb", "ccc " , " ddd"]); + val = json!(["aaa", " bbb", "ccc ", " ddd"]); assert!(val.is_comma_string_slice()); - assert_eq!(val.as_comma_string_slice(), Some(vec!["aaa".to_string(), "bbb".to_string(), "ccc".to_string(), "ddd".to_string()])); + assert_eq!( + val.as_comma_string_slice(), + Some(vec!["aaa".to_string(), "bbb".to_string(), "ccc".to_string(), "ddd".to_string()]) + ); val = json!([11, 22, 33, 44]); assert!(val.is_comma_string_slice()); - assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "22".to_string(), "33".to_string(), "44".to_string()])); + assert_eq!( + val.as_comma_string_slice(), + Some(vec!["11".to_string(), "22".to_string(), "33".to_string(), "44".to_string()]) + ); val = json!([11, "aa22", 33, 44]); assert!(val.is_comma_string_slice()); - assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "aa22".to_string(), "33".to_string(), "44".to_string()])); + assert_eq!( + val.as_comma_string_slice(), + Some(vec!["11".to_string(), "aa22".to_string(), "33".to_string(), "44".to_string()]) + ); val = json!("aa, bb, cc ,dd, , 88,"); assert!(val.is_comma_string_slice()); - assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string(), "88".to_string()])); + assert_eq!( + val.as_comma_string_slice(), + Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string(), "88".to_string()]) + ); let mut map: HashMap = HashMap::new(); map.insert("k1".to_string(), "v1".to_string()); diff --git a/src/logical/mod.rs b/src/logical/mod.rs index 20fd0dc..89527cc 100644 --- a/src/logical/mod.rs +++ b/src/logical/mod.rs @@ -17,9 +17,7 @@ use enum_map::Enum; use serde::{Deserialize, Serialize}; use strum::{Display, EnumString}; -use crate::{ - context::Context, errors::RvError, -}; +use crate::{context::Context, errors::RvError}; pub mod auth; pub mod backend; @@ -33,13 +31,13 @@ pub mod secret; pub use auth::Auth; pub use backend::{LogicalBackend, CTX_KEY_BACKEND_PATH}; +pub use connection::Connection; pub use field::{Field, FieldType}; pub use lease::Lease; pub use path::{Path, PathOperation}; pub use request::Request; pub use response::Response; pub use secret::{Secret, SecretData}; -pub use connection::Connection; #[derive(Eq, PartialEq, Copy, Clone, Debug, EnumString, Display, Enum, Serialize, Deserialize)] pub enum Operation { diff --git a/src/logical/path.rs b/src/logical/path.rs index 998f7f4..b3f61b2 100644 --- a/src/logical/path.rs +++ b/src/logical/path.rs @@ -1,9 +1,7 @@ use std::{collections::HashMap, fmt, sync::Arc}; use super::{request::Request, response::Response, Backend, Field, Operation}; -use crate::{ - context::Context, errors::RvError -}; +use crate::{context::Context, errors::RvError}; type PathOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result, RvError> + Send + Sync; @@ -30,7 +28,13 @@ impl fmt::Debug for PathOperation { impl Path { pub fn new(pattern: &str) -> Self { - Self { ctx: Arc::new(Context::new()), pattern: pattern.to_string(), fields: HashMap::new(), operations: Vec::new(), help: String::new() } + Self { + ctx: Arc::new(Context::new()), + pattern: pattern.to_string(), + fields: HashMap::new(), + operations: Vec::new(), + help: String::new(), + } } pub fn get_field(&self, key: &str) -> Option> { diff --git a/src/logical/request.rs b/src/logical/request.rs index 0dde765..af66778 100644 --- a/src/logical/request.rs +++ b/src/logical/request.rs @@ -151,14 +151,13 @@ impl Request { } pub fn get_data_as_str(&self, key: &str) -> Result { - self.get_data(key)? - .as_str() - .ok_or(RvError::ErrRequestFieldInvalid) - .and_then(|s| if s.trim().is_empty() { + self.get_data(key)?.as_str().ok_or(RvError::ErrRequestFieldInvalid).and_then(|s| { + if s.trim().is_empty() { Err(RvError::ErrResponse(format!("missing {}", key))) } else { Ok(s.trim().to_string()) - }) + } + }) } pub fn get_field_default_or_zero(&self, key: &str) -> Result { diff --git a/src/logical/response.rs b/src/logical/response.rs index d1af04a..e4c7f11 100644 --- a/src/logical/response.rs +++ b/src/logical/response.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; +use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; -use lazy_static::lazy_static; use crate::{ errors::RvError, - logical::{Auth, secret::SecretData}, + logical::{secret::SecretData, Auth}, }; lazy_static! { @@ -32,7 +32,15 @@ pub struct Response { impl Default for Response { fn default() -> Self { - Response { request_id: String::new(), headers: None, data: None, auth: None, secret: None, redirect: String::new(), warnings: Vec::new(), } + Response { + request_id: String::new(), + headers: None, + data: None, + auth: None, + secret: None, + redirect: String::new(), + warnings: Vec::new(), + } } } @@ -98,7 +106,10 @@ impl Response { let mut data: Map = json!({ HTTP_CONTENT_TYPE.to_string(): "application/json", HTTP_STATUS_CODE.to_string(): code, - }).as_object().unwrap().clone(); + }) + .as_object() + .unwrap() + .clone(); if let Some(response) = resp { let raw_body = serde_json::to_value(response).unwrap(); diff --git a/src/modules/auth/token_store.rs b/src/modules/auth/token_store.rs index 71c4d42..dfeeae4 100644 --- a/src/modules/auth/token_store.rs +++ b/src/modules/auth/token_store.rs @@ -11,8 +11,9 @@ use super::{ AUTH_ROUTER_PREFIX, }; use crate::{ + context::Context, core::Core, - context::Context, errors::RvError, + errors::RvError, handler::Handler, logical::{ Auth, Backend, Field, FieldType, Lease, LogicalBackend, Operation, Path, PathOperation, Request, Response, diff --git a/src/modules/credential/userpass/path_login.rs b/src/modules/credential/userpass/path_login.rs index b16dc32..651b811 100644 --- a/src/modules/credential/userpass/path_login.rs +++ b/src/modules/credential/userpass/path_login.rs @@ -2,7 +2,8 @@ use std::{collections::HashMap, sync::Arc}; use super::{UserPassBackend, UserPassBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Auth, Backend, Field, FieldType, Lease, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/credential/userpass/path_users.rs b/src/modules/credential/userpass/path_users.rs index b936e32..885a2fa 100644 --- a/src/modules/credential/userpass/path_users.rs +++ b/src/modules/credential/userpass/path_users.rs @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize}; use super::{UserPassBackend, UserPassBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/kv/mod.rs b/src/modules/kv/mod.rs index 316e17c..13a7be5 100644 --- a/src/modules/kv/mod.rs +++ b/src/modules/kv/mod.rs @@ -12,8 +12,9 @@ use humantime::parse_duration; use serde_json::{Map, Value}; use crate::{ + context::Context, core::Core, - context::Context, errors::RvError, + errors::RvError, logical::{ secret::Secret, Backend, Field, FieldType, LogicalBackend, Operation, Path, PathOperation, Request, Response, }, diff --git a/src/modules/pki/path_config_ca.rs b/src/modules/pki/path_config_ca.rs index 0e6fffa..cbd8a72 100644 --- a/src/modules/pki/path_config_ca.rs +++ b/src/modules/pki/path_config_ca.rs @@ -8,7 +8,8 @@ use pem; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_config_crl.rs b/src/modules/pki/path_config_crl.rs index 1112c0f..dbc1fd8 100644 --- a/src/modules/pki/path_config_crl.rs +++ b/src/modules/pki/path_config_crl.rs @@ -2,7 +2,8 @@ use std::{collections::HashMap, sync::Arc}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/pki/path_fetch.rs b/src/modules/pki/path_fetch.rs index 5312335..290c19a 100644 --- a/src/modules/pki/path_fetch.rs +++ b/src/modules/pki/path_fetch.rs @@ -5,7 +5,8 @@ use serde_json::json; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_issue.rs b/src/modules/pki/path_issue.rs index 136c453..3e38e9f 100644 --- a/src/modules/pki/path_issue.rs +++ b/src/modules/pki/path_issue.rs @@ -10,7 +10,8 @@ use serde_json::{json, Map, Value}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, utils, utils::cert, diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index 7ff650a..7fd2ad7 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -5,11 +5,12 @@ use serde_json::{json, Value}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, - utils::key::{KeyBundle, EncryptExtraData}, + utils::key::{EncryptExtraData, KeyBundle}, }; const PKI_CONFIG_KEY_PREFIX: &str = "config/key/"; @@ -292,11 +293,11 @@ impl PkiBackendInner { "rsa" => { let rsa = Rsa::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = rsa.size() * 8; - }, + } "ec" | "sm2" => { let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = ec_key.group().degree(); - }, + } _ => { return Err(RvError::ErrPkiKeyTypeInvalid); } @@ -307,7 +308,7 @@ impl PkiBackendInner { key_bundle.key = hex::decode(&hex_bundle)?; key_bundle.bits = (key_bundle.key.len() as u32) * 8; match key_bundle.bits { - 128 | 192 | 256 => {}, + 128 | 192 | 256 => {} _ => { return Err(RvError::ErrPkiKeyBitsInvalid); } @@ -353,7 +354,8 @@ impl PkiBackendInner { let data_value = req.get_data("data")?; let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_bundle = + self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.sign(&decoded_data)?; @@ -374,7 +376,8 @@ impl PkiBackendInner { let signature_value = req.get_data("signature")?; let signature = signature_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_bundle = + self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let decoded_signature = hex::decode(signature.as_bytes())?; @@ -396,7 +399,8 @@ impl PkiBackendInner { let aad_value = req.get_data_or_default("aad")?; let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_bundle = + self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; @@ -417,7 +421,8 @@ impl PkiBackendInner { let aad_value = req.get_data_or_default("aad")?; let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_bundle = + self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; diff --git a/src/modules/pki/path_revoke.rs b/src/modules/pki/path_revoke.rs index 5bbcc41..4961fd3 100644 --- a/src/modules/pki/path_revoke.rs +++ b/src/modules/pki/path_revoke.rs @@ -2,7 +2,8 @@ use std::{collections::HashMap, sync::Arc}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/pki/path_roles.rs b/src/modules/pki/path_roles.rs index da6f728..700e73f 100644 --- a/src/modules/pki/path_roles.rs +++ b/src/modules/pki/path_roles.rs @@ -3,9 +3,10 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use humantime::parse_duration; use serde::{Deserialize, Serialize}; -use super::{PkiBackend, PkiBackendInner, util::DEFAULT_MAX_TTL}; +use super::{util::DEFAULT_MAX_TTL, PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, @@ -103,13 +104,13 @@ max_ttl, whichever is shorter."# "max_ttl": { field_type: FieldType::Str, description: r#" -The maximum allowed lease duration. If not set, defaults to the system maximum lease TTL."# + The maximum allowed lease duration. If not set, defaults to the system maximum lease TTL."# }, "use_pss": { field_type: FieldType::Bool, default: false, description: r#" -Whether or not to use PSS signatures when using a RSA key-type issuer. Defaults to false."# + Whether or not to use PSS signatures when using a RSA key-type issuer. Defaults to false."# }, "allow_localhost": { field_type: FieldType::Bool, @@ -154,30 +155,30 @@ See the documentation for more information."# field_type: FieldType::Bool, default: true, description: r#" -If set, IP Subject Alternative Names are allowed. Any valid IP is accepted and No authorization checking is performed."# + If set, IP Subject Alternative Names are allowed. Any valid IP is accepted and No authorization checking is performed."# }, "server_flag": { field_type: FieldType::Bool, default: true, description: r#" -If set, certificates are flagged for server auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# + If set, certificates are flagged for server auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# }, "client_flag": { field_type: FieldType::Bool, default: true, description: r#" -If set, certificates are flagged for client auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# + If set, certificates are flagged for client auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# }, "code_signing_flag": { field_type: FieldType::Bool, description: r#" -If set, certificates are flagged for code signing use. defaults to false. See also RFC 5280 Section 4.2.1.12."# + If set, certificates are flagged for code signing use. defaults to false. See also RFC 5280 Section 4.2.1.12."# }, "key_type": { field_type: FieldType::Str, default: "rsa", description: r#" -The type of key to use; defaults to RSA. "rsa" "ec", "ed25519" and "any" are the only valid values."# + The type of key to use; defaults to RSA. "rsa" "ec", "ed25519" and "any" are the only valid values."# }, "key_bits": { field_type: FieldType::Int, @@ -212,43 +213,43 @@ The value format should be given in UTC format YYYY-MM-ddTHH:MM:SSZ."# required: false, field_type: FieldType::Str, description: r#" -If set, OU (OrganizationalUnit) will be set to this value in certificates issued by this role."# + If set, OU (OrganizationalUnit) will be set to this value in certificates issued by this role."# }, "organization": { required: false, field_type: FieldType::Str, description: r#" -If set, O (Organization) will be set to this value in certificates issued by this role."# + If set, O (Organization) will be set to this value in certificates issued by this role."# }, "country": { required: false, field_type: FieldType::Str, description: r#" -If set, Country will be set to this value in certificates issued by this role."# + If set, Country will be set to this value in certificates issued by this role."# }, "locality": { required: false, field_type: FieldType::Str, description: r#" -If set, Locality will be set to this value in certificates issued by this role."# + If set, Locality will be set to this value in certificates issued by this role."# }, "province": { required: false, field_type: FieldType::Str, description: r#" -If set, Province will be set to this value in certificates issued by this role."# + If set, Province will be set to this value in certificates issued by this role."# }, "street_address": { required: false, field_type: FieldType::Str, description: r#" -If set, Street Address will be set to this value."# + If set, Street Address will be set to this value."# }, "postal_code": { required: false, field_type: FieldType::Str, description: r#" -If set, Postal Code will be set to this value."# + If set, Postal Code will be set to this value."# }, "use_csr_common_name": { field_type: FieldType::Bool, @@ -364,27 +365,42 @@ impl PkiBackendInner { } } - let signature_bits = req.get_data_or_default("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_localhost = req.get_data_or_default("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_bare_domains = req.get_data_or_default("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_subdomains = req.get_data_or_default("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_any_name = req.get_data_or_default("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_ip_sans = req.get_data_or_default("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let signature_bits = + req.get_data_or_default("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_localhost = + req.get_data_or_default("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_bare_domains = + req.get_data_or_default("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_subdomains = + req.get_data_or_default("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_any_name = + req.get_data_or_default("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_ip_sans = + req.get_data_or_default("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; let server_flag = req.get_data_or_default("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; let client_flag = req.get_data_or_default("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; let use_csr_sans = req.get_data_or_default("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let use_csr_common_name = req.get_data_or_default("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_common_name = + req.get_data_or_default("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; let country = req.get_data_or_default("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let province = req.get_data_or_default("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let locality = req.get_data_or_default("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let organization = req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = + req.get_data_or_default("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = + req.get_data_or_default("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = + req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let ou = req.get_data_or_default("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let street_address = req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let postal_code = req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = + req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = + req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let no_store = req.get_data_or_default("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let generate_lease = req.get_data_or_default("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let not_after = req.get_data_or_default("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let not_before_duration_u64 = req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let generate_lease = + req.get_data_or_default("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_after = + req.get_data_or_default("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_before_duration_u64 = + req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let not_before_duration = Duration::from_secs(not_before_duration_u64); let role_entry = RoleEntry { diff --git a/src/modules/pki/path_root.rs b/src/modules/pki/path_root.rs index 56a636f..6757a1f 100644 --- a/src/modules/pki/path_root.rs +++ b/src/modules/pki/path_root.rs @@ -4,7 +4,8 @@ use serde_json::{json, Value}; use super::{field, util, PkiBackend, PkiBackendInner}; use crate::{ - context::Context, errors::RvError, + context::Context, + errors::RvError, logical::{Backend, Operation, Path, PathOperation, Request, Response}, new_path, new_path_internal, utils, }; diff --git a/src/modules/pki/util.rs b/src/modules/pki/util.rs index 05c56b9..c08aa42 100644 --- a/src/modules/pki/util.rs +++ b/src/modules/pki/util.rs @@ -16,7 +16,8 @@ pub fn get_role_params(req: &mut Request) -> Result { ttl = parse_duration(ttl_str)?; } } - let not_before_duration_u64 = req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_before_duration_u64 = + req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let not_before_duration = Duration::from_secs(not_before_duration_u64); let key_type_value = req.get_data_or_default("key_type")?; let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; @@ -50,10 +51,13 @@ pub fn get_role_params(req: &mut Request) -> Result { let country = req.get_data_or_default("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let province = req.get_data_or_default("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let locality = req.get_data_or_default("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let organization = req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = + req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let ou = req.get_data_or_default("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let street_address = req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let postal_code = req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = + req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = + req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let not_after = req.get_data_or_default("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let role_entry = RoleEntry { diff --git a/src/modules/system/mod.rs b/src/modules/system/mod.rs index 6c75d84..b35945a 100644 --- a/src/modules/system/mod.rs +++ b/src/modules/system/mod.rs @@ -11,8 +11,9 @@ use as_any::Downcast; use serde_json::{from_value, json, Map, Value}; use crate::{ + context::Context, core::Core, - context::Context, errors::RvError, + errors::RvError, logical::{Backend, Field, FieldType, LogicalBackend, Operation, Path, PathOperation, Request, Response}, modules::{auth::AuthModule, Module}, mount::MountEntry, diff --git a/src/shamir.rs b/src/shamir.rs index 348ee52..d35d3aa 100644 --- a/src/shamir.rs +++ b/src/shamir.rs @@ -27,11 +27,12 @@ //! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE //! SOFTWARE. +use std::ops::DerefMut; + use rand::{thread_rng, RngCore}; +use zeroize::Zeroizing; use crate::errors::RvError; -use zeroize::Zeroizing; -use std::ops::DerefMut; static GF256_EXP: [u8; 256] = [ 0x01, 0xe5, 0x4c, 0xb5, 0xfb, 0x9f, 0xfc, 0x12, 0x03, 0x34, 0xd4, 0xc4, 0x16, 0xba, 0x1f, 0x36, 0x05, 0x5c, 0x67, diff --git a/src/storage/barrier.rs b/src/storage/barrier.rs index 9e5db4f..82122f1 100644 --- a/src/storage/barrier.rs +++ b/src/storage/barrier.rs @@ -5,9 +5,10 @@ //! It usually means a different symmetric encryption algorithm is going to be supported, //! if a new barrier is under development. +use zeroize::Zeroizing; + use super::Storage; use crate::errors::RvError; -use zeroize::Zeroizing; pub const BARRIER_INIT_PATH: &str = "barrier/init"; diff --git a/src/storage/barrier_aes_gcm.rs b/src/storage/barrier_aes_gcm.rs index 2aede80..bd29809 100644 --- a/src/storage/barrier_aes_gcm.rs +++ b/src/storage/barrier_aes_gcm.rs @@ -1,8 +1,10 @@ //! This is the implementation of aes-gcm barrier, which uses aes-gcm block cipher to encrypt or //! decrypt data before writing or reading data to or from specific storage backend. -use std::sync::{Arc, RwLock}; -use std::ops::{Deref, DerefMut}; +use std::{ + ops::{Deref, DerefMut}, + sync::{Arc, RwLock}, +}; use openssl::{ cipher::{Cipher, CipherRef}, @@ -10,14 +12,13 @@ use openssl::{ }; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; +use zeroize::{Zeroize, Zeroizing}; use super::{ barrier::{SecurityBarrier, BARRIER_INIT_PATH}, - Backend, BackendEntry, - Storage, StorageEntry, + Backend, BackendEntry, Storage, StorageEntry, }; use crate::errors::RvError; -use zeroize::{Zeroize, Zeroizing}; const EPOCH_SIZE: usize = 4; const KEY_EPOCH: u8 = 1; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 7af8040..d815746 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -15,6 +15,7 @@ //! Different strage types are all as sub-module of this module. use std::{collections::HashMap, sync::Arc}; + use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -23,9 +24,9 @@ use crate::errors::RvError; pub mod barrier; pub mod barrier_aes_gcm; pub mod barrier_view; -pub mod physical; #[cfg(feature = "storage_mysql")] pub mod mysql; +pub mod physical; /// A trait that abstracts core methods for all storage barrier types. pub trait Storage: Send + Sync { @@ -78,12 +79,12 @@ pub fn new_backend(t: &str, conf: &HashMap) -> Result { let backend = physical::file::FileBackend::new(conf)?; Ok(Arc::new(backend)) - }, + } #[cfg(feature = "storage_mysql")] "mysql" => { let backend = mysql::mysql_backend::MysqlBackend::new(conf)?; Ok(Arc::new(backend)) - }, + } "mock" => Ok(Arc::new(physical::mock::MockBackend::new())), _ => Err(RvError::ErrPhysicalTypeInvalid), } diff --git a/src/storage/mysql/mod.rs b/src/storage/mysql/mod.rs index 348b102..0a431af 100644 --- a/src/storage/mysql/mod.rs +++ b/src/storage/mysql/mod.rs @@ -2,11 +2,13 @@ use std::collections::HashMap; -use diesel::mysql::MysqlConnection; -use diesel::r2d2::{self, ConnectionManager}; +use diesel::{ + mysql::MysqlConnection, + r2d2::{self, ConnectionManager}, +}; +use serde_json::Value; use crate::errors::RvError; -use serde_json::Value; type MysqlDbPool = r2d2::Pool>; @@ -15,8 +17,8 @@ pub mod mysql_backend; pub fn new(conf: &HashMap) -> Result { let pool = establish_mysql_connection(conf); match pool { - Ok(pool)=> Ok(pool), - Err(e)=> Err(e), + Ok(pool) => Ok(pool), + Err(e) => Err(e), } } @@ -47,7 +49,7 @@ fn establish_mysql_connection(conf: &HashMap) -> Result { log::error!("Error: {:?}", e); Err(RvError::ErrConnectionPoolCreate { source: (e) }) - }, + } } } @@ -65,7 +67,7 @@ mod test { conf.insert("password".to_string(), Value::String("password".to_string())); let pool = establish_mysql_connection(&conf); - + assert!(pool.is_ok()); } } diff --git a/src/storage/mysql/mysql_backend.rs b/src/storage/mysql/mysql_backend.rs index 13147dd..b7df815 100644 --- a/src/storage/mysql/mysql_backend.rs +++ b/src/storage/mysql/mysql_backend.rs @@ -3,22 +3,21 @@ use std::{ sync::{Arc, Mutex}, }; -use diesel::prelude::*; -use diesel::{r2d2::ConnectionManager, MysqlConnection}; +use diesel::{prelude::*, r2d2::ConnectionManager, MysqlConnection}; use r2d2::Pool; use serde::Deserialize; use serde_json::Value; -use crate::schema::vault; -use crate::schema::vault::dsl::*; +use super::new; use crate::{ errors::RvError, - schema::vault::vault_key, + schema::{ + vault, + vault::{dsl::*, vault_key}, + }, storage::{Backend, BackendEntry}, }; -use super::new; - pub struct MysqlBackend { pool: Arc>>>, } @@ -49,7 +48,7 @@ impl Backend for MysqlBackend { let key = key.trim_start_matches(prefix); match key.find('/') { Some(i) => { - let key = &key[0..i+1]; + let key = &key[0..i + 1]; if !keys.contains(&key.to_string()) { keys.push(key.to_string()); } @@ -127,13 +126,12 @@ impl MysqlBackend { #[cfg(test)] mod test { - use serde_json::Value; use std::collections::HashMap; - use crate::storage::test::test_backend; - use crate::storage::test::test_backend_list_prefix; + use serde_json::Value; use super::MysqlBackend; + use crate::storage::test::{test_backend, test_backend_list_prefix}; #[test] fn test_mysql_backend() { diff --git a/src/storage/physical/file.rs b/src/storage/physical/file.rs index f848a4f..1252d4b 100644 --- a/src/storage/physical/file.rs +++ b/src/storage/physical/file.rs @@ -8,8 +8,10 @@ use std::{ use serde_json::Value; -use crate::storage::{Backend, BackendEntry}; -use crate::errors::RvError; +use crate::{ + errors::RvError, + storage::{Backend, BackendEntry}, +}; #[derive(Debug)] pub struct FileBackend { diff --git a/src/storage/physical/mock.rs b/src/storage/physical/mock.rs index 3ce18aa..9a44e00 100644 --- a/src/storage/physical/mock.rs +++ b/src/storage/physical/mock.rs @@ -1,7 +1,9 @@ use std::default::Default; -use crate::storage::{Backend, BackendEntry}; -use crate::errors::RvError; +use crate::{ + errors::RvError, + storage::{Backend, BackendEntry}, +}; #[derive(Default)] pub struct MockBackend(()); diff --git a/src/utils/cidr.rs b/src/utils/cidr.rs index b3d83e5..104a408 100644 --- a/src/utils/cidr.rs +++ b/src/utils/cidr.rs @@ -2,17 +2,14 @@ //! use std::{ - str::FromStr, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, collections::HashSet, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr, }; use ipnetwork::IpNetwork; -use super::{ - sock_addr::{new_sock_addr, SockAddrType, SockAddr}, -}; - +use super::sock_addr::{new_sock_addr, SockAddr, SockAddrType}; use crate::errors::RvError; pub fn is_ip_addr(addr: &dyn SockAddr) -> bool { @@ -73,10 +70,8 @@ pub fn validate_cidr_string(cidr_list: &str, separator: &str) -> Result = cidr_list.split(separator) - .map(|cidr| cidr.trim()) - .filter(|cidr| !cidr.is_empty()) - .collect(); + let cidrs_set: HashSet<&str> = + cidr_list.split(separator).map(|cidr| cidr.trim()).filter(|cidr| !cidr.is_empty()).collect(); let cidrs: Vec<&str> = cidrs_set.into_iter().collect(); diff --git a/src/utils/crypto.rs b/src/utils/crypto.rs new file mode 100644 index 0000000..4662a10 --- /dev/null +++ b/src/utils/crypto.rs @@ -0,0 +1,6 @@ +use blake2b_simd::Params; + +pub fn blake2b256_hash(key: &str) -> Vec { + let hash = Params::new().hash_length(32).to_state().update(key.as_bytes()).finalize(); + hash.as_bytes().to_vec() +} diff --git a/src/utils/ip_sock_addr.rs b/src/utils/ip_sock_addr.rs index c5bbdf7..0a35b0a 100644 --- a/src/utils/ip_sock_addr.rs +++ b/src/utils/ip_sock_addr.rs @@ -1,20 +1,13 @@ //! This module is a Rust replica of //! -use std::{ - fmt, - str::FromStr, - net::SocketAddr, -}; +use std::{fmt, net::SocketAddr, str::FromStr}; use as_any::Downcast; use ipnetwork::IpNetwork; use serde::{Deserialize, Serialize}; -use super::{ - sock_addr::{SockAddr, SockAddrType}, -}; - +use super::sock_addr::{SockAddr, SockAddrType}; use crate::errors::RvError; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -26,15 +19,9 @@ pub struct IpSockAddr { impl IpSockAddr { pub fn new(s: &str) -> Result { if let Ok(sock_addr) = SocketAddr::from_str(s) { - return Ok(IpSockAddr { - addr: IpNetwork::from(sock_addr.ip()), - port: sock_addr.port(), - }); + return Ok(IpSockAddr { addr: IpNetwork::from(sock_addr.ip()), port: sock_addr.port() }); } else if let Ok(ip_addr) = IpNetwork::from_str(s) { - return Ok(IpSockAddr { - addr: ip_addr, - port: 0, - }); + return Ok(IpSockAddr { addr: ip_addr, port: 0 }); } return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); } @@ -85,9 +72,7 @@ impl fmt::Display for IpSockAddr { #[cfg(test)] mod test { - use super::{ - *, super::sock_addr::{SockAddrType}, - }; + use super::{super::sock_addr::SockAddrType, *}; #[test] fn test_ip_sock_addr() { diff --git a/src/utils/key.rs b/src/utils/key.rs index 8e20a8d..9a13662 100644 --- a/src/utils/key.rs +++ b/src/utils/key.rs @@ -48,7 +48,7 @@ fn key_bits_default(key_type: &str) -> u32 { "ec" | "sm2" => 256, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm" => 256, _ => 0, - } + }; } // TODO: this function needs to be refactored to use crypto adaptors. @@ -73,26 +73,20 @@ fn cipher_from_key_type_and_bits(key_type: &str, bits: u32) -> Result Self { - let bits = if key_bits == 0 { - key_bits_default(key_type) - } else { - key_bits - }; - Self { name: name.to_string(), key_type: key_type.to_string(), bits: bits, ..KeyBundle::default() } + let bits = if key_bits == 0 { key_bits_default(key_type) } else { key_bits }; + Self { name: name.to_string(), key_type: key_type.to_string(), bits, ..KeyBundle::default() } } pub fn generate(&mut self) -> Result<(), RvError> { let key_bits = self.bits; let priv_key = match self.key_type.as_str() { - "rsa" => { - match key_bits { - 2048 | 3072 | 4096 => { - let rsa_key = Rsa::generate(key_bits)?; - PKey::from_rsa(rsa_key)?.private_key_to_pem_pkcs8()? - }, - _ => return Err(RvError::ErrPkiKeyBitsInvalid), + "rsa" => match key_bits { + 2048 | 3072 | 4096 => { + let rsa_key = Rsa::generate(key_bits)?; + PKey::from_rsa(rsa_key)?.private_key_to_pem_pkcs8()? } - } + _ => return Err(RvError::ErrPkiKeyBitsInvalid), + }, "ec" => { let curve_name = match key_bits { 224 => Nid::SECP224R1, @@ -111,7 +105,7 @@ impl KeyBundle { let ec_group = EcGroup::from_curve_name(Nid::SM2)?; let ec_key = EcKey::generate(&ec_group)?; PKey::from_ec_key(ec_key)?.private_key_to_pem_pkcs8()? - }, + } "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm" => { let _ = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; @@ -135,7 +129,7 @@ impl KeyBundle { let mut key = vec![0u8; key_bits as usize / 8]; rand_bytes(&mut key)?; key - }, + } _ => { return Err(RvError::ErrPkiKeyTypeInvalid); } @@ -193,14 +187,7 @@ impl KeyBundle { _ => "".as_bytes(), }); let mut tag = vec![0u8; 16]; - let mut ciphertext = encrypt_aead( - cipher, - &self.key, - Some(&self.iv), - aad, - data, - &mut tag, - )?; + let mut ciphertext = encrypt_aead(cipher, &self.key, Some(&self.iv), aad, data, &mut tag)?; ciphertext.extend_from_slice(&tag); Ok(ciphertext) } @@ -234,7 +221,6 @@ impl KeyBundle { } pub fn decrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { - match self.key_type.as_str() { "aes-gcm" | "sm4-gcm" | "sm4-ccm" => { let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; @@ -248,12 +234,12 @@ impl KeyBundle { } let (ciphertext, tag) = data.split_at(data.len() - tag_len); Ok(decrypt_aead(cipher, &self.key, Some(&self.iv), aad, ciphertext, tag)?) - }, + } "aes-cbc" | "aes-ecb" => { let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; let iv = if self.key_type == "aes-ecb" { None } else { Some(self.iv.as_slice()) }; Ok(decrypt(cipher, &self.key, iv, data)?) - }, + } "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; if data.len() > rsa.size() as usize { diff --git a/src/utils/locks.rs b/src/utils/locks.rs new file mode 100644 index 0000000..2cb5c1e --- /dev/null +++ b/src/utils/locks.rs @@ -0,0 +1,166 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/locksutil/locks.go + +use std::sync::{Arc, RwLock}; + +use super::crypto::blake2b256_hash; + +static LOCK_COUNT: usize = 256; + +#[derive(Debug)] +pub struct LockEntry { + pub lock: RwLock, +} + +#[derive(Debug)] +pub struct Locks { + pub locks: Vec>, +} + +impl Locks { + pub fn new() -> Self { + let mut locks = Self { locks: Vec::with_capacity(LOCK_COUNT) }; + + for _ in 0..LOCK_COUNT { + locks.locks.push(Arc::new(LockEntry { lock: RwLock::new(0) })); + } + + locks + } + + pub fn get_lock(&self, key: &str) -> Arc { + let index: usize = blake2b256_hash(key)[0].into(); + Arc::clone(&self.locks[index]) + } +} + +#[cfg(test)] +mod test { + use std::{ + thread::{self, sleep}, + time::Duration, + }; + + use super::*; + + struct MyTestData { + lock: Locks, + num: RwLock, + } + + fn write_case(data: Arc) -> u32 { + let lock_entry = data.lock.get_lock("test"); + let _locked = lock_entry.lock.write().unwrap(); + sleep(Duration::from_secs(5)); + let mut num = data.num.write().unwrap(); + *num = *num * 2; + return *num; + } + + fn read_case(data: Arc) -> u32 { + let lock_entry = data.lock.get_lock("test"); + let _locked = lock_entry.lock.read().unwrap(); + let num = data.num.read().unwrap(); + return *num; + } + + #[test] + fn test_locks_writer_reader() { + let data = Arc::new(MyTestData { lock: Locks::new(), num: RwLock::new(11) }); + + let data_writer = Arc::clone(&data); + let data_reader = Arc::clone(&data); + + let writer = thread::spawn(move || { + let num = write_case(data_writer); + assert_eq!(num, 22); + }); + + sleep(Duration::from_secs(1)); + + let reader = thread::spawn(move || { + let num = read_case(data_reader); + assert_eq!(num, 22); + }); + + writer.join().unwrap(); + sleep(Duration::from_secs(1)); + reader.join().unwrap(); + + assert_eq!(*data.num.read().unwrap(), 22); + } + + #[test] + fn test_locks_reader_writer() { + let data = Arc::new(MyTestData { lock: Locks::new(), num: RwLock::new(11) }); + + let data_writer = Arc::clone(&data); + let data_reader = Arc::clone(&data); + + let reader = thread::spawn(move || { + let num = read_case(data_reader); + assert_eq!(num, 11); + }); + + sleep(Duration::from_secs(1)); + + let writer = thread::spawn(move || { + let num = write_case(data_writer); + assert_eq!(num, 22); + }); + + reader.join().unwrap(); + writer.join().unwrap(); + + assert_eq!(*data.num.read().unwrap(), 22); + } + + #[test] + fn test_locks_writer_writer() { + let data = Arc::new(MyTestData { lock: Locks::new(), num: RwLock::new(11) }); + + let data_writer1 = Arc::clone(&data); + let data_writer2 = Arc::clone(&data); + + let writer1 = thread::spawn(move || { + let num = write_case(data_writer1); + assert_eq!(num, 22); + }); + + sleep(Duration::from_secs(1)); + + let writer2 = thread::spawn(move || { + let num = write_case(data_writer2); + assert_eq!(num, 44); + }); + + writer1.join().unwrap(); + writer2.join().unwrap(); + + assert_eq!(*data.num.read().unwrap(), 44); + } + + #[test] + fn test_locks_reader_reader() { + let data = Arc::new(MyTestData { lock: Locks::new(), num: RwLock::new(11) }); + + let data_reader1 = Arc::clone(&data); + let data_reader2 = Arc::clone(&data); + + let reader1 = thread::spawn(move || { + let num = read_case(data_reader1); + assert_eq!(num, 11); + }); + + sleep(Duration::from_secs(1)); + + let reader2 = thread::spawn(move || { + let num = read_case(data_reader2); + assert_eq!(num, 11); + }); + + reader1.join().unwrap(); + reader2.join().unwrap(); + assert_eq!(*data.num.read().unwrap(), 11); + } +} diff --git a/src/utils/ocsp.rs b/src/utils/ocsp.rs new file mode 100644 index 0000000..1e769b8 --- /dev/null +++ b/src/utils/ocsp.rs @@ -0,0 +1,30 @@ +use openssl::{ + x509::X509, +}; + +#[repr(u32)] +pub enum FailureMode { + OcspFailOpenNotSet = 0, + FailOpenTrue = 1, + FailOpenFalse = 2, +} + +pub struct OcspConfig { + pub enable: bool, + pub extra_ca: Vec, + pub servers_override: Vec, + pub failure_mode: FailureMode, + pub query_all_servers: bool, +} + +impl Default for OcspConfig { + fn default() -> Self { + OcspConfig { + enable: false, + extra_ca: Vec::new(), + servers_override: Vec::new(), + failure_mode: FailureMode::OcspFailOpenNotSet, + query_all_servers: false, + } + } +} diff --git a/src/utils/policy.rs b/src/utils/policy.rs new file mode 100644 index 0000000..b616d15 --- /dev/null +++ b/src/utils/policy.rs @@ -0,0 +1,96 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/policyutil/policyutil.go + +use super::string::remove_duplicates; + +// sanitize_policies performs the common input validation tasks +// which are performed on the list of policies across RustyVault. +// The resulting collection will have no duplicate elements. +// If 'root' policy was present in the list of policies, then +// all other policies will be ignored, the result will contain +// just the 'root'. In cases where 'root' is not present, if +// 'default' policy is not already present, it will be added +// if add_default is set to true. +pub fn sanitize_policies(policies: &mut Vec, add_default: bool) { + let mut default_found = false; + for p in policies.iter() { + let q = p.trim().to_lowercase(); + if q.is_empty() { + continue; + } + + // If 'root' policy is present, ignore all other policies. + if q == "root" { + policies.clear(); + policies.push("root".to_string()); + default_found = true; + break; + } + if q == "default" { + default_found = true; + } + } + + // Always add 'default' except only if the policies contain 'root'. + if add_default && (!default_found || policies.is_empty()) { + policies.push("default".to_string()); + } + + remove_duplicates(policies, false, true) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_policy() { + let mut policies1 = vec![ + String::from("Root"), + String::from("root"), + String::from("Admin"), + String::from("Default"), + String::from(""), + String::from("Admin"), + ]; + + sanitize_policies(&mut policies1, true); + assert_eq!(policies1, vec!["root".to_string()]); + + let mut policies2 = vec![ + String::from("rooot"), + String::from("Admin"), + String::from("Default"), + String::from(""), + String::from("Admin"), + ]; + + sanitize_policies(&mut policies2, true); + assert_eq!(policies2, vec!["admin".to_string(), "default".to_string(), "rooot".to_string()]); + + let mut policies3 = vec![String::from("rooot"), String::from("Admin"), String::from(""), String::from("Admin")]; + + sanitize_policies(&mut policies3, true); + assert_eq!(policies3, vec!["admin".to_string(), "default".to_string(), "rooot".to_string()]); + + let mut policies4 = vec![String::from("")]; + + sanitize_policies(&mut policies4, true); + assert_eq!(policies4, vec!["default".to_string()]); + + let mut policies5 = Vec::new(); + + sanitize_policies(&mut policies5, true); + assert_eq!(policies5, vec!["default".to_string()]); + + let mut policies6 = Vec::new(); + + sanitize_policies(&mut policies6, false); + assert_eq!(policies6.len(), 0); + + let mut policies7 = vec![String::from("rooot"), String::from("Admin"), String::from(""), String::from("Admin")]; + + sanitize_policies(&mut policies7, false); + assert_eq!(policies7, vec!["admin".to_string(), "rooot".to_string()]); + } +} diff --git a/src/utils/salt.rs b/src/utils/salt.rs index 313bad0..c6e8ae1 100644 --- a/src/utils/salt.rs +++ b/src/utils/salt.rs @@ -1,21 +1,18 @@ //! This module is a Rust replica of //! +use derivative::Derivative; use openssl::{ hash::{hash, MessageDigest}, - pkey::PKey, nid::Nid, + pkey::PKey, sign::Signer, }; -use derivative::Derivative; - -use super::{ - generate_uuid, -}; +use super::generate_uuid; use crate::{ - storage::{Storage, StorageEntry}, errors::RvError, + storage::{Storage, StorageEntry}, }; static DEFAULT_LOCATION: &str = "salt"; @@ -31,19 +28,15 @@ pub struct Salt { #[derivative(Debug, Clone)] pub struct Config { pub location: String, - #[derivative(Debug="ignore")] + #[derivative(Debug = "ignore")] pub hash_type: MessageDigest, - #[derivative(Debug="ignore")] + #[derivative(Debug = "ignore")] pub hmac_type: MessageDigest, } impl Default for Salt { fn default() -> Self { - Self { - salt: generate_uuid(), - generated: true, - config: Config::default(), - } + Self { salt: generate_uuid(), generated: true, config: Config::default() } } } @@ -79,10 +72,7 @@ impl Salt { salt.salt = String::from_utf8_lossy(&raw.value).to_string(); salt.generated = false; } else { - let entry = StorageEntry { - key: salt.config.location.clone(), - value: salt.salt.as_bytes().to_vec(), - }; + let entry = StorageEntry { key: salt.config.location.clone(), value: salt.salt.as_bytes().to_vec() }; s.put(&entry)?; } @@ -137,16 +127,13 @@ impl Salt { #[cfg(test)] mod test { use std::{collections::HashMap, env, fs, sync::Arc}; + use go_defer::defer; use rand::{thread_rng, Rng}; use serde_json::Value; + use super::*; - use crate::{ - storage::{ - barrier_view, barrier_aes_gcm, - barrier::SecurityBarrier, - } - }; + use crate::storage::{barrier::SecurityBarrier, barrier_aes_gcm, barrier_view}; #[test] fn test_salt() { @@ -204,6 +191,6 @@ mod test { let sid1 = sid1.unwrap(); let sid2 = sid2.unwrap(); assert_eq!(sid1, sid2); - assert_eq!(sid1.len(), salt.config.hash_type.size()*2); + assert_eq!(sid1.len(), salt.config.hash_type.size() * 2); } } diff --git a/src/utils/sock_addr.rs b/src/utils/sock_addr.rs index 57b9856..779c5ab 100644 --- a/src/utils/sock_addr.rs +++ b/src/utils/sock_addr.rs @@ -1,19 +1,12 @@ //! This module is a Rust replica of //! -use std::{ - fmt, - str::FromStr, -}; +use std::{fmt, str::FromStr}; use as_any::AsAny; -use serde::{Deserialize, Serialize, Deserializer, Serializer}; - -use super::{ - ip_sock_addr::IpSockAddr, - unix_sock_addr::UnixSockAddr, -}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use super::{ip_sock_addr::IpSockAddr, unix_sock_addr::UnixSockAddr}; use crate::errors::RvError; pub trait CloneBox { @@ -61,9 +54,7 @@ impl SockAddrMarshaler { pub fn from_str(s: &str) -> Result { let sock_addr = new_sock_addr(s)?; - Ok(SockAddrMarshaler { - sock_addr: sock_addr, - }) + Ok(SockAddrMarshaler { sock_addr }) } } @@ -96,9 +87,7 @@ impl<'de> Deserialize<'de> for SockAddrMarshaler { { let s = String::deserialize(deserializer)?; let sock_addr = new_sock_addr(&s).map_err(serde::de::Error::custom)?; - Ok(SockAddrMarshaler { - sock_addr: sock_addr, - }) + Ok(SockAddrMarshaler { sock_addr }) } } @@ -121,7 +110,7 @@ impl FromStr for SockAddrType { "IPv4" | "ipv4" => Ok(SockAddrType::IPv4), "IPv6" | "ipv6" => Ok(SockAddrType::IPv6), "Unix" | "UNIX" | "unix" => Ok(SockAddrType::Unix), - _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())) + _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())), } } } @@ -141,11 +130,8 @@ pub fn new_sock_addr(s: &str) -> Result, RvError> { #[cfg(test)] mod test { use super::{ - *, super::{ - sock_addr::{SockAddrType}, - ip_sock_addr::IpSockAddr, - unix_sock_addr::UnixSockAddr, - }, + super::{ip_sock_addr::IpSockAddr, sock_addr::SockAddrType, unix_sock_addr::UnixSockAddr}, + *, }; #[test] diff --git a/src/utils/string.rs b/src/utils/string.rs new file mode 100644 index 0000000..3b57328 --- /dev/null +++ b/src/utils/string.rs @@ -0,0 +1,69 @@ +use std::collections::HashSet; + +pub fn remove_duplicates(strings: &mut Vec, stable: bool, lowercase: bool) { + if stable { + let mut seen = HashSet::new(); + let mut i = 0; + while i < strings.len() { + if lowercase { + strings[i].make_ascii_lowercase(); + } + if strings[i].trim().is_empty() || !seen.insert(strings[i].clone()) { + strings.remove(i); + } else { + i += 1; + } + } + } else { + if lowercase { + strings.iter_mut().for_each(|s| s.make_ascii_lowercase()); + } + strings.retain(|s| !s.trim().is_empty()); + strings.sort(); + strings.dedup(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_string() { + let strings = vec![ + String::from("Orange"), + String::from("Apple"), + String::from("banana"), + String::from(""), + String::from("banana"), + String::from(""), + String::from(""), + String::from(""), + String::from("orange"), + String::from(""), + ]; + + let mut strings1 = strings.clone(); + let mut strings2 = strings.clone(); + let mut strings3 = strings.clone(); + let mut strings4 = strings.clone(); + + remove_duplicates(&mut strings1, true, true); + assert_eq!(strings1, vec!["orange".to_string(), "apple".to_string(), "banana".to_string()]); + + remove_duplicates(&mut strings2, true, false); + assert_eq!( + strings2, + vec!["Orange".to_string(), "Apple".to_string(), "banana".to_string(), "orange".to_string()] + ); + + remove_duplicates(&mut strings3, false, true); + assert_eq!(strings3, vec!["apple".to_string(), "banana".to_string(), "orange".to_string()]); + + remove_duplicates(&mut strings4, false, false); + assert_eq!( + strings4, + vec!["Apple".to_string(), "Orange".to_string(), "banana".to_string(), "orange".to_string()] + ); + } +} diff --git a/src/utils/token_util.rs b/src/utils/token_util.rs new file mode 100644 index 0000000..7596584 --- /dev/null +++ b/src/utils/token_util.rs @@ -0,0 +1,283 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; + +use crate::{ + errors::RvError, + logical::{field::FieldTrait, Auth, Field, FieldType, Request}, + new_fields, new_fields_internal, + utils::{deserialize_duration, serialize_duration, sock_addr::SockAddrMarshaler}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenParams { + pub token_type: String, + #[serde(serialize_with = "serialize_duration", deserialize_with = "deserialize_duration")] + pub token_ttl: Duration, + #[serde(serialize_with = "serialize_duration", deserialize_with = "deserialize_duration")] + pub token_max_ttl: Duration, + #[serde(serialize_with = "serialize_duration", deserialize_with = "deserialize_duration")] + pub token_explicit_max_ttl: Duration, + #[serde(serialize_with = "serialize_duration", deserialize_with = "deserialize_duration")] + pub token_period: Duration, + pub token_no_default_policy: bool, + pub token_num_uses: u64, + pub token_policies: Vec, + pub token_bound_cidrs: Vec, +} + +impl Default for TokenParams { + fn default() -> Self { + TokenParams { + token_type: String::new(), + token_ttl: Duration::from_secs(0), + token_max_ttl: Duration::from_secs(0), + token_explicit_max_ttl: Duration::from_secs(0), + token_period: Duration::from_secs(0), + token_no_default_policy: false, + token_num_uses: 0, + token_policies: Vec::new(), + token_bound_cidrs: Vec::new(), + } + } +} + +pub fn token_fields() -> HashMap> { + let fields = new_fields!({ + "token_type": { + field_type: FieldType::Str, + default: "default", + description: "The type of token to generate, service or batch" + }, + "token_ttl": { + field_type: FieldType::DurationSecond, + description: "The initial ttl of the token to generate" + }, + "token_max_ttl": { + field_type: FieldType::DurationSecond, + description: "The maximum lifetime of the generated token" + }, + "token_explicit_max_ttl": { + field_type: FieldType::DurationSecond, + description: r#"If set, tokens created via this role carry an explicit maximum TTL. +During renewal, the current maximum TTL values of the role and the mount are not checked for changes, +and any updates to these values will have no effect on the token being renewed."# + }, + "token_period": { + field_type: FieldType::DurationSecond, + description: r#"If set, tokens created via this role will have no max lifetime; +instead, their renewal period will be fixed to this value. This takes an integer number of seconds, +or a string duration (e.g. "24h")."# + }, + "token_no_default_policy": { + field_type: FieldType::Bool, + description: "If true, the 'default' policy will not automatically be added to generated tokens" + }, + "token_policies": { + field_type: FieldType::CommaStringSlice, + description: "Comma-separated list of policies" + }, + "token_bound_cidrs": { + field_type: FieldType::CommaStringSlice, + required: false, + description: r#"Comma separated string or JSON list of CIDR blocks. If set, specifies the blocks of IP addresses which are allowed to use the generated token."# + }, + "token_num_uses": { + field_type: FieldType::Int, + description: "The maximum number of times a token may be used, a value of zero means unlimited" + } + }); + + fields +} + +impl TokenParams { + pub fn new(token_type: &str) -> Self { + Self { token_type: token_type.to_string(), ..TokenParams::default() } + } + + pub fn parse_token_fields(&mut self, req: &Request) -> Result<(), RvError> { + if let Ok(ttl_value) = req.get_data("token_ttl") { + self.token_ttl = ttl_value.as_duration().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(max_ttl_value) = req.get_data("token_max_ttl") { + self.token_max_ttl = max_ttl_value.as_duration().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(explicit_max_ttl_value) = req.get_data("token_explicit_max_ttl") { + self.token_explicit_max_ttl = + explicit_max_ttl_value.as_duration().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(period_value) = req.get_data("token_period") { + self.token_period = period_value.as_duration().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(no_default_policy_value) = req.get_data("token_no_default_policy") { + self.token_no_default_policy = no_default_policy_value.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(num_uses_value) = req.get_data("token_num_uses") { + self.token_num_uses = num_uses_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + println!("111"); + if let Ok(type_value) = req.get_data_or_default("token_type") { + let token_type = type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + self.token_type = match token_type.as_str() { + "" => "default".to_string(), + "default-service" => "service".to_string(), + "default-batch" => "batch".to_string(), + _ => token_type.clone(), + }; + + match self.token_type.as_str() { + "default" | "service" | "batch" => {} + _ => { + return Err(RvError::ErrRequestFieldInvalid); + } + }; + } + println!("222"); + + if let Ok(policies_value) = req.get_data("token_policies") { + self.token_policies = policies_value.as_comma_string_slice().ok_or(RvError::ErrRequestFieldInvalid)?; + } + + if let Ok(token_bound_cidrs_value) = req.get_data("token_bound_cidrs") { + let token_bound_cidrs = + token_bound_cidrs_value.as_comma_string_slice().ok_or(RvError::ErrRequestFieldInvalid)?; + self.token_bound_cidrs = token_bound_cidrs + .iter() + .map(|s| SockAddrMarshaler::from_str(s)) + .collect::, _>>()?; + } + + Ok(()) + } + + pub fn populate_token_data(&self, data: &mut Map) { + data.insert("token_type".to_string(), json!(self.token_type.clone())); + data.insert("token_ttl".to_string(), json!(self.token_ttl.as_secs())); + data.insert("token_max_ttl".to_string(), json!(self.token_max_ttl.as_secs())); + data.insert("token_explicit_max_ttl".to_string(), json!(self.token_explicit_max_ttl.as_secs())); + data.insert("token_period".to_string(), json!(self.token_period.as_secs())); + data.insert("token_no_default_policy".to_string(), json!(self.token_no_default_policy)); + data.insert("token_num_uses".to_string(), json!(self.token_num_uses)); + data.insert("token_policies".to_string(), json!(self.token_policies)); + data.insert("token_bound_cidrs".to_string(), json!(self.token_bound_cidrs)); + } + + pub fn populate_token_auth(&self, auth: &mut Auth) { + auth.ttl = self.token_ttl; + auth.max_ttl = self.token_max_ttl; + auth.policies = self.token_policies.clone(); + auth.renewable = true; + } +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, env, fs, sync::Arc}; + + use go_defer::defer; + use serde_json::json; + + use super::*; + use crate::{ + logical::{Operation, Path}, + storage::{self, barrier_aes_gcm::AESGCMBarrier}, + }; + + #[test] + fn test_token_util() { + let dir = env::temp_dir().join("rusty_vault_test_token_util"); + assert!(fs::create_dir(&dir).is_ok()); + defer! ( + assert!(fs::remove_dir_all(&dir).is_ok()); + ); + + let mut conf: HashMap = HashMap::new(); + conf.insert("path".to_string(), Value::String(dir.to_string_lossy().into_owned())); + + let backend = storage::new_backend("file", &conf).unwrap(); + + let barrier = AESGCMBarrier::new(Arc::clone(&backend)); + + let token_fields = token_fields(); + let mut path = Path::new("/"); + path.fields = token_fields; + + let mut req = Request::new("/"); + req.operation = Operation::Write; + req.storage = Some(Arc::new(barrier)); + req.match_path = Some(Arc::new(path)); + + req.path = "/2/foo/goo".to_string(); + + let req_body = json!({ + "token_type": "default", + "token_ttl": "60", + "token_max_ttl": 600, + "token_explicit_max_ttl": 800, + "token_no_default_policy": true, + "token_num_uses": 100, + "token_policies": "aa,bb,cc", + "token_bound_cidrs": ["192.168.1.1:8080","10.0.0.1:80"], + }); + req.body = Some(req_body.as_object().unwrap().clone()); + + let mut token_params = TokenParams::new("tt1"); + let ret = token_params.parse_token_fields(&req); + println!("ret: {:?}", ret); + assert!(ret.is_ok()); + println!("token_params: {:?}", token_params); + + let mut token_params_map: Map = Map::new(); + token_params.populate_token_data(&mut token_params_map); + println!("token_params_map: {:?}", token_params_map); + + assert_eq!(req_body["token_type"], token_params_map["token_type"]); + assert_eq!(req_body["token_ttl"].as_int(), token_params_map["token_ttl"].as_int()); + assert_eq!(req_body["token_max_ttl"].as_int(), token_params_map["token_max_ttl"].as_int()); + assert_eq!(req_body["token_explicit_max_ttl"].as_int(), token_params_map["token_explicit_max_ttl"].as_int()); + assert_eq!(req_body["token_no_default_policy"], token_params_map["token_no_default_policy"]); + assert_eq!(req_body["token_num_uses"].as_int(), token_params_map["token_num_uses"].as_int()); + let token_policies = token_params_map["token_policies"] + .as_array() + .map(|vec| vec.iter().filter_map(|val| val.as_str().map(|s| s.to_string())).collect()); + let token_bound_cidrs = token_params_map["token_bound_cidrs"] + .as_array() + .map(|vec| vec.iter().filter_map(|val| val.as_str().map(|s| s.to_string())).collect()); + assert_eq!(req_body["token_policies"].as_comma_string_slice(), token_policies); + assert_eq!(req_body["token_bound_cidrs"].as_comma_string_slice(), token_bound_cidrs); + + let req_body = json!({ + "token_type": "service", + "token_ttl": "60", + "token_max_ttl": 600, + "token_explicit_max_ttl": 800, + "token_no_default_policy": true, + "token_num_uses": 100, + }); + req.body = Some(req_body.as_object().unwrap().clone()); + + let mut token_params = TokenParams::new("tt2"); + let ret = token_params.parse_token_fields(&req); + assert!(ret.is_ok()); + println!("token_params: {:?}", token_params); + + let mut token_params_map: Map = Map::new(); + token_params.populate_token_data(&mut token_params_map); + println!("token_params_map: {:?}", token_params_map); + + assert_eq!(req_body["token_type"], token_params_map["token_type"]); + assert_eq!(req_body["token_ttl"].as_int(), token_params_map["token_ttl"].as_int()); + assert_eq!(req_body["token_max_ttl"].as_int(), token_params_map["token_max_ttl"].as_int()); + assert_eq!(req_body["token_explicit_max_ttl"].as_int(), token_params_map["token_explicit_max_ttl"].as_int()); + assert_eq!(req_body["token_no_default_policy"], token_params_map["token_no_default_policy"]); + assert_eq!(req_body["token_num_uses"].as_int(), token_params_map["token_num_uses"].as_int()); + } +} diff --git a/src/utils/unix_sock_addr.rs b/src/utils/unix_sock_addr.rs index 27eaa2b..4e2b3ce 100644 --- a/src/utils/unix_sock_addr.rs +++ b/src/utils/unix_sock_addr.rs @@ -2,13 +2,11 @@ //! use std::fmt; + use as_any::Downcast; use serde::{Deserialize, Serialize}; -use super::{ - sock_addr::{SockAddr, SockAddrType}, -}; - +use super::sock_addr::{SockAddr, SockAddrType}; use crate::errors::RvError; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -20,11 +18,13 @@ impl UnixSockAddr { pub fn new(s: &str) -> Result { // Check to make sure the string begins with either a '.' or '/', or contains a '/'. if s.len() > 1 && (s[0..1].contains('.') || s[0..1].contains('/') || s.contains('/')) { - Ok(Self { - path: s.to_string(), - }) + Ok(Self { path: s.to_string() }) } else { - Err(RvError::ErrResponse(format!("Unable to convert {} to a UNIX Socke, make sure the string begins with either a '.' or '/', or contains a '/'", s))) + Err(RvError::ErrResponse(format!( + "Unable to convert {} to a UNIX Socke, make sure the string begins with either a '.' or '/', or \ + contains a '/'", + s + ))) } } } @@ -59,9 +59,7 @@ impl fmt::Display for UnixSockAddr { #[cfg(test)] mod test { - use super::{ - *, super::sock_addr::{SockAddrType}, - }; + use super::{super::sock_addr::SockAddrType, *}; #[test] fn test_unix_sock_addr() {