diff --git a/src/logical/backend.rs b/src/logical/backend.rs index 8d574eb..5e5c21b 100644 --- a/src/logical/backend.rs +++ b/src/logical/backend.rs @@ -565,15 +565,15 @@ mod test { }, {op: Operation::Write, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result, RvError> { let array_val = req.get_data("array")?; - let array_default_val = req.get_data("array_default")?; + let array_default_val = req.get_data_or_default("array_default")?; let bool_val = req.get_data("bool")?; - let bool_default_val = req.get_data("bool_default")?; + let bool_default_val = req.get_data_or_default("bool_default")?; let comma_val = req.get_data("comma")?; - let comma_default_val = req.get_data("comma_default")?; + let comma_default_val = req.get_data_or_default("comma_default")?; let map_val = req.get_data("map")?; - let map_default_val = req.get_data("map_default")?; + let map_default_val = req.get_data_or_default("map_default")?; let duration_val = req.get_data("duration")?; - let duration_default_val = req.get_data("duration_default")?; + let duration_default_val = req.get_data_or_default("duration_default")?; let data = json!({ "array": array_val, "array_default": array_default_val, diff --git a/src/logical/request.rs b/src/logical/request.rs index e130932..a24dc45 100644 --- a/src/logical/request.rs +++ b/src/logical/request.rs @@ -62,15 +62,7 @@ impl Request { Self { operation: Operation::Renew, path: path.to_string(), auth, data, ..Default::default() } } - pub fn get_data(&self, key: &str) -> Result { - if self.storage.is_none() || self.match_path.is_none() { - return Err(RvError::ErrRequestNotReady); - } - - if self.data.is_none() && self.body.is_none() { - return Err(RvError::ErrRequestNoData); - } - + fn get_data_raw(&self, key: &str, default: bool) -> Result { let field = self.match_path.as_ref().unwrap().get_field(key); if field.is_none() { return Err(RvError::ErrRequestNoDataField); @@ -95,11 +87,64 @@ impl Request { } } - if field.required { - return Err(RvError::ErrRequestFieldNotFound); + if default { + if field.required { + return Err(RvError::ErrRequestFieldNotFound); + } + + return field.get_default(); + } + + return Err(RvError::ErrRequestFieldNotFound); + } + + pub fn get_data(&self, key: &str) -> Result { + if self.storage.is_none() || self.match_path.is_none() { + return Err(RvError::ErrRequestNotReady); + } + + if self.data.is_none() && self.body.is_none() { + return Err(RvError::ErrRequestNoData); + } + + self.get_data_raw(key, false) + } + + pub fn get_data_or_default(&self, key: &str) -> Result { + if self.storage.is_none() || self.match_path.is_none() { + return Err(RvError::ErrRequestNotReady); + } + + if self.data.is_none() && self.body.is_none() { + return Err(RvError::ErrRequestNoData); + } + + self.get_data_raw(key, true) + } + + pub fn get_data_or_next(&self, keys: &[&str]) -> Result { + if self.storage.is_none() || self.match_path.is_none() { + return Err(RvError::ErrRequestNotReady); + } + + if self.data.is_none() && self.body.is_none() { + return Err(RvError::ErrRequestNoData); + } + + for &key in keys.iter() { + match self.get_data_raw(key, false) { + Ok(raw) => { + return Ok(raw); + }, + Err(e) => { + if e != RvError::ErrRequestFieldNotFound { + return Err(e); + } + } + } } - return field.get_default(); + return Err(RvError::ErrRequestFieldNotFound); } //TODO: the sensitive data is still in the memory. Need to totally resolve this in `serde_json` someday. diff --git a/src/modules/credential/userpass/path_users.rs b/src/modules/credential/userpass/path_users.rs index 4d97e74..ba9ebd1 100644 --- a/src/modules/credential/userpass/path_users.rs +++ b/src/modules/credential/userpass/path_users.rs @@ -149,7 +149,7 @@ impl UserPassBackendInner { pub fn read_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let username_value = req.get_data("username")?; - let username = username_value.as_str().unwrap().to_lowercase(); + let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase(); let entry = self.get_user(req, &username)?; if entry.is_none() { @@ -165,43 +165,41 @@ impl UserPassBackendInner { pub fn write_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let username_value = req.get_data("username")?; - let username = username_value.as_str().unwrap(); + let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase(); let mut user_entry = UserEntry::default(); - let entry = self.get_user(req, username)?; - if entry.is_some() { - user_entry = entry.unwrap(); + if let Some(entry) = self.get_user(req, &username)? { + user_entry = entry; } - let password_value = req.get_data("password")?; - let password = password_value.as_str().unwrap(); - if password != "" { - let password_hash = self.gen_password_hash(password)?; - - user_entry.password_hash = password_hash; + if let Ok(password_value) = req.get_data("password") { + let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + if password != "" { + user_entry.password_hash = self.gen_password_hash(password)?; + } } - let ttl_value = req.get_data("ttl")?; - let ttl = ttl_value.as_u64().unwrap(); + let ttl_value = req.get_data_or_default("ttl")?; + let ttl = ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; if ttl > 0 { user_entry.ttl = Duration::from_secs(ttl); } - let max_ttl_value = req.get_data("max_ttl")?; - let max_ttl = max_ttl_value.as_u64().unwrap(); + let max_ttl_value = req.get_data_or_default("max_ttl")?; + let max_ttl = max_ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; if max_ttl > 0 { user_entry.max_ttl = Duration::from_secs(max_ttl); } - self.set_user(req, username, &user_entry)?; + self.set_user(req, &username, &user_entry)?; Ok(None) } pub fn delete_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let username_value = req.get_data("username")?; - let username = username_value.as_str().unwrap(); + let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if username == "" { return Err(RvError::ErrRequestNoDataField); } @@ -218,7 +216,7 @@ impl UserPassBackendInner { pub fn write_user_password(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let username_value = req.get_data("username")?; - let username = username_value.as_str().unwrap(); + let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let mut user_entry = UserEntry::default(); @@ -228,7 +226,7 @@ impl UserPassBackendInner { } let password_value = req.get_data("password")?; - let password = password_value.as_str().unwrap(); + let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let password_hash = self.gen_password_hash(password)?; diff --git a/src/modules/pki/path_fetch.rs b/src/modules/pki/path_fetch.rs index 53c4550..c4932f9 100644 --- a/src/modules/pki/path_fetch.rs +++ b/src/modules/pki/path_fetch.rs @@ -55,7 +55,6 @@ Using "ca" or "crl" as the value fetches the appropriate information in DER enco fields: { "serial": { field_type: FieldType::Str, - default: "72h", description: "Certificate serial number, in colon- or hyphen-separated octal" } }, diff --git a/src/modules/pki/path_issue.rs b/src/modules/pki/path_issue.rs index b5521fd..9c203d6 100644 --- a/src/modules/pki/path_issue.rs +++ b/src/modules/pki/path_issue.rs @@ -65,20 +65,16 @@ requested common name is allowed by the role policy. impl PkiBackendInner { pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - //let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let mut common_names = Vec::new(); - let common_name_value = req.get_data("common_name")?; + let common_name_value = req.get_data_or_default("common_name")?; let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } - let alt_names_value = req.get_data("alt_names"); - if alt_names_value.is_ok() { - let alt_names_val = alt_names_value.unwrap(); - let alt_names = alt_names_val.as_str().unwrap(); + if let Ok(alt_names_value) = req.get_data("alt_names") { + let alt_names = alt_names_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if alt_names != "" { for v in alt_names.split(',') { common_names.push(v.to_string()); @@ -94,10 +90,8 @@ impl PkiBackendInner { let role_entry = role.unwrap(); let mut ip_sans = Vec::new(); - let ip_sans_value = req.get_data("ip_sans"); - if ip_sans_value.is_ok() { - let ip_sans_val = ip_sans_value.unwrap(); - let ip_sans_str = ip_sans_val.as_str().unwrap(); + if let Ok(ip_sans_value) = req.get_data("ip_sans") { + let ip_sans_str = ip_sans_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if ip_sans_str != "" { for v in ip_sans_str.split(',') { ip_sans.push(v.to_string()); @@ -109,9 +103,8 @@ impl PkiBackendInner { let not_before = SystemTime::now() - Duration::from_secs(10); let mut not_after = not_before + parse_duration("30d").unwrap(); - let ttl_value = req.get_data("ttl")?; - let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - if ttl != "" { + if let Ok(ttl_value) = req.get_data("ttl") { + let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let ttl_dur = parse_duration(ttl)?; let req_ttl_not_after_dur = SystemTime::now() + ttl_dur; let req_ttl_not_after = diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index 9149911..734810c 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -22,12 +22,13 @@ impl PkiBackend { pattern: r"keys/generate/(exported|internal)", fields: { "key_name": { + required: true, field_type: FieldType::Str, description: "key name" }, "key_bits": { - required: true, field_type: FieldType::Int, + default: 0, description: r#" The number of bits to use. Allowed values are 0 (universal default); with rsa key_type: 2048 (default), 3072, or 4096; with ec key_type: 224, 256 (default), @@ -213,9 +214,10 @@ impl PkiBackendInner { pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_type_value = req.get_data("key_type")?; + let key_type_value = req.get_data_or_default("key_type")?; let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let key_bits_value = req.get_data_or_default("key_bits")?; + let key_bits = key_bits_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let mut export_private_key = false; if req.path.ends_with("/exported") { @@ -266,11 +268,11 @@ impl PkiBackendInner { pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_type_value = req.get_data("key_type")?; + let key_type_value = req.get_data_or_default("key_type")?; let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let pem_bundle_value = req.get_data("pem_bundle")?; + let pem_bundle_value = req.get_data_or_default("pem_bundle")?; let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let hex_bundle_value = req.get_data("hex_bundle")?; + let hex_bundle_value = req.get_data_or_default("hex_bundle")?; let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if pem_bundle.len() == 0 && hex_bundle.len() == 0 { @@ -310,7 +312,7 @@ impl PkiBackendInner { return Err(RvError::ErrPkiKeyBitsInvalid); } }; - let iv_value = req.get_data("iv")?; + let iv_value = req.get_data_or_default("iv")?; let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm"); #[cfg(tongsuo)] let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm"); @@ -391,7 +393,7 @@ impl PkiBackendInner { pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let data_value = req.get_data("data")?; let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let aad_value = req.get_data("aad")?; + let aad_value = req.get_data_or_default("aad")?; let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; @@ -412,7 +414,7 @@ impl PkiBackendInner { pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let data_value = req.get_data("data")?; let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let aad_value = req.get_data("aad")?; + let aad_value = req.get_data_or_default("aad")?; let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; diff --git a/src/modules/pki/path_roles.rs b/src/modules/pki/path_roles.rs index d851772..59b1927 100644 --- a/src/modules/pki/path_roles.rs +++ b/src/modules/pki/path_roles.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use humantime::parse_duration; use serde::{Deserialize, Serialize}; -use super::{PkiBackend, PkiBackendInner}; +use super::{PkiBackend, PkiBackendInner, util::DEFAULT_MAX_TTL}; use crate::{ errors::RvError, logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, @@ -12,8 +12,6 @@ use crate::{ utils::{deserialize_duration, serialize_duration}, }; -const DEFAULT_MAX_TTL: Duration = Duration::from_secs(365 * 24 * 60 * 60 as u64); - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoleEntry { #[serde(serialize_with = "serialize_duration", deserialize_with = "deserialize_duration")] @@ -104,15 +102,14 @@ max_ttl, whichever is shorter."# }, "max_ttl": { field_type: FieldType::Str, - required: true, description: r#" - The maximum allowed lease duration. If not set, defaults to the system maximum lease TTL."# +The maximum allowed lease duration. If not set, defaults to the system maximum lease TTL."# }, "use_pss": { field_type: FieldType::Bool, default: false, description: r#" - Whether or not to use PSS signatures when using a RSA key-type issuer. Defaults to false."# +Whether or not to use PSS signatures when using a RSA key-type issuer. Defaults to false."# }, "allow_localhost": { field_type: FieldType::Bool, @@ -157,30 +154,30 @@ See the documentation for more information."# field_type: FieldType::Bool, default: true, description: r#" - If set, IP Subject Alternative Names are allowed. Any valid IP is accepted and No authorization checking is performed."# +If set, IP Subject Alternative Names are allowed. Any valid IP is accepted and No authorization checking is performed."# }, "server_flag": { field_type: FieldType::Bool, default: true, description: r#" - If set, certificates are flagged for server auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# +If set, certificates are flagged for server auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# }, "client_flag": { field_type: FieldType::Bool, default: true, description: r#" - If set, certificates are flagged for client auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# +If set, certificates are flagged for client auth use. defaults to true. See also RFC 5280 Section 4.2.1.12."# }, "code_signing_flag": { field_type: FieldType::Bool, description: r#" - If set, certificates are flagged for code signing use. defaults to false. See also RFC 5280 Section 4.2.1.12."# +If set, certificates are flagged for code signing use. defaults to false. See also RFC 5280 Section 4.2.1.12."# }, "key_type": { field_type: FieldType::Str, default: "rsa", description: r#" - The type of key to use; defaults to RSA. "rsa" "ec", "ed25519" and "any" are the only valid values."# +The type of key to use; defaults to RSA. "rsa" "ec", "ed25519" and "any" are the only valid values."# }, "key_bits": { field_type: FieldType::Int, @@ -215,43 +212,43 @@ The value format should be given in UTC format YYYY-MM-ddTHH:MM:SSZ."# required: false, field_type: FieldType::Str, description: r#" - If set, OU (OrganizationalUnit) will be set to this value in certificates issued by this role."# +If set, OU (OrganizationalUnit) will be set to this value in certificates issued by this role."# }, "organization": { required: false, field_type: FieldType::Str, description: r#" - If set, O (Organization) will be set to this value in certificates issued by this role."# +If set, O (Organization) will be set to this value in certificates issued by this role."# }, "country": { required: false, field_type: FieldType::Str, description: r#" - If set, Country will be set to this value in certificates issued by this role."# +If set, Country will be set to this value in certificates issued by this role."# }, "locality": { required: false, field_type: FieldType::Str, description: r#" - If set, Locality will be set to this value in certificates issued by this role."# +If set, Locality will be set to this value in certificates issued by this role."# }, "province": { required: false, field_type: FieldType::Str, description: r#" - If set, Province will be set to this value in certificates issued by this role."# +If set, Province will be set to this value in certificates issued by this role."# }, "street_address": { required: false, field_type: FieldType::Str, description: r#" - If set, Street Address will be set to this value."# +If set, Street Address will be set to this value."# }, "postal_code": { required: false, field_type: FieldType::Str, description: r#" - If set, Postal Code will be set to this value."# +If set, Postal Code will be set to this value."# }, "use_csr_common_name": { field_type: FieldType::Bool, @@ -326,11 +323,23 @@ impl PkiBackendInner { pub fn create_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let name_value = req.get_data("name")?; let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; - let max_ttl = parse_duration(req.get_data("max_ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; - let key_type_value = req.get_data("key_type")?; + let mut ttl = DEFAULT_MAX_TTL; + if let Ok(ttl_value) = req.get_data("ttl") { + let ttl_str = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + if ttl_str != "" { + ttl = parse_duration(ttl_str)?; + } + } + let mut max_ttl = DEFAULT_MAX_TTL; + if let Ok(max_ttl_value) = req.get_data("max_ttl") { + let max_ttl_str = max_ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + if max_ttl_str != "" { + max_ttl = parse_duration(max_ttl_str)?; + } + } + let key_type_value = req.get_data_or_default("key_type")?; let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data_or_default("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -355,27 +364,27 @@ impl PkiBackendInner { } } - let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_localhost = req.get_data("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_bare_domains = req.get_data("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_subdomains = req.get_data("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_any_name = req.get_data("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let allow_ip_sans = req.get_data("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let server_flag = req.get_data("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let client_flag = req.get_data("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let use_csr_sans = req.get_data("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let use_csr_common_name = req.get_data("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let no_store = req.get_data("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let generate_lease = req.get_data("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let signature_bits = req.get_data_or_default("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_localhost = req.get_data_or_default("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_bare_domains = req.get_data_or_default("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_subdomains = req.get_data_or_default("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_any_name = req.get_data_or_default("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_ip_sans = req.get_data_or_default("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let server_flag = req.get_data_or_default("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let client_flag = req.get_data_or_default("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_sans = req.get_data_or_default("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_common_name = req.get_data_or_default("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data_or_default("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data_or_default("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data_or_default("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data_or_default("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let no_store = req.get_data_or_default("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let generate_lease = req.get_data_or_default("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_after = req.get_data_or_default("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_before_duration_u64 = req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let not_before_duration = Duration::from_secs(not_before_duration_u64); let role_entry = RoleEntry { diff --git a/src/modules/pki/path_root.rs b/src/modules/pki/path_root.rs index e645163..1d7806a 100644 --- a/src/modules/pki/path_root.rs +++ b/src/modules/pki/path_root.rs @@ -46,7 +46,7 @@ impl PkiBackend { impl PkiBackendInner { pub fn generate_root(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let mut export_private_key = false; - if req.get_data("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" { + if req.get_data_or_default("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" { export_private_key = true; } diff --git a/src/modules/pki/util.rs b/src/modules/pki/util.rs index 7f6bdc9..05c56b9 100644 --- a/src/modules/pki/util.rs +++ b/src/modules/pki/util.rs @@ -6,13 +6,21 @@ use openssl::x509::X509NameBuilder; use super::path_roles::RoleEntry; use crate::{errors::RvError, logical::Request, utils::cert::Certificate}; +pub const DEFAULT_MAX_TTL: Duration = Duration::from_secs(365 * 24 * 60 * 60 as u64); + pub fn get_role_params(req: &mut Request) -> Result { - let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; - let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut ttl = DEFAULT_MAX_TTL; + if let Ok(ttl_value) = req.get_data("ttl") { + let ttl_str = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + if ttl_str != "" { + ttl = parse_duration(ttl_str)?; + } + } + let not_before_duration_u64 = req.get_data_or_default("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let not_before_duration = Duration::from_secs(not_before_duration_u64); - let key_type_value = req.get_data("key_type")?; + let key_type_value = req.get_data_or_default("key_type")?; let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data_or_default("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -37,16 +45,16 @@ pub fn get_role_params(req: &mut Request) -> Result { } } - let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; - let use_pss = req.get_data("use_pss")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; - let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); - let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let signature_bits = req.get_data_or_default("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_pss = req.get_data_or_default("use_pss")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data_or_default("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data_or_default("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data_or_default("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data_or_default("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data_or_default("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data_or_default("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data_or_default("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_after = req.get_data_or_default("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let role_entry = RoleEntry { ttl, @@ -72,16 +80,14 @@ pub fn get_role_params(req: &mut Request) -> Result { pub fn generate_certificate(role_entry: &RoleEntry, req: &mut Request) -> Result { let mut common_names = Vec::new(); - let common_name_value = req.get_data("common_name")?; + let common_name_value = req.get_data_or_default("common_name")?; let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } - let alt_names_value = req.get_data("alt_names"); - if alt_names_value.is_ok() { - let alt_names_val = alt_names_value.unwrap(); - let alt_names = alt_names_val.as_str().unwrap(); + if let Ok(alt_names_value) = req.get_data("alt_names") { + let alt_names = alt_names_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if alt_names != "" { for v in alt_names.split(',') { common_names.push(v.to_string()); @@ -90,10 +96,8 @@ pub fn generate_certificate(role_entry: &RoleEntry, req: &mut Request) -> Result } let mut ip_sans = Vec::new(); - let ip_sans_value = req.get_data("ip_sans"); - if ip_sans_value.is_ok() { - let ip_sans_val = ip_sans_value.unwrap(); - let ip_sans_str = ip_sans_val.as_str().unwrap(); + if let Ok(ip_sans_value) = req.get_data("ip_sans") { + let ip_sans_str = ip_sans_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if ip_sans_str != "" { for v in ip_sans_str.split(',') { ip_sans.push(v.to_string()); diff --git a/src/modules/system/mod.rs b/src/modules/system/mod.rs index f55450b..9dc448d 100644 --- a/src/modules/system/mod.rs +++ b/src/modules/system/mod.rs @@ -289,7 +289,7 @@ impl SystemBackendInner { pub fn handle_mount(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let path = req.get_data("path")?; let logical_type = req.get_data("type")?; - let description = req.get_data("description")?; + let description = req.get_data_or_default("description")?; let path = path.as_str().unwrap(); let logical_type = logical_type.as_str().unwrap(); @@ -382,7 +382,7 @@ impl SystemBackendInner { pub fn handle_auth_enable(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let path = req.get_data("path")?; let logical_type = req.get_data("type")?; - let description = req.get_data("description")?; + let description = req.get_data_or_default("description")?; let path = path.as_str().unwrap(); let logical_type = logical_type.as_str().unwrap(); diff --git a/src/utils/key.rs b/src/utils/key.rs index 9c86853..0348345 100644 --- a/src/utils/key.rs +++ b/src/utils/key.rs @@ -42,6 +42,15 @@ impl Default for KeyBundle { } } +fn key_bits_default(key_type: &str) -> u32 { + return match key_type { + "rsa" => 2048, + "ec" | "sm2" => 256, + "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm" => 256, + _ => 0, + } +} + fn cipher_from_key_type_and_bits(key_type: &str, bits: u32) -> Result { match (key_type, bits) { ("aes-gcm", 128) => Ok(Cipher::aes_128_gcm()), @@ -63,7 +72,12 @@ fn cipher_from_key_type_and_bits(key_type: &str, bits: u32) -> Result Self { - Self { name: name.to_string(), key_type: key_type.to_string(), bits: key_bits, ..KeyBundle::default() } + let bits = if key_bits == 0 { + key_bits_default(key_type) + } else { + key_bits + }; + Self { name: name.to_string(), key_type: key_type.to_string(), bits: bits, ..KeyBundle::default() } } pub fn generate(&mut self) -> Result<(), RvError> {