Skip to content

Commit

Permalink
Improve Context and add several interface functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
wa5i committed Aug 18, 2024
1 parent 8833a13 commit 12151d1
Show file tree
Hide file tree
Showing 33 changed files with 422 additions and 87 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ glob = "0.3"
serde_asn1_der = "0.8"
base64 = "0.22"
ipnetwork = "0.20"
blake2b_simd = "1.0"
derive_more = "0.99.17"
dashmap = "5.5"
tokio = "1.38"

# optional dependencies
openssl = { version = "0.10", optional = true }
Expand Down
30 changes: 19 additions & 11 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,35 @@
use std::{
any::Any,
cell::RefCell,
collections::HashMap,
sync::{Arc, Mutex},
sync::{Arc, RwLock},
};

use dashmap::DashMap;

#[derive(Debug)]
pub struct Context {
data_map: Mutex<HashMap<String, Arc<RefCell<dyn Any>>>>,
data_map: DashMap<String, Arc<dyn Any + Send + Sync>>,
data_map_mut: DashMap<String, Arc<RwLock<dyn Any + Send + Sync>>>,
}

impl Context {
pub fn new() -> Self {
Self { data_map: Mutex::new(HashMap::new()) }
Self { data_map: DashMap::new(), data_map_mut: DashMap::new() }
}

pub fn set_mut(&self, key: &str, data: Arc<RwLock<dyn Any + Send + Sync>>) {
self.data_map_mut.insert(key.to_string(), data);
}

pub fn get_mut(&self, key: &str) -> Option<Arc<RwLock<dyn Any + Send + Sync>>> {
self.data_map_mut.get(key).map(|r| Arc::clone(&r.value()))
}

pub fn set(&self, key: &str, data: Arc<RefCell<dyn Any>>) {
let mut data_map = self.data_map.lock().unwrap();
data_map.insert(key.to_string(), data);
pub fn set(&self, key: &str, data: Arc<dyn Any + Send + Sync>) {
self.data_map.insert(key.to_string(), data);
}

pub fn get(&self, key: &str) -> Option<Arc<RefCell<dyn Any>>> {
let data_map = self.data_map.lock().unwrap();
data_map.get(key).cloned()
pub fn get(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
self.data_map.get(key).map(|r| Arc::clone(&*r))
}
}
36 changes: 24 additions & 12 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,37 @@
use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
ops::{Deref, DerefMut},
sync::{Arc, Mutex, RwLock},
};

use as_any::Downcast;
use go_defer::defer;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;

use crate::{
cli::config::Config,
errors::RvError,
handler::Handler,
logical::{Backend, Request, Response},
module_manager::ModuleManager,
modules::{auth::AuthModule, credential::userpass::UserPassModule, pki::PkiModule},
modules::{
auth::AuthModule,
credential::{
userpass::UserPassModule,
},
pki::PkiModule,
},
mount::MountTable,
router::Router,
shamir::{ShamirSecret, SHAMIR_OVERHEAD},
storage::{
barrier::SecurityBarrier,
barrier_aes_gcm,
barrier_view::BarrierView,
physical,
Backend as PhysicalBackend,
BackendEntry as PhysicalBackendEntry,
barrier::SecurityBarrier, barrier_aes_gcm, barrier_view::BarrierView, physical, Backend as PhysicalBackend,
BackendEntry as PhysicalBackendEntry, Storage,
},
};

use zeroize::Zeroizing;

pub type LogicalBackendNewFunc = dyn Fn(Arc<RwLock<Core>>) -> Result<Arc<dyn Backend>, RvError> + Send + Sync;

pub const SEAL_CONFIG_PATH: &str = "core/seal-config";
Expand Down Expand Up @@ -156,8 +157,11 @@ impl Core {
if seal_config.secret_shares == 1 {
init_result.secret_shares.deref_mut().push(master_key.deref().clone());
} else {
init_result.secret_shares =
ShamirSecret::split(master_key.deref().as_slice(), seal_config.secret_shares, seal_config.secret_threshold)?;
init_result.secret_shares = ShamirSecret::split(
master_key.deref().as_slice(),
seal_config.secret_shares,
seal_config.secret_threshold,
)?;
}

log::debug!("master_key: {}", hex::encode(master_key.deref()));
Expand Down Expand Up @@ -197,6 +201,14 @@ impl Core {
Ok(init_result)
}

pub fn get_system_view(&self) -> Option<Arc<BarrierView>> {
self.system_view.clone()
}

pub fn get_system_storage(&self) -> &dyn Storage {
self.system_view.as_ref().unwrap().as_storage()
}

pub fn get_logical_backend(&self, logical_type: &str) -> Result<Arc<LogicalBackendNewFunc>, RvError> {
let logical_backends = self.logical_backends.lock().unwrap();
if let Some(backend) = logical_backends.get(logical_type) {
Expand Down
13 changes: 12 additions & 1 deletion src/logical/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ use regex::Regex;
use serde_json::{Map, Value};

use super::{path::Path, request::Request, response::Response, secret::Secret, FieldType, Backend, Operation};
use crate::errors::RvError;
use crate::{
context::Context, errors::RvError
};

type BackendOperationHandler = dyn Fn(&dyn Backend, &mut Request) -> Result<Option<Response>, RvError> + Send + Sync;

pub const CTX_KEY_BACKEND_PATH: &str = "backend.path";

#[derive(Clone)]
pub struct LogicalBackend {
pub paths: Vec<Arc<Path>>,
Expand All @@ -17,6 +21,7 @@ pub struct LogicalBackend {
pub help: String,
pub secrets: Vec<Arc<Secret>>,
pub auth_renew_handler: Option<Arc<BackendOperationHandler>>,
pub ctx: Arc<Context>,
}

impl Backend for LogicalBackend {
Expand Down Expand Up @@ -58,6 +63,10 @@ impl Backend for LogicalBackend {
Some(self.root_paths.clone())
}

fn get_ctx(&self) -> Option<Arc<Context>> {
Some(Arc::clone(&self.ctx))
}

fn handle_request(&self, req: &mut Request) -> Result<Option<Response>, RvError> {
if req.storage.is_none() {
return Err(RvError::ErrRequestNotReady);
Expand Down Expand Up @@ -86,6 +95,7 @@ impl Backend for LogicalBackend {
req.match_path = Some(path.clone());
for operation in &path.operations {
if operation.op == req.operation {
self.ctx.set(CTX_KEY_BACKEND_PATH, path.clone());
let ret = operation.handle_request(self, req);
self.clear_secret_field(req);
return ret;
Expand Down Expand Up @@ -113,6 +123,7 @@ impl LogicalBackend {
help: String::new(),
secrets: Vec::new(),
auth_renew_handler: None,
ctx: Arc::new(Context::new()),
}
}

Expand Down
153 changes: 133 additions & 20 deletions src/logical/field.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{fmt, time::Duration};
use std::{fmt, time::Duration, collections::HashMap};

use enum_map::Enum;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use strum::{Display, EnumString};
use humantime::parse_duration;

use crate::errors::RvError;

Expand Down Expand Up @@ -39,9 +40,11 @@ pub trait FieldTrait {
fn is_int(&self) -> bool;
fn is_duration(&self) -> bool;
fn is_comma_string_slice(&self) -> bool;
fn is_map(&self) -> bool;
fn as_int(&self) -> Option<i64>;
fn as_duration(&self) -> Option<Duration>;
fn as_comma_string_slice(&self) -> Option<Vec<String>>;
fn as_map(&self) -> Option<HashMap<String, String>>;
}

impl FieldTrait for Value {
Expand All @@ -68,17 +71,15 @@ impl FieldTrait for Value {
return true;
}

let secs_str = self.as_str();
if secs_str.is_none() {
return false;
}

let secs = secs_str.unwrap().parse::<i64>().ok();
if secs.is_none() {
return false;
if let Some(secs_str) = self.as_str() {
if secs_str.parse::<u64>().ok().is_some() {
return true;
} else if parse_duration(secs_str).is_ok() {
return true;
}
}

true
false
}

fn is_comma_string_slice(&self) -> bool {
Expand Down Expand Up @@ -115,6 +116,20 @@ impl FieldTrait for Value {
false
}

fn is_map(&self) -> bool {
if self.is_object() {
return true;
}

let map_str = self.as_str();
if map_str.is_none() {
return false;
}

let map = serde_json::from_str::<Value>(map_str.unwrap());
return map.is_ok() && map.unwrap().is_object();
}

fn as_int(&self) -> Option<i64> {
let mut int = self.as_i64();
if int.is_none() {
Expand All @@ -133,19 +148,19 @@ impl FieldTrait for Value {
}

fn as_duration(&self) -> Option<Duration> {
let mut secs = self.as_u64();
if secs.is_none() {
let secs_str = self.as_str();
if secs_str.is_none() {
return None;
}
if let Some(secs) = self.as_u64() {
return Some(Duration::from_secs(secs))
}

secs = secs_str.unwrap().parse::<u64>().ok();
if secs.is_none() {
return None;
if let Some(secs_str) = self.as_str() {
if let Some(secs_int) = secs_str.parse::<u64>().ok() {
return Some(Duration::from_secs(secs_int));
} else if let Ok(ret) = parse_duration(secs_str) {
return Some(ret);
}
}
Some(Duration::from_secs(secs.unwrap()))

None
}

fn as_comma_string_slice(&self) -> Option<Vec<String>> {
Expand Down Expand Up @@ -185,6 +200,27 @@ impl FieldTrait for Value {

None
}

fn as_map(&self) -> Option<HashMap<String, String>> {
let mut ret: HashMap<String, String> = HashMap::new();
if let Some(map) = self.as_object() {
for (key, value) in map {
if !value.is_string() {
continue;
}
ret.insert(key.clone(), value.as_str().unwrap().to_string());
}

return Some(ret);
}

if let Some(value) = self.as_str() {
let map: HashMap<String, String> = serde_json::from_str(value).unwrap_or(HashMap::new());
return Some(map);
}

None
}
}

impl Field {
Expand Down Expand Up @@ -378,4 +414,81 @@ mod test {
let val = field.get_default().unwrap();
assert_eq!(val.as_comma_string_slice(), Some(val_str.iter().map(|&s| s.to_string()).collect::<Vec<String>>()));
}

#[test]
fn test_field_trait() {
let mut val = json!("45");
assert!(val.is_int());
assert_eq!(val.as_int(), Some(45));
assert!(val.is_duration());
assert_eq!(val.as_duration(), Some(Duration::from_secs(45)));

val = json!(50);
assert!(val.is_int());
assert_eq!(val.as_int(), Some(50));
assert!(val.is_duration());
assert_eq!(val.as_duration(), Some(Duration::from_secs(50)));

val = json!("45s");
assert!(!val.is_int());
assert_eq!(val.as_int(), None);
assert!(val.is_duration());
assert_eq!(val.as_duration(), Some(Duration::from_secs(45)));

val = json!("5m");
assert!(!val.is_int());
assert_eq!(val.as_int(), None);
assert!(val.is_duration());
assert_eq!(val.as_duration(), Some(Duration::from_secs(300)));

val = json!("aa, bb, cc ,dd");
assert!(val.is_comma_string_slice());
assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string()]));

val = json!(["aaa", " bbb", "ccc " , " ddd"]);
assert!(val.is_comma_string_slice());
assert_eq!(val.as_comma_string_slice(), Some(vec!["aaa".to_string(), "bbb".to_string(), "ccc".to_string(), "ddd".to_string()]));

val = json!([11, 22, 33, 44]);
assert!(val.is_comma_string_slice());
assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "22".to_string(), "33".to_string(), "44".to_string()]));

val = json!([11, "aa22", 33, 44]);
assert!(val.is_comma_string_slice());
assert_eq!(val.as_comma_string_slice(), Some(vec!["11".to_string(), "aa22".to_string(), "33".to_string(), "44".to_string()]));

val = json!("aa, bb, cc ,dd, , 88,");
assert!(val.is_comma_string_slice());
assert_eq!(val.as_comma_string_slice(), Some(vec!["aa".to_string(), "bb".to_string(), "cc".to_string(), "dd".to_string(), "88".to_string()]));

let mut map: HashMap<String, String> = HashMap::new();
map.insert("k1".to_string(), "v1".to_string());
map.insert("k2".to_string(), "v2".to_string());

val = json!(r#"{"k1": "v1", "k2": "v2"}"#);
assert!(val.is_map());
assert_eq!(val.as_map(), Some(map.clone()));

val = serde_json::from_str(r#"{"k1": "v1", "k2": "v2"}"#).unwrap();
assert!(val.is_map());
assert_eq!(val.as_map(), Some(map.clone()));

val = serde_json::from_str(r#"{"k1": "v1", "k2": {"kk2": "vv2"}}"#).unwrap();
assert!(val.is_map());
map.remove("k2");
assert_eq!(val.as_map(), Some(map.clone()));

map.clear();
map.insert("tag1".to_string(), "production".to_string());
val = json!({
"metadata": "{ \"tag1\": \"production\" }",
"ttl": 600,
"num_uses": 50
});
assert!(val.is_object());
assert!(val.is_map());
let obj = val.as_object().unwrap();
assert!(obj["metadata"].is_map());
assert_eq!(obj["metadata"].as_map(), Some(map));
}
}
Loading

0 comments on commit 12151d1

Please sign in to comment.