Skip to content

Commit

Permalink
Merge pull request #1 from EphyraSoftware/add-windows-support
Browse files Browse the repository at this point in the history
Add Windows support
  • Loading branch information
ThetaSinner authored Nov 2, 2024
2 parents ac08483 + d112455 commit 3d74c41
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 15 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ sysinfo = { version = "0.32.0", optional = true }
[target.'cfg(target_os = "linux")'.dependencies]
procfs = "0.17"

[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.58", features = ["Win32_Networking", "Win32_Networking_WinSock", "Win32_NetworkManagement_IpHelper"] }

[dev-dependencies]
retry = "2.0.0"
tokio = { version = "1", features = ["time", "rt", "macros"] }
Expand Down
4 changes: 2 additions & 2 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))]
pub(crate) trait MaybeHasPid {
fn get_pid(&self) -> Option<crate::Pid>;
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))]
pub(crate) fn resolve_pid(maybe_has_pid: &dyn MaybeHasPid) -> crate::ProcCtlResult<crate::Pid> {
match &maybe_has_pid.get_pid() {
Some(pid) => Ok(*pid),
Expand Down
5 changes: 5 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ pub enum ProcCtlError {
#[error("process error")]
ProcessError(#[from] procfs::ProcError),

/// An error occurred while searching process information
#[cfg(target_os = "windows")]
#[error("process error")]
ProcessError(String),

/// The user made an error using the API, a more specific error message will be provided
#[error("configuration error {0}")]
ConfigurationError(String),
Expand Down
150 changes: 147 additions & 3 deletions src/port_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ impl PortQuery {

/// Execute the query
pub fn execute(&self) -> ProcCtlResult<Vec<ProtocolPort>> {
#[cfg(target_os = "linux")]
#[cfg(any(target_os = "linux", target_os = "windows"))]
let ports = list_ports_for_pid(self, crate::common::resolve_pid(self)?)?;
#[cfg(not(target_os = "linux"))]
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
let ports = Vec::with_capacity(0);

if let Some(num) = &self.min_num_ports {
Expand Down Expand Up @@ -180,7 +180,151 @@ fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<Protocol
Ok(out)
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(target_os = "windows")]
fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
let mut out = Vec::new();

if query.tcp_addresses {
if query.ipv4_addresses {
let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET)?;
let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID = unsafe {
&mut *(table.as_mut_ptr()
as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID)
};

for i in 0..table.dwNumEntries as usize {
let row = unsafe { &*table.table.as_mut_ptr().add(i) };
if row.dwOwningPid == pid {
out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
}
}
}
if query.ipv6_addresses {
let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID = unsafe {
&mut *(table.as_mut_ptr()
as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID)
};

for i in 0..table.dwNumEntries as usize {
let row = unsafe { &*table.table.as_mut_ptr().add(i) };
if row.dwOwningPid == pid {
out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
}
}
}
}
if query.udp_addresses {
if query.ipv4_addresses {
let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET)?;
let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID = unsafe {
&mut *(table.as_mut_ptr()
as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID)
};

for i in 0..table.dwNumEntries as usize {
let row = unsafe { &*table.table.as_mut_ptr().add(i) };
if row.dwOwningPid == pid {
out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
}
}
}
if query.ipv6_addresses {
let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID = unsafe {
&mut *(table.as_mut_ptr()
as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID)
};

for i in 0..table.dwNumEntries as usize {
let row = unsafe { &*table.table.as_mut_ptr().add(i) };
if row.dwOwningPid == pid {
out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
}
}
}
}

Ok(out)
}

#[cfg(target_os = "windows")]
fn load_tcp_table(
family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
) -> ProcCtlResult<Vec<u8>> {
let mut table = Vec::<u8>::with_capacity(0);
let mut table_size: u32 = 0;
for _ in 0..3 {
let err_code = unsafe {
windows::Win32::Foundation::WIN32_ERROR(
windows::Win32::NetworkManagement::IpHelper::GetExtendedTcpTable(
Some(table.as_mut_ptr() as *mut _),
&mut table_size,
false,
family.0 as u32,
windows::Win32::NetworkManagement::IpHelper::TCP_TABLE_OWNER_PID_ALL,
0,
),
)
};

if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
table.resize(table_size as usize, 0);
continue;
} else if err_code != windows::Win32::Foundation::NO_ERROR {
return Err(ProcCtlError::ProcessError(format!(
"Failed to get TCP table: {:?}",
err_code
)));
}

return Ok(table);
}

Err(ProcCtlError::ProcessError(
"Failed to get TCP table".to_string(),
))
}

#[cfg(target_os = "windows")]
fn load_udp_table(
family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
) -> ProcCtlResult<Vec<u8>> {
let mut table = Vec::<u8>::with_capacity(0);
let mut table_size: u32 = 0;
for _ in 0..3 {
let err_code = unsafe {
windows::Win32::Foundation::WIN32_ERROR(
windows::Win32::NetworkManagement::IpHelper::GetExtendedUdpTable(
Some(table.as_mut_ptr() as *mut _),
&mut table_size,
false,
family.0 as u32,
windows::Win32::NetworkManagement::IpHelper::UDP_TABLE_OWNER_PID,
0,
),
)
};

if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
table.resize(table_size as usize, 0);
continue;
} else if err_code != windows::Win32::Foundation::NO_ERROR {
return Err(ProcCtlError::ProcessError(format!(
"Failed to get UDP table: {:?}",
err_code
)));
}

return Ok(table);
}

Err(ProcCtlError::ProcessError(
"Failed to get UDP table".to_string(),
))
}

#[cfg(any(target_os = "linux", target_os = "windows", feature = "proc"))]
impl crate::common::MaybeHasPid for PortQuery {
fn get_pid(&self) -> Option<Pid> {
self.process_id
Expand Down
23 changes: 13 additions & 10 deletions tests/lib_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
fn create_command_for_sample(name: &str) -> std::process::Command {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("target")
Expand All @@ -17,24 +17,24 @@ fn create_command_for_sample(name: &str) -> std::process::Command {
std::process::Command::new(path)
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
struct DropChild(std::process::Child);

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
impl DropChild {
fn spawn(mut cmd: std::process::Command) -> Self {
DropChild(cmd.spawn().expect("Failed to spawn child process"))
}
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
impl Drop for DropChild {
fn drop(&mut self) {
self.0.kill().expect("Failed to kill child process");
}
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
impl std::ops::Deref for DropChild {
type Target = std::process::Child;

Expand All @@ -43,14 +43,14 @@ impl std::ops::Deref for DropChild {
}
}

#[cfg(any(target_os = "linux", feature = "proc"))]
#[cfg(any(feature = "proc", target_os = "linux", target_os = "windows"))]
impl std::ops::DerefMut for DropChild {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

#[cfg(target_os = "linux")]
#[cfg(any(target_os = "linux", target_os = "windows"))]
#[test]
fn port_query() {
use retry::delay::Fixed;
Expand All @@ -71,7 +71,7 @@ fn port_query() {
assert_eq!(1, ports.len());
}

#[cfg(target_os = "linux")]
#[cfg(any(target_os = "linux", target_os = "windows"))]
#[test]
fn port_query_which_expects_too_many_ports() {
use retry::delay::Fixed;
Expand All @@ -93,7 +93,10 @@ fn port_query_which_expects_too_many_ports() {
result.expect_err("Should have had an error about too few ports");
}

#[cfg(all(feature = "resilience", target_os = "linux"))]
#[cfg(all(
feature = "resilience",
any(target_os = "linux", target_os = "windows")
))]
#[test]
fn port_query_with_sync_retry() {
use std::time::Duration;
Expand All @@ -116,7 +119,7 @@ fn port_query_with_sync_retry() {
assert_eq!(1, ports.len());
}

#[cfg(all(feature = "async", target_os = "linux"))]
#[cfg(all(feature = "async", any(target_os = "linux", target_os = "windows")))]
#[tokio::test]
async fn port_query_with_async_retry() {
use std::time::Duration;
Expand Down

0 comments on commit 3d74c41

Please sign in to comment.