diff --git a/Cargo.toml b/Cargo.toml index 8ef85ebc..024ee548 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ include = [ [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] -targets = ["aarch64-apple-ios", "aarch64-linux-android", "x86_64-apple-darwin", "x86_64-unknown-fuchsia", "x86_64-pc-windows-msvc", "x86_64-pc-solaris", "x86_64-unknown-freebsd", "x86_64-unknown-illumos", "x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl", "x86_64-unknown-netbsd", "x86_64-unknown-redox", "armv7-linux-androideabi", "i686-linux-android"] +targets = ["aarch64-apple-ios", "aarch64-linux-android", "x86_64-apple-darwin", "x86_64-unknown-fuchsia", "x86_64-pc-windows-msvc", "x86_64-pc-windows-gnu", "x86_64-pc-solaris", "x86_64-unknown-freebsd", "x86_64-unknown-illumos", "x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl", "x86_64-unknown-netbsd", "x86_64-unknown-redox", "armv7-linux-androideabi", "i686-linux-android"] [package.metadata.playground] features = ["all"] diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 11f2b7b0..729be055 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -20,14 +20,16 @@ use std::time::{Duration, Instant}; use std::{process, ptr, slice}; use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; -#[cfg(feature = "all")] -use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW; use windows_sys::Win32::Networking::WinSock::{ self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS, SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW, WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED, }; +#[cfg(feature = "all")] +use windows_sys::Win32::Networking::WinSock::{ + IP6T_SO_ORIGINAL_DST, SOL_IP, SO_ORIGINAL_DST, SO_PROTOCOL_INFOW, +}; use windows_sys::Win32::System::Threading::INFINITE; use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type}; @@ -927,6 +929,52 @@ impl crate::Socket { } } + /// Get the value for the `SO_ORIGINAL_DST` option on this socket. + /// + #[cfg(feature = "all")] + #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))] + pub fn original_dst(&self) -> io::Result { + unsafe { + SockAddr::try_init(|storage, len| { + syscall!( + getsockopt( + self.as_raw(), + SOL_IP as i32, + SO_ORIGINAL_DST as i32, + storage.cast(), + len, + ), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + .map(|(_, addr)| addr) + } + + /// Get the value for the `IP6T_SO_ORIGINAL_DST` option on this socket. + /// + #[cfg(feature = "all")] + #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))] + pub fn original_dst_ipv6(&self) -> io::Result { + unsafe { + SockAddr::try_init(|storage, len| { + syscall!( + getsockopt( + self.as_raw(), + SOL_IP as i32, + IP6T_SO_ORIGINAL_DST as i32, + storage.cast(), + len, + ), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + .map(|(_, addr)| addr) + } + /// Returns the [`Protocol`] of this socket by checking the `SO_PROTOCOL_INFOW` /// option on this socket. /// diff --git a/tests/socket.rs b/tests/socket.rs index 2300f0ed..69df1186 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -42,6 +42,9 @@ use std::num::NonZeroUsize; use std::os::unix::io::AsRawFd; #[cfg(windows)] use std::os::windows::io::AsRawSocket; +#[cfg(windows)] +use windows_sys::Win32::Networking::WinSock::WSAEINVAL; + #[cfg(unix)] use std::path::Path; use std::str; @@ -1607,7 +1610,27 @@ fn original_dst() { } #[test] -#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))] +#[cfg(all(feature = "all", target_os = "windows"))] +fn original_dst() { + let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap(); + match socket.original_dst() { + Ok(_) => panic!("original_dst on non-redirected socket should fail"), + Err(err) => assert_eq!(err.raw_os_error(), Some(WSAEINVAL)), + } + + // Not supported on IPv6 socket. + let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap(); + match socket.original_dst_ipv6() { + Ok(_) => panic!("original_dst_ipv6 on non-redirected socket should fail"), + Err(err) => assert_eq!(err.raw_os_error(), Some(WSAEINVAL)), + } +} + +#[test] +#[cfg(all( + feature = "all", + any(target_os = "android", target_os = "fuchsia", target_os = "linux") +))] fn original_dst_ipv6() { let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap(); match socket.original_dst_ipv6() { @@ -1623,6 +1646,23 @@ fn original_dst_ipv6() { } } +#[test] +#[cfg(all(feature = "all", target_os = "windows"))] +fn original_dst_ipv6() { + let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap(); + match socket.original_dst_ipv6() { + Ok(_) => panic!("original_dst_ipv6 on non-redirected socket should fail"), + Err(err) => assert_eq!(err.raw_os_error(), Some(WSAEINVAL)), + } + + // Not supported on IPv4 socket. + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + match socket.original_dst_ipv6() { + Ok(_) => panic!("original_dst_ipv6 on non-redirected socket should fail"), + Err(err) => assert_eq!(err.raw_os_error(), Some(WSAEINVAL)), + } +} + #[test] #[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))] fn tcp_congestion() {