Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make cmsg_space! usable in const contexts. #889

Merged
merged 1 commit into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::net::UCred;

use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::{size_of, size_of_val, take};
use core::mem::{align_of, size_of, size_of_val, take};
use core::{ptr, slice};

use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};
Expand Down Expand Up @@ -40,8 +40,19 @@ macro_rules! cmsg_space {
}

#[doc(hidden)]
pub fn __cmsg_space(len: usize) -> usize {
unsafe { c::CMSG_SPACE(len.try_into().expect("CMSG_SPACE size overflow")) as usize }
pub const fn __cmsg_space(len: usize) -> usize {
// Add `align_of::<c::cmsghdr>()` so that we can align the user-provided
// `&[u8]` to the required alignment boundary.
let len = len + align_of::<c::cmsghdr>();

// Convert `len` to `u32` for `CMSG_SPACE`. This would be `try_into()` if
// we could call that in a `const fn`.
let converted_len = len as u32;
if converted_len as usize != len {
unreachable!(); // `CMSG_SPACE` size overflow
}

unsafe { c::CMSG_SPACE(converted_len) as usize }
}

/// Ancillary message for [`sendmsg`], [`sendmsg_v4`], [`sendmsg_v6`],
Expand All @@ -59,19 +70,11 @@ impl SendAncillaryMessage<'_, '_> {
/// Get the maximum size of an ancillary message.
///
/// This can be helpful in determining the size of the buffer you allocate.
pub fn size(&self) -> usize {
let total_bytes = match self {
Self::ScmRights(slice) => size_of_val(*slice),
pub const fn size(&self) -> usize {
match self {
Self::ScmRights(slice) => cmsg_space!(ScmRights(slice.len())),
#[cfg(linux_kernel)]
Self::ScmCredentials(ucred) => size_of_val(ucred),
};

unsafe {
c::CMSG_SPACE(
total_bytes
.try_into()
.expect("size too large for CMSG_SPACE"),
) as usize
Self::ScmCredentials(_) => cmsg_space!(ScmCredentials(1)),
}
}
}
Expand Down Expand Up @@ -107,15 +110,20 @@ impl<'buf> From<&'buf mut [u8]> for SendAncillaryBuffer<'buf, '_, '_> {

impl Default for SendAncillaryBuffer<'_, '_, '_> {
fn default() -> Self {
Self::new(&mut [])
Self {
buffer: &mut [],
length: 0,
_phantom: PhantomData,
}
}
}

impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// Create a new, empty `SendAncillaryBuffer` from a raw byte buffer.
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self {
buffer,
buffer: align_for_cmsghdr(buffer),
length: 0,
_phantom: PhantomData,
}
Expand Down Expand Up @@ -234,15 +242,20 @@ impl<'buf> From<&'buf mut [u8]> for RecvAncillaryBuffer<'buf> {

impl Default for RecvAncillaryBuffer<'_> {
fn default() -> Self {
Self::new(&mut [])
Self {
buffer: &mut [],
read: 0,
length: 0,
}
}
}

impl<'buf> RecvAncillaryBuffer<'buf> {
/// Create a new, empty `RecvAncillaryBuffer` from a raw byte buffer.
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self {
buffer,
buffer: align_for_cmsghdr(buffer),
read: 0,
length: 0,
}
Expand Down Expand Up @@ -297,6 +310,16 @@ impl Drop for RecvAncillaryBuffer<'_> {
}
}

/// Return a slice of `buffer` starting at the first `cmsghdr` alignment
/// boundary.
#[inline]
fn align_for_cmsghdr(buffer: &mut [u8]) -> &mut [u8] {
let align = align_of::<c::cmsghdr>();
let addr = buffer.as_ptr() as usize;
let adjusted = (addr + (align - 1)) & align.wrapping_neg();
&mut buffer[adjusted - addr..]
}

/// An iterator that drains messages from a [`RecvAncillaryBuffer`].
pub struct AncillaryDrain<'buf> {
/// Inner iterator over messages.
Expand Down
2 changes: 2 additions & 0 deletions tests/net/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod poll;
mod sockopt;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
mod unix_alloc;
mod v4;
mod v6;

Expand Down
24 changes: 12 additions & 12 deletions tests/net/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn server(ready: Arc<(Mutex<bool>, Condvar)>, path: &Path) {
cvar.notify_all();
}

let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
Expand Down Expand Up @@ -68,7 +68,7 @@ fn client(ready: Arc<(Mutex<bool>, Condvar)>, path: &Path, runs: &[(&[&str], i32
}

let addr = SocketAddrUnix::new(path).unwrap();
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];

for (args, sum) in runs {
let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();
Expand Down Expand Up @@ -136,7 +136,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) {
listen(&connection_socket, 1).unwrap();

move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
Expand Down Expand Up @@ -173,7 +173,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) {
};

let client = move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
let runs: &[(&[&str], i32)] = &[
(&["1", "2"], 3),
(&["4", "77", "103"], 184),
Expand Down Expand Up @@ -266,7 +266,7 @@ fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) {
bind_unix(&data_socket, &addr).unwrap();

move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
for expected_sum in runs {
let mut sum = 0;
loop {
Expand Down Expand Up @@ -434,8 +434,8 @@ fn test_unix_msg_with_scm_rights() {
move || {
let mut pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))];
let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
Expand Down Expand Up @@ -495,7 +495,7 @@ fn test_unix_msg_with_scm_rights() {
let client = move || {
let addr = SocketAddrUnix::new(path).unwrap();
let (read_end, write_end) = pipe().unwrap();
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
let runs: &[(&[&str], i32)] = &[
(&["1", "2"], 3),
(&["4", "77", "103"], 184),
Expand Down Expand Up @@ -543,7 +543,7 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = vec![0; msg.size()];
let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
assert!(cmsg_buffer.push(msg));

Expand Down Expand Up @@ -606,7 +606,7 @@ fn test_unix_peercred() {
assert_eq!(ucred.gid, getgid());

let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = vec![0; msg.size()];
let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
assert!(cmsg_buffer.push(msg));

Expand All @@ -618,10 +618,10 @@ fn test_unix_peercred() {
)
.unwrap();

let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);

let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
recvmsg(
&recv_sock,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down
Loading
Loading