diff --git a/Cargo.toml b/Cargo.toml index a00990c..ff005df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ ureq = "2.9" glob = "0.3" serde_asn1_der = "0.8" base64 = "0.22" +ipnetwork = "0.20" [features] storage_mysql = ["diesel", "r2d2", "r2d2-diesel"] diff --git a/src/errors.rs b/src/errors.rs index 80182d6..764eeba 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -229,6 +229,17 @@ pub enum RvError { #[error("RwLock was poisoned (writing)")] ErrRwLockWritePoison, + #[error("Some net addr parse error happened, {:?}", .source)] + AddrParseError { + #[from] + source: std::net::AddrParseError, + }, + #[error("Some ipnetwork error happened, {:?}", .source)] + IpNetworkError { + #[from] + source: ipnetwork::IpNetworkError, + }, + /// Database Errors Begin /// #[error("Database type is not support now. Please try postgressql or mysql again.")] diff --git a/src/utils/cidr.rs b/src/utils/cidr.rs new file mode 100644 index 0000000..6be607b --- /dev/null +++ b/src/utils/cidr.rs @@ -0,0 +1,309 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/cidrutil/cidr.go + +use std::{ + str::FromStr, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + collections::HashSet, +}; + +use ipnetwork::IpNetwork; + +use super::{ + sock_addr::{new_sock_addr, SockAddrType, SockAddr}, +}; + +use crate::errors::RvError; + +pub fn is_ip_addr(addr: &dyn SockAddr) -> bool { + (addr.sock_addr_type() as u8 & SockAddrType::IP as u8) != 0 +} + +pub fn remote_addr_is_ok(remote_addr: &str, bound_cidrs: &[Box]) -> bool { + if bound_cidrs.len() == 0 { + return true; + } + + if let Ok(addr) = new_sock_addr(remote_addr) { + for cidr in bound_cidrs.iter() { + if is_ip_addr(cidr.as_ref()) && cidr.contains(addr.as_ref()) { + return true; + } + } + } + + false +} + +pub fn ip_belongs_to_cidr(ip_addr: &str, cidr: &str) -> Result { + if ip_addr == "" { + return Err(RvError::ErrResponse("missing IP address".to_string())); + } + + let ip = IpAddr::from_str(ip_addr)?; + let ipnet = IpNetwork::from_str(cidr)?; + + Ok(ipnet.contains(ip)) +} + +pub fn ip_belongs_to_cidrs(ip_addr: &str, cidrs: &[&str]) -> Result { + if ip_addr == "" { + return Err(RvError::ErrResponse("missing IP address".to_string())); + } + + if cidrs.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks to be checked against".to_string())); + } + + for cidr in cidrs.iter() { + if ip_belongs_to_cidr(ip_addr, cidr)? { + return Ok(true); + } + } + + Ok(false) +} + +pub fn validate_cidr_string(cidr_list: &str, separator: &str) -> Result { + if cidr_list == "" { + return Err(RvError::ErrResponse("missing CIDR list that needs validation".to_string())); + } + + if separator == "" { + return Err(RvError::ErrResponse("missing separator".to_string())); + } + + let cidrs_set: HashSet<&str> = cidr_list.split(separator) + .map(|cidr| cidr.trim()) + .filter(|cidr| !cidr.is_empty()) + .collect(); + + let cidrs: Vec<&str> = cidrs_set.into_iter().collect(); + + validate_cidrs(&cidrs) +} + +pub fn validate_cidrs(cidrs: &[&str]) -> Result { + if cidrs.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks that needs validation".to_string())); + } + + for cidr in cidrs.iter() { + let _ = IpNetwork::from_str(cidr)?; + } + + Ok(true) +} + +pub fn subset(cidr1: &str, cidr2: &str) -> Result { + if cidr1 == "" { + return Err(RvError::ErrResponse("missing CIDR to be checked against".to_string())); + } + + if cidr2 == "" { + return Err(RvError::ErrResponse("missing CIDR that needs to be checked".to_string())); + } + + let ipnet1 = IpNetwork::from_str(cidr1)?; + let mask_len1 = ipnet1.prefix(); + + if !is_ip_addr_zero(&ipnet1.ip()) && mask_len1 == 0 { + return Err(RvError::ErrResponse("CIDR to be checked against is not in its canonical form".to_string())); + } + + let ipnet2 = IpNetwork::from_str(cidr2)?; + let mask_len2 = ipnet2.prefix(); + + if !is_ip_addr_zero(&ipnet2.ip()) && mask_len2 == 0 { + return Err(RvError::ErrResponse("CIDR that needs to be checked is not in its canonical form".to_string())); + } + + /* + * If the mask length of the CIDR that needs to be checked is smaller + * then the mask length of the CIDR to be checked against, then the + * former will encompass more IPs than the latter, and hence can't be a + * subset of the latter. + */ + if mask_len2 < mask_len1 { + return Ok(false); + } + + Ok(ipnet1.contains(ipnet2.ip())) +} + +/* + * subset_blocks checks if each CIDR block of a given set of CIDR blocks, is a + * subset of at least one CIDR block belonging to another set of CIDR blocks. + * First parameter is the set of CIDR blocks to check against and the second + * parameter is the set of CIDR blocks that needs to be checked. + */ +pub fn subset_blocks(cidr_blocks1: &[&str], cidr_blocks2: &[&str]) -> Result { + if cidr_blocks1.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks to be checked against".to_string())); + } + + if cidr_blocks2.len() == 0 { + return Err(RvError::ErrResponse("missing CIDR blocks that needs to be checked".to_string())); + } + + // Check if all the elements of cidr_blocks2 is a subset of at least one + // element of cidr_blocks1 + for cidr_block2 in cidr_blocks2.iter() { + let mut is_subset = false; + for cidr_block1 in cidr_blocks1.iter() { + if subset(cidr_block1, cidr_block2)? { + is_subset = true; + break; + } + } + + if !is_subset { + return Ok(false); + } + } + + Ok(true) +} + +fn is_ip_addr_zero(ip_addr: &IpAddr) -> bool { + match *ip_addr { + IpAddr::V4(addr) => addr == Ipv4Addr::UNSPECIFIED, + IpAddr::V6(addr) => addr == Ipv6Addr::UNSPECIFIED, + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cidr_ip_belongs_to_cidr() { + let ip = "192.168.25.30"; + let cidr = "192.168.25.30/16"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "10.197.192.6"; + let cidr = "10.197.192.0/18"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "192.168.25.30"; + let cidr = "192.168.26.30/24"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_ok()); + assert!(!belongs.unwrap()); + + let ip = "192.168.25.30.100"; + let cidr = "192.168.26.30/24"; + let belongs = ip_belongs_to_cidr(ip, cidr); + assert!(belongs.is_err()); + } + + #[test] + fn test_cidr_ip_belongs_to_cidrs() { + let ip = "192.168.27.29"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_ok()); + assert!(belongs.unwrap()); + + let ip = "192.168.27.29"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_err()); + + let ip = "30.40.50.60"; + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let belongs = ip_belongs_to_cidrs(ip, &cidrs); + assert!(belongs.is_ok()); + assert!(!belongs.unwrap()); + } + + #[test] + fn test_cidr_validate_cidr_string() { + let cidr = "172.169.100.200/18,192.168.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidr = "172.169.100.200,192.168.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidr = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"; + let valid = validate_cidr_string(cidr, ","); + assert!(valid.is_err()); + } + + #[test] + fn test_cidr_validate_cidrs() { + let cidrs = vec!["172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidrs = vec!["172.169.100.200", "192.168.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_ok()); + assert!(valid.unwrap()); + + let cidrs = vec!["172.169.100.200/18", "192.168.0.0.0/16", "10.10.20.20/24"]; + let valid = validate_cidrs(&cidrs); + assert!(valid.is_err()); + } + + #[test] + fn test_cidr_subset() { + let cidr1 = "192.168.27.29/24"; + let cidr2 = "192.168.27.29/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + + let cidr1 = "192.168.27.29/16"; + let cidr2 = "192.168.27.29/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + let ret = subset(cidr2, cidr1); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + + let cidr1 = "192.168.0.128/25"; + let cidr2 = "192.168.0.0/24"; + let ret = subset(cidr1, cidr2); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + let ret = subset(cidr2, cidr1); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + } + + #[test] + fn test_cidr_subset_blocks() { + let cidrs1 = vec!["192.168.27.29/16", "172.245.30.40/24", "10.20.30.40/30"]; + let cidrs2 = vec!["192.168.27.29/20", "172.245.30.40/25", "10.20.30.40/32"]; + let ret = subset_blocks(&cidrs1, &cidrs2); + assert!(ret.is_ok()); + assert!(ret.unwrap()); + + let cidrs1 = vec!["192.168.27.29/16", "172.245.30.40/25", "10.20.30.40/30"]; + let cidrs2 = vec!["192.168.27.29/20", "172.245.30.40/24", "10.20.30.40/32"]; + let ret = subset_blocks(&cidrs1, &cidrs2); + assert!(ret.is_ok()); + assert!(!ret.unwrap()); + } + + #[test] + fn test_cidr_remote_addr_is_ok() { + let addr = new_sock_addr("127.0.0.1/8"); + assert!(addr.is_ok()); + let bound_cidrs = vec![addr.unwrap()]; + assert!(!remote_addr_is_ok("123.0.0.1", &bound_cidrs)); + assert!(remote_addr_is_ok("127.0.0.1", &bound_cidrs)); + } +} diff --git a/src/utils/ip_sock_addr.rs b/src/utils/ip_sock_addr.rs new file mode 100644 index 0000000..9e1e0da --- /dev/null +++ b/src/utils/ip_sock_addr.rs @@ -0,0 +1,119 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/ipv4addr.go + +use std::{ + fmt, + str::FromStr, + net::SocketAddr, +}; + +use as_any::Downcast; +use ipnetwork::IpNetwork; +use serde::{Deserialize, Serialize}; + +use super::{ + sock_addr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IpSockAddr { + pub addr: IpNetwork, + pub port: u16, +} + +impl IpSockAddr { + pub fn new(s: &str) -> Result { + if let Ok(sock_addr) = SocketAddr::from_str(s) { + return Ok(IpSockAddr { + addr: IpNetwork::from(sock_addr.ip()), + port: sock_addr.port(), + }); + } else if let Ok(ip_addr) = IpNetwork::from_str(s) { + return Ok(IpSockAddr { + addr: ip_addr, + port: 0, + }); + } + return Err(RvError::ErrResponse(format!("Unable to parse {} to an IP address:", s))); + } +} + +impl SockAddr for IpSockAddr { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr.contains(ip_addr.addr.ip()); + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(ip_addr) = other.downcast_ref::() { + return self.addr == ip_addr.addr && self.port == ip_addr.port; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + match self.addr { + IpNetwork::V4(_) => SockAddrType::IPv4, + IpNetwork::V6(_) => SockAddrType::IPv6, + } + } +} + +impl fmt::Display for IpSockAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.port == 0 { + write!(f, "{}", self.addr.ip()) + } else { + write!(f, "{}:{}", self.addr.ip(), self.port) + } + } +} + +#[cfg(test)] +mod test { + use super::{ + *, super::sock_addr::{SockAddrType}, + }; + + #[test] + fn test_ip_sock_addr() { + let addr1 = IpSockAddr::new("1.1.1.1").unwrap(); + let addr2 = IpSockAddr::new("1.1.1.1").unwrap(); + let addr3 = IpSockAddr::new("2.2.2.2").unwrap(); + let addr4 = IpSockAddr::new("333.333.333.333"); + let addr5 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let addr6 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let addr7 = IpSockAddr::new("1.1.1.1:8080").unwrap(); + let addr8 = IpSockAddr::new("2.2.2.2:80").unwrap(); + let addr9 = IpSockAddr::new("192.168.0.0/16").unwrap(); + let addr10 = IpSockAddr::new("192.168.0.0/24").unwrap(); + let addr11 = IpSockAddr::new("192.168.0.1").unwrap(); + let addr12 = IpSockAddr::new("192.168.1.1").unwrap(); + + assert!(addr4.is_err()); + assert!(addr1.contains(&addr2)); + assert!(addr1.equal(&addr2)); + assert!(!addr1.contains(&addr3)); + assert!(!addr1.equal(&addr3)); + assert_eq!(addr1.sock_addr_type(), SockAddrType::IPv4); + assert!(addr5.contains(&addr6)); + assert!(addr5.equal(&addr6)); + assert!(!addr5.equal(&addr7)); + assert!(!addr5.equal(&addr8)); + assert!(addr9.contains(&addr10)); + assert!(addr9.contains(&addr11)); + assert!(addr9.contains(&addr12)); + assert!(!addr9.contains(&addr1)); + assert!(addr10.contains(&addr9)); + assert!(addr10.contains(&addr11)); + assert!(!addr10.contains(&addr12)); + assert!(!addr9.equal(&addr10)); + assert!(!addr9.equal(&addr11)); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d77120a..508cd85 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -13,6 +13,11 @@ use crate::errors::RvError; pub mod cert; pub mod key; +pub mod salt; +pub mod cidr; +pub mod sock_addr; +pub mod ip_sock_addr; +pub mod unix_sock_addr; pub fn generate_uuid() -> String { let mut buf = [0u8; 16]; diff --git a/src/utils/salt.rs b/src/utils/salt.rs new file mode 100644 index 0000000..ed5962a --- /dev/null +++ b/src/utils/salt.rs @@ -0,0 +1,208 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/vault/blob/main/sdk/helper/salt/salt.go + +use openssl::{ + hash::{hash, MessageDigest}, + pkey::PKey, + nid::Nid, + sign::Signer, +}; +use derivative::Derivative; + +use super::{ + generate_uuid, +}; + +use crate::{ + storage::{Storage, StorageEntry}, + errors::RvError, +}; + +static DEFAULT_LOCATION: &str = "salt"; + +#[derive(Debug, Clone)] +pub struct Salt { + pub config: Config, + pub salt: String, + pub generated: bool, +} + +#[derive(Derivative)] +#[derivative(Debug, Clone)] +pub struct Config { + pub location: String, + #[derivative(Debug="ignore")] + pub hash_type: MessageDigest, + #[derivative(Debug="ignore")] + pub hmac_type: MessageDigest, +} + +impl Default for Salt { + fn default() -> Self { + Self { + salt: generate_uuid(), + generated: true, + config: Config::default(), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + location: DEFAULT_LOCATION.to_string(), + hash_type: MessageDigest::sha256(), + hmac_type: MessageDigest::sha256(), + } + } +} + +impl Salt { + pub fn new(storage: Option<&dyn Storage>, config: Option<&Config>) -> Result { + let mut salt = Salt::default(); + if let Some(c) = config { + if salt.config.location != c.location && c.location != "" { + salt.config.location = c.location.clone(); + } + + if salt.config.hash_type != c.hash_type { + salt.config.hash_type = c.hash_type.clone(); + } + + if salt.config.hmac_type != c.hmac_type { + salt.config.hmac_type = c.hmac_type.clone(); + } + } + + if let Some(s) = storage { + if let Some(raw) = s.get(&salt.config.location)? { + salt.salt = String::from_utf8_lossy(&raw.value).to_string(); + salt.generated = false; + } else { + let entry = StorageEntry { + key: salt.config.location.clone(), + value: salt.salt.as_bytes().to_vec(), + }; + + s.put(&entry)?; + } + } + + Ok(salt) + } + + pub fn new_nonpersistent() -> Self { + let mut salt = Salt::default(); + salt.config.location = "".to_string(); + salt + } + + pub fn get_hmac(&self, data: &str) -> Result { + let pkey = PKey::hmac(self.salt.as_bytes())?; + let mut signer = Signer::new(self.config.hmac_type, &pkey)?; + signer.update(data.as_bytes())?; + let hmac = signer.sign_to_vec()?; + Ok(hex::encode(hmac.as_slice())) + } + + pub fn get_identified_hamc(&self, data: &str) -> Result { + let hmac_type = match self.config.hmac_type.type_() { + Nid::SHA256 => "hmac-sha256", + Nid::SM3 => "hmac-sm3", + Nid::MD5 => "hmac-md5", + _ => "hmac-unknown", + }; + + let hmac = self.get_hmac(data)?; + + Ok(format!("{}:{}", hmac_type, hmac)) + } + + pub fn get_hash(&self, data: &str) -> Result { + let ret = hash(self.config.hash_type, data.as_bytes())?; + let bytes = ret.to_vec(); + Ok(hex::encode(bytes.as_slice())) + } + + pub fn salt_id(&self, id: &str) -> Result { + let comb = format!("{}{}", self.salt, id); + self.get_hash(&comb) + } + + pub fn did_generate(&self) -> bool { + self.generated + } +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, env, fs, sync::Arc}; + use go_defer::defer; + use rand::{thread_rng, Rng}; + use serde_json::Value; + use super::*; + use crate::{ + storage::{ + physical, barrier_view, barrier_aes_gcm, + barrier::SecurityBarrier, + } + }; + + #[test] + fn test_salt() { + // init the storage + let dir = env::temp_dir().join("rusty_vault_test_salt"); + assert!(fs::create_dir(&dir).is_ok()); + defer! ( + assert!(fs::remove_dir_all(&dir).is_ok()); + ); + + let mut conf: HashMap = HashMap::new(); + conf.insert("path".to_string(), Value::String(dir.to_string_lossy().into_owned())); + + let mut key = vec![0u8; 32]; + thread_rng().fill(key.as_mut_slice()); + + let backend = physical::new_backend("file", &conf); + assert!(backend.is_ok()); + let backend = backend.unwrap(); + let aes_gcm_view = barrier_aes_gcm::AESGCMBarrier::new(Arc::clone(&backend)); + + let init = aes_gcm_view.init(key.as_slice()); + assert!(init.is_ok()); + + assert!(aes_gcm_view.unseal(key.as_slice()).is_ok()); + + let view = barrier_view::BarrierView::new(Arc::new(aes_gcm_view), "test"); + + //test salt + let salt = Salt::new(Some(view.as_storage()), None); + assert!(salt.is_ok()); + + let salt = salt.unwrap(); + assert!(salt.did_generate()); + + let ss = view.get(DEFAULT_LOCATION); + assert!(ss.is_ok()); + assert!(ss.unwrap().is_some()); + + let salt2 = Salt::new(Some(view.as_storage()), Some(&salt.config)); + assert!(salt2.is_ok()); + + let salt2 = salt2.unwrap(); + assert!(!salt2.did_generate()); + + assert_eq!(salt.salt, salt2.salt); + + let id = "foobarbaz"; + let sid1 = salt.salt_id(id); + let sid2 = salt2.salt_id(id); + assert!(sid1.is_ok()); + assert!(sid2.is_ok()); + + let sid1 = sid1.unwrap(); + let sid2 = sid2.unwrap(); + assert_eq!(sid1, sid2); + assert_eq!(sid1.len(), salt.config.hash_type.size()*2); + } +} diff --git a/src/utils/sock_addr.rs b/src/utils/sock_addr.rs new file mode 100644 index 0000000..63574ca --- /dev/null +++ b/src/utils/sock_addr.rs @@ -0,0 +1,144 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/sockaddr.go + +use std::{ + fmt, + str::FromStr, +}; + +use as_any::AsAny; +use serde::{Deserialize, Serialize}; + +use super::{ + ip_sock_addr::IpSockAddr, +}; + +use crate::errors::RvError; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SockAddrType { + Unknown = 0x0, + Unix = 0x1, + IPv4 = 0x2, + IPv6 = 0x4, + // IP is the union of IPv4 and IPv6 + IP = 0x6, +} + +pub trait SockAddr: fmt::Display + AsAny { + // contains returns true if the other SockAddr is contained within the receiver + fn contains(&self, other: &dyn SockAddr) -> bool; + + // equal allows for the comparison of two SockAddrs + fn equal(&self, other: &dyn SockAddr) -> bool; + + fn sock_addr_type(&self) -> SockAddrType; +} + +pub struct SockAddrMarshaler { + pub sock_addr: Box, +} + +impl SockAddrMarshaler { + pub fn new(sock_addr: Box) -> Self { + SockAddrMarshaler { sock_addr } + } +} + +impl fmt::Display for SockAddrMarshaler { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.sock_addr) + } +} + +impl fmt::Display for SockAddrType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let type_str = match self { + SockAddrType::IPv4 => "IPv4", + SockAddrType::IPv6 => "IPv6", + SockAddrType::Unix => "Unix", + _ => "Unknown", + }; + write!(f, "{}", type_str) + } +} + +impl FromStr for SockAddrType { + type Err = RvError; + fn from_str(s: &str) -> Result { + match s { + "IPv4" | "ipv4" => Ok(SockAddrType::IPv4), + "IPv6" | "ipv6" => Ok(SockAddrType::IPv6), + "Unix" | "UNIX" | "unix" => Ok(SockAddrType::Unix), + _ => Err(RvError::ErrResponse("invalid sockaddr type".to_string())) + } + } +} + +pub fn new_sock_addr(s: &str) -> Result, RvError> { + let ret = IpSockAddr::new(s)?; + Ok(Box::new(ret)) +} + +#[cfg(test)] +mod test { + use super::{ + *, super::{ + sock_addr::{SockAddrType}, + ip_sock_addr::IpSockAddr, + unix_sock_addr::UnixSockAddr, + }, + }; + + #[test] + fn test_sock_addr() { + let unix_addr1 = UnixSockAddr::new("/tmp/bar").unwrap(); + let unix_addr2 = UnixSockAddr::new("/tmp/bar").unwrap(); + let unix_addr3 = UnixSockAddr::new("/tmp/foo").unwrap(); + let ip_addr1 = IpSockAddr::new("1.1.1.1").unwrap(); + let ip_addr2 = IpSockAddr::new("1.1.1.1").unwrap(); + let ip_addr3 = IpSockAddr::new("2.2.2.2").unwrap(); + let ip_addr4 = IpSockAddr::new("333.333.333.333"); + let ip_addr5 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let ip_addr6 = IpSockAddr::new("1.1.1.1:80").unwrap(); + let ip_addr7 = IpSockAddr::new("1.1.1.1:8080").unwrap(); + let ip_addr8 = IpSockAddr::new("2.2.2.2:80").unwrap(); + let ip_addr9 = IpSockAddr::new("192.168.0.0/16").unwrap(); + let ip_addr10 = IpSockAddr::new("192.168.0.0/24").unwrap(); + let ip_addr11 = IpSockAddr::new("192.168.0.1").unwrap(); + let ip_addr12 = IpSockAddr::new("192.168.1.1").unwrap(); + + assert!(unix_addr1.contains(&unix_addr2)); + assert!(unix_addr1.equal(&unix_addr2)); + assert!(!unix_addr1.contains(&unix_addr3)); + assert!(!unix_addr1.equal(&unix_addr3)); + assert_ne!(unix_addr1.sock_addr_type(), ip_addr1.sock_addr_type()); + + assert!(ip_addr4.is_err()); + assert!(ip_addr1.contains(&ip_addr2)); + assert!(ip_addr1.equal(&ip_addr2)); + assert!(!ip_addr1.contains(&ip_addr3)); + assert!(!ip_addr1.equal(&ip_addr3)); + assert_eq!(ip_addr1.sock_addr_type(), SockAddrType::IPv4); + assert_eq!(ip_addr1.sock_addr_type(), ip_addr2.sock_addr_type()); + assert_ne!(ip_addr1.sock_addr_type(), unix_addr2.sock_addr_type()); + assert!(ip_addr5.contains(&ip_addr6)); + assert!(ip_addr5.equal(&ip_addr6)); + assert!(!ip_addr5.equal(&ip_addr7)); + assert!(!ip_addr5.equal(&ip_addr8)); + assert!(ip_addr9.contains(&ip_addr10)); + assert!(ip_addr9.contains(&ip_addr11)); + assert!(ip_addr9.contains(&ip_addr12)); + assert!(!ip_addr9.contains(&ip_addr1)); + assert!(ip_addr10.contains(&ip_addr9)); + assert!(ip_addr10.contains(&ip_addr11)); + assert!(!ip_addr10.contains(&ip_addr12)); + assert!(!ip_addr9.equal(&ip_addr10)); + assert!(!ip_addr9.equal(&ip_addr11)); + + assert!(!ip_addr1.contains(&unix_addr1)); + assert!(!ip_addr1.equal(&unix_addr1)); + assert!(!unix_addr1.contains(&ip_addr1)); + assert!(!unix_addr1.equal(&ip_addr1)); + } +} diff --git a/src/utils/unix_sock_addr.rs b/src/utils/unix_sock_addr.rs new file mode 100644 index 0000000..c7cc8ca --- /dev/null +++ b/src/utils/unix_sock_addr.rs @@ -0,0 +1,73 @@ +//! This module is a Rust replica of +//! https://github.com/hashicorp/go-sockaddr/blob/master/unixsock.go + +use std::fmt; +use as_any::Downcast; +use serde::{Deserialize, Serialize}; + +use super::{ + sock_addr::{SockAddr, SockAddrType}, +}; + +use crate::errors::RvError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnixSockAddr { + pub path: String, +} + +impl UnixSockAddr { + pub fn new(s: &str) -> Result { + Ok(Self { + path: s.to_string(), + }) + } +} + +impl SockAddr for UnixSockAddr { + fn contains(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn equal(&self, other: &dyn SockAddr) -> bool { + if let Some(unix_sock) = other.downcast_ref::() { + return self.path == unix_sock.path; + } + + false + } + + fn sock_addr_type(&self) -> SockAddrType { + SockAddrType::Unix + } +} + +impl fmt::Display for UnixSockAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path) + } +} + +#[cfg(test)] +mod test { + use super::{ + *, super::sock_addr::{SockAddrType}, + }; + + #[test] + fn test_unix_sock_addr() { + let addr1 = UnixSockAddr::new("/tmp/bar").unwrap(); + let addr2 = UnixSockAddr::new("/tmp/bar").unwrap(); + let addr3 = UnixSockAddr::new("/tmp/foo").unwrap(); + + assert!(addr1.contains(&addr2)); + assert!(addr1.equal(&addr2)); + assert!(!addr1.contains(&addr3)); + assert!(!addr1.equal(&addr3)); + assert_eq!(addr1.sock_addr_type(), SockAddrType::Unix); + } +}