From 77f3f2046582211eac2d913da1b047e3b303d2d9 Mon Sep 17 00:00:00 2001 From: Colin Marc Date: Fri, 18 Oct 2024 17:01:12 +0200 Subject: [PATCH 1/2] Introduce a RawSocketAddr type It wraps SockAddrStorage, but includes a length, for when a sockaddr has to be stored for a longer lifetime. --- src/backend/libc/net/msghdr.rs | 19 +++++++++++ src/backend/libc/net/syscalls.rs | 31 +++++++++++++++-- src/backend/linux_raw/net/msghdr.rs | 18 ++++++++++ src/backend/linux_raw/net/syscalls.rs | 45 +++++++++++++++++++++++-- src/net/mod.rs | 2 +- src/net/send_recv/msg.rs | 37 +++++++++++++++++++++ src/net/socket_addr_any.rs | 48 +++++++++++++++++++++++++++ tests/net/addr.rs | 15 +++++++++ 8 files changed, 209 insertions(+), 6 deletions(-) diff --git a/src/backend/libc/net/msghdr.rs b/src/backend/libc/net/msghdr.rs index d212c65a6..7f351d879 100644 --- a/src/backend/libc/net/msghdr.rs +++ b/src/backend/libc/net/msghdr.rs @@ -150,6 +150,25 @@ pub(crate) fn with_xdp_msghdr( }) } +/// Create a message header with a pre-encoded address. +pub(crate) fn with_raw_msghdr( + addr: &RawSocketAddr, + iov: &[IoSlice<'_>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + f: impl FnOnce(c::msghdr) -> R, +) -> R { + f({ + let mut h = zero_msghdr(); + h.msg_name = addr.as_ptr() as _; + h.msg_namelen = addr.namelen() as _; + h.msg_iov = iov.as_ptr() as _; + h.msg_iovlen = msg_iov_len(iov.len()); + h.msg_control = control.as_control_ptr().cast(); + h.msg_controllen = msg_control_len(control.control_len()); + h + }) +} + /// Create a zero-initialized message header struct value. #[cfg(all(unix, not(target_os = "redox")))] pub(crate) fn zero_msghdr() -> c::msghdr { diff --git a/src/backend/libc/net/syscalls.rs b/src/backend/libc/net/syscalls.rs index 3013f9922..4e812a83f 100644 --- a/src/backend/libc/net/syscalls.rs +++ b/src/backend/libc/net/syscalls.rs @@ -12,7 +12,9 @@ use crate::fd::{BorrowedFd, OwnedFd}; use crate::io; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; -use crate::net::{SocketAddrAny, SocketAddrV4, SocketAddrV6}; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; +use crate::net::{RawSocketAddr, SocketAddrAny, SocketAddrV4, SocketAddrV6}; use crate::utils::as_ptr; use core::mem::{size_of, MaybeUninit}; #[cfg(not(any( @@ -23,7 +25,9 @@ use core::mem::{size_of, MaybeUninit}; target_os = "wasi" )))] use { - super::msghdr::{with_noaddr_msghdr, with_recv_msghdr, with_v4_msghdr, with_v6_msghdr}, + super::msghdr::{ + with_noaddr_msghdr, with_raw_msghdr, with_recv_msghdr, with_v4_msghdr, with_v6_msghdr, + }, crate::io::{IoSlice, IoSliceMut}, crate::net::{RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer}, }; @@ -455,6 +459,29 @@ pub(crate) fn sendmsg_xdp( }) } +#[cfg(not(any( + windows, + target_os = "espidf", + target_os = "redox", + target_os = "vita", + target_os = "wasi" +)))] +pub(crate) fn sendmsg_raw( + sockfd: BorrowedFd<'_>, + addr: &RawSocketAddr, + iov: &[IoSlice<'_>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + msg_flags: SendFlags, +) -> io::Result { + with_raw_msghdr(addr, iov, control, |msghdr| unsafe { + ret_send_recv(c::sendmsg( + borrowed_fd(sockfd), + &msghdr, + bitflags_bits!(msg_flags), + )) + }) +} + #[cfg(not(any( apple, windows, diff --git a/src/backend/linux_raw/net/msghdr.rs b/src/backend/linux_raw/net/msghdr.rs index 3ccce04c9..4d4bb9b98 100644 --- a/src/backend/linux_raw/net/msghdr.rs +++ b/src/backend/linux_raw/net/msghdr.rs @@ -157,6 +157,24 @@ pub(crate) fn with_xdp_msghdr( }) } +/// Create a message header with a pre-encoded address. +pub(crate) fn with_raw_msghdr( + addr: &RawSocketAddr, + iov: &[IoSlice<'_>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + f: impl FnOnce(c::msghdr) -> R, +) -> R { + f(c::msghdr { + msg_name: addr.as_ptr() as _, + msg_namelen: addr.namelen() as _, + msg_iov: iov.as_ptr() as _, + msg_iovlen: msg_iov_len(iov.len()), + msg_control: control.as_control_ptr().cast(), + msg_controllen: msg_control_len(control.control_len()), + msg_flags: 0, + }) +} + /// Create a zero-initialized message header struct value. pub(crate) fn zero_msghdr() -> c::msghdr { c::msghdr { diff --git a/src/backend/linux_raw/net/syscalls.rs b/src/backend/linux_raw/net/syscalls.rs index 4d4427a40..f06456f1e 100644 --- a/src/backend/linux_raw/net/syscalls.rs +++ b/src/backend/linux_raw/net/syscalls.rs @@ -8,7 +8,8 @@ #[cfg(target_os = "linux")] use super::msghdr::with_xdp_msghdr; use super::msghdr::{ - with_noaddr_msghdr, with_recv_msghdr, with_unix_msghdr, with_v4_msghdr, with_v6_msghdr, + with_noaddr_msghdr, with_raw_msghdr, with_recv_msghdr, with_unix_msghdr, with_v4_msghdr, + with_v6_msghdr, }; use super::read_sockaddr::{initialize_family_to_unspec, maybe_read_sockaddr_os, read_sockaddr_os}; use super::send_recv::{RecvFlags, SendFlags}; @@ -25,8 +26,9 @@ use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; use crate::net::{ - AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer, Shutdown, - SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6, SocketFlags, SocketType, + AddressFamily, Protocol, RawSocketAddr, RecvAncillaryBuffer, RecvMsgReturn, + SendAncillaryBuffer, Shutdown, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6, + SocketFlags, SocketType, }; use c::{sockaddr, sockaddr_in, sockaddr_in6, socklen_t}; use core::mem::MaybeUninit; @@ -439,6 +441,43 @@ pub(crate) fn sendmsg_xdp( }) } +#[cfg(not(any( + windows, + target_os = "espidf", + target_os = "redox", + target_os = "vita", + target_os = "wasi" +)))] +#[inline] +pub(crate) fn sendmsg_raw( + sockfd: BorrowedFd<'_>, + addr: &RawSocketAddr, + iov: &[IoSlice<'_>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + msg_flags: SendFlags, +) -> io::Result { + with_raw_msghdr(addr, iov, control, |msghdr| { + #[cfg(not(target_arch = "x86"))] + let result = + unsafe { ret_usize(syscall!(__NR_sendmsg, sockfd, by_ref(&msghdr), msg_flags)) }; + + #[cfg(target_arch = "x86")] + let result = unsafe { + ret_usize(syscall!( + __NR_socketcall, + x86_sys(SYS_SENDMSG), + slice_just_addr::, _>(&[ + sockfd.into(), + by_ref(&msghdr), + msg_flags.into() + ]) + )) + }; + + result + }) +} + #[inline] pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> { #[cfg(not(target_arch = "x86"))] diff --git a/src/net/mod.rs b/src/net/mod.rs index 7ec8bc698..7fd5c20f1 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -25,7 +25,7 @@ pub use crate::maybe_polyfill::net::{ }; pub use send_recv::*; pub use socket::*; -pub use socket_addr_any::{SocketAddrAny, SocketAddrStorage}; +pub use socket_addr_any::{RawSocketAddr, SocketAddrAny, SocketAddrStorage}; #[cfg(not(any(windows, target_os = "wasi")))] pub use socketpair::socketpair; pub use types::*; diff --git a/src/net/send_recv/msg.rs b/src/net/send_recv/msg.rs index 794485d9f..6806ab237 100644 --- a/src/net/send_recv/msg.rs +++ b/src/net/send_recv/msg.rs @@ -5,6 +5,7 @@ use crate::backend::{self, c}; use crate::fd::{AsFd, BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; +use crate::net::RawSocketAddr; #[cfg(linux_kernel)] use crate::net::UCred; @@ -781,6 +782,42 @@ pub fn sendmsg_any( } } +/// `sendmsg(msghdr)`—Sends a message on a socket to a specific address. +/// +/// # References +/// - [POSIX] +/// - [Linux] +/// - [Apple] +/// - [FreeBSD] +/// - [NetBSD] +/// - [OpenBSD] +/// - [DragonFly BSD] +/// - [illumos] +/// +/// [POSIX]: https://pubs.opengroup.org/onlinepubs/9799919799/functions/sendmsg.html +/// [Linux]: https://man7.org/linux/man-pages/man2/sendmsg.2.html +/// [Apple]: https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/sendmsg.2.html +/// [FreeBSD]: https://man.freebsd.org/cgi/man.cgi?query=sendmsg&sektion=2 +/// [NetBSD]: https://man.netbsd.org/sendmsg.2 +/// [OpenBSD]: https://man.openbsd.org/sendmsg.2 +/// [DragonFly BSD]: https://man.dragonflybsd.org/?command=sendmsg§ion=2 +/// [illumos]: https://illumos.org/man/3SOCKET/sendmsg +#[inline] +pub fn sendmsg_raw( + socket: impl AsFd, + addr: Option<&RawSocketAddr>, + iov: &[IoSlice<'_>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + flags: SendFlags, +) -> io::Result { + match addr { + None => backend::net::syscalls::sendmsg(socket.as_fd(), iov, control, flags), + Some(addr) => { + backend::net::syscalls::sendmsg_raw(socket.as_fd(), addr, iov, control, flags) + } + } +} + /// `recvmsg(msghdr)`—Receives a message from a socket. /// /// # References diff --git a/src/net/socket_addr_any.rs b/src/net/socket_addr_any.rs index b43d09667..1882b7d02 100644 --- a/src/net/socket_addr_any.rs +++ b/src/net/socket_addr_any.rs @@ -14,9 +14,11 @@ use crate::net::xdp::SocketAddrXdp; #[cfg(unix)] use crate::net::SocketAddrUnix; use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6}; +use crate::utils::{as_mut_ptr, as_ptr}; use crate::{backend, io}; #[cfg(feature = "std")] use core::fmt; +use core::mem::zeroed; pub use backend::net::addr::SocketAddrStorage; @@ -83,6 +85,23 @@ impl SocketAddrAny { } } + /// Creates a platform-specific encoding of this socket address, + /// and returns it. + pub fn to_raw(&self) -> RawSocketAddr { + let mut raw = RawSocketAddr { + storage: unsafe { zeroed() }, + len: 0, + }; + + raw.len = unsafe { self.write(raw.as_mut_ptr()) }; + raw + } + + /// Reads a platform-specific encoding of a socket address. + pub fn from_raw(raw: RawSocketAddr) -> io::Result { + unsafe { Self::read(raw.as_ptr(), raw.len) } + } + /// Writes a platform-specific encoding of this socket address to /// the memory pointed to by `storage`, and returns the number of /// bytes used. @@ -107,6 +126,35 @@ impl SocketAddrAny { } } +/// A raw sockaddr and its length. +#[repr(C)] +pub struct RawSocketAddr { + pub(crate) storage: SocketAddrStorage, + pub(crate) len: usize, +} + +impl RawSocketAddr { + /// Creates a raw encoded sockaddr from the given address. + pub fn new(addr: impl Into) -> Self { + addr.into().to_raw() + } + + /// Returns a raw pointer to the sockaddr. + pub fn as_ptr(&self) -> *const SocketAddrStorage { + as_ptr(&self.storage) + } + + /// Returns a raw mutable pointer to the sockaddr. + pub fn as_mut_ptr(&mut self) -> *mut SocketAddrStorage { + as_mut_ptr(&mut self.storage) + } + + /// Returns the length of the encoded sockaddr. + pub fn namelen(&self) -> usize { + self.len + } +} + #[cfg(feature = "std")] impl fmt::Debug for SocketAddrAny { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/tests/net/addr.rs b/tests/net/addr.rs index 81b86d74b..40542c101 100644 --- a/tests/net/addr.rs +++ b/tests/net/addr.rs @@ -13,12 +13,22 @@ fn encode_decode() { let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::V4(orig)); + let orig = SocketAddrV4::new(Ipv4Addr::new(2, 3, 5, 6), 33); + let encoded = SocketAddrAny::V4(orig).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::V4(orig)); + let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36); let mut encoded = std::mem::MaybeUninit::::uninit(); let len = SocketAddrAny::V6(orig).write(encoded.as_mut_ptr()); let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::V6(orig)); + let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36); + let encoded = SocketAddrAny::V6(orig).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::V6(orig)); + #[cfg(not(windows))] { let orig = SocketAddrUnix::new("/path/to/socket").unwrap(); @@ -26,6 +36,11 @@ fn encode_decode() { let len = SocketAddrAny::Unix(orig.clone()).write(encoded.as_mut_ptr()); let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::Unix(orig)); + + let orig = SocketAddrUnix::new("/path/to/socket").unwrap(); + let encoded = SocketAddrAny::Unix(orig.clone()).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::Unix(orig)); } } } From e960980a1e2e72f7e4cc321f7af8b881e107b873 Mon Sep 17 00:00:00 2001 From: Colin Marc Date: Wed, 18 Sep 2024 23:15:50 +0200 Subject: [PATCH 2/2] Add support for sendmmsg(2) on linux https://man7.org/linux/man-pages/man2/sendmmsg.2.html Partially addresses #1156. Signed-off-by: Colin Marc --- src/backend/libc/net/msghdr.rs | 4 +- src/backend/libc/net/syscalls.rs | 19 ++++++ src/backend/linux_raw/c.rs | 10 +-- src/backend/linux_raw/net/msghdr.rs | 3 +- src/backend/linux_raw/net/syscalls.rs | 32 ++++++++- src/net/send_recv/msg.rs | 60 +++++++++++++++++ tests/net/v4.rs | 88 +++++++++++++++++++++++++ tests/net/v6.rs | 94 +++++++++++++++++++++++++++ 8 files changed, 301 insertions(+), 9 deletions(-) diff --git a/src/backend/libc/net/msghdr.rs b/src/backend/libc/net/msghdr.rs index 7f351d879..4c8ef8931 100644 --- a/src/backend/libc/net/msghdr.rs +++ b/src/backend/libc/net/msghdr.rs @@ -12,7 +12,9 @@ use crate::backend::net::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6 use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; -use crate::net::{RecvAncillaryBuffer, SendAncillaryBuffer, SocketAddrV4, SocketAddrV6}; +use crate::net::{ + RawSocketAddr, RecvAncillaryBuffer, SendAncillaryBuffer, SocketAddrV4, SocketAddrV6, +}; use crate::utils::as_ptr; use core::mem::{size_of, zeroed, MaybeUninit}; diff --git a/src/backend/libc/net/syscalls.rs b/src/backend/libc/net/syscalls.rs index 4e812a83f..466846a29 100644 --- a/src/backend/libc/net/syscalls.rs +++ b/src/backend/libc/net/syscalls.rs @@ -7,6 +7,8 @@ use super::msghdr::with_xdp_msghdr; #[cfg(target_os = "linux")] use super::write_sockaddr::encode_sockaddr_xdp; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::ret_u32; use crate::backend::conv::{borrowed_fd, ret, ret_owned_fd, ret_send_recv, send_recv_len}; use crate::fd::{BorrowedFd, OwnedFd}; use crate::io; @@ -482,6 +484,23 @@ pub(crate) fn sendmsg_raw( }) } +#[cfg(target_os = "linux")] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + unsafe { + ret_u32(c::sendmmsg( + borrowed_fd(sockfd), + msgs.as_mut_ptr() as _, + msgs.len().try_into().unwrap_or(c::c_uint::MAX), + bitflags_bits!(flags), + )) + .map(|ret| ret as usize) + } +} + #[cfg(not(any( apple, windows, diff --git a/src/backend/linux_raw/c.rs b/src/backend/linux_raw/c.rs index 4035bf945..f075b0abe 100644 --- a/src/backend/linux_raw/c.rs +++ b/src/backend/linux_raw/c.rs @@ -56,12 +56,12 @@ pub(crate) use linux_raw_sys::{ general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK}, if_ether::*, net::{ - linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, __kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage, - cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK, - AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET, - AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC, - AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, + cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, linger, mmsghdr, + msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, + AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, + AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, + AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP, IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS, diff --git a/src/backend/linux_raw/net/msghdr.rs b/src/backend/linux_raw/net/msghdr.rs index 4d4bb9b98..aa0c3efa0 100644 --- a/src/backend/linux_raw/net/msghdr.rs +++ b/src/backend/linux_raw/net/msghdr.rs @@ -9,10 +9,11 @@ use crate::backend::c; #[cfg(target_os = "linux")] use crate::backend::net::write_sockaddr::encode_sockaddr_xdp; use crate::backend::net::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6}; - use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; +#[cfg(target_os = "linux")] +use crate::net::RawSocketAddr; use crate::net::{RecvAncillaryBuffer, SendAncillaryBuffer, SocketAddrV4, SocketAddrV6}; use crate::utils::as_ptr; diff --git a/src/backend/linux_raw/net/syscalls.rs b/src/backend/linux_raw/net/syscalls.rs index f06456f1e..d7dc030c6 100644 --- a/src/backend/linux_raw/net/syscalls.rs +++ b/src/backend/linux_raw/net/syscalls.rs @@ -17,6 +17,8 @@ use super::send_recv::{RecvFlags, SendFlags}; use super::write_sockaddr::encode_sockaddr_xdp; use super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6}; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::slice_mut; use crate::backend::conv::{ by_mut, by_ref, c_int, c_uint, pass_usize, ret, ret_owned_fd, ret_usize, size_of, slice, socklen_t, zero, @@ -25,6 +27,8 @@ use crate::fd::{BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; use crate::net::{ AddressFamily, Protocol, RawSocketAddr, RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer, Shutdown, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6, @@ -38,8 +42,8 @@ use { crate::backend::reg::{ArgReg, SocketArg}, linux_raw_sys::net::{ SYS_ACCEPT, SYS_ACCEPT4, SYS_BIND, SYS_CONNECT, SYS_GETPEERNAME, SYS_GETSOCKNAME, - SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMSG, SYS_SENDTO, - SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, + SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMMSG, SYS_SENDMSG, + SYS_SENDTO, SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, }, }; @@ -478,6 +482,30 @@ pub(crate) fn sendmsg_raw( }) } +#[cfg(target_os = "linux")] +#[inline] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + let (msgs, len) = slice_mut(msgs); + + #[cfg(not(target_arch = "x86"))] + let result = unsafe { ret_usize(syscall!(__NR_sendmmsg, sockfd, msgs, len, flags)) }; + + #[cfg(target_arch = "x86")] + let result = unsafe { + ret_usize(syscall!( + __NR_socketcall, + x86_sys(SYS_SENDMMSG), + slice_just_addr::, _>(&[sockfd.into(), msgs, len, flags.into()]) + )) + }; + + result +} + #[inline] pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> { #[cfg(not(target_arch = "x86"))] diff --git a/src/net/send_recv/msg.rs b/src/net/send_recv/msg.rs index 6806ab237..0e32442c4 100644 --- a/src/net/send_recv/msg.rs +++ b/src/net/send_recv/msg.rs @@ -2,6 +2,8 @@ #![allow(unsafe_code)] +#[cfg(target_os = "linux")] +use crate::backend::net::msghdr::{with_noaddr_msghdr, with_raw_msghdr}; use crate::backend::{self, c}; use crate::fd::{AsFd, BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; @@ -592,6 +594,48 @@ impl<'buf> Iterator for AncillaryDrain<'buf> { impl FusedIterator for AncillaryDrain<'_> {} +/// An ABI-compatible wrapper for `mmsghdr`, for sending multiple messages with +/// [sendmmsg]. +#[cfg(target_os = "linux")] +#[repr(transparent)] +pub struct MMsgHdr<'a> { + raw: c::mmsghdr, + _phantom: PhantomData<&'a mut ()>, +} + +#[cfg(target_os = "linux")] +impl<'a> MMsgHdr<'a> { + /// Constructs a new message with no destination address. + pub fn new(iov: &'a [IoSlice<'_>], control: &'a mut SendAncillaryBuffer<'_, '_, '_>) -> Self { + with_noaddr_msghdr(iov, control, Self::wrap) + } + + /// Constructs a new message to a specific address. + pub fn new_with_addr( + addr: &'a RawSocketAddr, + iov: &'a [IoSlice<'_>], + control: &'a mut SendAncillaryBuffer<'_, '_, '_>, + ) -> MMsgHdr<'a> { + with_raw_msghdr(addr, iov, control, Self::wrap) + } + + fn wrap(msg_hdr: c::msghdr) -> Self { + Self { + raw: c::mmsghdr { + msg_hdr, + msg_len: 0, + }, + _phantom: PhantomData, + } + } + + /// Returns the number of bytes sent. This will return 0 until after a + /// successful call to [sendmmsg]. + pub fn bytes_sent(&self) -> usize { + self.raw.msg_len as _ + } +} + /// `sendmsg(msghdr)`—Sends a message on a socket. /// /// # References @@ -818,6 +862,22 @@ pub fn sendmsg_raw( } } +/// `sendmmsg(msghdr)`—Sends multiple messages on a socket. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/sendmmsg.2.html +#[inline] +#[cfg(target_os = "linux")] +pub fn sendmmsg( + socket: impl AsFd, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags) +} + /// `recvmsg(msghdr)`—Receives a message from a socket. /// /// # References diff --git a/tests/net/v4.rs b/tests/net/v4.rs index d770b657b..49cadef6c 100644 --- a/tests/net/v4.rs +++ b/tests/net/v4.rs @@ -194,3 +194,91 @@ fn test_v4_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v4_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0); + bind_v4(&connection_socket, &name).unwrap(); + + let who = match getsockname(&connection_socket).unwrap() { + SocketAddrAny::V4(addr) => addr, + _ => panic!(), + }; + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port); + let data_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + connect_v4(&data_socket, &addr).unwrap(); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} diff --git a/tests/net/v6.rs b/tests/net/v6.rs index 0d0a596c9..a95402a43 100644 --- a/tests/net/v6.rs +++ b/tests/net/v6.rs @@ -193,3 +193,97 @@ fn test_v6_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v6_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::{sendmmsg, MMsgHdr, RawSocketAddr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 0, 0, 0); + bind_v6(&connection_socket, &name).unwrap(); + + let who = match getsockname(&connection_socket).unwrap() { + SocketAddrAny::V6(addr) => addr, + _ => panic!(), + }; + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), port, 0, 0); + let data_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + connect_v6(&data_socket, &addr).unwrap(); + + let raw_addr = RawSocketAddr::new(addr); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new_with_addr( + &raw_addr, + &[IoSlice::new(b"hello")], + &mut Default::default(), + ), + MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +}