Skip to content

Commit

Permalink
Added the get_data_or_default and get_data_or_next interfaces.
Browse files Browse the repository at this point in the history
1. Added the get_data_or_default interface which retrieves a parameter from
   the request, or defaults to a predefined value if retrieval fails. This
   is distinct from the get_data interface which only fetches the parameter's
   value from the request.
2. Added the get_data_or_next interface which retrieves the data of the
   first field from the request; if it does not exist, it proceeds to
   fetch the data of the next field.
  • Loading branch information
wa5i committed May 12, 2024
1 parent 7e8c441 commit 716be26
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 130 deletions.
10 changes: 5 additions & 5 deletions src/logical/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,15 @@ mod test {
},
{op: Operation::Write, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result<Option<Response>, RvError> {
let array_val = req.get_data("array")?;
let array_default_val = req.get_data("array_default")?;
let array_default_val = req.get_data_or_default("array_default")?;
let bool_val = req.get_data("bool")?;
let bool_default_val = req.get_data("bool_default")?;
let bool_default_val = req.get_data_or_default("bool_default")?;
let comma_val = req.get_data("comma")?;
let comma_default_val = req.get_data("comma_default")?;
let comma_default_val = req.get_data_or_default("comma_default")?;
let map_val = req.get_data("map")?;
let map_default_val = req.get_data("map_default")?;
let map_default_val = req.get_data_or_default("map_default")?;
let duration_val = req.get_data("duration")?;
let duration_default_val = req.get_data("duration_default")?;
let duration_default_val = req.get_data_or_default("duration_default")?;
let data = json!({
"array": array_val,
"array_default": array_default_val,
Expand Down
69 changes: 57 additions & 12 deletions src/logical/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,7 @@ impl Request {
Self { operation: Operation::Renew, path: path.to_string(), auth, data, ..Default::default() }
}

pub fn get_data(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

fn get_data_raw(&self, key: &str, default: bool) -> Result<Value, RvError> {
let field = self.match_path.as_ref().unwrap().get_field(key);
if field.is_none() {
return Err(RvError::ErrRequestNoDataField);
Expand All @@ -95,11 +87,64 @@ impl Request {
}
}

if field.required {
return Err(RvError::ErrRequestFieldNotFound);
if default {
if field.required {
return Err(RvError::ErrRequestFieldNotFound);
}

return field.get_default();
}

return Err(RvError::ErrRequestFieldNotFound);
}

pub fn get_data(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

self.get_data_raw(key, false)
}

pub fn get_data_or_default(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

self.get_data_raw(key, true)
}

pub fn get_data_or_next(&self, keys: &[&str]) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

for &key in keys.iter() {
match self.get_data_raw(key, false) {
Ok(raw) => {
return Ok(raw);
},
Err(e) => {
if e != RvError::ErrRequestFieldNotFound {
return Err(e);
}
}
}
}

return field.get_default();
return Err(RvError::ErrRequestFieldNotFound);
}

//TODO: the sensitive data is still in the memory. Need to totally resolve this in `serde_json` someday.
Expand Down
36 changes: 17 additions & 19 deletions src/modules/credential/userpass/path_users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl UserPassBackendInner {

pub fn read_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap().to_lowercase();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase();

let entry = self.get_user(req, &username)?;
if entry.is_none() {
Expand All @@ -165,43 +165,41 @@ impl UserPassBackendInner {

pub fn write_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase();

let mut user_entry = UserEntry::default();

let entry = self.get_user(req, username)?;
if entry.is_some() {
user_entry = entry.unwrap();
if let Some(entry) = self.get_user(req, &username)? {
user_entry = entry;
}

let password_value = req.get_data("password")?;
let password = password_value.as_str().unwrap();
if password != "" {
let password_hash = self.gen_password_hash(password)?;

user_entry.password_hash = password_hash;
if let Ok(password_value) = req.get_data("password") {
let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if password != "" {
user_entry.password_hash = self.gen_password_hash(password)?;
}
}

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_u64().unwrap();
let ttl_value = req.get_data_or_default("ttl")?;
let ttl = ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl > 0 {
user_entry.ttl = Duration::from_secs(ttl);
}

let max_ttl_value = req.get_data("max_ttl")?;
let max_ttl = max_ttl_value.as_u64().unwrap();
let max_ttl_value = req.get_data_or_default("max_ttl")?;
let max_ttl = max_ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
if max_ttl > 0 {
user_entry.max_ttl = Duration::from_secs(max_ttl);
}

self.set_user(req, username, &user_entry)?;
self.set_user(req, &username, &user_entry)?;

Ok(None)
}

pub fn delete_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if username == "" {
return Err(RvError::ErrRequestNoDataField);
}
Expand All @@ -218,7 +216,7 @@ impl UserPassBackendInner {

pub fn write_user_password(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut user_entry = UserEntry::default();

Expand All @@ -228,7 +226,7 @@ impl UserPassBackendInner {
}

let password_value = req.get_data("password")?;
let password = password_value.as_str().unwrap();
let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let password_hash = self.gen_password_hash(password)?;

Expand Down
1 change: 0 additions & 1 deletion src/modules/pki/path_fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Using "ca" or "crl" as the value fetches the appropriate information in DER enco
fields: {
"serial": {
field_type: FieldType::Str,
default: "72h",
description: "Certificate serial number, in colon- or hyphen-separated octal"
}
},
Expand Down
21 changes: 7 additions & 14 deletions src/modules/pki/path_issue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,16 @@ requested common name is allowed by the role policy.

impl PkiBackendInner {
pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
//let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut common_names = Vec::new();

let common_name_value = req.get_data("common_name")?;
let common_name_value = req.get_data_or_default("common_name")?;
let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if common_name != "" {
common_names.push(common_name.to_string());
}

let alt_names_value = req.get_data("alt_names");
if alt_names_value.is_ok() {
let alt_names_val = alt_names_value.unwrap();
let alt_names = alt_names_val.as_str().unwrap();
if let Ok(alt_names_value) = req.get_data("alt_names") {
let alt_names = alt_names_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if alt_names != "" {
for v in alt_names.split(',') {
common_names.push(v.to_string());
Expand All @@ -94,10 +90,8 @@ impl PkiBackendInner {
let role_entry = role.unwrap();

let mut ip_sans = Vec::new();
let ip_sans_value = req.get_data("ip_sans");
if ip_sans_value.is_ok() {
let ip_sans_val = ip_sans_value.unwrap();
let ip_sans_str = ip_sans_val.as_str().unwrap();
if let Ok(ip_sans_value) = req.get_data("ip_sans") {
let ip_sans_str = ip_sans_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ip_sans_str != "" {
for v in ip_sans_str.split(',') {
ip_sans.push(v.to_string());
Expand All @@ -109,9 +103,8 @@ impl PkiBackendInner {
let not_before = SystemTime::now() - Duration::from_secs(10);
let mut not_after = not_before + parse_duration("30d").unwrap();

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl != "" {
if let Ok(ttl_value) = req.get_data("ttl") {
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let ttl_dur = parse_duration(ttl)?;
let req_ttl_not_after_dur = SystemTime::now() + ttl_dur;
let req_ttl_not_after =
Expand Down
20 changes: 11 additions & 9 deletions src/modules/pki/path_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ impl PkiBackend {
pattern: r"keys/generate/(exported|internal)",
fields: {
"key_name": {
required: true,
field_type: FieldType::Str,
description: "key name"
},
"key_bits": {
required: true,
field_type: FieldType::Int,
default: 0,
description: r#"
The number of bits to use. Allowed values are 0 (universal default); with rsa
key_type: 2048 (default), 3072, or 4096; with ec key_type: 224, 256 (default),
Expand Down Expand Up @@ -213,9 +214,10 @@ impl PkiBackendInner {
pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type_value = req.get_data_or_default("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits_value = req.get_data_or_default("key_bits")?;
let key_bits = key_bits_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut export_private_key = false;
if req.path.ends_with("/exported") {
Expand Down Expand Up @@ -266,11 +268,11 @@ impl PkiBackendInner {
pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type_value = req.get_data_or_default("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle_value = req.get_data_or_default("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let hex_bundle_value = req.get_data("hex_bundle")?;
let hex_bundle_value = req.get_data_or_default("hex_bundle")?;
let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

if pem_bundle.len() == 0 && hex_bundle.len() == 0 {
Expand Down Expand Up @@ -310,7 +312,7 @@ impl PkiBackendInner {
return Err(RvError::ErrPkiKeyBitsInvalid);
}
};
let iv_value = req.get_data("iv")?;
let iv_value = req.get_data_or_default("iv")?;
let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm");
#[cfg(tongsuo)]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm");
Expand Down Expand Up @@ -391,7 +393,7 @@ impl PkiBackendInner {
pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let data_value = req.get_data("data")?;
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad_value = req.get_data_or_default("aad")?;
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
Expand All @@ -412,7 +414,7 @@ impl PkiBackendInner {
pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let data_value = req.get_data("data")?;
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad_value = req.get_data_or_default("aad")?;
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
Expand Down
Loading

0 comments on commit 716be26

Please sign in to comment.