From 9549c832792f375b6fd06f37641b318b7f95f75e Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Thu, 28 Sep 2023 23:10:58 -0700 Subject: [PATCH] Fixes. --- src/backend/libc/net/sockopt.rs | 17 +++++-- src/backend/linux_raw/net/sockopt.rs | 16 ++++--- src/net/sockopt.rs | 10 +++-- tests/net/sockopt.rs | 66 ++++++++++++++++++---------- 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/src/backend/libc/net/sockopt.rs b/src/backend/libc/net/sockopt.rs index de884eb6d..d960790dc 100644 --- a/src/backend/libc/net/sockopt.rs +++ b/src/backend/libc/net/sockopt.rs @@ -4,6 +4,7 @@ use super::ext::{in6_addr_new, in_addr_new}; use crate::backend::c; use crate::backend::conv::{borrowed_fd, ret}; use crate::fd::BorrowedFd; +#[cfg(feature = "alloc")] #[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] use crate::ffi::CStr; use crate::io; @@ -43,6 +44,12 @@ use crate::net::{Ipv4Addr, Ipv6Addr, SocketType}; #[cfg(any(linux_kernel, target_os = "fuchsia"))] use crate::net::{SocketAddrAny, SocketAddrStorage, SocketAddrV4}; use crate::utils::as_mut_ptr; +#[cfg(feature = "alloc")] +#[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] +use alloc::borrow::ToOwned; +#[cfg(feature = "alloc")] +#[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] +use alloc::string::String; #[cfg(apple)] use c::TCP_KEEPALIVE as TCP_KEEPIDLE; #[cfg(not(any(apple, target_os = "openbsd", target_os = "haiku", target_os = "nto")))] @@ -411,7 +418,7 @@ pub(crate) fn get_socket_protocol(fd: BorrowedFd<'_>) -> io::Result) -> io::Result { getsockopt(fd, c::SOL_SOCKET, c::SO_COOKIE) @@ -770,13 +777,13 @@ pub(crate) fn get_ipv6_original_dst(fd: BorrowedFd<'_>) -> io::Result, value: u32) -> io::Result<()> { setsockopt(fd, c::IPPROTO_IPV6, c::IPV6_TCLASS, value) } -#[cfg(not(solarish))] +#[cfg(not(any(solarish, target_os = "haiku")))] #[inline] pub(crate) fn get_ipv6_tclass(fd: BorrowedFd<'_>) -> io::Result { getsockopt(fd, c::IPPROTO_IPV6, c::IPV6_TCLASS) @@ -865,6 +872,7 @@ pub(crate) fn set_tcp_congestion(fd: BorrowedFd<'_>, value: &str) -> io::Result< setsockopt_raw(fd, level, optname, value.as_ptr(), optlen) } +#[cfg(feature = "alloc")] #[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] #[inline] pub(crate) fn get_tcp_congestion(fd: BorrowedFd<'_>) -> io::Result { @@ -877,8 +885,9 @@ pub(crate) fn get_tcp_congestion(fd: BorrowedFd<'_>) -> io::Result { unsafe { let value = value.assume_init(); let slice: &[u8] = core::mem::transmute(&value[..optlen as usize]); + assert!(slice.iter().any(|b| *b == b'\0')); Ok( - core::str::from_utf8(CStr::from_bytes_until_nul(slice).unwrap().to_bytes()) + core::str::from_utf8(CStr::from_ptr(slice.as_ptr().cast()).to_bytes()) .unwrap() .to_owned(), ) diff --git a/src/backend/linux_raw/net/sockopt.rs b/src/backend/linux_raw/net/sockopt.rs index 6eb57fd75..00aa83faf 100644 --- a/src/backend/linux_raw/net/sockopt.rs +++ b/src/backend/linux_raw/net/sockopt.rs @@ -8,6 +8,7 @@ use crate::backend::c; use crate::backend::conv::{by_mut, c_uint, ret, socklen_t}; use crate::fd::BorrowedFd; +#[cfg(feature = "alloc")] use crate::ffi::CStr; use crate::io; use crate::net::sockopt::Timeout; @@ -15,6 +16,10 @@ use crate::net::{ AddressFamily, Ipv4Addr, Ipv6Addr, Protocol, RawProtocol, SocketAddrAny, SocketAddrStorage, SocketAddrV4, SocketAddrV6, SocketType, }; +#[cfg(feature = "alloc")] +use alloc::borrow::ToOwned; +#[cfg(feature = "alloc")] +use alloc::string::String; use core::mem::MaybeUninit; use core::time::Duration; use linux_raw_sys::general::{__kernel_old_timeval, __kernel_sock_timeval}; @@ -66,7 +71,6 @@ fn getsockopt_raw( } #[cfg(target_arch = "x86")] unsafe { - let mut value = MaybeUninit::::uninit(); ret(syscall!( __NR_socketcall, x86_sys(SYS_GETSOCKOPT), @@ -74,8 +78,8 @@ fn getsockopt_raw( fd.into(), c_uint(level), c_uint(optname), - (&mut value).into(), - by_mut(&mut optlen), + value.into(), + by_mut(optlen), ]) )) } @@ -119,7 +123,7 @@ fn setsockopt_raw( fd.into(), c_uint(level), c_uint(optname), - ptr, + ptr.into(), socklen_t(optlen), ]) )) @@ -749,6 +753,7 @@ pub(crate) fn set_tcp_congestion(fd: BorrowedFd<'_>, value: &str) -> io::Result< setsockopt_raw(fd, level, optname, value.as_ptr(), optlen) } +#[cfg(feature = "alloc")] #[inline] pub(crate) fn get_tcp_congestion(fd: BorrowedFd<'_>) -> io::Result { let level = c::IPPROTO_TCP; @@ -760,8 +765,9 @@ pub(crate) fn get_tcp_congestion(fd: BorrowedFd<'_>) -> io::Result { unsafe { let value = value.assume_init(); let slice: &[u8] = core::mem::transmute(&value[..optlen as usize]); + assert!(slice.iter().any(|b| *b == b'\0')); Ok( - core::str::from_utf8(CStr::from_bytes_until_nul(slice).unwrap().to_bytes()) + core::str::from_utf8(CStr::from_ptr(slice.as_ptr().cast()).to_bytes()) .unwrap() .to_owned(), ) diff --git a/src/net/sockopt.rs b/src/net/sockopt.rs index f13729452..58866196b 100644 --- a/src/net/sockopt.rs +++ b/src/net/sockopt.rs @@ -170,6 +170,9 @@ use crate::net::SocketAddrV4; use crate::net::SocketAddrV6; use crate::net::{Ipv4Addr, Ipv6Addr, SocketType}; use crate::{backend, io}; +#[cfg(feature = "alloc")] +#[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] +use alloc::string::String; use backend::c; use backend::fd::AsFd; use core::time::Duration; @@ -547,7 +550,7 @@ pub fn get_socket_protocol(fd: Fd) -> io::Result> { /// See the [module-level documentation] for more. /// /// [module-level documentation]: self#references-for-get_socket_-and-set_socket_-functions -#[cfg(linux_like)] +#[cfg(target_os = "linux")] #[inline] #[doc(alias = "SO_COOKIE")] pub fn get_socket_cookie(fd: Fd) -> io::Result { @@ -1089,7 +1092,7 @@ pub fn get_ipv6_original_dst(fd: Fd) -> io::Result { /// See the [module-level documentation] for more. /// /// [module-level documentation]: self#references-for-get_ipv6_-and-set_ipv6_-functions -#[cfg(not(solarish))] +#[cfg(not(any(solarish, target_os = "haiku")))] #[inline] #[doc(alias = "IPV6_TCLASS")] pub fn set_ipv6_tclass(fd: Fd, value: u32) -> io::Result<()> { @@ -1101,7 +1104,7 @@ pub fn set_ipv6_tclass(fd: Fd, value: u32) -> io::Result<()> { /// See the [module-level documentation] for more. /// /// [module-level documentation]: self#references-for-get_ipv6_-and-set_ipv6_-functions -#[cfg(not(solarish))] +#[cfg(not(any(solarish, target_os = "haiku")))] #[inline] #[doc(alias = "IPV6_TCLASS")] pub fn get_ipv6_tclass(fd: Fd) -> io::Result { @@ -1271,6 +1274,7 @@ pub fn set_tcp_congestion(fd: Fd, value: &str) -> io::Result<()> { /// See the [module-level documentation] for more. /// /// [module-level documentation]: self#references-for-get_tcp_-and-set_tcp_-functions +#[cfg(feature = "alloc")] #[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] #[inline] #[doc(alias = "TCP_CONGESTION")] diff --git a/tests/net/sockopt.rs b/tests/net/sockopt.rs index 5a71d84ae..95e55bf70 100644 --- a/tests/net/sockopt.rs +++ b/tests/net/sockopt.rs @@ -1,7 +1,15 @@ use rustix::fd::OwnedFd; use rustix::io; -use rustix::net::sockopt; -use rustix::net::{ipproto, AddressFamily, SocketType}; +#[cfg(any( + linux_like, + target_os = "freebsd", + target_os = "fuchsia", + target_os = "openbsd", + target_os = "redox", + target_env = "newlib" +))] +use rustix::net::ipproto; +use rustix::net::{sockopt, AddressFamily, SocketType}; use std::time::Duration; // Test `socket` socket options. @@ -11,10 +19,20 @@ fn test_sockopts_socket(s: &OwnedFd) { .unwrap() .is_none()); assert_eq!(sockopt::get_socket_type(&s).unwrap(), SocketType::STREAM); - assert_eq!( - sockopt::get_socket_protocol(&s).unwrap(), - Some(ipproto::TCP) - ); + #[cfg(any( + linux_like, + target_os = "freebsd", + target_os = "fuchsia", + target_os = "openbsd", + target_os = "redox", + target_env = "newlib" + ))] + { + assert_eq!( + sockopt::get_socket_protocol(&s).unwrap(), + Some(ipproto::TCP) + ); + } assert!(!sockopt::get_socket_reuseaddr(&s).unwrap()); #[cfg(not(windows))] assert!(!sockopt::get_socket_broadcast(&s).unwrap()); @@ -137,17 +155,20 @@ fn test_sockopts_socket(s: &OwnedFd) { // Check the initial value of SO_REUSEPORT_LB, set it, and check it. #[cfg(target_os = "freebsd")] { - assert_eq!(!sockopt::get_socket_reuseport_lb(&s).unwrap()); + assert!(!sockopt::get_socket_reuseport_lb(&s).unwrap()); sockopt::set_socket_reuseport_lb(&s, true).unwrap(); - assert_eq!(sockopt::get_socket_reuseport_lb(&s).unwrap()); + assert!(sockopt::get_socket_reuseport_lb(&s).unwrap()); } // Not much we can check with `get_socket_cookie`, but make sure we can // call it and that it returns the same value if called twice. - assert_eq!( - sockopt::get_socket_cookie(&s).unwrap(), - sockopt::get_socket_cookie(&s).unwrap() - ); + #[cfg(target_os = "linux")] + { + assert_eq!( + sockopt::get_socket_cookie(&s).unwrap(), + sockopt::get_socket_cookie(&s).unwrap() + ); + } // Check the initial value of SO_INCOMING_CPU, set it, and check it. #[cfg(target_os = "linux")] @@ -211,6 +232,7 @@ fn test_sockopts_tcp(s: &OwnedFd) { // Check the initial value of TCP_CONGESTION, set it, and check it. #[cfg(any(linux_like, solarish, target_os = "freebsd", target_os = "fuchsia"))] + #[cfg(feature = "alloc")] { assert_eq!(sockopt::get_tcp_congestion(&s).unwrap(), "cubic"); sockopt::set_tcp_congestion(&s, "reno").unwrap(); @@ -313,7 +335,7 @@ fn test_sockopts_ipv4() { // Check that we can query SO_ORIGINAL_DST. #[cfg(any(linux_kernel, target_os = "fuchsia"))] { - assert_eq!(sockopt::get_ip_original_dst(&s), Err(io::Errno::NOPROTOOPT)); + assert_eq!(sockopt::get_ip_original_dst(&s), Err(io::Errno::NOENT)); } test_sockopts_tcp(&s); @@ -343,9 +365,9 @@ fn test_sockopts_ipv6() { assert_ne!(sockopt::get_ipv6_unicast_hops(&s).unwrap(), 0); match sockopt::get_ipv6_multicast_loop(&s) { Ok(multicast_loop) => assert!(multicast_loop), - Err(rustix::io::Errno::OPNOTSUPP) => (), - Err(rustix::io::Errno::INVAL) => (), - Err(rustix::io::Errno::NOPROTOOPT) => (), + Err(io::Errno::OPNOTSUPP) => (), + Err(io::Errno::INVAL) => (), + Err(io::Errno::NOPROTOOPT) => (), Err(err) => Err(err).unwrap(), } assert_ne!(sockopt::get_ipv6_unicast_hops(&s).unwrap(), 0); @@ -355,8 +377,8 @@ fn test_sockopts_ipv6() { #[cfg(not(target_os = "netbsd"))] match sockopt::get_ipv6_multicast_hops(&s) { Ok(hops) => assert_eq!(hops, 0), - Err(rustix::io::Errno::NOPROTOOPT) => (), - Err(rustix::io::Errno::INVAL) => (), + Err(io::Errno::NOPROTOOPT) => (), + Err(io::Errno::INVAL) => (), Err(err) => Err(err).unwrap(), }; @@ -376,9 +398,9 @@ fn test_sockopts_ipv6() { Err(err) => Err(err).unwrap(), } } - Err(rustix::io::Errno::OPNOTSUPP) => (), - Err(rustix::io::Errno::INVAL) => (), - Err(rustix::io::Errno::NOPROTOOPT) => (), + Err(io::Errno::OPNOTSUPP) => (), + Err(io::Errno::INVAL) => (), + Err(io::Errno::NOPROTOOPT) => (), Err(err) => Err(err).unwrap(), } @@ -417,7 +439,7 @@ fn test_sockopts_ipv6() { } // Check the initial value of IPV6_TCLASS, set it, and check it. - #[cfg(not(solarish))] + #[cfg(not(any(solarish, target_os = "haiku")))] { assert_eq!(sockopt::get_ipv6_tclass(&s).unwrap(), 0); sockopt::set_ipv6_tclass(&s, 12).unwrap();