From d11245581a3925aedb7c12ac1f31d392237663d0 Mon Sep 17 00:00:00 2001 From: ThetaSinner Date: Sat, 2 Nov 2024 03:51:04 +0000 Subject: [PATCH] Add Windows support --- Cargo.toml | 3 + src/common.rs | 4 +- src/error.rs | 5 ++ src/port_query.rs | 150 +++++++++++++++++++++++++++++++++++++++++++++- tests/lib_test.rs | 23 +++---- 5 files changed, 170 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b124f9a..6fa4393 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,9 @@ sysinfo = { version = "0.32.0", optional = true } [target.'cfg(target_os = "linux")'.dependencies] procfs = "0.17" +[target.'cfg(target_os = "windows")'.dependencies] +windows = { version = "0.58", features = ["Win32_Networking", "Win32_Networking_WinSock", "Win32_NetworkManagement_IpHelper"] } + [dev-dependencies] retry = "2.0.0" tokio = { version = "1", features = ["time", "rt", "macros"] } diff --git a/src/common.rs b/src/common.rs index 3f42fe6..a3f9c72 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,9 +1,9 @@ -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))] pub(crate) trait MaybeHasPid { fn get_pid(&self) -> Option; } -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))] pub(crate) fn resolve_pid(maybe_has_pid: &dyn MaybeHasPid) -> crate::ProcCtlResult { match &maybe_has_pid.get_pid() { Some(pid) => Ok(*pid), diff --git a/src/error.rs b/src/error.rs index 37f88d5..fe73780 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,11 @@ pub enum ProcCtlError { #[error("process error")] ProcessError(#[from] procfs::ProcError), + /// An error occurred while searching process information + #[cfg(target_os = "windows")] + #[error("process error")] + ProcessError(String), + /// The user made an error using the API, a more specific error message will be provided #[error("configuration error {0}")] ConfigurationError(String), diff --git a/src/port_query.rs b/src/port_query.rs index 970f7ea..fda02e5 100644 --- a/src/port_query.rs +++ b/src/port_query.rs @@ -77,9 +77,9 @@ impl PortQuery { /// Execute the query pub fn execute(&self) -> ProcCtlResult> { - #[cfg(target_os = "linux")] + #[cfg(any(target_os = "linux", target_os = "windows"))] let ports = list_ports_for_pid(self, crate::common::resolve_pid(self)?)?; - #[cfg(not(target_os = "linux"))] + #[cfg(not(any(target_os = "linux", target_os = "windows")))] let ports = Vec::with_capacity(0); if let Some(num) = &self.min_num_ports { @@ -180,7 +180,151 @@ fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult ProcCtlResult> { + let mut out = Vec::new(); + + if query.tcp_addresses { + if query.ipv4_addresses { + let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET)?; + let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID = unsafe { + &mut *(table.as_mut_ptr() + as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID) + }; + + for i in 0..table.dwNumEntries as usize { + let row = unsafe { &*table.table.as_mut_ptr().add(i) }; + if row.dwOwningPid == pid { + out.push(ProtocolPort::Tcp(row.dwLocalPort as u16)); + } + } + } + if query.ipv6_addresses { + let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET6)?; + let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID = unsafe { + &mut *(table.as_mut_ptr() + as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID) + }; + + for i in 0..table.dwNumEntries as usize { + let row = unsafe { &*table.table.as_mut_ptr().add(i) }; + if row.dwOwningPid == pid { + out.push(ProtocolPort::Tcp(row.dwLocalPort as u16)); + } + } + } + } + if query.udp_addresses { + if query.ipv4_addresses { + let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET)?; + let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID = unsafe { + &mut *(table.as_mut_ptr() + as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID) + }; + + for i in 0..table.dwNumEntries as usize { + let row = unsafe { &*table.table.as_mut_ptr().add(i) }; + if row.dwOwningPid == pid { + out.push(ProtocolPort::Tcp(row.dwLocalPort as u16)); + } + } + } + if query.ipv6_addresses { + let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET6)?; + let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID = unsafe { + &mut *(table.as_mut_ptr() + as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID) + }; + + for i in 0..table.dwNumEntries as usize { + let row = unsafe { &*table.table.as_mut_ptr().add(i) }; + if row.dwOwningPid == pid { + out.push(ProtocolPort::Tcp(row.dwLocalPort as u16)); + } + } + } + } + + Ok(out) +} + +#[cfg(target_os = "windows")] +fn load_tcp_table( + family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY, +) -> ProcCtlResult> { + let mut table = Vec::::with_capacity(0); + let mut table_size: u32 = 0; + for _ in 0..3 { + let err_code = unsafe { + windows::Win32::Foundation::WIN32_ERROR( + windows::Win32::NetworkManagement::IpHelper::GetExtendedTcpTable( + Some(table.as_mut_ptr() as *mut _), + &mut table_size, + false, + family.0 as u32, + windows::Win32::NetworkManagement::IpHelper::TCP_TABLE_OWNER_PID_ALL, + 0, + ), + ) + }; + + if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER { + table.resize(table_size as usize, 0); + continue; + } else if err_code != windows::Win32::Foundation::NO_ERROR { + return Err(ProcCtlError::ProcessError(format!( + "Failed to get TCP table: {:?}", + err_code + ))); + } + + return Ok(table); + } + + Err(ProcCtlError::ProcessError( + "Failed to get TCP table".to_string(), + )) +} + +#[cfg(target_os = "windows")] +fn load_udp_table( + family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY, +) -> ProcCtlResult> { + let mut table = Vec::::with_capacity(0); + let mut table_size: u32 = 0; + for _ in 0..3 { + let err_code = unsafe { + windows::Win32::Foundation::WIN32_ERROR( + windows::Win32::NetworkManagement::IpHelper::GetExtendedUdpTable( + Some(table.as_mut_ptr() as *mut _), + &mut table_size, + false, + family.0 as u32, + windows::Win32::NetworkManagement::IpHelper::UDP_TABLE_OWNER_PID, + 0, + ), + ) + }; + + if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER { + table.resize(table_size as usize, 0); + continue; + } else if err_code != windows::Win32::Foundation::NO_ERROR { + return Err(ProcCtlError::ProcessError(format!( + "Failed to get UDP table: {:?}", + err_code + ))); + } + + return Ok(table); + } + + Err(ProcCtlError::ProcessError( + "Failed to get UDP table".to_string(), + )) +} + +#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))] impl crate::common::MaybeHasPid for PortQuery { fn get_pid(&self) -> Option { self.process_id diff --git a/tests/lib_test.rs b/tests/lib_test.rs index 5a2bd51..b764114 100644 --- a/tests/lib_test.rs +++ b/tests/lib_test.rs @@ -1,4 +1,4 @@ -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] fn create_command_for_sample(name: &str) -> std::process::Command { let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) .join("target") @@ -17,24 +17,24 @@ fn create_command_for_sample(name: &str) -> std::process::Command { std::process::Command::new(path) } -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] struct DropChild(std::process::Child); -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] impl DropChild { fn spawn(mut cmd: std::process::Command) -> Self { DropChild(cmd.spawn().expect("Failed to spawn child process")) } } -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] impl Drop for DropChild { fn drop(&mut self) { self.0.kill().expect("Failed to kill child process"); } } -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] impl std::ops::Deref for DropChild { type Target = std::process::Child; @@ -43,14 +43,14 @@ impl std::ops::Deref for DropChild { } } -#[cfg(any(target_os = "linux", feature = "proc"))] +#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))] impl std::ops::DerefMut for DropChild { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "windows"))] #[test] fn port_query() { use retry::delay::Fixed; @@ -71,7 +71,7 @@ fn port_query() { assert_eq!(1, ports.len()); } -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "windows"))] #[test] fn port_query_which_expects_too_many_ports() { use retry::delay::Fixed; @@ -93,7 +93,10 @@ fn port_query_which_expects_too_many_ports() { result.expect_err("Should have had an error about too few ports"); } -#[cfg(all(feature = "resilience", target_os = "linux"))] +#[cfg(all( + feature = "resilience", + any(target_os = "linux", target_os = "windows") +))] #[test] fn port_query_with_sync_retry() { use std::time::Duration; @@ -116,7 +119,7 @@ fn port_query_with_sync_retry() { assert_eq!(1, ports.len()); } -#[cfg(all(feature = "async", target_os = "linux"))] +#[cfg(all(feature = "async", any(target_os = "linux", target_os = "windows")))] #[tokio::test] async fn port_query_with_async_retry() { use std::time::Duration;