Skip to content

Commit

Permalink
adjust sendmmsg wrapper to use stack allocated libc types (#4190)
Browse files Browse the repository at this point in the history
* adjust sendmmsg wrapper to use stack allocated libc types

* use constants for sockaddr padding

* chunk packets to handle packets.len() > MAX_IOV

* target_os = "linux" for MAX_IOV

* assume_init_drop for initialized hdrs, iovs, addrs
  • Loading branch information
cpubot authored Dec 30, 2024
1 parent 937af8b commit 7f2bbd3
Showing 1 changed file with 60 additions and 15 deletions.
75 changes: 60 additions & 15 deletions streamer/src/sendmmsg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ fn mmsghdr_for_packet(
) {
const SIZE_OF_SOCKADDR_IN: usize = mem::size_of::<sockaddr_in>();
const SIZE_OF_SOCKADDR_IN6: usize = mem::size_of::<sockaddr_in6>();
const SIZE_OF_SOCKADDR_STORAGE: usize = mem::size_of::<sockaddr_storage>();
const SOCKADDR_IN_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN;
const SOCKADDR_IN6_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN6;

iov.write(iovec {
iov_base: packet.as_ptr() as *mut libc::c_void,
Expand All @@ -76,20 +79,34 @@ fn mmsghdr_for_packet(

let msg_namelen = match dest {
SocketAddr::V4(socket_addr_v4) => {
let ptr: *mut sockaddr_in = addr.as_mut_ptr() as *mut _;
unsafe {
ptr::write(
addr.as_mut_ptr() as *mut _,
ptr,
*nix::sys::socket::SockaddrIn::from(*socket_addr_v4).as_ref(),
);
// Zero the remaining bytes after sockaddr_in
ptr::write_bytes(
(ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN),
0,
SOCKADDR_IN_PADDING,
);
}
SIZE_OF_SOCKADDR_IN as socklen_t
}
SocketAddr::V6(socket_addr_v6) => {
let ptr: *mut sockaddr_in6 = addr.as_mut_ptr() as *mut _;
unsafe {
ptr::write(
addr.as_mut_ptr() as *mut _,
ptr,
*nix::sys::socket::SockaddrIn6::from(*socket_addr_v6).as_ref(),
);
// Zero the remaining bytes after sockaddr_in6
ptr::write_bytes(
(ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN6),
0,
SOCKADDR_IN6_PADDING,
);
}
SIZE_OF_SOCKADDR_IN6 as socklen_t
}
Expand Down Expand Up @@ -161,27 +178,55 @@ fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPkts
}

#[cfg(target_os = "linux")]
pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
const MAX_IOV: usize = libc::UIO_MAXIOV as usize;

#[cfg(target_os = "linux")]
pub fn batch_send_max_iov<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
let size = packets.len();
let mut iovs = vec![MaybeUninit::uninit(); size];
let mut addrs = vec![MaybeUninit::zeroed(); size];
let mut hdrs = vec![MaybeUninit::uninit(); size];
assert!(packets.len() <= MAX_IOV);

let mut iovs = [MaybeUninit::uninit(); MAX_IOV];
let mut addrs = [MaybeUninit::uninit(); MAX_IOV];
let mut hdrs = [MaybeUninit::uninit(); MAX_IOV];

// izip! will iterate packets.len() times, leaving hdrs, iovs, and addrs initialized only up to packets.len()
for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
}
// mmsghdr_for_packet() performs initialization so we can safely transmute
// the Vecs to their initialized counterparts
let _iovs = unsafe { mem::transmute::<Vec<MaybeUninit<iovec>>, Vec<iovec>>(iovs) };
let _addrs = unsafe {
mem::transmute::<Vec<MaybeUninit<sockaddr_storage>>, Vec<sockaddr_storage>>(addrs)
};
let mut hdrs = unsafe { mem::transmute::<Vec<MaybeUninit<mmsghdr>>, Vec<mmsghdr>>(hdrs) };

sendmmsg_retry(sock, &mut hdrs)
// SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
// guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
let hdrs_slice =
unsafe { std::slice::from_raw_parts_mut(hdrs.as_mut_ptr() as *mut mmsghdr, packets.len()) };

let result = sendmmsg_retry(sock, hdrs_slice);

// SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
// guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(packets.len()) {
unsafe {
hdr.assume_init_drop();
iov.assume_init_drop();
addr.assume_init_drop();
}
}

result
}

#[cfg(target_os = "linux")]
pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
for chunk in packets.chunks(MAX_IOV) {
batch_send_max_iov(sock, chunk)?;
}
Ok(())
}

pub fn multi_target_send<S, T>(
Expand Down

0 comments on commit 7f2bbd3

Please sign in to comment.