From 6eac13138712e03f22e34922e755639307fca3f8 Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Mon, 6 May 2024 16:55:43 +0800 Subject: [PATCH] Switch to rust-tongsuo, supporting SM2 and SM4 algorithms. --- Cargo.toml | 9 +- build.rs | 7 + src/modules/pki/path_config_ca.rs | 2 +- src/modules/pki/path_fetch.rs | 2 +- src/modules/pki/path_issue.rs | 9 +- src/modules/pki/path_keys.rs | 79 +++++---- src/modules/pki/path_roles.rs | 95 ++++------- src/modules/pki/path_root.rs | 4 +- src/modules/pki/util.rs | 49 ++---- src/storage/physical/mock.rs | 4 +- src/utils/key.rs | 275 +++++++++++++++--------------- 11 files changed, 249 insertions(+), 286 deletions(-) create mode 100644 build.rs diff --git a/Cargo.toml b/Cargo.toml index ff005df..591410d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ RustyVault's RESTful API is designed to be fully compatible with Hashicorp Vault """ repository = "https://github.com/Tongsuo-Project/RustyVault" documentation = "https://docs.rs/rusty_vault/latest/rusty_vault/" +build = "build.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -24,8 +25,8 @@ serde_json = "^1.0" serde_bytes = "0.11" go-defer = "^0.1" rand = "^0.8" -openssl = "0.10" -openssl-sys = "0.9.92" +openssl = { version = "0.10" } +openssl-sys = { version = "0.9" } derivative = "2.2.0" enum-map = "2.6.1" strum = { version = "0.25", features = ["derive"] } @@ -60,6 +61,10 @@ serde_asn1_der = "0.8" base64 = "0.22" ipnetwork = "0.20" +[patch.crates-io] +openssl = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" } +openssl-sys = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" } + [features] storage_mysql = ["diesel", "r2d2", "r2d2-diesel"] diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..2c0ca9e --- /dev/null +++ b/build.rs @@ -0,0 +1,7 @@ +use std::env; + +fn main() { + if let Ok(_) = env::var("DEP_OPENSSL_TONGSUO") { + println!("cargo:rustc-cfg=tongsuo"); + } +} diff --git a/src/modules/pki/path_config_ca.rs b/src/modules/pki/path_config_ca.rs index 56b41d8..98120de 100644 --- a/src/modules/pki/path_config_ca.rs +++ b/src/modules/pki/path_config_ca.rs @@ -46,7 +46,7 @@ For security reasons, you can only view the certificate when reading this endpoi impl PkiBackendInner { pub fn write_path_ca(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let pem_bundle_value = req.get_data("pem_bundle")?; - let pem_bundle = pem_bundle_value.as_str().unwrap(); + let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let items = pem::parse_many(pem_bundle)?; let mut key_found = false; diff --git a/src/modules/pki/path_fetch.rs b/src/modules/pki/path_fetch.rs index 4e2d954..53c4550 100644 --- a/src/modules/pki/path_fetch.rs +++ b/src/modules/pki/path_fetch.rs @@ -122,7 +122,7 @@ impl PkiBackendInner { pub fn read_path_fetch_cert(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let serial_number_value = req.get_data("serial")?; - let serial_number = serial_number_value.as_str().unwrap(); + let serial_number = serial_number_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let serial_number_hex = serial_number.replace(":", "-").to_lowercase(); let cert = self.fetch_cert(req, &serial_number_hex)?; let ca_bundle = self.fetch_ca_bundle(req)?; diff --git a/src/modules/pki/path_issue.rs b/src/modules/pki/path_issue.rs index 1277e7e..b5521fd 100644 --- a/src/modules/pki/path_issue.rs +++ b/src/modules/pki/path_issue.rs @@ -65,13 +65,12 @@ requested common name is allowed by the role policy. impl PkiBackendInner { pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let role_value = req.get_data("role")?; - let role_name = role_value.as_str().unwrap(); + //let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let mut common_names = Vec::new(); let common_name_value = req.get_data("common_name")?; - let common_name = common_name_value.as_str().unwrap(); + let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } @@ -87,7 +86,7 @@ impl PkiBackendInner { } } - let role = self.get_role(req, &role_name)?; + let role = self.get_role(req, req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; if role.is_none() { return Err(RvError::ErrPkiRoleNotFound); } @@ -111,7 +110,7 @@ impl PkiBackendInner { let mut not_after = not_before + parse_duration("30d").unwrap(); let ttl_value = req.get_data("ttl")?; - let ttl = ttl_value.as_str().unwrap(); + let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if ttl != "" { let ttl_dur = parse_duration(ttl)?; let req_ttl_not_after_dur = SystemTime::now() + ttl_dur; diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index a2304b6..9149911 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -212,11 +212,10 @@ used for sign,verify,encrypt,decrypt. impl PkiBackendInner { pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); + let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_type_value = req.get_data("key_type")?; - let key_type = key_type_value.as_str().unwrap(); - let key_bits_value = req.get_data("key_bits")?; - let key_bits = key_bits_value.as_u64().unwrap(); + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let mut export_private_key = false; if req.path.ends_with("/exported") { @@ -245,7 +244,7 @@ impl PkiBackendInner { if export_private_key { match key_type { - "rsa" | "ec" => { + "rsa" | "ec" | "sm2" => { resp_data.insert( "private_key".to_string(), Value::String(String::from_utf8_lossy(&key_bundle.key).to_string()), @@ -266,13 +265,13 @@ impl PkiBackendInner { pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); + let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_type_value = req.get_data("key_type")?; - let key_type = key_type_value.as_str().unwrap(); + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let pem_bundle_value = req.get_data("pem_bundle")?; - let pem_bundle = pem_bundle_value.as_str().unwrap(); + let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let hex_bundle_value = req.get_data("hex_bundle")?; - let hex_bundle = hex_bundle_value.as_str().unwrap(); + let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if pem_bundle.len() == 0 && hex_bundle.len() == 0 { return Err(RvError::ErrRequestFieldNotFound); @@ -292,7 +291,7 @@ impl PkiBackendInner { let rsa = Rsa::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = rsa.size() * 8; }, - "ec" => { + "ec" | "sm2" => { let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = ec_key.group().degree(); }, @@ -312,19 +311,25 @@ impl PkiBackendInner { } }; let iv_value = req.get_data("iv")?; - match key_type { - "aes-gcm" | "aes-cbc" => { - if let Some(iv) = iv_value.as_str() { - key_bundle.iv = hex::decode(&iv)?; - } else { - return Err(RvError::ErrRequestFieldNotFound); - } - }, - "aes-ecb" => {}, - _ => { - return Err(RvError::ErrPkiKeyTypeInvalid); + let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm"); + #[cfg(tongsuo)] + let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm"); + #[cfg(not(tongsuo))] + let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb"); + + // Check if the key type is valid, if not return an error. + if !is_valid_key_type { + return Err(RvError::ErrPkiKeyTypeInvalid); + } + + // Proceed to check IV only if required by the key type. + if is_iv_required { + if let Some(iv) = iv_value.as_str() { + key_bundle.iv = hex::decode(&iv)?; + } else { + return Err(RvError::ErrRequestFieldNotFound); } - }; + } } self.write_key(req, &key_bundle)?; @@ -343,12 +348,10 @@ impl PkiBackendInner { } pub fn key_sign(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.sign(&decoded_data)?; @@ -364,14 +367,12 @@ impl PkiBackendInner { } pub fn key_verify(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let signature_value = req.get_data("signature")?; - let signature = signature_value.as_str().unwrap(); + let signature = signature_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let decoded_signature = hex::decode(signature.as_bytes())?; @@ -388,14 +389,12 @@ impl PkiBackendInner { } pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let aad_value = req.get_data("aad")?; - let aad = aad_value.as_str().unwrap(); + let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; @@ -411,14 +410,12 @@ impl PkiBackendInner { } pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let aad_value = req.get_data("aad")?; - let aad = aad_value.as_str().unwrap(); + let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; diff --git a/src/modules/pki/path_roles.rs b/src/modules/pki/path_roles.rs index 3d93234..d851772 100644 --- a/src/modules/pki/path_roles.rs +++ b/src/modules/pki/path_roles.rs @@ -318,30 +318,19 @@ impl PkiBackendInner { } pub fn read_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); - let role_entry = self.get_role(req, name)?; + let role_entry = self.get_role(req, req.get_data("name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let data = serde_json::to_value(&role_entry)?; Ok(Some(Response::data_response(Some(data.as_object().unwrap().clone())))) } pub fn create_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); - let ttl_vale = req.get_data("ttl")?; - let ttl = { - let ttl_str = ttl_vale.as_str().unwrap(); - parse_duration(ttl_str)? - }; - let max_ttl_vale = req.get_data("max_ttl")?; - let max_ttl = { - let max_ttl_str = max_ttl_vale.as_str().unwrap(); - parse_duration(max_ttl_str)? - }; - let key_type_vale = req.get_data("key_type")?; - let key_type = key_type_vale.as_str().unwrap(); - let key_bits_vale = req.get_data("key_bits")?; - let mut key_bits = key_bits_vale.as_u64().unwrap(); + let name_value = req.get_data("name")?; + let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let max_ttl = parse_duration(req.get_data("max_ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_type_value = req.get_data("key_type")?; + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -366,48 +355,28 @@ impl PkiBackendInner { } } - let signature_bits_vale = req.get_data("signature_bits")?; - let signature_bits = signature_bits_vale.as_u64().unwrap(); - let allow_localhost_vale = req.get_data("allow_localhost")?; - let allow_localhost = allow_localhost_vale.as_bool().unwrap(); - let allow_bare_domain_vale = req.get_data("allow_bare_domains")?; - let allow_bare_domains = allow_bare_domain_vale.as_bool().unwrap(); - let allow_subdomains_vale = req.get_data("allow_subdomains")?; - let allow_subdomains = allow_subdomains_vale.as_bool().unwrap(); - let allow_any_name_vale = req.get_data("allow_any_name")?; - let allow_any_name = allow_any_name_vale.as_bool().unwrap(); - let allow_ip_sans_vale = req.get_data("allow_ip_sans")?; - let allow_ip_sans = allow_ip_sans_vale.as_bool().unwrap(); - let server_flag_vale = req.get_data("server_flag")?; - let server_flag = server_flag_vale.as_bool().unwrap(); - let client_flag_vale = req.get_data("client_flag")?; - let client_flag = client_flag_vale.as_bool().unwrap(); - let use_csr_sans_vale = req.get_data("use_csr_sans")?; - let use_csr_sans = use_csr_sans_vale.as_bool().unwrap(); - let use_csr_common_name_vale = req.get_data("use_csr_common_name")?; - let use_csr_common_name = use_csr_common_name_vale.as_bool().unwrap(); - let country_vale = req.get_data("country")?; - let country = country_vale.as_str().unwrap().to_string(); - let province_vale = req.get_data("province")?; - let province = province_vale.as_str().unwrap().to_string(); - let locality_vale = req.get_data("locality")?; - let locality = locality_vale.as_str().unwrap().to_string(); - let organization_vale = req.get_data("organization")?; - let organization = organization_vale.as_str().unwrap().to_string(); - let ou_vale = req.get_data("ou")?; - let ou = ou_vale.as_str().unwrap().to_string(); - let street_address_vale = req.get_data("street_address")?; - let street_address = street_address_vale.as_str().unwrap().to_string(); - let postal_code_vale = req.get_data("postal_code")?; - let postal_code = postal_code_vale.as_str().unwrap().to_string(); - let no_store_vale = req.get_data("no_store")?; - let no_store = no_store_vale.as_bool().unwrap(); - let generate_lease_vale = req.get_data("generate_lease")?; - let generate_lease = generate_lease_vale.as_bool().unwrap(); - let not_after_vale = req.get_data("not_after")?; - let not_after = not_after_vale.as_str().unwrap().to_string(); - let not_before_duration_vale = req.get_data("not_before_duration")?; - let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap()); + let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_localhost = req.get_data("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_bare_domains = req.get_data("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_subdomains = req.get_data("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_any_name = req.get_data("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_ip_sans = req.get_data("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let server_flag = req.get_data("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let client_flag = req.get_data("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_sans = req.get_data("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_common_name = req.get_data("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let no_store = req.get_data("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let generate_lease = req.get_data("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_before_duration = Duration::from_secs(not_before_duration_u64); let role_entry = RoleEntry { ttl, @@ -446,8 +415,8 @@ impl PkiBackendInner { } pub fn delete_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); + let name_value = req.get_data("name")?; + let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if name == "" { return Err(RvError::ErrRequestNoDataField); } diff --git a/src/modules/pki/path_root.rs b/src/modules/pki/path_root.rs index c74f5b4..e645163 100644 --- a/src/modules/pki/path_root.rs +++ b/src/modules/pki/path_root.rs @@ -46,9 +46,7 @@ impl PkiBackend { impl PkiBackendInner { pub fn generate_root(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let mut export_private_key = false; - let exported_vale = req.get_data("exported")?; - let exported = exported_vale.as_str().unwrap(); - if exported == "exported" { + if req.get_data("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" { export_private_key = true; } diff --git a/src/modules/pki/util.rs b/src/modules/pki/util.rs index eb7fd62..7f6bdc9 100644 --- a/src/modules/pki/util.rs +++ b/src/modules/pki/util.rs @@ -7,17 +7,12 @@ use super::path_roles::RoleEntry; use crate::{errors::RvError, logical::Request, utils::cert::Certificate}; pub fn get_role_params(req: &mut Request) -> Result { - let ttl_vale = req.get_data("ttl")?; - let ttl = { - let ttl_str = ttl_vale.as_str().unwrap(); - parse_duration(ttl_str)? - }; - let not_before_duration_vale = req.get_data("not_before_duration")?; - let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap()); - let key_type_vale = req.get_data("key_type")?; - let key_type = key_type_vale.as_str().unwrap(); - let key_bits_vale = req.get_data("key_bits")?; - let mut key_bits = key_bits_vale.as_u64().unwrap(); + let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_before_duration = Duration::from_secs(not_before_duration_u64); + let key_type_value = req.get_data("key_type")?; + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -42,26 +37,16 @@ pub fn get_role_params(req: &mut Request) -> Result { } } - let signature_bits_vale = req.get_data("signature_bits")?; - let signature_bits = signature_bits_vale.as_u64().unwrap(); - let use_pss_value = req.get_data("use_pss")?; - let use_pss = use_pss_value.as_bool().unwrap(); - let country_vale = req.get_data("country")?; - let country = country_vale.as_str().unwrap().to_string(); - let province_vale = req.get_data("province")?; - let province = province_vale.as_str().unwrap().to_string(); - let locality_vale = req.get_data("locality")?; - let locality = locality_vale.as_str().unwrap().to_string(); - let organization_vale = req.get_data("organization")?; - let organization = organization_vale.as_str().unwrap().to_string(); - let ou_vale = req.get_data("ou")?; - let ou = ou_vale.as_str().unwrap().to_string(); - let street_address_vale = req.get_data("street_address")?; - let street_address = street_address_vale.as_str().unwrap().to_string(); - let postal_code_vale = req.get_data("postal_code")?; - let postal_code = postal_code_vale.as_str().unwrap().to_string(); - let not_after_vale = req.get_data("not_after")?; - let not_after = not_after_vale.as_str().unwrap().to_string(); + let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_pss = req.get_data("use_pss")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let role_entry = RoleEntry { ttl, @@ -88,7 +73,7 @@ pub fn generate_certificate(role_entry: &RoleEntry, req: &mut Request) -> Result let mut common_names = Vec::new(); let common_name_value = req.get_data("common_name")?; - let common_name = common_name_value.as_str().unwrap(); + let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } diff --git a/src/storage/physical/mock.rs b/src/storage/physical/mock.rs index d568d7d..b543281 100644 --- a/src/storage/physical/mock.rs +++ b/src/storage/physical/mock.rs @@ -4,7 +4,7 @@ use super::{Backend, BackendEntry}; use crate::errors::RvError; #[derive(Default)] -pub struct MockBackend(u32); +pub struct MockBackend(()); impl Backend for MockBackend { fn list(&self, _prefix: &str) -> Result, RvError> { @@ -26,6 +26,6 @@ impl Backend for MockBackend { impl MockBackend { pub fn new() -> Self { - MockBackend(0) + MockBackend(()) } } diff --git a/src/utils/key.rs b/src/utils/key.rs index e6f8b97..9c86853 100644 --- a/src/utils/key.rs +++ b/src/utils/key.rs @@ -42,6 +42,25 @@ impl Default for KeyBundle { } } +fn cipher_from_key_type_and_bits(key_type: &str, bits: u32) -> Result { + match (key_type, bits) { + ("aes-gcm", 128) => Ok(Cipher::aes_128_gcm()), + ("aes-gcm", 192) => Ok(Cipher::aes_192_gcm()), + ("aes-gcm", 256) => Ok(Cipher::aes_256_gcm()), + ("aes-cbc", 128) => Ok(Cipher::aes_128_cbc()), + ("aes-cbc", 192) => Ok(Cipher::aes_192_cbc()), + ("aes-cbc", 256) => Ok(Cipher::aes_256_cbc()), + ("aes-ecb", 128) => Ok(Cipher::aes_128_ecb()), + ("aes-ecb", 192) => Ok(Cipher::aes_192_ecb()), + ("aes-ecb", 256) => Ok(Cipher::aes_256_ecb()), + #[cfg(tongsuo)] + ("sm4-gcm", 128) => Ok(Cipher::sm4_gcm()), + #[cfg(tongsuo)] + ("sm4-ccm", 128) => Ok(Cipher::sm4_ccm()), + _ => Err(RvError::ErrPkiKeyBitsInvalid), + } +} + impl KeyBundle { pub fn new(name: &str, key_type: &str, key_bits: u32) -> Self { Self { name: name.to_string(), key_type: key_type.to_string(), bits: key_bits, ..KeyBundle::default() } @@ -51,12 +70,13 @@ impl KeyBundle { let key_bits = self.bits; let priv_key = match self.key_type.as_str() { "rsa" => { - if key_bits != 2048 && key_bits != 3072 && key_bits != 4096 { - return Err(RvError::ErrPkiKeyBitsInvalid); + match key_bits { + 2048 | 3072 | 4096 => { + let rsa_key = Rsa::generate(key_bits)?; + PKey::from_rsa(rsa_key)?.private_key_to_pem_pkcs8()? + }, + _ => return Err(RvError::ErrPkiKeyBitsInvalid), } - let rsa_key = Rsa::generate(key_bits)?; - let pkey = PKey::from_rsa(rsa_key)?; - pkey.private_key_to_pem_pkcs8()? } "ec" => { let curve_name = match key_bits { @@ -64,30 +84,43 @@ impl KeyBundle { 256 => Nid::SECP256K1, 384 => Nid::SECP384R1, 521 => Nid::SECP521R1, - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } + _ => return Err(RvError::ErrPkiKeyBitsInvalid), }; let ec_group = EcGroup::from_curve_name(curve_name)?; - let ec_key = EcKey::generate(ec_group.as_ref())?; - let pkey = PKey::from_ec_key(ec_key)?; - pkey.private_key_to_pem_pkcs8()? - } - "aes-gcm" | "aes-cbc" | "aes-ecb" => { - if key_bits != 128 && key_bits != 192 && key_bits != 256 { - return Err(RvError::ErrPkiKeyBitsInvalid); + let ec_key = EcKey::generate(&ec_group)?; + PKey::from_ec_key(ec_key)?.private_key_to_pem_pkcs8()? + }, + #[cfg(tongsuo)] + "sm2" => { + self.bits = 256; + let ec_group = EcGroup::from_curve_name(Nid::SM2)?; + let ec_key = EcKey::generate(&ec_group)?; + PKey::from_ec_key(ec_key)?.private_key_to_pem_pkcs8()? + }, + "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm" => { + let _ = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + + #[cfg(not(tongsuo))] + if self.key_type.starts_with("sm4-") { + return Err(RvError::ErrPkiKeyTypeInvalid); } - if self.key_type.as_str() != "aes-ecb" { - let mut iv_bytes = vec![0u8; 16]; - rand_bytes(&mut iv_bytes)?; - self.iv = iv_bytes; + match self.key_type.as_str() { + "aes-ecb" => (), + "sm4-ccm" => { + self.iv = vec![0u8; 12]; + rand_bytes(&mut self.iv)?; + } + _ => { + self.iv = vec![0u8; 16]; + rand_bytes(&mut self.iv)?; + } } - let mut random_bytes = vec![0u8; (key_bits / 8) as usize]; - rand_bytes(&mut random_bytes)?; - random_bytes - } + let mut key = vec![0u8; key_bits as usize / 8]; + rand_bytes(&mut key)?; + key + }, _ => { return Err(RvError::ErrPkiKeyTypeInvalid); } @@ -99,95 +132,67 @@ impl KeyBundle { } pub fn sign(&self, data: &[u8]) -> Result, RvError> { - match self.key_type.as_str() { - "rsa" => { - let rsa = Rsa::private_key_from_pem(&self.key)?; - let pkey = PKey::from_rsa(rsa)?; - let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; - signer.set_rsa_padding(Padding::PKCS1)?; - signer.update(data)?; - return Ok(signer.sign_to_vec()?); - } - "ec" => { - let ec_key = EcKey::private_key_from_pem(&self.key)?; - let pkey = PKey::from_ec_key(ec_key)?; - let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; - signer.update(data)?; - return Ok(signer.sign_to_vec()?); - } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + let digest = match self.key_type.as_str() { + "rsa" | "ec" => MessageDigest::sha256(), + #[cfg(tongsuo)] + "sm2" => MessageDigest::sm3(), + _ => return Err(RvError::ErrPkiKeyOperationInvalid), + }; + + let pkey = PKey::private_key_from_pem(&self.key)?; + + let mut signer = Signer::new(digest, &pkey)?; + if self.key_type == "rsa" { + signer.set_rsa_padding(Padding::PKCS1)?; } + + signer.update(data)?; + signer.sign_to_vec().map_err(From::from) } pub fn verify(&self, data: &[u8], signature: &[u8]) -> Result { - match self.key_type.as_str() { - "rsa" => { - let rsa = Rsa::private_key_from_pem(&self.key)?; - let pkey = PKey::from_rsa(rsa)?; - let mut verifier = Verifier::new(MessageDigest::sha256(), &pkey)?; - verifier.set_rsa_padding(Padding::PKCS1)?; - verifier.update(data)?; - return Ok(verifier.verify(signature).unwrap_or(false)); - } - "ec" => { - let ec_key = EcKey::private_key_from_pem(&self.key)?; - let pkey = PKey::from_ec_key(ec_key)?; - let mut verifier = Verifier::new(MessageDigest::sha256(), &pkey)?; - verifier.update(data)?; - return Ok(verifier.verify(signature).unwrap_or(false)); - } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + let digest = match self.key_type.as_str() { + "rsa" | "ec" => MessageDigest::sha256(), + #[cfg(tongsuo)] + "sm2" => MessageDigest::sm3(), + _ => return Err(RvError::ErrPkiKeyOperationInvalid), + }; + + let pkey = PKey::private_key_from_pem(&self.key)?; + + let mut verifier = Verifier::new(digest, &pkey)?; + if self.key_type == "rsa" { + verifier.set_rsa_padding(Padding::PKCS1)?; } + + verifier.update(data)?; + Ok(verifier.verify(signature).unwrap_or(false)) } pub fn encrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { match self.key_type.as_str() { - "aes-gcm" => { + "aes-gcm" | "sm4-gcm" | "sm4-ccm" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; let aad = extra.map_or("".as_bytes(), |ex| match ex { EncryptExtraData::Aad(aad) => aad, _ => "".as_bytes(), }); - let cipher = match self.bits { - 128 => Cipher::aes_128_gcm(), - 192 => Cipher::aes_192_gcm(), - 256 => Cipher::aes_256_gcm(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; let mut tag = vec![0u8; 16]; - let mut ciphertext = - encrypt_aead(cipher, &self.key, Some(&self.iv), aad, data, &mut tag)?; + let mut ciphertext = encrypt_aead( + cipher, + &self.key, + Some(&self.iv), + aad, + data, + &mut tag, + )?; ciphertext.extend_from_slice(&tag); Ok(ciphertext) } - "aes-cbc" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_cbc(), - 192 => Cipher::aes_192_cbc(), - 256 => Cipher::aes_256_cbc(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(encrypt(cipher, &self.key, Some(&self.iv), data)?) - } - "aes-ecb" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_ecb(), - 192 => Cipher::aes_192_ecb(), - 256 => Cipher::aes_256_ecb(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(encrypt(cipher, &self.key, None, data)?) + "aes-cbc" | "aes-ecb" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + let iv = if self.key_type == "aes-ecb" { None } else { Some(self.iv.as_slice()) }; + Ok(encrypt(cipher, &self.key, iv, data)?) } "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; @@ -209,54 +214,31 @@ impl KeyBundle { return Ok(buf); } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + _ => Err(RvError::ErrPkiKeyOperationInvalid), } } pub fn decrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { + match self.key_type.as_str() { - "aes-gcm" => { + "aes-gcm" | "sm4-gcm" | "sm4-ccm" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; let aad = extra.map_or("".as_bytes(), |ex| match ex { EncryptExtraData::Aad(aad) => aad, _ => "".as_bytes(), }); - let cipher = match self.bits { - 128 => Cipher::aes_128_gcm(), - 192 => Cipher::aes_192_gcm(), - 256 => Cipher::aes_256_gcm(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - let (ciphertext, tag) = data.split_at(data.len() - 16); + let tag_len = 16; + if data.len() < tag_len { + return Err(RvError::ErrPkiInternal); + } + let (ciphertext, tag) = data.split_at(data.len() - tag_len); Ok(decrypt_aead(cipher, &self.key, Some(&self.iv), aad, ciphertext, tag)?) - } - "aes-cbc" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_cbc(), - 192 => Cipher::aes_192_cbc(), - 256 => Cipher::aes_256_cbc(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(decrypt(cipher, &self.key, Some(&self.iv), data)?) - } - "aes-ecb" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_ecb(), - 192 => Cipher::aes_192_ecb(), - 256 => Cipher::aes_256_ecb(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(decrypt(cipher, &self.key, None, data)?) - } + }, + "aes-cbc" | "aes-ecb" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + let iv = if self.key_type == "aes-ecb" { None } else { Some(self.iv.as_slice()) }; + Ok(decrypt(cipher, &self.key, iv, data)?) + }, "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; if data.len() > rsa.size() as usize { @@ -284,9 +266,7 @@ impl KeyBundle { return Ok(buf); } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + _ => Err(RvError::ErrPkiKeyOperationInvalid), } } } @@ -348,6 +328,13 @@ mod test { test_key_sign_verify(&mut key_bundle); } + #[test] + #[cfg(tongsuo)] + fn test_sm2_key_operation() { + let mut key_bundle = KeyBundle::new("sm2", "sm2", 256); + test_key_sign_verify(&mut key_bundle); + } + #[test] fn test_aes_key_operation() { // test aes-gcm @@ -381,4 +368,20 @@ mod test { let mut key_bundle = KeyBundle::new("aes-ecb-256", "aes-ecb", 256); test_key_encrypt_decrypt(&mut key_bundle, None); } + + #[test] + #[cfg(tongsuo)] + fn test_sm4_key_operation() { + // test sm4-gcm + let mut key_bundle = KeyBundle::new("sm4-gcm-128", "sm4-gcm", 128); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); + + // test sm4-ccm + let mut key_bundle = KeyBundle::new("sm4-ccm-128", "sm4-ccm", 128); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); + } }