Skip to content

Commit

Permalink
Merge pull request #60 from wa5i/sm4
Browse files Browse the repository at this point in the history
Switch to rust-tongsuo, supporting SM2 and SM4 algorithms.
  • Loading branch information
InfoHunter authored May 10, 2024
2 parents 5afb067 + 6eac131 commit 7e8c441
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 286 deletions.
9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RustyVault's RESTful API is designed to be fully compatible with Hashicorp Vault
"""
repository = "https://github.com/Tongsuo-Project/RustyVault"
documentation = "https://docs.rs/rusty_vault/latest/rusty_vault/"
build = "build.rs"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -24,8 +25,8 @@ serde_json = "^1.0"
serde_bytes = "0.11"
go-defer = "^0.1"
rand = "^0.8"
openssl = "0.10"
openssl-sys = "0.9.92"
openssl = { version = "0.10" }
openssl-sys = { version = "0.9" }
derivative = "2.2.0"
enum-map = "2.6.1"
strum = { version = "0.25", features = ["derive"] }
Expand Down Expand Up @@ -60,6 +61,10 @@ serde_asn1_der = "0.8"
base64 = "0.22"
ipnetwork = "0.20"

[patch.crates-io]
openssl = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" }
openssl-sys = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" }

[features]
storage_mysql = ["diesel", "r2d2", "r2d2-diesel"]

Expand Down
7 changes: 7 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use std::env;

fn main() {
if let Ok(_) = env::var("DEP_OPENSSL_TONGSUO") {
println!("cargo:rustc-cfg=tongsuo");
}
}
2 changes: 1 addition & 1 deletion src/modules/pki/path_config_ca.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ For security reasons, you can only view the certificate when reading this endpoi
impl PkiBackendInner {
pub fn write_path_ca(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().unwrap();
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let items = pem::parse_many(pem_bundle)?;
let mut key_found = false;
Expand Down
2 changes: 1 addition & 1 deletion src/modules/pki/path_fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl PkiBackendInner {

pub fn read_path_fetch_cert(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let serial_number_value = req.get_data("serial")?;
let serial_number = serial_number_value.as_str().unwrap();
let serial_number = serial_number_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let serial_number_hex = serial_number.replace(":", "-").to_lowercase();
let cert = self.fetch_cert(req, &serial_number_hex)?;
let ca_bundle = self.fetch_ca_bundle(req)?;
Expand Down
9 changes: 4 additions & 5 deletions src/modules/pki/path_issue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ requested common name is allowed by the role policy.

impl PkiBackendInner {
pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let role_value = req.get_data("role")?;
let role_name = role_value.as_str().unwrap();
//let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut common_names = Vec::new();

let common_name_value = req.get_data("common_name")?;
let common_name = common_name_value.as_str().unwrap();
let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if common_name != "" {
common_names.push(common_name.to_string());
}
Expand All @@ -87,7 +86,7 @@ impl PkiBackendInner {
}
}

let role = self.get_role(req, &role_name)?;
let role = self.get_role(req, req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
if role.is_none() {
return Err(RvError::ErrPkiRoleNotFound);
}
Expand All @@ -111,7 +110,7 @@ impl PkiBackendInner {
let mut not_after = not_before + parse_duration("30d").unwrap();

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_str().unwrap();
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl != "" {
let ttl_dur = parse_duration(ttl)?;
let req_ttl_not_after_dur = SystemTime::now() + ttl_dur;
Expand Down
79 changes: 38 additions & 41 deletions src/modules/pki/path_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,10 @@ used for sign,verify,encrypt,decrypt.
impl PkiBackendInner {
pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().unwrap();
let key_bits_value = req.get_data("key_bits")?;
let key_bits = key_bits_value.as_u64().unwrap();
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut export_private_key = false;
if req.path.ends_with("/exported") {
Expand Down Expand Up @@ -245,7 +244,7 @@ impl PkiBackendInner {

if export_private_key {
match key_type {
"rsa" | "ec" => {
"rsa" | "ec" | "sm2" => {
resp_data.insert(
"private_key".to_string(),
Value::String(String::from_utf8_lossy(&key_bundle.key).to_string()),
Expand All @@ -266,13 +265,13 @@ impl PkiBackendInner {

pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().unwrap();
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().unwrap();
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let hex_bundle_value = req.get_data("hex_bundle")?;
let hex_bundle = hex_bundle_value.as_str().unwrap();
let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

if pem_bundle.len() == 0 && hex_bundle.len() == 0 {
return Err(RvError::ErrRequestFieldNotFound);
Expand All @@ -292,7 +291,7 @@ impl PkiBackendInner {
let rsa = Rsa::private_key_from_pem(&key_bundle.key)?;
key_bundle.bits = rsa.size() * 8;
},
"ec" => {
"ec" | "sm2" => {
let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?;
key_bundle.bits = ec_key.group().degree();
},
Expand All @@ -312,19 +311,25 @@ impl PkiBackendInner {
}
};
let iv_value = req.get_data("iv")?;
match key_type {
"aes-gcm" | "aes-cbc" => {
if let Some(iv) = iv_value.as_str() {
key_bundle.iv = hex::decode(&iv)?;
} else {
return Err(RvError::ErrRequestFieldNotFound);
}
},
"aes-ecb" => {},
_ => {
return Err(RvError::ErrPkiKeyTypeInvalid);
let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm");
#[cfg(tongsuo)]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm");
#[cfg(not(tongsuo))]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb");

// Check if the key type is valid, if not return an error.
if !is_valid_key_type {
return Err(RvError::ErrPkiKeyTypeInvalid);
}

// Proceed to check IV only if required by the key type.
if is_iv_required {
if let Some(iv) = iv_value.as_str() {
key_bundle.iv = hex::decode(&iv)?;
} else {
return Err(RvError::ErrRequestFieldNotFound);
}
};
}
}

self.write_key(req, &key_bundle)?;
Expand All @@ -343,12 +348,10 @@ impl PkiBackendInner {
}

pub fn key_sign(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.sign(&decoded_data)?;
Expand All @@ -364,14 +367,12 @@ impl PkiBackendInner {
}

pub fn key_verify(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let signature_value = req.get_data("signature")?;
let signature = signature_value.as_str().unwrap();
let signature = signature_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let decoded_signature = hex::decode(signature.as_bytes())?;
Expand All @@ -388,14 +389,12 @@ impl PkiBackendInner {
}

pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad = aad_value.as_str().unwrap();
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?;
Expand All @@ -411,14 +410,12 @@ impl PkiBackendInner {
}

pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad = aad_value.as_str().unwrap();
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?;
Expand Down
95 changes: 32 additions & 63 deletions src/modules/pki/path_roles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,30 +318,19 @@ impl PkiBackendInner {
}

pub fn read_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let role_entry = self.get_role(req, name)?;
let role_entry = self.get_role(req, req.get_data("name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let data = serde_json::to_value(&role_entry)?;
Ok(Some(Response::data_response(Some(data.as_object().unwrap().clone()))))
}

pub fn create_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let ttl_vale = req.get_data("ttl")?;
let ttl = {
let ttl_str = ttl_vale.as_str().unwrap();
parse_duration(ttl_str)?
};
let max_ttl_vale = req.get_data("max_ttl")?;
let max_ttl = {
let max_ttl_str = max_ttl_vale.as_str().unwrap();
parse_duration(max_ttl_str)?
};
let key_type_vale = req.get_data("key_type")?;
let key_type = key_type_vale.as_str().unwrap();
let key_bits_vale = req.get_data("key_bits")?;
let mut key_bits = key_bits_vale.as_u64().unwrap();
let name_value = req.get_data("name")?;
let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let max_ttl = parse_duration(req.get_data("max_ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
match key_type {
"rsa" => {
if key_bits == 0 {
Expand All @@ -366,48 +355,28 @@ impl PkiBackendInner {
}
}

let signature_bits_vale = req.get_data("signature_bits")?;
let signature_bits = signature_bits_vale.as_u64().unwrap();
let allow_localhost_vale = req.get_data("allow_localhost")?;
let allow_localhost = allow_localhost_vale.as_bool().unwrap();
let allow_bare_domain_vale = req.get_data("allow_bare_domains")?;
let allow_bare_domains = allow_bare_domain_vale.as_bool().unwrap();
let allow_subdomains_vale = req.get_data("allow_subdomains")?;
let allow_subdomains = allow_subdomains_vale.as_bool().unwrap();
let allow_any_name_vale = req.get_data("allow_any_name")?;
let allow_any_name = allow_any_name_vale.as_bool().unwrap();
let allow_ip_sans_vale = req.get_data("allow_ip_sans")?;
let allow_ip_sans = allow_ip_sans_vale.as_bool().unwrap();
let server_flag_vale = req.get_data("server_flag")?;
let server_flag = server_flag_vale.as_bool().unwrap();
let client_flag_vale = req.get_data("client_flag")?;
let client_flag = client_flag_vale.as_bool().unwrap();
let use_csr_sans_vale = req.get_data("use_csr_sans")?;
let use_csr_sans = use_csr_sans_vale.as_bool().unwrap();
let use_csr_common_name_vale = req.get_data("use_csr_common_name")?;
let use_csr_common_name = use_csr_common_name_vale.as_bool().unwrap();
let country_vale = req.get_data("country")?;
let country = country_vale.as_str().unwrap().to_string();
let province_vale = req.get_data("province")?;
let province = province_vale.as_str().unwrap().to_string();
let locality_vale = req.get_data("locality")?;
let locality = locality_vale.as_str().unwrap().to_string();
let organization_vale = req.get_data("organization")?;
let organization = organization_vale.as_str().unwrap().to_string();
let ou_vale = req.get_data("ou")?;
let ou = ou_vale.as_str().unwrap().to_string();
let street_address_vale = req.get_data("street_address")?;
let street_address = street_address_vale.as_str().unwrap().to_string();
let postal_code_vale = req.get_data("postal_code")?;
let postal_code = postal_code_vale.as_str().unwrap().to_string();
let no_store_vale = req.get_data("no_store")?;
let no_store = no_store_vale.as_bool().unwrap();
let generate_lease_vale = req.get_data("generate_lease")?;
let generate_lease = generate_lease_vale.as_bool().unwrap();
let not_after_vale = req.get_data("not_after")?;
let not_after = not_after_vale.as_str().unwrap().to_string();
let not_before_duration_vale = req.get_data("not_before_duration")?;
let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap());
let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_localhost = req.get_data("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_bare_domains = req.get_data("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_subdomains = req.get_data("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_any_name = req.get_data("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_ip_sans = req.get_data("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let server_flag = req.get_data("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let client_flag = req.get_data("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let use_csr_sans = req.get_data("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let use_csr_common_name = req.get_data("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let no_store = req.get_data("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let generate_lease = req.get_data("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let not_before_duration = Duration::from_secs(not_before_duration_u64);

let role_entry = RoleEntry {
ttl,
Expand Down Expand Up @@ -446,8 +415,8 @@ impl PkiBackendInner {
}

pub fn delete_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let name_value = req.get_data("name")?;
let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if name == "" {
return Err(RvError::ErrRequestNoDataField);
}
Expand Down
4 changes: 1 addition & 3 deletions src/modules/pki/path_root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ impl PkiBackend {
impl PkiBackendInner {
pub fn generate_root(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let mut export_private_key = false;
let exported_vale = req.get_data("exported")?;
let exported = exported_vale.as_str().unwrap();
if exported == "exported" {
if req.get_data("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" {
export_private_key = true;
}

Expand Down
Loading

0 comments on commit 7e8c441

Please sign in to comment.