Skip to content

Commit

Permalink
Use MaybeUninit properly in streamer::recv_mmsg() (#3348)
Browse files Browse the repository at this point in the history
recv_mmsg() currently abuses the MaybeUninit API by calling
assume_init() on items that have not actually been initialized. This
creates the possibility for UB

To avoid potential UB, this change leaves items as MaybeUninits and
uses the MaybeUninit API to initialize/access/drop the appropriate
items
  • Loading branch information
steviez authored Oct 31, 2024
1 parent 4994762 commit af03b1d
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions streamer/src/recvmmsg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ use {
#[cfg(target_os = "linux")]
use {
itertools::izip,
libc::{iovec, mmsghdr, sockaddr_storage, socklen_t, AF_INET, AF_INET6, MSG_WAITFORONE},
libc::{
iovec, mmsghdr, msghdr, sockaddr_storage, socklen_t, AF_INET, AF_INET6, MSG_WAITFORONE,
},
std::{
mem,
mem::{self, MaybeUninit},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::unix::io::AsRawFd,
ptr,
},
};

Expand Down Expand Up @@ -78,16 +81,19 @@ fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option<SocketAddr
}

#[cfg(target_os = "linux")]
#[allow(clippy::uninit_assumed_init)]
pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num packets:*/ usize> {
// Should never hit this, but bail if the caller didn't provide any Packets
// to receive into
if packets.is_empty() {
return Ok(0);
}
// Assert that there are no leftovers in packets.
debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
const SOCKADDR_STORAGE_SIZE: usize = mem::size_of::<sockaddr_storage>();

let mut hdrs: [mmsghdr; NUM_RCVMMSGS] = unsafe { mem::zeroed() };
let iovs = mem::MaybeUninit::<[iovec; NUM_RCVMMSGS]>::uninit();
let mut iovs: [iovec; NUM_RCVMMSGS] = unsafe { iovs.assume_init() };
let mut addrs: [sockaddr_storage; NUM_RCVMMSGS] = unsafe { mem::zeroed() };
let mut iovs = [MaybeUninit::uninit(); NUM_RCVMMSGS];
let mut addrs = [MaybeUninit::zeroed(); NUM_RCVMMSGS];
let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS];

let sock_fd = sock.as_raw_fd();
let count = cmp::min(iovs.len(), packets.len());
Expand All @@ -96,15 +102,25 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num p
izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count)
{
let buffer = packet.buffer_mut();
*iov = iovec {
iov.write(iovec {
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
iov_len: buffer.len(),
};
hdr.msg_hdr.msg_name = addr as *mut _ as *mut _;
hdr.msg_hdr.msg_namelen = SOCKADDR_STORAGE_SIZE as socklen_t;
hdr.msg_hdr.msg_iov = iov;
hdr.msg_hdr.msg_iovlen = 1;
});

hdr.write(mmsghdr {
msg_len: 0,
msg_hdr: msghdr {
msg_name: addr.as_mut_ptr() as *mut _,
msg_namelen: SOCKADDR_STORAGE_SIZE as socklen_t,
msg_iov: iov.as_mut_ptr(),
msg_iovlen: 1,
msg_control: ptr::null::<libc::c_void>() as *mut _,
msg_controllen: 0,
msg_flags: 0,
},
});
}

let mut ts = libc::timespec {
tv_sec: 1,
tv_nsec: 0,
Expand All @@ -114,7 +130,7 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num p
let nrecv = unsafe {
libc::recvmmsg(
sock_fd,
&mut hdrs[0],
hdrs[0].assume_init_mut(),
count as u32,
MSG_WAITFORONE.try_into().unwrap(),
&mut ts,
Expand All @@ -126,11 +142,35 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num p
usize::try_from(nrecv).unwrap()
};
for (addr, hdr, pkt) in izip!(addrs, hdrs, packets.iter_mut()).take(nrecv) {
pkt.meta_mut().size = hdr.msg_len as usize;
if let Some(addr) = cast_socket_addr(&addr, &hdr) {
// SAFETY: We initialized `count` elements of `hdrs` above. `count` is
// passed to recvmmsg() as the limit of messages that can be read. So,
// `nrevc <= count` which means we initialized this `hdr` and
// recvmmsg() will have updated it appropriately
let hdr_ref = unsafe { hdr.assume_init_ref() };
// SAFETY: Similar to above, we initialized this `addr` and recvmmsg()
// will have populated it
let addr_ref = unsafe { addr.assume_init_ref() };
pkt.meta_mut().size = hdr_ref.msg_len as usize;
if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) {
pkt.meta_mut().set_socket_addr(&addr);
}
}

for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) {
// SAFETY: We initialized `count` elements of each array above
//
// It may be that `packets.len() != NUM_RCVMMSGS`; thus, some elements
// in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must
// manually drop `count` elements from each array instead of being able
// to convert [MaybeUninit<T>] to [T] and letting `Drop` do the work
// for us when these items go out of scope at the end of the function
unsafe {
iov.assume_init_drop();
addr.assume_init_drop();
hdr.assume_init_drop();
}
}

Ok(nrecv)
}

Expand Down

0 comments on commit af03b1d

Please sign in to comment.