diff --git a/Cargo.toml b/Cargo.toml index b055ccb..025680e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,10 @@ glob = "0.3" serde_asn1_der = "0.8" base64 = "0.22" ipnetwork = "0.20" +blake2b_simd = "1.0" +derive_more = "0.99.17" +dashmap = "5.5" +tokio = "1.38" # optional dependencies openssl = { version = "0.10", optional = true } diff --git a/src/context.rs b/src/context.rs index 3758da7..06ddb93 100644 --- a/src/context.rs +++ b/src/context.rs @@ -3,27 +3,35 @@ use std::{ any::Any, - cell::RefCell, - collections::HashMap, - sync::{Arc, Mutex}, + sync::{Arc, RwLock}, }; +use dashmap::DashMap; + +#[derive(Debug)] pub struct Context { - data_map: Mutex>>>, + data_map: DashMap>, + data_map_mut: DashMap>>, } impl Context { pub fn new() -> Self { - Self { data_map: Mutex::new(HashMap::new()) } + Self { data_map: DashMap::new(), data_map_mut: DashMap::new() } + } + + pub fn set_mut(&self, key: &str, data: Arc>) { + self.data_map_mut.insert(key.to_string(), data); + } + + pub fn get_mut(&self, key: &str) -> Option>> { + self.data_map_mut.get(key).map(|r| Arc::clone(&r.value())) } - pub fn set(&self, key: &str, data: Arc>) { - let mut data_map = self.data_map.lock().unwrap(); - data_map.insert(key.to_string(), data); + pub fn set(&self, key: &str, data: Arc) { + self.data_map.insert(key.to_string(), data); } - pub fn get(&self, key: &str) -> Option>> { - let data_map = self.data_map.lock().unwrap(); - data_map.get(key).cloned() + pub fn get(&self, key: &str) -> Option> { + self.data_map.get(key).map(|r| Arc::clone(&*r)) } } diff --git a/src/core.rs b/src/core.rs index 37d2cc6..2a4df2b 100644 --- a/src/core.rs +++ b/src/core.rs @@ -9,13 +9,14 @@ use std::{ collections::HashMap, - sync::{Arc, Mutex, RwLock}, ops::{Deref, DerefMut}, + sync::{Arc, Mutex, RwLock}, }; use as_any::Downcast; use go_defer::defer; use serde::{Deserialize, Serialize}; +use zeroize::Zeroizing; use crate::{ cli::config::Config, @@ -23,22 +24,22 @@ use crate::{ handler::Handler, logical::{Backend, Request, Response}, module_manager::ModuleManager, - modules::{auth::AuthModule, credential::userpass::UserPassModule, pki::PkiModule}, + modules::{ + auth::AuthModule, + credential::{ + userpass::UserPassModule, + }, + pki::PkiModule, + }, mount::MountTable, router::Router, shamir::{ShamirSecret, SHAMIR_OVERHEAD}, storage::{ - barrier::SecurityBarrier, - barrier_aes_gcm, - barrier_view::BarrierView, - physical, - Backend as PhysicalBackend, - BackendEntry as PhysicalBackendEntry, + barrier::SecurityBarrier, barrier_aes_gcm, barrier_view::BarrierView, physical, Backend as PhysicalBackend, + BackendEntry as PhysicalBackendEntry, Storage, }, }; -use zeroize::Zeroizing; - pub type LogicalBackendNewFunc = dyn Fn(Arc>) -> Result, RvError> + Send + Sync; pub const SEAL_CONFIG_PATH: &str = "core/seal-config"; @@ -156,8 +157,11 @@ impl Core { if seal_config.secret_shares == 1 { init_result.secret_shares.deref_mut().push(master_key.deref().clone()); } else { - init_result.secret_shares = - ShamirSecret::split(master_key.deref().as_slice(), seal_config.secret_shares, seal_config.secret_threshold)?; + init_result.secret_shares = ShamirSecret::split( + master_key.deref().as_slice(), + seal_config.secret_shares, + seal_config.secret_threshold, + )?; } log::debug!("master_key: {}", hex::encode(master_key.deref())); @@ -197,6 +201,14 @@ impl Core { Ok(init_result) } + pub fn get_system_view(&self) -> Option> { + self.system_view.clone() + } + + pub fn get_system_storage(&self) -> &dyn Storage { + self.system_view.as_ref().unwrap().as_storage() + } + pub fn get_logical_backend(&self, logical_type: &str) -> Result, RvError> { let logical_backends = self.logical_backends.lock().unwrap(); if let Some(backend) = logical_backends.get(logical_type) { diff --git a/src/logical/backend.rs b/src/logical/backend.rs index 3fcfaaa..e8809b5 100644 --- a/src/logical/backend.rs +++ b/src/logical/backend.rs @@ -4,10 +4,14 @@ use regex::Regex; use serde_json::{Map, Value}; use super::{path::Path, request::Request, response::Response, secret::Secret, FieldType, Backend, Operation}; -use crate::errors::RvError; +use crate::{ + context::Context, errors::RvError +}; type BackendOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result, RvError> + Send + Sync; +pub const CTX_KEY_BACKEND_PATH: &str = "backend.path"; + #[derive(Clone)] pub struct LogicalBackend { pub paths: Vec>, @@ -17,6 +21,7 @@ pub struct LogicalBackend { pub help: String, pub secrets: Vec>, pub auth_renew_handler: Option>, + pub ctx: Arc, } impl Backend for LogicalBackend { @@ -58,6 +63,10 @@ impl Backend for LogicalBackend { Some(self.root_paths.clone()) } + fn get_ctx(&self) -> Option> { + Some(Arc::clone(&self.ctx)) + } + fn handle_request(&self, req: &mut Request) -> Result, RvError> { if req.storage.is_none() { return Err(RvError::ErrRequestNotReady); @@ -86,6 +95,7 @@ impl Backend for LogicalBackend { req.match_path = Some(path.clone()); for operation in &path.operations { if operation.op == req.operation { + self.ctx.set(CTX_KEY_BACKEND_PATH, path.clone()); let ret = operation.handle_request(self, req); self.clear_secret_field(req); return ret; @@ -113,6 +123,7 @@ impl LogicalBackend { help: String::new(), secrets: Vec::new(), auth_renew_handler: None, + ctx: Arc::new(Context::new()), } } diff --git a/src/logical/field.rs b/src/logical/field.rs index f2bd91f..8fefeb8 100644 --- a/src/logical/field.rs +++ b/src/logical/field.rs @@ -1,9 +1,10 @@ -use std::{fmt, time::Duration}; +use std::{fmt, time::Duration, collections::HashMap}; use enum_map::Enum; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strum::{Display, EnumString}; +use humantime::parse_duration; use crate::errors::RvError; @@ -39,9 +40,11 @@ pub trait FieldTrait { fn is_int(&self) -> bool; fn is_duration(&self) -> bool; fn is_comma_string_slice(&self) -> bool; + fn is_map(&self) -> bool; fn as_int(&self) -> Option; fn as_duration(&self) -> Option; fn as_comma_string_slice(&self) -> Option>; + fn as_map(&self) -> Option>; } impl FieldTrait for Value { @@ -68,17 +71,15 @@ impl FieldTrait for Value { return true; } - let secs_str = self.as_str(); - if secs_str.is_none() { - return false; - } - - let secs = secs_str.unwrap().parse::().ok(); - if secs.is_none() { - return false; + if let Some(secs_str) = self.as_str() { + if secs_str.parse::().ok().is_some() { + return true; + } else if parse_duration(secs_str).is_ok() { + return true; + } } - true + false } fn is_comma_string_slice(&self) -> bool { @@ -115,6 +116,20 @@ impl FieldTrait for Value { false } + fn is_map(&self) -> bool { + if self.is_object() { + return true; + } + + let map_str = self.as_str(); + if map_str.is_none() { + return false; + } + + let map = serde_json::from_str::(map_str.unwrap()); + return map.is_ok() && map.unwrap().is_object(); + } + fn as_int(&self) -> Option { let mut int = self.as_i64(); if int.is_none() { @@ -133,19 +148,19 @@ impl FieldTrait for Value { } fn as_duration(&self) -> Option { - let mut secs = self.as_u64(); - if secs.is_none() { - let secs_str = self.as_str(); - if secs_str.is_none() { - return None; - } + if let Some(secs) = self.as_u64() { + return Some(Duration::from_secs(secs)) + } - secs = secs_str.unwrap().parse::().ok(); - if secs.is_none() { - return None; + if let Some(secs_str) = self.as_str() { + if let Some(secs_int) = secs_str.parse::().ok() { + return Some(Duration::from_secs(secs_int)); + } else if let Ok(ret) = parse_duration(secs_str) { + return Some(ret); } } - Some(Duration::from_secs(secs.unwrap())) + + None } fn as_comma_string_slice(&self) -> Option> { @@ -185,6 +200,27 @@ impl FieldTrait for Value { None } + + fn as_map(&self) -> Option> { + let mut ret: HashMap = HashMap::new(); + if let Some(map) = self.as_object() { + for (key, value) in map { + if !value.is_string() { + continue; + } + ret.insert(key.clone(), value.as_str().unwrap().to_string()); + } + + return Some(ret); + } + + if let Some(value) = self.as_str() { + let map: HashMap = serde_json::from_str(value).unwrap_or(HashMap::new()); + return Some(map); + } + + None + } } impl Field { @@ -378,4 +414,81 @@ mod test { let val = field.get_default().unwrap(); assert_eq!(val.as_comma_string_slice(), Some(val_str.iter().map(|&s| s.to_string()).collect::>())); } + + #[test] + fn test_field_trait() { + let mut val = json!("45"); + assert!(val.is_int()); + assert_eq!(val.as_int(), Some(45)); + assert!(val.is_duration()); + assert_eq!(val.as_duration(), Some(Duration::from_secs(45))); + + val = json!(50); + assert!(val.is_int()); + assert_eq!(val.as_int(), Some(50)); + assert!(val.is_duration()); + assert_eq!(val.as_duration(), Some(Duration::from_secs(50))); + + val = json!("45s"); + assert!(!val.is_int()); + assert_eq!(val.as_int(), None); + assert!(val.is_duration()); + assert_eq!(val.as_duration(), Some(Duration::from_secs(45))); + + val = json!("5m"); + assert!(!val.is_int()); + assert_eq!(val.as_int(), None); + assert!(val.is_duration()); + assert_eq!(val.as_duration(), Some(Duration::from_secs(300))); + + val = json!("aa, bb, cc ,dd"); + assert!(val.is_comma_string_slice()); + assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string()])); + + val = json!(["aaa", " bbb", "ccc " , " ddd"]); + assert!(val.is_comma_string_slice()); + assert_eq!(val.as_comma_string_slice(), Some(vec!["aaa".to_string(), "bbb".to_string(), "ccc".to_string(), "ddd".to_string()])); + + val = json!([11, 22, 33, 44]); + assert!(val.is_comma_string_slice()); + assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "22".to_string(), "33".to_string(), "44".to_string()])); + + val = json!([11, "aa22", 33, 44]); + assert!(val.is_comma_string_slice()); + assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "aa22".to_string(), "33".to_string(), "44".to_string()])); + + val = json!("aa, bb, cc ,dd, , 88,"); + assert!(val.is_comma_string_slice()); + assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string(), "88".to_string()])); + + let mut map: HashMap = HashMap::new(); + map.insert("k1".to_string(), "v1".to_string()); + map.insert("k2".to_string(), "v2".to_string()); + + val = json!(r#"{"k1": "v1", "k2": "v2"}"#); + assert!(val.is_map()); + assert_eq!(val.as_map(), Some(map.clone())); + + val = serde_json::from_str(r#"{"k1": "v1", "k2": "v2"}"#).unwrap(); + assert!(val.is_map()); + assert_eq!(val.as_map(), Some(map.clone())); + + val = serde_json::from_str(r#"{"k1": "v1", "k2": {"kk2": "vv2"}}"#).unwrap(); + assert!(val.is_map()); + map.remove("k2"); + assert_eq!(val.as_map(), Some(map.clone())); + + map.clear(); + map.insert("tag1".to_string(), "production".to_string()); + val = json!({ + "metadata": "{ \"tag1\": \"production\" }", + "ttl": 600, + "num_uses": 50 + }); + assert!(val.is_object()); + assert!(val.is_map()); + let obj = val.as_object().unwrap(); + assert!(obj["metadata"].is_map()); + assert_eq!(obj["metadata"].as_map(), Some(map)); + } } diff --git a/src/logical/lease.rs b/src/logical/lease.rs index aa7eebd..bdc40b0 100644 --- a/src/logical/lease.rs +++ b/src/logical/lease.rs @@ -6,11 +6,11 @@ use serde::{Deserialize, Serialize}; pub struct Lease { #[serde(rename = "lease")] pub ttl: Duration, + #[serde(skip)] pub max_ttl: Duration, pub renewable: bool, #[serde(skip)] pub increment: Duration, - //pub issue_time: SystemTime, #[serde(skip)] pub issue_time: Option, } @@ -22,7 +22,6 @@ impl Default for Lease { max_ttl: Duration::new(0, 0), renewable: true, increment: Duration::new(0, 0), - //issue_time: SystemTime::now(), issue_time: Some(SystemTime::now()), } } @@ -46,7 +45,6 @@ impl Lease { } pub fn expiration_time(&self) -> SystemTime { - //self.issue_time + self.max_ttl if self.issue_time.is_some() { self.issue_time.unwrap() + self.ttl } else { diff --git a/src/logical/mod.rs b/src/logical/mod.rs index 72e89ae..20fd0dc 100644 --- a/src/logical/mod.rs +++ b/src/logical/mod.rs @@ -17,7 +17,9 @@ use enum_map::Enum; use serde::{Deserialize, Serialize}; use strum::{Display, EnumString}; -use crate::errors::RvError; +use crate::{ + context::Context, errors::RvError, +}; pub mod auth; pub mod backend; @@ -30,7 +32,7 @@ pub mod response; pub mod secret; pub use auth::Auth; -pub use backend::LogicalBackend; +pub use backend::{LogicalBackend, CTX_KEY_BACKEND_PATH}; pub use field::{Field, FieldType}; pub use lease::Lease; pub use path::{Path, PathOperation}; @@ -65,6 +67,7 @@ pub trait Backend: Send + Sync { fn cleanup(&self) -> Result<(), RvError>; fn get_unauth_paths(&self) -> Option>>; fn get_root_paths(&self) -> Option>>; + fn get_ctx(&self) -> Option>; fn handle_request(&self, req: &mut Request) -> Result, RvError>; fn secret(&self, key: &str) -> Option<&Arc>; } diff --git a/src/logical/path.rs b/src/logical/path.rs index 7846702..998f7f4 100644 --- a/src/logical/path.rs +++ b/src/logical/path.rs @@ -1,12 +1,15 @@ use std::{collections::HashMap, fmt, sync::Arc}; use super::{request::Request, response::Response, Backend, Field, Operation}; -use crate::errors::RvError; +use crate::{ + context::Context, errors::RvError +}; type PathOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result, RvError> + Send + Sync; #[derive(Debug, Clone)] pub struct Path { + pub ctx: Arc, pub pattern: String, pub fields: HashMap>, pub operations: Vec, @@ -27,7 +30,7 @@ impl fmt::Debug for PathOperation { impl Path { pub fn new(pattern: &str) -> Self { - Self { pattern: pattern.to_string(), fields: HashMap::new(), operations: Vec::new(), help: String::new() } + Self { ctx: Arc::new(Context::new()), pattern: pattern.to_string(), fields: HashMap::new(), operations: Vec::new(), help: String::new() } } pub fn get_field(&self, key: &str) -> Option> { @@ -165,6 +168,7 @@ macro_rules! new_path_internal { ({ $($tt:tt)+ }) => { { let mut path = Path { + ctx: Arc::new(Context::new()), pattern: String::new(), fields: HashMap::new(), operations: Vec::new(), diff --git a/src/logical/request.rs b/src/logical/request.rs index a24dc45..0dde765 100644 --- a/src/logical/request.rs +++ b/src/logical/request.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use serde_json::{Map, Value}; +use tokio::task::JoinHandle; use super::{Operation, Path}; use crate::{ @@ -23,6 +24,7 @@ pub struct Request { pub connection: Option, pub secret: Option, pub auth: Option, + pub tasks: Vec>, } impl Default for Request { @@ -41,6 +43,7 @@ impl Default for Request { connection: None, secret: None, auth: None, + tasks: Vec::new(), } } } @@ -135,7 +138,7 @@ impl Request { match self.get_data_raw(key, false) { Ok(raw) => { return Ok(raw); - }, + } Err(e) => { if e != RvError::ErrRequestFieldNotFound { return Err(e); @@ -147,6 +150,22 @@ impl Request { return Err(RvError::ErrRequestFieldNotFound); } + pub fn get_data_as_str(&self, key: &str) -> Result { + self.get_data(key)? + .as_str() + .ok_or(RvError::ErrRequestFieldInvalid) + .and_then(|s| if s.trim().is_empty() { + Err(RvError::ErrResponse(format!("missing {}", key))) + } else { + Ok(s.trim().to_string()) + }) + } + + pub fn get_field_default_or_zero(&self, key: &str) -> Result { + let field = self.match_path.as_ref().unwrap().get_field(key).ok_or(RvError::ErrRequestNoDataField)?; + field.get_default() + } + //TODO: the sensitive data is still in the memory. Need to totally resolve this in `serde_json` someday. pub fn clear_data(&mut self, key: &str) { if self.data.is_some() { @@ -197,4 +216,8 @@ impl Request { self.storage.as_ref().unwrap().delete(key) } + + pub fn add_task(&mut self, task: JoinHandle<()>) { + self.tasks.push(task); + } } diff --git a/src/logical/response.rs b/src/logical/response.rs index d812f6a..d1af04a 100644 --- a/src/logical/response.rs +++ b/src/logical/response.rs @@ -1,21 +1,38 @@ use std::collections::HashMap; +use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; +use lazy_static::lazy_static; -use crate::logical::{secret::SecretData, Auth}; +use crate::{ + errors::RvError, + logical::{Auth, secret::SecretData}, +}; -#[derive(Debug, Clone)] +lazy_static! { + static ref HTTP_RAW_BODY: &'static str = "http_raw_body"; + static ref HTTP_CONTENT_TYPE: &'static str = "http_content_type"; + static ref HTTP_STATUS_CODE: &'static str = "http_status_code"; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Response { + #[serde(default)] + pub request_id: String, + #[serde(skip)] pub headers: Option>, pub data: Option>, pub auth: Option, pub secret: Option, pub redirect: String, + // warnings allow operations or backends to return warnings in response + // to user actions without failing the action outright. + pub warnings: Vec, } impl Default for Response { fn default() -> Self { - Response { headers: None, data: None, auth: None, secret: None, redirect: String::new() } + Response { request_id: String::new(), headers: None, data: None, auth: None, secret: None, redirect: String::new(), warnings: Vec::new(), } } } @@ -75,4 +92,33 @@ impl Response { ); resp } + + pub fn respond_with_status_code(resp: Option, code: u8) -> Self { + let mut ret = Response::new(); + let mut data: Map = json!({ + HTTP_CONTENT_TYPE.to_string(): "application/json", + HTTP_STATUS_CODE.to_string(): code, + }).as_object().unwrap().clone(); + + if let Some(response) = resp { + let raw_body = serde_json::to_value(response).unwrap(); + data.insert(HTTP_RAW_BODY.to_string(), raw_body); + } + + ret.data = Some(data); + + ret + } + + pub fn add_warning(&mut self, warning: &str) { + self.warnings.push(warning.to_string()); + } + + pub fn to_string(&self) -> Result { + Ok(serde_json::to_string(self)?) + } + + pub fn set_request_id(&mut self, id: &str) { + self.request_id = id.to_string() + } } diff --git a/src/logical/secret.rs b/src/logical/secret.rs index 8d7cc69..aefe5bc 100644 --- a/src/logical/secret.rs +++ b/src/logical/secret.rs @@ -14,8 +14,10 @@ type SecretOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result, } diff --git a/src/modules/auth/token_store.rs b/src/modules/auth/token_store.rs index 885e2bd..71c4d42 100644 --- a/src/modules/auth/token_store.rs +++ b/src/modules/auth/token_store.rs @@ -12,7 +12,7 @@ use super::{ }; use crate::{ core::Core, - errors::RvError, + context::Context, errors::RvError, handler::Handler, logical::{ Auth, Backend, Field, FieldType, Lease, LogicalBackend, Operation, Path, PathOperation, Request, Response, diff --git a/src/modules/credential/userpass/mod.rs b/src/modules/credential/userpass/mod.rs index 8233228..4b6eaae 100644 --- a/src/modules/credential/userpass/mod.rs +++ b/src/modules/credential/userpass/mod.rs @@ -259,6 +259,7 @@ mod test { #[test] fn test_userpass_module() { let dir = env::temp_dir().join("rusty_vault_credential_userpass_module"); + let _ = fs::remove_dir_all(&dir).is_ok(); assert!(fs::create_dir(&dir).is_ok()); defer! ( assert!(fs::remove_dir_all(&dir).is_ok()); diff --git a/src/modules/credential/userpass/path_login.rs b/src/modules/credential/userpass/path_login.rs index 744d432..b16dc32 100644 --- a/src/modules/credential/userpass/path_login.rs +++ b/src/modules/credential/userpass/path_login.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use super::{UserPassBackend, UserPassBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Auth, Backend, Field, FieldType, Lease, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/credential/userpass/path_users.rs b/src/modules/credential/userpass/path_users.rs index ba9ebd1..b936e32 100644 --- a/src/modules/credential/userpass/path_users.rs +++ b/src/modules/credential/userpass/path_users.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use super::{UserPassBackend, UserPassBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/kv/mod.rs b/src/modules/kv/mod.rs index 441027a..316e17c 100644 --- a/src/modules/kv/mod.rs +++ b/src/modules/kv/mod.rs @@ -13,7 +13,7 @@ use serde_json::{Map, Value}; use crate::{ core::Core, - errors::RvError, + context::Context, errors::RvError, logical::{ secret::Secret, Backend, Field, FieldType, LogicalBackend, Operation, Path, PathOperation, Request, Response, }, diff --git a/src/modules/pki/path_config_ca.rs b/src/modules/pki/path_config_ca.rs index 98120de..0e6fffa 100644 --- a/src/modules/pki/path_config_ca.rs +++ b/src/modules/pki/path_config_ca.rs @@ -8,7 +8,7 @@ use pem; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_config_crl.rs b/src/modules/pki/path_config_crl.rs index 4e7d23b..1112c0f 100644 --- a/src/modules/pki/path_config_crl.rs +++ b/src/modules/pki/path_config_crl.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/pki/path_fetch.rs b/src/modules/pki/path_fetch.rs index c4932f9..5312335 100644 --- a/src/modules/pki/path_fetch.rs +++ b/src/modules/pki/path_fetch.rs @@ -5,7 +5,7 @@ use serde_json::json; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_issue.rs b/src/modules/pki/path_issue.rs index 9c203d6..136c453 100644 --- a/src/modules/pki/path_issue.rs +++ b/src/modules/pki/path_issue.rs @@ -10,7 +10,7 @@ use serde_json::{json, Map, Value}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, utils, utils::cert, diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index f290727..7ff650a 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -5,7 +5,7 @@ use serde_json::{json, Value}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_revoke.rs b/src/modules/pki/path_revoke.rs index 7423c9a..5bbcc41 100644 --- a/src/modules/pki/path_revoke.rs +++ b/src/modules/pki/path_revoke.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use super::{PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, }; diff --git a/src/modules/pki/path_roles.rs b/src/modules/pki/path_roles.rs index 59b1927..da6f728 100644 --- a/src/modules/pki/path_roles.rs +++ b/src/modules/pki/path_roles.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use super::{PkiBackend, PkiBackendInner, util::DEFAULT_MAX_TTL}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, diff --git a/src/modules/pki/path_root.rs b/src/modules/pki/path_root.rs index 1d7806a..56a636f 100644 --- a/src/modules/pki/path_root.rs +++ b/src/modules/pki/path_root.rs @@ -4,7 +4,7 @@ use serde_json::{json, Value}; use super::{field, util, PkiBackend, PkiBackendInner}; use crate::{ - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Operation, Path, PathOperation, Request, Response}, new_path, new_path_internal, utils, }; diff --git a/src/modules/system/mod.rs b/src/modules/system/mod.rs index 9dc448d..6c75d84 100644 --- a/src/modules/system/mod.rs +++ b/src/modules/system/mod.rs @@ -12,7 +12,7 @@ use serde_json::{from_value, json, Map, Value}; use crate::{ core::Core, - errors::RvError, + context::Context, errors::RvError, logical::{Backend, Field, FieldType, LogicalBackend, Operation, Path, PathOperation, Request, Response}, modules::{auth::AuthModule, Module}, mount::MountEntry, diff --git a/src/storage/barrier_aes_gcm.rs b/src/storage/barrier_aes_gcm.rs index d29673d..2aede80 100644 --- a/src/storage/barrier_aes_gcm.rs +++ b/src/storage/barrier_aes_gcm.rs @@ -451,6 +451,7 @@ mod test { #[test] fn test_barriew_storage_api() { let dir = env::temp_dir().join("rusty_vault_test_barriew_storage_api"); + let _ = fs::remove_dir_all(&dir); assert!(fs::create_dir(&dir).is_ok()); defer! ( assert!(fs::remove_dir_all(&dir).is_ok()); diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 02099ca..7af8040 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -28,7 +28,7 @@ pub mod physical; pub mod mysql; /// A trait that abstracts core methods for all storage barrier types. -pub trait Storage { +pub trait Storage: Send + Sync { fn list(&self, prefix: &str) -> Result, RvError>; fn get(&self, key: &str) -> Result, RvError>; fn put(&self, entry: &StorageEntry) -> Result<(), RvError>; diff --git a/src/storage/physical/file.rs b/src/storage/physical/file.rs index 9b2b15e..f848a4f 100644 --- a/src/storage/physical/file.rs +++ b/src/storage/physical/file.rs @@ -147,6 +147,7 @@ mod test { #[test] fn test_file_backend() { let dir = env::temp_dir().join("rusty_vault"); + let _ = fs::remove_dir_all(&dir); assert!(fs::create_dir(&dir).is_ok()); defer! ( assert!(fs::remove_dir_all(&dir).is_ok()); diff --git a/src/utils/ip_sock_addr.rs b/src/utils/ip_sock_addr.rs index e404732..c5bbdf7 100644 --- a/src/utils/ip_sock_addr.rs +++ b/src/utils/ip_sock_addr.rs @@ -38,6 +38,10 @@ impl IpSockAddr { } return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); } + + pub fn to_string(&self) -> String { + format!("{}", self) + } } impl SockAddr for IpSockAddr { @@ -67,11 +71,15 @@ impl SockAddr for IpSockAddr { impl fmt::Display for IpSockAddr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.port == 0 { - write!(f, "{}", self.addr.ip()) - } else { - write!(f, "{}:{}", self.addr.ip(), self.port) + if self.port != 0 { + return write!(f, "{}:{}", self.addr.ip(), self.port); } + + if self.addr.prefix() == 32 { + return write!(f, "{}", self.addr.ip()); + } + + write!(f, "{}/{}", self.addr.ip(), self.addr.prefix()) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 508cd85..54849ce 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -12,11 +12,16 @@ use serde::{Deserialize, Deserializer, Serializer}; use crate::errors::RvError; pub mod cert; +pub mod cidr; +pub mod crypto; +pub mod ip_sock_addr; pub mod key; +pub mod locks; +pub mod policy; pub mod salt; -pub mod cidr; pub mod sock_addr; -pub mod ip_sock_addr; +pub mod string; +pub mod token_util; pub mod unix_sock_addr; pub fn generate_uuid() -> String { @@ -105,11 +110,8 @@ pub fn asn1time_to_timestamp(time_str: &str) -> Result { pub fn hex_encode_with_colon(bytes: &[u8]) -> String { let hex_str = hex::encode(bytes); - let split_hex: Vec = hex_str - .as_bytes() - .chunks(2) - .map(|chunk| String::from_utf8(chunk.to_vec()).unwrap()) - .collect(); + let split_hex: Vec = + hex_str.as_bytes().chunks(2).map(|chunk| String::from_utf8(chunk.to_vec()).unwrap()).collect(); split_hex.join(":") } diff --git a/src/utils/salt.rs b/src/utils/salt.rs index 65e21a7..313bad0 100644 --- a/src/utils/salt.rs +++ b/src/utils/salt.rs @@ -152,6 +152,7 @@ mod test { fn test_salt() { // init the storage let dir = env::temp_dir().join("rusty_vault_test_salt"); + let _ = fs::remove_dir_all(&dir); assert!(fs::create_dir(&dir).is_ok()); defer! ( assert!(fs::remove_dir_all(&dir).is_ok()); diff --git a/src/utils/sock_addr.rs b/src/utils/sock_addr.rs index a73c7ef..57b9856 100644 --- a/src/utils/sock_addr.rs +++ b/src/utils/sock_addr.rs @@ -7,14 +7,28 @@ use std::{ }; use as_any::AsAny; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, Deserializer, Serializer}; use super::{ ip_sock_addr::IpSockAddr, + unix_sock_addr::UnixSockAddr, }; use crate::errors::RvError; +pub trait CloneBox { + fn clone_box(&self) -> Box; +} + +impl CloneBox for T +where + T: 'static + SockAddr + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum SockAddrType { Unknown = 0x0, @@ -25,7 +39,7 @@ pub enum SockAddrType { IP = 0x6, } -pub trait SockAddr: fmt::Display + AsAny { +pub trait SockAddr: fmt::Display + AsAny + fmt::Debug + CloneBox { // contains returns true if the other SockAddr is contained within the receiver fn contains(&self, other: &dyn SockAddr) -> bool; @@ -35,6 +49,7 @@ pub trait SockAddr: fmt::Display + AsAny { fn sock_addr_type(&self) -> SockAddrType; } +#[derive(Debug, Clone)] pub struct SockAddrMarshaler { pub sock_addr: Box, } @@ -43,6 +58,19 @@ impl SockAddrMarshaler { pub fn new(sock_addr: Box) -> Self { SockAddrMarshaler { sock_addr } } + + pub fn from_str(s: &str) -> Result { + let sock_addr = new_sock_addr(s)?; + Ok(SockAddrMarshaler { + sock_addr: sock_addr, + }) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } } impl fmt::Display for SockAddrMarshaler { @@ -51,6 +79,29 @@ impl fmt::Display for SockAddrMarshaler { } } +impl Serialize for SockAddrMarshaler { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let sock_addr_str = self.sock_addr.to_string(); + serializer.serialize_str(&sock_addr_str) + } +} + +impl<'de> Deserialize<'de> for SockAddrMarshaler { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + let sock_addr = new_sock_addr(&s).map_err(serde::de::Error::custom)?; + Ok(SockAddrMarshaler { + sock_addr: sock_addr, + }) + } +} + impl fmt::Display for SockAddrType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let type_str = match self { @@ -76,8 +127,15 @@ impl FromStr for SockAddrType { } pub fn new_sock_addr(s: &str) -> Result, RvError> { - let ret = IpSockAddr::new(s)?; - Ok(Box::new(ret)) + if let Ok(ip) = IpSockAddr::new(s) { + return Ok(Box::new(ip)); + } + + if let Ok(ip) = UnixSockAddr::new(s) { + return Ok(Box::new(ip)); + } + + Err(RvError::ErrResponse(format!("Unable to convert {} to an IPv4 or IPv6 address, or a UNIX Socket", s))) } #[cfg(test)] @@ -140,5 +198,39 @@ mod test { assert!(!ip_addr1.equal(&unix_addr1)); assert!(!unix_addr1.contains(&ip_addr1)); assert!(!unix_addr1.equal(&ip_addr1)); + + let sock_addr1 = new_sock_addr("1.1.1.1").unwrap(); + let sock_addr2 = new_sock_addr("1.1.1.1").unwrap(); + let sock_addr3 = new_sock_addr("2.2.2.2").unwrap(); + let sock_addr4 = new_sock_addr("333.333.333.333"); + let sock_addr5 = new_sock_addr("1.1.1.1:80").unwrap(); + let sock_addr6 = new_sock_addr("1.1.1.1:80").unwrap(); + let sock_addr7 = new_sock_addr("1.1.1.1:8080").unwrap(); + let sock_addr8 = new_sock_addr("2.2.2.2:80").unwrap(); + let sock_addr9 = new_sock_addr("192.168.0.0/16").unwrap(); + let sock_addr10 = new_sock_addr("192.168.0.0/24").unwrap(); + let sock_addr11 = new_sock_addr("192.168.0.1").unwrap(); + let sock_addr12 = new_sock_addr("192.168.1.1").unwrap(); + assert!(sock_addr1.equal(sock_addr2.as_ref())); + assert!(sock_addr1.contains(sock_addr2.as_ref())); + assert!(!sock_addr1.contains(sock_addr3.as_ref())); + assert!(!sock_addr1.equal(sock_addr3.as_ref())); + assert_eq!(sock_addr1.sock_addr_type(), SockAddrType::IPv4); + assert_eq!(sock_addr1.sock_addr_type(), sock_addr2.sock_addr_type()); + assert_ne!(sock_addr1.sock_addr_type(), unix_addr2.sock_addr_type()); + assert!(sock_addr4.is_err()); + assert!(sock_addr5.contains(sock_addr6.as_ref())); + assert!(sock_addr5.equal(sock_addr6.as_ref())); + assert!(!sock_addr5.equal(sock_addr7.as_ref())); + assert!(!sock_addr5.equal(sock_addr8.as_ref())); + assert!(sock_addr9.contains(sock_addr10.as_ref())); + assert!(sock_addr9.contains(sock_addr11.as_ref())); + assert!(sock_addr9.contains(sock_addr12.as_ref())); + assert!(!sock_addr9.contains(sock_addr1.as_ref())); + assert!(sock_addr10.contains(sock_addr9.as_ref())); + assert!(sock_addr10.contains(sock_addr11.as_ref())); + assert!(!sock_addr10.contains(sock_addr12.as_ref())); + assert!(!sock_addr9.equal(sock_addr10.as_ref())); + assert!(!sock_addr9.equal(sock_addr11.as_ref())); } } diff --git a/src/utils/unix_sock_addr.rs b/src/utils/unix_sock_addr.rs index d63b99d..27eaa2b 100644 --- a/src/utils/unix_sock_addr.rs +++ b/src/utils/unix_sock_addr.rs @@ -18,9 +18,14 @@ pub struct UnixSockAddr { impl UnixSockAddr { pub fn new(s: &str) -> Result { - Ok(Self { - path: s.to_string(), - }) + // Check to make sure the string begins with either a '.' or '/', or contains a '/'. + if s.len() > 1 && (s[0..1].contains('.') || s[0..1].contains('/') || s.contains('/')) { + Ok(Self { + path: s.to_string(), + }) + } else { + Err(RvError::ErrResponse(format!("Unable to convert {} to a UNIX Socke, make sure the string begins with either a '.' or '/', or contains a '/'", s))) + } } }