From f28ece5c7adaa8b4547af51e8c255be2e9d0d9e2 Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Fri, 19 Apr 2024 11:08:35 +0800 Subject: [PATCH 1/4] Added the implementation of the salt tool. --- src/utils/mod.rs | 1 + src/utils/salt.rs | 205 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 src/utils/salt.rs diff --git a/src/utils/mod.rs b/src/utils/mod.rs index f23150c..aa2c1a9 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -10,6 +10,7 @@ use crate::errors::RvError; pub mod cert; pub mod key; +pub mod salt; pub fn generate_uuid() -> String { let mut buf = [0u8; 16]; diff --git a/src/utils/salt.rs b/src/utils/salt.rs new file mode 100644 index 0000000..438300b --- /dev/null +++ b/src/utils/salt.rs @@ -0,0 +1,205 @@ +use openssl::{ + hash::{hash, MessageDigest}, + pkey::PKey, + nid::Nid, + sign::Signer, +}; +use derivative::Derivative; + +use super::{ + generate_uuid, +}; + +use crate::{ + storage::{Storage, StorageEntry}, + errors::RvError, +}; + +static DEFAULT_LOCATION: &str = "salt"; + +#[derive(Debug, Clone)] +pub struct Salt { + pub config: Config, + pub salt: String, + pub generated: bool, +} + +#[derive(Derivative)] +#[derivative(Debug, Clone)] +pub struct Config { + pub location: String, + #[derivative(Debug="ignore")] + pub hash_type: MessageDigest, + #[derivative(Debug="ignore")] + pub hmac_type: MessageDigest, +} + +impl Default for Salt { + fn default() -> Self { + Self { + salt: generate_uuid(), + generated: true, + config: Config::default(), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + location: DEFAULT_LOCATION.to_string(), + hash_type: MessageDigest::sha256(), + hmac_type: MessageDigest::sha256(), + } + } +} + +impl Salt { + pub fn new(storage: Option<&dyn Storage>, config: Option<&Config>) -> Result { + let mut salt = Salt::default(); + if let Some(c) = config { + if salt.config.location != c.location && c.location != "" { + salt.config.location = c.location.clone(); + } + + if salt.config.hash_type != c.hash_type { + salt.config.hash_type = c.hash_type.clone(); + } + + if salt.config.hmac_type != c.hmac_type { + salt.config.hmac_type = c.hmac_type.clone(); + } + } + + if let Some(s) = storage { + if let Some(raw) = s.get(&salt.config.location)? { + salt.salt = String::from_utf8_lossy(&raw.value).to_string(); + salt.generated = false; + } else { + let entry = StorageEntry { + key: salt.config.location.clone(), + value: salt.salt.as_bytes().to_vec(), + }; + + s.put(&entry)?; + } + } + + Ok(salt) + } + + pub fn new_nonpersistent() -> Self { + let mut salt = Salt::default(); + salt.config.location = "".to_string(); + salt + } + + pub fn get_hmac(&self, data: &str) -> Result { + let pkey = PKey::hmac(self.salt.as_bytes())?; + let mut signer = Signer::new(self.config.hmac_type, &pkey)?; + signer.update(data.as_bytes())?; + let hmac = signer.sign_to_vec()?; + Ok(hex::encode(hmac.as_slice())) + } + + pub fn get_identified_hamc(&self, data: &str) -> Result { + let hmac_type = match self.config.hmac_type.type_() { + Nid::SHA256 => "hmac-sha256", + Nid::SM3 => "hmac-sm3", + Nid::MD5 => "hmac-md5", + _ => "hmac-unknown", + }; + + let hmac = self.get_hmac(data)?; + + Ok(format!("{}:{}", hmac_type, hmac)) + } + + pub fn get_hash(&self, data: &str) -> Result { + let ret = hash(self.config.hash_type, data.as_bytes())?; + let bytes = ret.to_vec(); + Ok(hex::encode(bytes.as_slice())) + } + + pub fn salt_id(&self, id: &str) -> Result { + let comb = format!("{}{}", self.salt, id); + self.get_hash(&comb) + } + + pub fn did_generate(&self) -> bool { + self.generated + } +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, env, fs, sync::Arc}; + use go_defer::defer; + use rand::{thread_rng, Rng}; + use serde_json::Value; + use super::*; + use crate::{ + storage::{ + physical, barrier_view, barrier_aes_gcm, + barrier::SecurityBarrier, + } + }; + + #[test] + fn test_salt() { + // init the storage + let dir = env::temp_dir().join("rusty_vault_test_salt"); + assert!(fs::create_dir(&dir).is_ok()); + defer! ( + assert!(fs::remove_dir_all(&dir).is_ok()); + ); + + let mut conf: HashMap = HashMap::new(); + conf.insert("path".to_string(), Value::String(dir.to_string_lossy().into_owned())); + + let mut key = vec![0u8; 32]; + thread_rng().fill(key.as_mut_slice()); + + let backend = physical::new_backend("file", &conf); + assert!(backend.is_ok()); + let backend = backend.unwrap(); + let aes_gcm_view = barrier_aes_gcm::AESGCMBarrier::new(Arc::clone(&backend)); + + let init = aes_gcm_view.init(key.as_slice()); + assert!(init.is_ok()); + + assert!(aes_gcm_view.unseal(key.as_slice()).is_ok()); + + let view = barrier_view::BarrierView::new(Arc::new(aes_gcm_view), "test"); + + //test salt + let salt = Salt::new(Some(view.as_storage()), None); + assert!(salt.is_ok()); + + let salt = salt.unwrap(); + assert!(salt.did_generate()); + + let ss = view.get(DEFAULT_LOCATION); + assert!(ss.is_ok()); + assert!(ss.unwrap().is_some()); + + let salt2 = Salt::new(Some(view.as_storage()), Some(&salt.config)); + assert!(salt2.is_ok()); + + let salt2 = salt2.unwrap(); + assert!(!salt2.did_generate()); + + assert_eq!(salt.salt, salt2.salt); + + let id = "foobarbaz"; + let sid1 = salt.salt_id(id); + let sid2 = salt2.salt_id(id); + assert!(sid1.is_ok()); + assert!(sid2.is_ok()); + + let sid1 = sid1.unwrap(); + let sid2 = sid2.unwrap(); + assert_eq!(sid1, sid2); + assert_eq!(sid1.len(), salt.config.hash_type.size()*2); + } +} From e7407ff000611a4b1e575a4acd1c7af69ea53351 Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Fri, 19 Apr 2024 18:29:55 +0800 Subject: [PATCH 2/4] Added the implementation of the CIDR tool. --- Cargo.toml | 1 + src/errors.rs | 11 ++ src/utils/cidr.rs | 306 ++++++++++++++++++++++++++++++++++++++++++ src/utils/ipaddr.rs | 74 ++++++++++ src/utils/mod.rs | 4 + src/utils/sockaddr.rs | 78 +++++++++++ src/utils/unixsock.rs | 51 +++++++ 7 files changed, 525 insertions(+) create mode 100644 src/utils/cidr.rs create mode 100644 src/utils/ipaddr.rs create mode 100644 src/utils/sockaddr.rs create mode 100644 src/utils/unixsock.rs diff --git a/Cargo.toml b/Cargo.toml index 298044f..a85cd73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ ureq = "2.9" glob = "0.3" serde_asn1_der = "0.8" base64 = "0.22" +ipnetwork = "0.20" [target.'cfg(unix)'.dependencies] daemonize = "0.5" diff --git a/src/errors.rs b/src/errors.rs index fc3d01c..80ca54a 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -222,6 +222,17 @@ pub enum RvError { #[error("RwLock was poisoned (writing)")] ErrRwLockWritePoison, + #[error("Some net addr parse error happened, {:?}", .source)] + AddrParseError { + #[from] + source: std::net::AddrParseError, + }, + #[error("Some ipnetwork error happened, {:?}", .source)] + IpNetworkError { + #[from] + source: ipnetwork::IpNetworkError, + }, + /// Database Errors Begin /// #[error("Database type is not support now. Please try postgressql or mysql again.")] diff --git a/src/utils/cidr.rs b/src/utils/cidr.rs new file mode 100644 index 0000000..79f59b5 --- /dev/null +++ b/src/utils/cidr.rs @@ -0,0 +1,306 @@ +use std::{ + str::FromStr, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + collections::HashSet, +}; + +use ipnetwork::IpNetwork; + +use super::{ + sockaddr::{new_sock_addr, SockAddrType, SockAddr}, +}; + +use crate::errors::RvError; + +pub fn is_ip_addr(addr: &dyn SockAddr) -> bool { + (addr.sock_addr_type() as u8 & SockAddrType::IP as u8) != 0 +} + +pub fn remote_addr_is_ok(remote_addr: &str, bound_cidrs: &[Box]) -> bool { + if bound_cidrs.len() == 0 { + return true; + } + + if let Ok(addr) = new_sock_addr(remote_addr) { + for cidr in bound_cidrs.iter() { + if is_ip_addr(cidr.as_ref()) && cidr.contains(addr.as_ref()) { + return true; + } + } + } + + false +} + +pub fn ip_belongs_to_cidr(ip_addr: &str, cidr: &str) -> Result { + if ip_addr == "" { + return Err(RvError::ErrResponse("missing IP address".to_string())); + } + + let ip = IpAddr::from_str(ip_addr)?; + let ipnet = IpNetwork::from_str(cidr)?; + + Ok(ipnet.contains(ip)) +} + +pub fn ip_belongs_to_cidrs(ip_addr: &str, cidrs: &[&str]) -> Result { + if ip_addr == "" { + return Err(RvError::ErrResponse("missing IP address".to_string())); + } + + if cidrs.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks to be checked against".to_string())); + } + + for cidr in cidrs.iter() { + if ip_belongs_to_cidr(ip_addr, cidr)? { + return Ok(true); + } + } + + Ok(false) +} + +pub fn validate_cidr_string(cidr_list: &str, separator: &str) -> Result { + if cidr_list == "" { + return Err(RvError::ErrResponse("missing CIDR list that needs validation".to_string())); + } + + if separator == "" { + return Err(RvError::ErrResponse("missing separator".to_string())); + } + + let cidrs_set: HashSet<&str> = cidr_list.split(separator) + .map(|cidr| cidr.trim()) + .filter(|cidr| !cidr.is_empty()) + .collect(); + + let cidrs: Vec<&str> = cidrs_set.into_iter().collect(); + + validate_cidrs(&cidrs) +} + +pub fn validate_cidrs(cidrs: &[&str]) -> Result { + if cidrs.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks that needs validation".to_string())); + } + + for cidr in cidrs.iter() { + let _ = IpNetwork::from_str(cidr)?; + } + + Ok(true) +} + +pub fn subset(cidr1: &str, cidr2: &str) -> Result { + if cidr1 == "" { + return Err(RvError::ErrResponse("missing CIDR to be checked against".to_string())); + } + + if cidr2 == "" { + return Err(RvError::ErrResponse("missing CIDR that needs to be checked".to_string())); + } + + let ipnet1 = IpNetwork::from_str(cidr1)?; + let mask_len1 = ipnet1.prefix(); + + if !is_ip_addr_zero(&ipnet1.ip()) && mask_len1 == 0 { + return Err(RvError::ErrResponse("CIDR to be checked against is not in its canonical form".to_string())); + } + + let ipnet2 = IpNetwork::from_str(cidr2)?; + let mask_len2 = ipnet2.prefix(); + + if !is_ip_addr_zero(&ipnet2.ip()) && mask_len2 == 0 { + return Err(RvError::ErrResponse("CIDR that needs to be checked is not in its canonical form".to_string())); + } + + /* + * If the mask length of the CIDR that needs to be checked is smaller + * then the mask length of the CIDR to be checked against, then the + * former will encompass more IPs than the latter, and hence can't be a + * subset of the latter. + */ + if mask_len2 < mask_len1 { + return Ok(false); + } + + Ok(ipnet1.contains(ipnet2.ip())) +} + +/* + * subset_blocks checks if each CIDR block of a given set of CIDR blocks, is a + * subset of at least one CIDR block belonging to another set of CIDR blocks. + * First parameter is the set of CIDR blocks to check against and the second + * parameter is the set of CIDR blocks that needs to be checked. + */ +pub fn subset_blocks(cidr_blocks1: &[&str], cidr_blocks2: &[&str]) -> Result { + if cidr_blocks1.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks to be checked against".to_string())); + } + + if cidr_blocks2.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks that needs to be checked".to_string())); + } + + // Check if all the elements of cidr_blocks2 is a subset of at least one + // element of cidr_blocks1 + for cidr_block2 in cidr_blocks2.iter() { + let mut is_subset = false; + for cidr_block1 in cidr_blocks1.iter() { + if subset(cidr_block1, cidr_block2)? { + is_subset = true; + break; + } + } + + if !is_subset { + return Ok(false); + } + } + + Ok(true) +} + +fn is_ip_addr_zero(ip_addr: &IpAddr) -> bool { + match *ip_addr { + IpAddr::V4(addr) => addr == Ipv4Addr::UNSPECIFIED, + IpAddr::V6(addr) => addr == Ipv6Addr::UNSPECIFIED, + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cidr_ip_belongs_to_cidr() { + let ip = "192.168.25.30"; + let cidr = "192.168.25.30/16"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "10.197.192.6"; + let cidr = "10.197.192.0/18"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "192.168.25.30"; + let cidr = "192.168.26.30/24"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(!belongs.unwrap()); + + let ip = "192.168.25.30.100"; + let cidr = "192.168.26.30/24"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_err()); + } + + #[test] + fn test_cidr_ip_belongs_to_cidrs() { + let ip = "192.168.27.29"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "192.168.27.29"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_err()); + + let ip = "30.40.50.60"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_ok()); + assert!(!belongs.unwrap()); + } + + #[test] + fn test_cidr_validate_cidr_string() { + let cidr = "172.169.100.200/18,192.168.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidr = "172.169.100.200,192.168.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidr = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_err()); + } + + #[test] + fn test_cidr_validate_cidrs() { + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidrs = vec!["172.169.100.200", "192.168.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidrs = vec!["172.169.100.200/18", "192.168.0.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_err()); + } + + #[test] + fn test_cidr_subset() { + let cidr1 = "192.168.27.29/24"; + let cidr2 = "192.168.27.29/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + + let cidr1 = "192.168.27.29/16"; + let cidr2 = "192.168.27.29/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + let ret = subset(cidr2, cidr1); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + + let cidr1 = "192.168.0.128/25"; + let cidr2 = "192.168.0.0/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + let ret = subset(cidr2, cidr1); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + } + + #[test] + fn test_cidr_subset_blocks() { + let cidrs1 = vec!["192.168.27.29/16", "172.245.30.40/24", "10.20.30.40/30"]; + let cidrs2 = vec!["192.168.27.29/20", "172.245.30.40/25", "10.20.30.40/32"]; + let ret = subset_blocks(&cidrs1, &cidrs2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + + let cidrs1 = vec!["192.168.27.29/16", "172.245.30.40/25", "10.20.30.40/30"]; + let cidrs2 = vec!["192.168.27.29/20", "172.245.30.40/24", "10.20.30.40/32"]; + let ret = subset_blocks(&cidrs1, &cidrs2); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + } + + #[test] + fn test_cidr_remote_addr_is_ok() { + let addr = new_sock_addr("127.0.0.1/8"); + assert!(addr.is_ok()); + let bound_cidrs = vec![addr.unwrap()]; + assert!(!remote_addr_is_ok("123.0.0.1", &bound_cidrs)); + assert!(remote_addr_is_ok("127.0.0.1", &bound_cidrs)); + } +} diff --git a/src/utils/ipaddr.rs b/src/utils/ipaddr.rs new file mode 100644 index 0000000..346e876 --- /dev/null +++ b/src/utils/ipaddr.rs @@ -0,0 +1,74 @@ +use std::{ + fmt, + str::FromStr, + net::SocketAddr, +}; + +use as_any::Downcast; +use ipnetwork::IpNetwork; +use serde::{Deserialize, Serialize}; + +use super::{ + sockaddr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IpAddr { + pub addr: IpNetwork, + pub port: u16, +} + +impl IpAddr { + pub fn new(s: &str) -> Result { + if let Ok(sock_addr) = SocketAddr::from_str(s) { + return Ok(IpAddr { + addr: IpNetwork::from(sock_addr.ip()), + port: sock_addr.port(), + }); + } else if let Ok(ip_addr) = IpNetwork::from_str(s) { + return Ok(IpAddr { + addr: ip_addr, + port: 0, + }); + } + return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); + } +} + +impl SockAddr for IpAddr { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr.contains(ip_addr.addr.ip()); + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr == ip_addr.addr && self.port == ip_addr.port; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + match self.addr { + IpNetwork::V4(_) => SockAddrType::IPv4, + IpNetwork::V6(_) => SockAddrType::IPv6, + } + } +} + +impl fmt::Display for IpAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.port == 0 { + write!(f, "{}", self.addr.ip()) + } else { + write!(f, "{}:{}", self.addr.ip(), self.port) + } + } +} + diff --git a/src/utils/mod.rs b/src/utils/mod.rs index aa2c1a9..ff4eca2 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -11,6 +11,10 @@ use crate::errors::RvError; pub mod cert; pub mod key; pub mod salt; +pub mod cidr; +pub mod sockaddr; +pub mod ipaddr; +pub mod unixsock; pub fn generate_uuid() -> String { let mut buf = [0u8; 16]; diff --git a/src/utils/sockaddr.rs b/src/utils/sockaddr.rs new file mode 100644 index 0000000..907ef6a --- /dev/null +++ b/src/utils/sockaddr.rs @@ -0,0 +1,78 @@ +use std::{ + fmt, + str::FromStr, +}; + +use as_any::AsAny; +use serde::{Deserialize, Serialize}; + +use super::{ + ipaddr::IpAddr, +}; + +use crate::errors::RvError; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SockAddrType { + Unknown = 0x0, + Unix = 0x1, + IPv4 = 0x2, + IPv6 = 0x4, + // IP is the union of IPv4 and IPv6 + IP = 0x6, +} + +pub trait SockAddr: fmt::Display + AsAny { + // contains returns true if the other SockAddr is contained within the receiver + fn contains(&self, other: &dyn SockAddr) -> bool; + + // equal allows for the comparison of two SockAddrs + fn equal(&self, other: &dyn SockAddr) -> bool; + + fn sock_addr_type(&self) -> SockAddrType; +} + +pub struct SockAddrMarshaler { + pub sock_addr: Box, +} + +impl SockAddrMarshaler { + pub fn new(sock_addr: Box) -> Self { + SockAddrMarshaler { sock_addr } + } +} + +impl fmt::Display for SockAddrMarshaler { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.sock_addr) + } +} + +impl fmt::Display for SockAddrType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let type_str = match self { + SockAddrType::IPv4 => "IPv4", + SockAddrType::IPv6 => "IPv6", + SockAddrType::Unix => "Unix", + _ => "Unknown", + }; + write!(f, "{}", type_str) + } +} + +impl FromStr for SockAddrType { + type Err = RvError; + fn from_str(s: &str) -> Result { + match s { + "IPv4" | "ipv4" => Ok(SockAddrType::IPv4), + "IPv6" | "ipv6" => Ok(SockAddrType::IPv6), + "Unix" | "UNIX" | "unix" => Ok(SockAddrType::Unix), + _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())) + } + } +} + +pub fn new_sock_addr(s: &str) -> Result, RvError> { + let ret = IpAddr::new(s)?; + Ok(Box::new(ret)) +} diff --git a/src/utils/unixsock.rs b/src/utils/unixsock.rs new file mode 100644 index 0000000..88a1034 --- /dev/null +++ b/src/utils/unixsock.rs @@ -0,0 +1,51 @@ +use std::fmt; +use as_any::Downcast; +use serde::{Deserialize, Serialize}; + +use super::{ + sockaddr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnixSock { + pub path: String, +} + +impl UnixSock { + pub fn new(s: &str) -> Result { + Ok(Self { + path: s.to_string(), + }) + } +} + +impl SockAddr for UnixSock { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + SockAddrType::Unix + } +} + +impl fmt::Display for UnixSock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path) + } +} + From e5a516c36d5571d836a7e9f707ff4f530ee6980a Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Thu, 9 May 2024 11:48:37 +0800 Subject: [PATCH 3/4] Fixed issues mentioned in PR review, and added test case for the sock_addr utils. --- src/utils/cidr.rs | 5 +- src/utils/ip_sock_addr.rs | 119 +++++++++++++++++++++++++++++ src/utils/ipaddr.rs | 74 ------------------ src/utils/mod.rs | 6 +- src/utils/salt.rs | 3 + src/utils/sock_addr.rs | 144 ++++++++++++++++++++++++++++++++++++ src/utils/sockaddr.rs | 78 ------------------- src/utils/unix_sock_addr.rs | 73 ++++++++++++++++++ src/utils/unixsock.rs | 51 ------------- 9 files changed, 346 insertions(+), 207 deletions(-) create mode 100644 src/utils/ip_sock_addr.rs delete mode 100644 src/utils/ipaddr.rs create mode 100644 src/utils/sock_addr.rs delete mode 100644 src/utils/sockaddr.rs create mode 100644 src/utils/unix_sock_addr.rs delete mode 100644 src/utils/unixsock.rs diff --git a/src/utils/cidr.rs b/src/utils/cidr.rs index 79f59b5..6be607b 100644 --- a/src/utils/cidr.rs +++ b/src/utils/cidr.rs @@ -1,3 +1,6 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/cidrutil/cidr.go + use std::{ str::FromStr, net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -7,7 +10,7 @@ use std::{ use ipnetwork::IpNetwork; use super::{ - sockaddr::{new_sock_addr, SockAddrType, SockAddr}, + sock_addr::{new_sock_addr, SockAddrType, SockAddr}, }; use crate::errors::RvError; diff --git a/src/utils/ip_sock_addr.rs b/src/utils/ip_sock_addr.rs new file mode 100644 index 0000000..9e1e0da --- /dev/null +++ b/src/utils/ip_sock_addr.rs @@ -0,0 +1,119 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/ipv4addr.go + +use std::{ + fmt, + str::FromStr, + net::SocketAddr, +}; + +use as_any::Downcast; +use ipnetwork::IpNetwork; +use serde::{Deserialize, Serialize}; + +use super::{ + sock_addr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IpSockAddr { + pub addr: IpNetwork, + pub port: u16, +} + +impl IpSockAddr { + pub fn new(s: &str) -> Result { + if let Ok(sock_addr) = SocketAddr::from_str(s) { + return Ok(IpSockAddr { + addr: IpNetwork::from(sock_addr.ip()), + port: sock_addr.port(), + }); + } else if let Ok(ip_addr) = IpNetwork::from_str(s) { + return Ok(IpSockAddr { + addr: ip_addr, + port: 0, + }); + } + return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); + } +} + +impl SockAddr for IpSockAddr { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr.contains(ip_addr.addr.ip()); + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr == ip_addr.addr && self.port == ip_addr.port; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + match self.addr { + IpNetwork::V4(_) => SockAddrType::IPv4, + IpNetwork::V6(_) => SockAddrType::IPv6, + } + } +} + +impl fmt::Display for IpSockAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.port == 0 { + write!(f, "{}", self.addr.ip()) + } else { + write!(f, "{}:{}", self.addr.ip(), self.port) + } + } +} + +#[cfg(test)] +mod test { + use super::{ + *, super::sock_addr::{SockAddrType}, + }; + + #[test] + fn test_ip_sock_addr() { + let addr1 = IpSockAddr::new("1.1.1.1").unwrap(); + let addr2 = IpSockAddr::new("1.1.1.1").unwrap(); + let addr3 = IpSockAddr::new("2.2.2.2").unwrap(); + let addr4 = IpSockAddr::new("333.333.333.333"); + let addr5 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let addr6 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let addr7 = IpSockAddr::new("1.1.1.1:8080").unwrap(); + let addr8 = IpSockAddr::new("2.2.2.2:80").unwrap(); + let addr9 = IpSockAddr::new("192.168.0.0/16").unwrap(); + let addr10 = IpSockAddr::new("192.168.0.0/24").unwrap(); + let addr11 = IpSockAddr::new("192.168.0.1").unwrap(); + let addr12 = IpSockAddr::new("192.168.1.1").unwrap(); + + assert!(addr4.is_err()); + assert!(addr1.contains(&addr2)); + assert!(addr1.equal(&addr2)); + assert!(!addr1.contains(&addr3)); + assert!(!addr1.equal(&addr3)); + assert_eq!(addr1.sock_addr_type(), SockAddrType::IPv4); + assert!(addr5.contains(&addr6)); + assert!(addr5.equal(&addr6)); + assert!(!addr5.equal(&addr7)); + assert!(!addr5.equal(&addr8)); + assert!(addr9.contains(&addr10)); + assert!(addr9.contains(&addr11)); + assert!(addr9.contains(&addr12)); + assert!(!addr9.contains(&addr1)); + assert!(addr10.contains(&addr9)); + assert!(addr10.contains(&addr11)); + assert!(!addr10.contains(&addr12)); + assert!(!addr9.equal(&addr10)); + assert!(!addr9.equal(&addr11)); + } +} diff --git a/src/utils/ipaddr.rs b/src/utils/ipaddr.rs deleted file mode 100644 index 346e876..0000000 --- a/src/utils/ipaddr.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::{ - fmt, - str::FromStr, - net::SocketAddr, -}; - -use as_any::Downcast; -use ipnetwork::IpNetwork; -use serde::{Deserialize, Serialize}; - -use super::{ - sockaddr::{SockAddr, SockAddrType}, -}; - -use crate::errors::RvError; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct IpAddr { - pub addr: IpNetwork, - pub port: u16, -} - -impl IpAddr { - pub fn new(s: &str) -> Result { - if let Ok(sock_addr) = SocketAddr::from_str(s) { - return Ok(IpAddr { - addr: IpNetwork::from(sock_addr.ip()), - port: sock_addr.port(), - }); - } else if let Ok(ip_addr) = IpNetwork::from_str(s) { - return Ok(IpAddr { - addr: ip_addr, - port: 0, - }); - } - return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); - } -} - -impl SockAddr for IpAddr { - fn contains(&self, other: &dyn SockAddr) -> bool { - if let Some(ip_addr) = other.downcast_ref::() { - return self.addr.contains(ip_addr.addr.ip()); - } - - false - } - - fn equal(&self, other: &dyn SockAddr) -> bool { - if let Some(ip_addr) = other.downcast_ref::() { - return self.addr == ip_addr.addr && self.port == ip_addr.port; - } - - false - } - - fn sock_addr_type(&self) -> SockAddrType { - match self.addr { - IpNetwork::V4(_) => SockAddrType::IPv4, - IpNetwork::V6(_) => SockAddrType::IPv6, - } - } -} - -impl fmt::Display for IpAddr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.port == 0 { - write!(f, "{}", self.addr.ip()) - } else { - write!(f, "{}:{}", self.addr.ip(), self.port) - } - } -} - diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ff4eca2..7997a67 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -12,9 +12,9 @@ pub mod cert; pub mod key; pub mod salt; pub mod cidr; -pub mod sockaddr; -pub mod ipaddr; -pub mod unixsock; +pub mod sock_addr; +pub mod ip_sock_addr; +pub mod unix_sock_addr; pub fn generate_uuid() -> String { let mut buf = [0u8; 16]; diff --git a/src/utils/salt.rs b/src/utils/salt.rs index 438300b..ed5962a 100644 --- a/src/utils/salt.rs +++ b/src/utils/salt.rs @@ -1,3 +1,6 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/salt/salt.go + use openssl::{ hash::{hash, MessageDigest}, pkey::PKey, diff --git a/src/utils/sock_addr.rs b/src/utils/sock_addr.rs new file mode 100644 index 0000000..63574ca --- /dev/null +++ b/src/utils/sock_addr.rs @@ -0,0 +1,144 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/sockaddr.go + +use std::{ + fmt, + str::FromStr, +}; + +use as_any::AsAny; +use serde::{Deserialize, Serialize}; + +use super::{ + ip_sock_addr::IpSockAddr, +}; + +use crate::errors::RvError; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SockAddrType { + Unknown = 0x0, + Unix = 0x1, + IPv4 = 0x2, + IPv6 = 0x4, + // IP is the union of IPv4 and IPv6 + IP = 0x6, +} + +pub trait SockAddr: fmt::Display + AsAny { + // contains returns true if the other SockAddr is contained within the receiver + fn contains(&self, other: &dyn SockAddr) -> bool; + + // equal allows for the comparison of two SockAddrs + fn equal(&self, other: &dyn SockAddr) -> bool; + + fn sock_addr_type(&self) -> SockAddrType; +} + +pub struct SockAddrMarshaler { + pub sock_addr: Box, +} + +impl SockAddrMarshaler { + pub fn new(sock_addr: Box) -> Self { + SockAddrMarshaler { sock_addr } + } +} + +impl fmt::Display for SockAddrMarshaler { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.sock_addr) + } +} + +impl fmt::Display for SockAddrType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let type_str = match self { + SockAddrType::IPv4 => "IPv4", + SockAddrType::IPv6 => "IPv6", + SockAddrType::Unix => "Unix", + _ => "Unknown", + }; + write!(f, "{}", type_str) + } +} + +impl FromStr for SockAddrType { + type Err = RvError; + fn from_str(s: &str) -> Result { + match s { + "IPv4" | "ipv4" => Ok(SockAddrType::IPv4), + "IPv6" | "ipv6" => Ok(SockAddrType::IPv6), + "Unix" | "UNIX" | "unix" => Ok(SockAddrType::Unix), + _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())) + } + } +} + +pub fn new_sock_addr(s: &str) -> Result, RvError> { + let ret = IpSockAddr::new(s)?; + Ok(Box::new(ret)) +} + +#[cfg(test)] +mod test { + use super::{ + *, super::{ + sock_addr::{SockAddrType}, + ip_sock_addr::IpSockAddr, + unix_sock_addr::UnixSockAddr, + }, + }; + + #[test] + fn test_sock_addr() { + let unix_addr1 = UnixSockAddr::new("/tmp/bar").unwrap(); + let unix_addr2 = UnixSockAddr::new("/tmp/bar").unwrap(); + let unix_addr3 = UnixSockAddr::new("/tmp/foo").unwrap(); + let ip_addr1 = IpSockAddr::new("1.1.1.1").unwrap(); + let ip_addr2 = IpSockAddr::new("1.1.1.1").unwrap(); + let ip_addr3 = IpSockAddr::new("2.2.2.2").unwrap(); + let ip_addr4 = IpSockAddr::new("333.333.333.333"); + let ip_addr5 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let ip_addr6 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let ip_addr7 = IpSockAddr::new("1.1.1.1:8080").unwrap(); + let ip_addr8 = IpSockAddr::new("2.2.2.2:80").unwrap(); + let ip_addr9 = IpSockAddr::new("192.168.0.0/16").unwrap(); + let ip_addr10 = IpSockAddr::new("192.168.0.0/24").unwrap(); + let ip_addr11 = IpSockAddr::new("192.168.0.1").unwrap(); + let ip_addr12 = IpSockAddr::new("192.168.1.1").unwrap(); + + assert!(unix_addr1.contains(&unix_addr2)); + assert!(unix_addr1.equal(&unix_addr2)); + assert!(!unix_addr1.contains(&unix_addr3)); + assert!(!unix_addr1.equal(&unix_addr3)); + assert_ne!(unix_addr1.sock_addr_type(), ip_addr1.sock_addr_type()); + + assert!(ip_addr4.is_err()); + assert!(ip_addr1.contains(&ip_addr2)); + assert!(ip_addr1.equal(&ip_addr2)); + assert!(!ip_addr1.contains(&ip_addr3)); + assert!(!ip_addr1.equal(&ip_addr3)); + assert_eq!(ip_addr1.sock_addr_type(), SockAddrType::IPv4); + assert_eq!(ip_addr1.sock_addr_type(), ip_addr2.sock_addr_type()); + assert_ne!(ip_addr1.sock_addr_type(), unix_addr2.sock_addr_type()); + assert!(ip_addr5.contains(&ip_addr6)); + assert!(ip_addr5.equal(&ip_addr6)); + assert!(!ip_addr5.equal(&ip_addr7)); + assert!(!ip_addr5.equal(&ip_addr8)); + assert!(ip_addr9.contains(&ip_addr10)); + assert!(ip_addr9.contains(&ip_addr11)); + assert!(ip_addr9.contains(&ip_addr12)); + assert!(!ip_addr9.contains(&ip_addr1)); + assert!(ip_addr10.contains(&ip_addr9)); + assert!(ip_addr10.contains(&ip_addr11)); + assert!(!ip_addr10.contains(&ip_addr12)); + assert!(!ip_addr9.equal(&ip_addr10)); + assert!(!ip_addr9.equal(&ip_addr11)); + + assert!(!ip_addr1.contains(&unix_addr1)); + assert!(!ip_addr1.equal(&unix_addr1)); + assert!(!unix_addr1.contains(&ip_addr1)); + assert!(!unix_addr1.equal(&ip_addr1)); + } +} diff --git a/src/utils/sockaddr.rs b/src/utils/sockaddr.rs deleted file mode 100644 index 907ef6a..0000000 --- a/src/utils/sockaddr.rs +++ /dev/null @@ -1,78 +0,0 @@ -use std::{ - fmt, - str::FromStr, -}; - -use as_any::AsAny; -use serde::{Deserialize, Serialize}; - -use super::{ - ipaddr::IpAddr, -}; - -use crate::errors::RvError; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum SockAddrType { - Unknown = 0x0, - Unix = 0x1, - IPv4 = 0x2, - IPv6 = 0x4, - // IP is the union of IPv4 and IPv6 - IP = 0x6, -} - -pub trait SockAddr: fmt::Display + AsAny { - // contains returns true if the other SockAddr is contained within the receiver - fn contains(&self, other: &dyn SockAddr) -> bool; - - // equal allows for the comparison of two SockAddrs - fn equal(&self, other: &dyn SockAddr) -> bool; - - fn sock_addr_type(&self) -> SockAddrType; -} - -pub struct SockAddrMarshaler { - pub sock_addr: Box, -} - -impl SockAddrMarshaler { - pub fn new(sock_addr: Box) -> Self { - SockAddrMarshaler { sock_addr } - } -} - -impl fmt::Display for SockAddrMarshaler { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.sock_addr) - } -} - -impl fmt::Display for SockAddrType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let type_str = match self { - SockAddrType::IPv4 => "IPv4", - SockAddrType::IPv6 => "IPv6", - SockAddrType::Unix => "Unix", - _ => "Unknown", - }; - write!(f, "{}", type_str) - } -} - -impl FromStr for SockAddrType { - type Err = RvError; - fn from_str(s: &str) -> Result { - match s { - "IPv4" | "ipv4" => Ok(SockAddrType::IPv4), - "IPv6" | "ipv6" => Ok(SockAddrType::IPv6), - "Unix" | "UNIX" | "unix" => Ok(SockAddrType::Unix), - _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())) - } - } -} - -pub fn new_sock_addr(s: &str) -> Result, RvError> { - let ret = IpAddr::new(s)?; - Ok(Box::new(ret)) -} diff --git a/src/utils/unix_sock_addr.rs b/src/utils/unix_sock_addr.rs new file mode 100644 index 0000000..c7cc8ca --- /dev/null +++ b/src/utils/unix_sock_addr.rs @@ -0,0 +1,73 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/unixsock.go + +use std::fmt; +use as_any::Downcast; +use serde::{Deserialize, Serialize}; + +use super::{ + sock_addr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnixSockAddr { + pub path: String, +} + +impl UnixSockAddr { + pub fn new(s: &str) -> Result { + Ok(Self { + path: s.to_string(), + }) + } +} + +impl SockAddr for UnixSockAddr { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + SockAddrType::Unix + } +} + +impl fmt::Display for UnixSockAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path) + } +} + +#[cfg(test)] +mod test { + use super::{ + *, super::sock_addr::{SockAddrType}, + }; + + #[test] + fn test_unix_sock_addr() { + let addr1 = UnixSockAddr::new("/tmp/bar").unwrap(); + let addr2 = UnixSockAddr::new("/tmp/bar").unwrap(); + let addr3 = UnixSockAddr::new("/tmp/foo").unwrap(); + + assert!(addr1.contains(&addr2)); + assert!(addr1.equal(&addr2)); + assert!(!addr1.contains(&addr3)); + assert!(!addr1.equal(&addr3)); + assert_eq!(addr1.sock_addr_type(), SockAddrType::Unix); + } +} diff --git a/src/utils/unixsock.rs b/src/utils/unixsock.rs deleted file mode 100644 index 88a1034..0000000 --- a/src/utils/unixsock.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::fmt; -use as_any::Downcast; -use serde::{Deserialize, Serialize}; - -use super::{ - sockaddr::{SockAddr, SockAddrType}, -}; - -use crate::errors::RvError; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct UnixSock { - pub path: String, -} - -impl UnixSock { - pub fn new(s: &str) -> Result { - Ok(Self { - path: s.to_string(), - }) - } -} - -impl SockAddr for UnixSock { - fn contains(&self, other: &dyn SockAddr) -> bool { - if let Some(unix_sock) = other.downcast_ref::() { - return self.path == unix_sock.path; - } - - false - } - - fn equal(&self, other: &dyn SockAddr) -> bool { - if let Some(unix_sock) = other.downcast_ref::() { - return self.path == unix_sock.path; - } - - false - } - - fn sock_addr_type(&self) -> SockAddrType { - SockAddrType::Unix - } -} - -impl fmt::Display for UnixSock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.path) - } -} - From 6eac13138712e03f22e34922e755639307fca3f8 Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Mon, 6 May 2024 16:55:43 +0800 Subject: [PATCH 4/4] Switch to rust-tongsuo, supporting SM2 and SM4 algorithms. --- Cargo.toml | 9 +- build.rs | 7 + src/modules/pki/path_config_ca.rs | 2 +- src/modules/pki/path_fetch.rs | 2 +- src/modules/pki/path_issue.rs | 9 +- src/modules/pki/path_keys.rs | 79 +++++---- src/modules/pki/path_roles.rs | 95 ++++------- src/modules/pki/path_root.rs | 4 +- src/modules/pki/util.rs | 49 ++---- src/storage/physical/mock.rs | 4 +- src/utils/key.rs | 275 +++++++++++++++--------------- 11 files changed, 249 insertions(+), 286 deletions(-) create mode 100644 build.rs diff --git a/Cargo.toml b/Cargo.toml index ff005df..591410d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ RustyVault's RESTful API is designed to be fully compatible with Hashicorp Vault """ repository = "https://github.com/Tongsuo-Project/RustyVault" documentation = "https://docs.rs/rusty_vault/latest/rusty_vault/" +build = "build.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -24,8 +25,8 @@ serde_json = "^1.0" serde_bytes = "0.11" go-defer = "^0.1" rand = "^0.8" -openssl = "0.10" -openssl-sys = "0.9.92" +openssl = { version = "0.10" } +openssl-sys = { version = "0.9" } derivative = "2.2.0" enum-map = "2.6.1" strum = { version = "0.25", features = ["derive"] } @@ -60,6 +61,10 @@ serde_asn1_der = "0.8" base64 = "0.22" ipnetwork = "0.20" +[patch.crates-io] +openssl = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" } +openssl-sys = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" } + [features] storage_mysql = ["diesel", "r2d2", "r2d2-diesel"] diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..2c0ca9e --- /dev/null +++ b/build.rs @@ -0,0 +1,7 @@ +use std::env; + +fn main() { + if let Ok(_) = env::var("DEP_OPENSSL_TONGSUO") { + println!("cargo:rustc-cfg=tongsuo"); + } +} diff --git a/src/modules/pki/path_config_ca.rs b/src/modules/pki/path_config_ca.rs index 56b41d8..98120de 100644 --- a/src/modules/pki/path_config_ca.rs +++ b/src/modules/pki/path_config_ca.rs @@ -46,7 +46,7 @@ For security reasons, you can only view the certificate when reading this endpoi impl PkiBackendInner { pub fn write_path_ca(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let pem_bundle_value = req.get_data("pem_bundle")?; - let pem_bundle = pem_bundle_value.as_str().unwrap(); + let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let items = pem::parse_many(pem_bundle)?; let mut key_found = false; diff --git a/src/modules/pki/path_fetch.rs b/src/modules/pki/path_fetch.rs index 4e2d954..53c4550 100644 --- a/src/modules/pki/path_fetch.rs +++ b/src/modules/pki/path_fetch.rs @@ -122,7 +122,7 @@ impl PkiBackendInner { pub fn read_path_fetch_cert(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let serial_number_value = req.get_data("serial")?; - let serial_number = serial_number_value.as_str().unwrap(); + let serial_number = serial_number_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let serial_number_hex = serial_number.replace(":", "-").to_lowercase(); let cert = self.fetch_cert(req, &serial_number_hex)?; let ca_bundle = self.fetch_ca_bundle(req)?; diff --git a/src/modules/pki/path_issue.rs b/src/modules/pki/path_issue.rs index 1277e7e..b5521fd 100644 --- a/src/modules/pki/path_issue.rs +++ b/src/modules/pki/path_issue.rs @@ -65,13 +65,12 @@ requested common name is allowed by the role policy. impl PkiBackendInner { pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let role_value = req.get_data("role")?; - let role_name = role_value.as_str().unwrap(); + //let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let mut common_names = Vec::new(); let common_name_value = req.get_data("common_name")?; - let common_name = common_name_value.as_str().unwrap(); + let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } @@ -87,7 +86,7 @@ impl PkiBackendInner { } } - let role = self.get_role(req, &role_name)?; + let role = self.get_role(req, req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; if role.is_none() { return Err(RvError::ErrPkiRoleNotFound); } @@ -111,7 +110,7 @@ impl PkiBackendInner { let mut not_after = not_before + parse_duration("30d").unwrap(); let ttl_value = req.get_data("ttl")?; - let ttl = ttl_value.as_str().unwrap(); + let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if ttl != "" { let ttl_dur = parse_duration(ttl)?; let req_ttl_not_after_dur = SystemTime::now() + ttl_dur; diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index a2304b6..9149911 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -212,11 +212,10 @@ used for sign,verify,encrypt,decrypt. impl PkiBackendInner { pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); + let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_type_value = req.get_data("key_type")?; - let key_type = key_type_value.as_str().unwrap(); - let key_bits_value = req.get_data("key_bits")?; - let key_bits = key_bits_value.as_u64().unwrap(); + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; let mut export_private_key = false; if req.path.ends_with("/exported") { @@ -245,7 +244,7 @@ impl PkiBackendInner { if export_private_key { match key_type { - "rsa" | "ec" => { + "rsa" | "ec" | "sm2" => { resp_data.insert( "private_key".to_string(), Value::String(String::from_utf8_lossy(&key_bundle.key).to_string()), @@ -266,13 +265,13 @@ impl PkiBackendInner { pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); + let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let key_type_value = req.get_data("key_type")?; - let key_type = key_type_value.as_str().unwrap(); + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let pem_bundle_value = req.get_data("pem_bundle")?; - let pem_bundle = pem_bundle_value.as_str().unwrap(); + let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let hex_bundle_value = req.get_data("hex_bundle")?; - let hex_bundle = hex_bundle_value.as_str().unwrap(); + let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if pem_bundle.len() == 0 && hex_bundle.len() == 0 { return Err(RvError::ErrRequestFieldNotFound); @@ -292,7 +291,7 @@ impl PkiBackendInner { let rsa = Rsa::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = rsa.size() * 8; }, - "ec" => { + "ec" | "sm2" => { let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?; key_bundle.bits = ec_key.group().degree(); }, @@ -312,19 +311,25 @@ impl PkiBackendInner { } }; let iv_value = req.get_data("iv")?; - match key_type { - "aes-gcm" | "aes-cbc" => { - if let Some(iv) = iv_value.as_str() { - key_bundle.iv = hex::decode(&iv)?; - } else { - return Err(RvError::ErrRequestFieldNotFound); - } - }, - "aes-ecb" => {}, - _ => { - return Err(RvError::ErrPkiKeyTypeInvalid); + let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm"); + #[cfg(tongsuo)] + let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm"); + #[cfg(not(tongsuo))] + let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb"); + + // Check if the key type is valid, if not return an error. + if !is_valid_key_type { + return Err(RvError::ErrPkiKeyTypeInvalid); + } + + // Proceed to check IV only if required by the key type. + if is_iv_required { + if let Some(iv) = iv_value.as_str() { + key_bundle.iv = hex::decode(&iv)?; + } else { + return Err(RvError::ErrRequestFieldNotFound); } - }; + } } self.write_key(req, &key_bundle)?; @@ -343,12 +348,10 @@ impl PkiBackendInner { } pub fn key_sign(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.sign(&decoded_data)?; @@ -364,14 +367,12 @@ impl PkiBackendInner { } pub fn key_verify(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let signature_value = req.get_data("signature")?; - let signature = signature_value.as_str().unwrap(); + let signature = signature_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let decoded_signature = hex::decode(signature.as_bytes())?; @@ -388,14 +389,12 @@ impl PkiBackendInner { } pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let aad_value = req.get_data("aad")?; - let aad = aad_value.as_str().unwrap(); + let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; @@ -411,14 +410,12 @@ impl PkiBackendInner { } pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let key_name_value = req.get_data("key_name")?; - let key_name = key_name_value.as_str().unwrap(); let data_value = req.get_data("data")?; - let data = data_value.as_str().unwrap(); + let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; let aad_value = req.get_data("aad")?; - let aad = aad_value.as_str().unwrap(); + let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; - let key_bundle = self.fetch_key(req, key_name)?; + let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let decoded_data = hex::decode(data.as_bytes())?; let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; diff --git a/src/modules/pki/path_roles.rs b/src/modules/pki/path_roles.rs index 3d93234..d851772 100644 --- a/src/modules/pki/path_roles.rs +++ b/src/modules/pki/path_roles.rs @@ -318,30 +318,19 @@ impl PkiBackendInner { } pub fn read_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); - let role_entry = self.get_role(req, name)?; + let role_entry = self.get_role(req, req.get_data("name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; let data = serde_json::to_value(&role_entry)?; Ok(Some(Response::data_response(Some(data.as_object().unwrap().clone())))) } pub fn create_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); - let ttl_vale = req.get_data("ttl")?; - let ttl = { - let ttl_str = ttl_vale.as_str().unwrap(); - parse_duration(ttl_str)? - }; - let max_ttl_vale = req.get_data("max_ttl")?; - let max_ttl = { - let max_ttl_str = max_ttl_vale.as_str().unwrap(); - parse_duration(max_ttl_str)? - }; - let key_type_vale = req.get_data("key_type")?; - let key_type = key_type_vale.as_str().unwrap(); - let key_bits_vale = req.get_data("key_bits")?; - let mut key_bits = key_bits_vale.as_u64().unwrap(); + let name_value = req.get_data("name")?; + let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let max_ttl = parse_duration(req.get_data("max_ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let key_type_value = req.get_data("key_type")?; + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -366,48 +355,28 @@ impl PkiBackendInner { } } - let signature_bits_vale = req.get_data("signature_bits")?; - let signature_bits = signature_bits_vale.as_u64().unwrap(); - let allow_localhost_vale = req.get_data("allow_localhost")?; - let allow_localhost = allow_localhost_vale.as_bool().unwrap(); - let allow_bare_domain_vale = req.get_data("allow_bare_domains")?; - let allow_bare_domains = allow_bare_domain_vale.as_bool().unwrap(); - let allow_subdomains_vale = req.get_data("allow_subdomains")?; - let allow_subdomains = allow_subdomains_vale.as_bool().unwrap(); - let allow_any_name_vale = req.get_data("allow_any_name")?; - let allow_any_name = allow_any_name_vale.as_bool().unwrap(); - let allow_ip_sans_vale = req.get_data("allow_ip_sans")?; - let allow_ip_sans = allow_ip_sans_vale.as_bool().unwrap(); - let server_flag_vale = req.get_data("server_flag")?; - let server_flag = server_flag_vale.as_bool().unwrap(); - let client_flag_vale = req.get_data("client_flag")?; - let client_flag = client_flag_vale.as_bool().unwrap(); - let use_csr_sans_vale = req.get_data("use_csr_sans")?; - let use_csr_sans = use_csr_sans_vale.as_bool().unwrap(); - let use_csr_common_name_vale = req.get_data("use_csr_common_name")?; - let use_csr_common_name = use_csr_common_name_vale.as_bool().unwrap(); - let country_vale = req.get_data("country")?; - let country = country_vale.as_str().unwrap().to_string(); - let province_vale = req.get_data("province")?; - let province = province_vale.as_str().unwrap().to_string(); - let locality_vale = req.get_data("locality")?; - let locality = locality_vale.as_str().unwrap().to_string(); - let organization_vale = req.get_data("organization")?; - let organization = organization_vale.as_str().unwrap().to_string(); - let ou_vale = req.get_data("ou")?; - let ou = ou_vale.as_str().unwrap().to_string(); - let street_address_vale = req.get_data("street_address")?; - let street_address = street_address_vale.as_str().unwrap().to_string(); - let postal_code_vale = req.get_data("postal_code")?; - let postal_code = postal_code_vale.as_str().unwrap().to_string(); - let no_store_vale = req.get_data("no_store")?; - let no_store = no_store_vale.as_bool().unwrap(); - let generate_lease_vale = req.get_data("generate_lease")?; - let generate_lease = generate_lease_vale.as_bool().unwrap(); - let not_after_vale = req.get_data("not_after")?; - let not_after = not_after_vale.as_str().unwrap().to_string(); - let not_before_duration_vale = req.get_data("not_before_duration")?; - let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap()); + let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_localhost = req.get_data("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_bare_domains = req.get_data("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_subdomains = req.get_data("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_any_name = req.get_data("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let allow_ip_sans = req.get_data("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let server_flag = req.get_data("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let client_flag = req.get_data("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_sans = req.get_data("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_csr_common_name = req.get_data("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let no_store = req.get_data("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let generate_lease = req.get_data("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_before_duration = Duration::from_secs(not_before_duration_u64); let role_entry = RoleEntry { ttl, @@ -446,8 +415,8 @@ impl PkiBackendInner { } pub fn delete_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { - let name_vale = req.get_data("name")?; - let name = name_vale.as_str().unwrap(); + let name_value = req.get_data("name")?; + let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if name == "" { return Err(RvError::ErrRequestNoDataField); } diff --git a/src/modules/pki/path_root.rs b/src/modules/pki/path_root.rs index c74f5b4..e645163 100644 --- a/src/modules/pki/path_root.rs +++ b/src/modules/pki/path_root.rs @@ -46,9 +46,7 @@ impl PkiBackend { impl PkiBackendInner { pub fn generate_root(&self, _backend: &dyn Backend, req: &mut Request) -> Result, RvError> { let mut export_private_key = false; - let exported_vale = req.get_data("exported")?; - let exported = exported_vale.as_str().unwrap(); - if exported == "exported" { + if req.get_data("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" { export_private_key = true; } diff --git a/src/modules/pki/util.rs b/src/modules/pki/util.rs index eb7fd62..7f6bdc9 100644 --- a/src/modules/pki/util.rs +++ b/src/modules/pki/util.rs @@ -7,17 +7,12 @@ use super::path_roles::RoleEntry; use crate::{errors::RvError, logical::Request, utils::cert::Certificate}; pub fn get_role_params(req: &mut Request) -> Result { - let ttl_vale = req.get_data("ttl")?; - let ttl = { - let ttl_str = ttl_vale.as_str().unwrap(); - parse_duration(ttl_str)? - }; - let not_before_duration_vale = req.get_data("not_before_duration")?; - let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap()); - let key_type_vale = req.get_data("key_type")?; - let key_type = key_type_vale.as_str().unwrap(); - let key_bits_vale = req.get_data("key_bits")?; - let mut key_bits = key_bits_vale.as_u64().unwrap(); + let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?; + let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let not_before_duration = Duration::from_secs(not_before_duration_u64); + let key_type_value = req.get_data("key_type")?; + let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; + let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; match key_type { "rsa" => { if key_bits == 0 { @@ -42,26 +37,16 @@ pub fn get_role_params(req: &mut Request) -> Result { } } - let signature_bits_vale = req.get_data("signature_bits")?; - let signature_bits = signature_bits_vale.as_u64().unwrap(); - let use_pss_value = req.get_data("use_pss")?; - let use_pss = use_pss_value.as_bool().unwrap(); - let country_vale = req.get_data("country")?; - let country = country_vale.as_str().unwrap().to_string(); - let province_vale = req.get_data("province")?; - let province = province_vale.as_str().unwrap().to_string(); - let locality_vale = req.get_data("locality")?; - let locality = locality_vale.as_str().unwrap().to_string(); - let organization_vale = req.get_data("organization")?; - let organization = organization_vale.as_str().unwrap().to_string(); - let ou_vale = req.get_data("ou")?; - let ou = ou_vale.as_str().unwrap().to_string(); - let street_address_vale = req.get_data("street_address")?; - let street_address = street_address_vale.as_str().unwrap().to_string(); - let postal_code_vale = req.get_data("postal_code")?; - let postal_code = postal_code_vale.as_str().unwrap().to_string(); - let not_after_vale = req.get_data("not_after")?; - let not_after = not_after_vale.as_str().unwrap().to_string(); + let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?; + let use_pss = req.get_data("use_pss")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?; + let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); + let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string(); let role_entry = RoleEntry { ttl, @@ -88,7 +73,7 @@ pub fn generate_certificate(role_entry: &RoleEntry, req: &mut Request) -> Result let mut common_names = Vec::new(); let common_name_value = req.get_data("common_name")?; - let common_name = common_name_value.as_str().unwrap(); + let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?; if common_name != "" { common_names.push(common_name.to_string()); } diff --git a/src/storage/physical/mock.rs b/src/storage/physical/mock.rs index d568d7d..b543281 100644 --- a/src/storage/physical/mock.rs +++ b/src/storage/physical/mock.rs @@ -4,7 +4,7 @@ use super::{Backend, BackendEntry}; use crate::errors::RvError; #[derive(Default)] -pub struct MockBackend(u32); +pub struct MockBackend(()); impl Backend for MockBackend { fn list(&self, _prefix: &str) -> Result, RvError> { @@ -26,6 +26,6 @@ impl Backend for MockBackend { impl MockBackend { pub fn new() -> Self { - MockBackend(0) + MockBackend(()) } } diff --git a/src/utils/key.rs b/src/utils/key.rs index e6f8b97..9c86853 100644 --- a/src/utils/key.rs +++ b/src/utils/key.rs @@ -42,6 +42,25 @@ impl Default for KeyBundle { } } +fn cipher_from_key_type_and_bits(key_type: &str, bits: u32) -> Result { + match (key_type, bits) { + ("aes-gcm", 128) => Ok(Cipher::aes_128_gcm()), + ("aes-gcm", 192) => Ok(Cipher::aes_192_gcm()), + ("aes-gcm", 256) => Ok(Cipher::aes_256_gcm()), + ("aes-cbc", 128) => Ok(Cipher::aes_128_cbc()), + ("aes-cbc", 192) => Ok(Cipher::aes_192_cbc()), + ("aes-cbc", 256) => Ok(Cipher::aes_256_cbc()), + ("aes-ecb", 128) => Ok(Cipher::aes_128_ecb()), + ("aes-ecb", 192) => Ok(Cipher::aes_192_ecb()), + ("aes-ecb", 256) => Ok(Cipher::aes_256_ecb()), + #[cfg(tongsuo)] + ("sm4-gcm", 128) => Ok(Cipher::sm4_gcm()), + #[cfg(tongsuo)] + ("sm4-ccm", 128) => Ok(Cipher::sm4_ccm()), + _ => Err(RvError::ErrPkiKeyBitsInvalid), + } +} + impl KeyBundle { pub fn new(name: &str, key_type: &str, key_bits: u32) -> Self { Self { name: name.to_string(), key_type: key_type.to_string(), bits: key_bits, ..KeyBundle::default() } @@ -51,12 +70,13 @@ impl KeyBundle { let key_bits = self.bits; let priv_key = match self.key_type.as_str() { "rsa" => { - if key_bits != 2048 && key_bits != 3072 && key_bits != 4096 { - return Err(RvError::ErrPkiKeyBitsInvalid); + match key_bits { + 2048 | 3072 | 4096 => { + let rsa_key = Rsa::generate(key_bits)?; + PKey::from_rsa(rsa_key)?.private_key_to_pem_pkcs8()? + }, + _ => return Err(RvError::ErrPkiKeyBitsInvalid), } - let rsa_key = Rsa::generate(key_bits)?; - let pkey = PKey::from_rsa(rsa_key)?; - pkey.private_key_to_pem_pkcs8()? } "ec" => { let curve_name = match key_bits { @@ -64,30 +84,43 @@ impl KeyBundle { 256 => Nid::SECP256K1, 384 => Nid::SECP384R1, 521 => Nid::SECP521R1, - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } + _ => return Err(RvError::ErrPkiKeyBitsInvalid), }; let ec_group = EcGroup::from_curve_name(curve_name)?; - let ec_key = EcKey::generate(ec_group.as_ref())?; - let pkey = PKey::from_ec_key(ec_key)?; - pkey.private_key_to_pem_pkcs8()? - } - "aes-gcm" | "aes-cbc" | "aes-ecb" => { - if key_bits != 128 && key_bits != 192 && key_bits != 256 { - return Err(RvError::ErrPkiKeyBitsInvalid); + let ec_key = EcKey::generate(&ec_group)?; + PKey::from_ec_key(ec_key)?.private_key_to_pem_pkcs8()? + }, + #[cfg(tongsuo)] + "sm2" => { + self.bits = 256; + let ec_group = EcGroup::from_curve_name(Nid::SM2)?; + let ec_key = EcKey::generate(&ec_group)?; + PKey::from_ec_key(ec_key)?.private_key_to_pem_pkcs8()? + }, + "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm" => { + let _ = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + + #[cfg(not(tongsuo))] + if self.key_type.starts_with("sm4-") { + return Err(RvError::ErrPkiKeyTypeInvalid); } - if self.key_type.as_str() != "aes-ecb" { - let mut iv_bytes = vec![0u8; 16]; - rand_bytes(&mut iv_bytes)?; - self.iv = iv_bytes; + match self.key_type.as_str() { + "aes-ecb" => (), + "sm4-ccm" => { + self.iv = vec![0u8; 12]; + rand_bytes(&mut self.iv)?; + } + _ => { + self.iv = vec![0u8; 16]; + rand_bytes(&mut self.iv)?; + } } - let mut random_bytes = vec![0u8; (key_bits / 8) as usize]; - rand_bytes(&mut random_bytes)?; - random_bytes - } + let mut key = vec![0u8; key_bits as usize / 8]; + rand_bytes(&mut key)?; + key + }, _ => { return Err(RvError::ErrPkiKeyTypeInvalid); } @@ -99,95 +132,67 @@ impl KeyBundle { } pub fn sign(&self, data: &[u8]) -> Result, RvError> { - match self.key_type.as_str() { - "rsa" => { - let rsa = Rsa::private_key_from_pem(&self.key)?; - let pkey = PKey::from_rsa(rsa)?; - let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; - signer.set_rsa_padding(Padding::PKCS1)?; - signer.update(data)?; - return Ok(signer.sign_to_vec()?); - } - "ec" => { - let ec_key = EcKey::private_key_from_pem(&self.key)?; - let pkey = PKey::from_ec_key(ec_key)?; - let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; - signer.update(data)?; - return Ok(signer.sign_to_vec()?); - } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + let digest = match self.key_type.as_str() { + "rsa" | "ec" => MessageDigest::sha256(), + #[cfg(tongsuo)] + "sm2" => MessageDigest::sm3(), + _ => return Err(RvError::ErrPkiKeyOperationInvalid), + }; + + let pkey = PKey::private_key_from_pem(&self.key)?; + + let mut signer = Signer::new(digest, &pkey)?; + if self.key_type == "rsa" { + signer.set_rsa_padding(Padding::PKCS1)?; } + + signer.update(data)?; + signer.sign_to_vec().map_err(From::from) } pub fn verify(&self, data: &[u8], signature: &[u8]) -> Result { - match self.key_type.as_str() { - "rsa" => { - let rsa = Rsa::private_key_from_pem(&self.key)?; - let pkey = PKey::from_rsa(rsa)?; - let mut verifier = Verifier::new(MessageDigest::sha256(), &pkey)?; - verifier.set_rsa_padding(Padding::PKCS1)?; - verifier.update(data)?; - return Ok(verifier.verify(signature).unwrap_or(false)); - } - "ec" => { - let ec_key = EcKey::private_key_from_pem(&self.key)?; - let pkey = PKey::from_ec_key(ec_key)?; - let mut verifier = Verifier::new(MessageDigest::sha256(), &pkey)?; - verifier.update(data)?; - return Ok(verifier.verify(signature).unwrap_or(false)); - } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + let digest = match self.key_type.as_str() { + "rsa" | "ec" => MessageDigest::sha256(), + #[cfg(tongsuo)] + "sm2" => MessageDigest::sm3(), + _ => return Err(RvError::ErrPkiKeyOperationInvalid), + }; + + let pkey = PKey::private_key_from_pem(&self.key)?; + + let mut verifier = Verifier::new(digest, &pkey)?; + if self.key_type == "rsa" { + verifier.set_rsa_padding(Padding::PKCS1)?; } + + verifier.update(data)?; + Ok(verifier.verify(signature).unwrap_or(false)) } pub fn encrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { match self.key_type.as_str() { - "aes-gcm" => { + "aes-gcm" | "sm4-gcm" | "sm4-ccm" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; let aad = extra.map_or("".as_bytes(), |ex| match ex { EncryptExtraData::Aad(aad) => aad, _ => "".as_bytes(), }); - let cipher = match self.bits { - 128 => Cipher::aes_128_gcm(), - 192 => Cipher::aes_192_gcm(), - 256 => Cipher::aes_256_gcm(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; let mut tag = vec![0u8; 16]; - let mut ciphertext = - encrypt_aead(cipher, &self.key, Some(&self.iv), aad, data, &mut tag)?; + let mut ciphertext = encrypt_aead( + cipher, + &self.key, + Some(&self.iv), + aad, + data, + &mut tag, + )?; ciphertext.extend_from_slice(&tag); Ok(ciphertext) } - "aes-cbc" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_cbc(), - 192 => Cipher::aes_192_cbc(), - 256 => Cipher::aes_256_cbc(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(encrypt(cipher, &self.key, Some(&self.iv), data)?) - } - "aes-ecb" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_ecb(), - 192 => Cipher::aes_192_ecb(), - 256 => Cipher::aes_256_ecb(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(encrypt(cipher, &self.key, None, data)?) + "aes-cbc" | "aes-ecb" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + let iv = if self.key_type == "aes-ecb" { None } else { Some(self.iv.as_slice()) }; + Ok(encrypt(cipher, &self.key, iv, data)?) } "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; @@ -209,54 +214,31 @@ impl KeyBundle { return Ok(buf); } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + _ => Err(RvError::ErrPkiKeyOperationInvalid), } } pub fn decrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { + match self.key_type.as_str() { - "aes-gcm" => { + "aes-gcm" | "sm4-gcm" | "sm4-ccm" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; let aad = extra.map_or("".as_bytes(), |ex| match ex { EncryptExtraData::Aad(aad) => aad, _ => "".as_bytes(), }); - let cipher = match self.bits { - 128 => Cipher::aes_128_gcm(), - 192 => Cipher::aes_192_gcm(), - 256 => Cipher::aes_256_gcm(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - let (ciphertext, tag) = data.split_at(data.len() - 16); + let tag_len = 16; + if data.len() < tag_len { + return Err(RvError::ErrPkiInternal); + } + let (ciphertext, tag) = data.split_at(data.len() - tag_len); Ok(decrypt_aead(cipher, &self.key, Some(&self.iv), aad, ciphertext, tag)?) - } - "aes-cbc" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_cbc(), - 192 => Cipher::aes_192_cbc(), - 256 => Cipher::aes_256_cbc(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(decrypt(cipher, &self.key, Some(&self.iv), data)?) - } - "aes-ecb" => { - let cipher = match self.bits { - 128 => Cipher::aes_128_ecb(), - 192 => Cipher::aes_192_ecb(), - 256 => Cipher::aes_256_ecb(), - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - }; - - Ok(decrypt(cipher, &self.key, None, data)?) - } + }, + "aes-cbc" | "aes-ecb" => { + let cipher = cipher_from_key_type_and_bits(self.key_type.as_str(), self.bits)?; + let iv = if self.key_type == "aes-ecb" { None } else { Some(self.iv.as_slice()) }; + Ok(decrypt(cipher, &self.key, iv, data)?) + }, "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; if data.len() > rsa.size() as usize { @@ -284,9 +266,7 @@ impl KeyBundle { return Ok(buf); } - _ => { - return Err(RvError::ErrPkiKeyOperationInvalid); - } + _ => Err(RvError::ErrPkiKeyOperationInvalid), } } } @@ -348,6 +328,13 @@ mod test { test_key_sign_verify(&mut key_bundle); } + #[test] + #[cfg(tongsuo)] + fn test_sm2_key_operation() { + let mut key_bundle = KeyBundle::new("sm2", "sm2", 256); + test_key_sign_verify(&mut key_bundle); + } + #[test] fn test_aes_key_operation() { // test aes-gcm @@ -381,4 +368,20 @@ mod test { let mut key_bundle = KeyBundle::new("aes-ecb-256", "aes-ecb", 256); test_key_encrypt_decrypt(&mut key_bundle, None); } + + #[test] + #[cfg(tongsuo)] + fn test_sm4_key_operation() { + // test sm4-gcm + let mut key_bundle = KeyBundle::new("sm4-gcm-128", "sm4-gcm", 128); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); + + // test sm4-ccm + let mut key_bundle = KeyBundle::new("sm4-ccm-128", "sm4-ccm", 128); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, None); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); + } }