From 61882545786ebfe8a66a56b6c45fc6133f40855e Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Sat, 21 Oct 2023 08:19:13 -0700 Subject: [PATCH] Make `cmsg_space!` usable in const contexts. Make `cmsg_space!` usable in const contexts, so that it can be used as a buffer size argument, and add a version of tests/net/unix.rs that uses stack-allocated buffers instead of `Vec`s. This exposes an alignment sublety, that buffers must be aligned to the needed alignment of `cmsghdr`; handle this by auto-aligning the provided buffer to the needed boundary. --- src/net/send_recv/msg.rs | 61 ++-- tests/net/main.rs | 2 + tests/net/unix.rs | 24 +- tests/net/unix_alloc.rs | 635 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 691 insertions(+), 31 deletions(-) create mode 100644 tests/net/unix_alloc.rs diff --git a/src/net/send_recv/msg.rs b/src/net/send_recv/msg.rs index 0c51666d2..d67a82699 100644 --- a/src/net/send_recv/msg.rs +++ b/src/net/send_recv/msg.rs @@ -10,7 +10,7 @@ use crate::net::UCred; use core::iter::FusedIterator; use core::marker::PhantomData; -use core::mem::{size_of, size_of_val, take}; +use core::mem::{align_of, size_of, size_of_val, take}; use core::{ptr, slice}; use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6}; @@ -40,8 +40,19 @@ macro_rules! cmsg_space { } #[doc(hidden)] -pub fn __cmsg_space(len: usize) -> usize { - unsafe { c::CMSG_SPACE(len.try_into().expect("CMSG_SPACE size overflow")) as usize } +pub const fn __cmsg_space(len: usize) -> usize { + // Add `align_of::()` so that we can align the user-provided + // `&[u8]` to the required alignment boundary. + let len = len + align_of::(); + + // Convert `len` to `u32` for `CMSG_SPACE`. This would be `try_into()` if + // we could call that in a `const fn`. + let converted_len = len as u32; + if converted_len as usize != len { + unreachable!(); // `CMSG_SPACE` size overflow + } + + unsafe { c::CMSG_SPACE(converted_len) as usize } } /// Ancillary message for [`sendmsg`], [`sendmsg_v4`], [`sendmsg_v6`], @@ -59,19 +70,11 @@ impl SendAncillaryMessage<'_, '_> { /// Get the maximum size of an ancillary message. /// /// This can be helpful in determining the size of the buffer you allocate. - pub fn size(&self) -> usize { - let total_bytes = match self { - Self::ScmRights(slice) => size_of_val(*slice), + pub const fn size(&self) -> usize { + match self { + Self::ScmRights(slice) => cmsg_space!(ScmRights(slice.len())), #[cfg(linux_kernel)] - Self::ScmCredentials(ucred) => size_of_val(ucred), - }; - - unsafe { - c::CMSG_SPACE( - total_bytes - .try_into() - .expect("size too large for CMSG_SPACE"), - ) as usize + Self::ScmCredentials(_) => cmsg_space!(ScmCredentials(1)), } } } @@ -107,15 +110,20 @@ impl<'buf> From<&'buf mut [u8]> for SendAncillaryBuffer<'buf, '_, '_> { impl Default for SendAncillaryBuffer<'_, '_, '_> { fn default() -> Self { - Self::new(&mut []) + Self { + buffer: &mut [], + length: 0, + _phantom: PhantomData, + } } } impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> { /// Create a new, empty `SendAncillaryBuffer` from a raw byte buffer. + #[inline] pub fn new(buffer: &'buf mut [u8]) -> Self { Self { - buffer, + buffer: align_for_cmsghdr(buffer), length: 0, _phantom: PhantomData, } @@ -234,15 +242,20 @@ impl<'buf> From<&'buf mut [u8]> for RecvAncillaryBuffer<'buf> { impl Default for RecvAncillaryBuffer<'_> { fn default() -> Self { - Self::new(&mut []) + Self { + buffer: &mut [], + read: 0, + length: 0, + } } } impl<'buf> RecvAncillaryBuffer<'buf> { /// Create a new, empty `RecvAncillaryBuffer` from a raw byte buffer. + #[inline] pub fn new(buffer: &'buf mut [u8]) -> Self { Self { - buffer, + buffer: align_for_cmsghdr(buffer), read: 0, length: 0, } @@ -297,6 +310,16 @@ impl Drop for RecvAncillaryBuffer<'_> { } } +/// Return a slice of `buffer` starting at the first `cmsghdr` alignment +/// boundary. +#[inline] +fn align_for_cmsghdr(buffer: &mut [u8]) -> &mut [u8] { + let align = align_of::(); + let addr = buffer.as_ptr() as usize; + let adjusted = (addr + (align - 1)) & align.wrapping_neg(); + &mut buffer[adjusted - addr..] +} + /// An iterator that drains messages from a [`RecvAncillaryBuffer`]. pub struct AncillaryDrain<'buf> { /// Inner iterator over messages. diff --git a/tests/net/main.rs b/tests/net/main.rs index 65edd00d5..e4b37878a 100644 --- a/tests/net/main.rs +++ b/tests/net/main.rs @@ -12,6 +12,8 @@ mod poll; mod sockopt; #[cfg(unix)] mod unix; +#[cfg(unix)] +mod unix_alloc; mod v4; mod v6; diff --git a/tests/net/unix.rs b/tests/net/unix.rs index 5aa06b188..eea825b27 100644 --- a/tests/net/unix.rs +++ b/tests/net/unix.rs @@ -35,7 +35,7 @@ fn server(ready: Arc<(Mutex, Condvar)>, path: &Path) { cvar.notify_all(); } - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; 'exit: loop { let data_socket = accept(&connection_socket).unwrap(); let mut sum = 0; @@ -68,7 +68,7 @@ fn client(ready: Arc<(Mutex, Condvar)>, path: &Path, runs: &[(&[&str], i32 } let addr = SocketAddrUnix::new(path).unwrap(); - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; for (args, sum) in runs { let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); @@ -136,7 +136,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) { listen(&connection_socket, 1).unwrap(); move || { - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; 'exit: loop { let data_socket = accept(&connection_socket).unwrap(); let mut sum = 0; @@ -173,7 +173,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) { }; let client = move || { - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; let runs: &[(&[&str], i32)] = &[ (&["1", "2"], 3), (&["4", "77", "103"], 184), @@ -266,7 +266,7 @@ fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) { bind_unix(&data_socket, &addr).unwrap(); move || { - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; for expected_sum in runs { let mut sum = 0; loop { @@ -434,8 +434,8 @@ fn test_unix_msg_with_scm_rights() { move || { let mut pipe_end = None; - let mut buffer = vec![0; BUFFER_SIZE]; - let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))]; + let mut buffer = [0; BUFFER_SIZE]; + let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(1))]; 'exit: loop { let data_socket = accept(&connection_socket).unwrap(); @@ -495,7 +495,7 @@ fn test_unix_msg_with_scm_rights() { let client = move || { let addr = SocketAddrUnix::new(path).unwrap(); let (read_end, write_end) = pipe().unwrap(); - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; let runs: &[(&[&str], i32)] = &[ (&["1", "2"], 3), (&["4", "77", "103"], 184), @@ -543,7 +543,7 @@ fn test_unix_msg_with_scm_rights() { // Format the CMSG. let we = [write_end.as_fd()]; let msg = SendAncillaryMessage::ScmRights(&we); - let mut space = vec![0; msg.size()]; + let mut space = [0; rustix::cmsg_space!(ScmRights(1))]; let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space); assert!(cmsg_buffer.push(msg)); @@ -606,7 +606,7 @@ fn test_unix_peercred() { assert_eq!(ucred.gid, getgid()); let msg = SendAncillaryMessage::ScmCredentials(ucred); - let mut space = vec![0; msg.size()]; + let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))]; let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space); assert!(cmsg_buffer.push(msg)); @@ -618,10 +618,10 @@ fn test_unix_peercred() { ) .unwrap(); - let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))]; + let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))]; let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space); - let mut buffer = vec![0; BUFFER_SIZE]; + let mut buffer = [0; BUFFER_SIZE]; recvmsg( &recv_sock, &mut [IoSliceMut::new(&mut buffer)], diff --git a/tests/net/unix_alloc.rs b/tests/net/unix_alloc.rs new file mode 100644 index 000000000..a69047cb5 --- /dev/null +++ b/tests/net/unix_alloc.rs @@ -0,0 +1,635 @@ +//! Like unix.rs, but uses `Vec`s for the buffers. + +// This test uses `AF_UNIX` with `SOCK_SEQPACKET` which is unsupported on macOS. +#![cfg(not(any(apple, target_os = "espidf", target_os = "redox", target_os = "wasi")))] +// This test uses `DecInt`. +#![cfg(feature = "itoa")] +#![cfg(feature = "fs")] + +use rustix::fs::{unlinkat, AtFlags, CWD}; +use rustix::io::{read, write}; +use rustix::net::{ + accept, bind_unix, connect_unix, listen, socket, AddressFamily, SocketAddrUnix, SocketType, +}; +use rustix::path::DecInt; +use std::path::Path; +use std::str::FromStr; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread; + +const BUFFER_SIZE: usize = 20; + +fn server(ready: Arc<(Mutex, Condvar)>, path: &Path) { + let connection_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + + let name = SocketAddrUnix::new(path).unwrap(); + bind_unix(&connection_socket, &name).unwrap(); + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut started = lock.lock().unwrap(); + *started = true; + cvar.notify_all(); + } + + let mut buffer = vec![0; BUFFER_SIZE]; + 'exit: loop { + let data_socket = accept(&connection_socket).unwrap(); + let mut sum = 0; + loop { + let nread = read(&data_socket, &mut buffer).unwrap(); + + if &buffer[..nread] == b"exit" { + break 'exit; + } + if &buffer[..nread] == b"sum" { + break; + } + + sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(); + } + + write(&data_socket, DecInt::new(sum).as_bytes()).unwrap(); + } + + unlinkat(CWD, path, AtFlags::empty()).unwrap(); +} + +fn client(ready: Arc<(Mutex, Condvar)>, path: &Path, runs: &[(&[&str], i32)]) { + { + let (lock, cvar) = &*ready; + let mut started = lock.lock().unwrap(); + while !*started { + started = cvar.wait(started).unwrap(); + } + } + + let addr = SocketAddrUnix::new(path).unwrap(); + let mut buffer = vec![0; BUFFER_SIZE]; + + for (args, sum) in runs { + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + connect_unix(&data_socket, &addr).unwrap(); + + for arg in *args { + write(&data_socket, arg.as_bytes()).unwrap(); + } + write(&data_socket, b"sum").unwrap(); + + let nread = read(&data_socket, &mut buffer).unwrap(); + assert_eq!( + i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(), + *sum + ); + } + + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + connect_unix(&data_socket, &addr).unwrap(); + write(&data_socket, b"exit").unwrap(); +} + +#[test] +fn test_unix() { + let ready = Arc::new((Mutex::new(false), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let tmp = tempfile::tempdir().unwrap(); + let path = tmp.path().join("soccer"); + let send_path = path.to_owned(); + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready, &send_path); + }) + .unwrap(); + let send_path = path.to_owned(); + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client( + ready_clone, + &send_path, + &[ + (&["1", "2"], 3), + (&["4", "77", "103"], 184), + (&["5", "78", "104"], 187), + (&[], 0), + ], + ); + }) + .unwrap(); + client.join().unwrap(); + server.join().unwrap(); +} + +#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))] +fn do_test_unix_msg(addr: SocketAddrUnix) { + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{recvmsg, sendmsg, RecvFlags, SendFlags}; + + let server = { + let connection_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + bind_unix(&connection_socket, &addr).unwrap(); + listen(&connection_socket, 1).unwrap(); + + move || { + let mut buffer = vec![0; BUFFER_SIZE]; + 'exit: loop { + let data_socket = accept(&connection_socket).unwrap(); + let mut sum = 0; + loop { + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + if &buffer[..nread] == b"exit" { + break 'exit; + } + if &buffer[..nread] == b"sum" { + break; + } + + sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(); + } + + let data = sum.to_string(); + sendmsg( + &data_socket, + &[IoSlice::new(data.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + } + }; + + let client = move || { + let mut buffer = vec![0; BUFFER_SIZE]; + let runs: &[(&[&str], i32)] = &[ + (&["1", "2"], 3), + (&["4", "77", "103"], 184), + (&["5", "78", "104"], 187), + (&[], 0), + ]; + + for (args, sum) in runs { + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + connect_unix(&data_socket, &addr).unwrap(); + + for arg in *args { + sendmsg( + &data_socket, + &[IoSlice::new(arg.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + sendmsg( + &data_socket, + &[IoSlice::new(b"sum")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + + let result = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap(); + let nread = result.bytes; + assert_eq!( + i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(), + *sum + ); + // Don't ask me why, but this was seen to fail on FreeBSD. + // `SocketAddrUnix::path()` returned `None` for some reason. + // illumos and NetBSD too. + #[cfg(not(any(target_os = "freebsd", target_os = "illumos", target_os = "netbsd")))] + assert_eq!( + Some(rustix::net::SocketAddrAny::Unix(addr.clone())), + result.address + ); + } + + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + connect_unix(&data_socket, &addr).unwrap(); + sendmsg( + &data_socket, + &[IoSlice::new(b"exit")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + }; + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} + +/// Similar to `do_test_unix_msg` but uses an unconnected socket and +/// `sendmsg_unix` instead of `sendmsg`. +#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))] +fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) { + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{recvmsg, sendmsg_unix, RecvFlags, SendFlags}; + + let server = { + let runs: &[i32] = &[3, 184, 187, 0]; + let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap(); + bind_unix(&data_socket, &addr).unwrap(); + + move || { + let mut buffer = vec![0; BUFFER_SIZE]; + for expected_sum in runs { + let mut sum = 0; + loop { + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + assert_ne!(&buffer[..nread], b"exit"); + if &buffer[..nread] == b"sum" { + break; + } + + sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(); + } + + assert_eq!(sum, *expected_sum); + } + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + assert_eq!(&buffer[..nread], b"exit"); + } + }; + + let client = move || { + let runs: &[&[&str]] = &[&["1", "2"], &["4", "77", "103"], &["5", "78", "104"], &[]]; + + for args in runs { + let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap(); + + for arg in *args { + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(arg.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(b"sum")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + + let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap(); + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(b"exit")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + }; + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} + +#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))] +#[test] +fn test_unix_msg() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new(&path).unwrap(); + do_test_unix_msg(name); + + unlinkat(CWD, path, AtFlags::empty()).unwrap(); +} + +/// Like `test_unix_msg` but tests `do_test_unix_msg_unconnected`. +#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))] +#[test] +fn test_unix_msg_unconnected() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new(&path).unwrap(); + do_test_unix_msg_unconnected(name); + + unlinkat(CWD, path, AtFlags::empty()).unwrap(); +} + +#[cfg(linux_kernel)] +#[test] +fn test_abstract_unix_msg() { + use std::os::unix::ffi::OsStrExt; + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new_abstract_name(path.as_os_str().as_bytes()).unwrap(); + do_test_unix_msg(name); +} + +/// Like `test_abstract_unix_msg` but tests `do_test_unix_msg_unconnected`. +#[cfg(linux_kernel)] +#[test] +fn test_abstract_unix_msg_unconnected() { + use std::os::unix::ffi::OsStrExt; + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new_abstract_name(path.as_os_str().as_bytes()).unwrap(); + do_test_unix_msg_unconnected(name); +} + +#[cfg(not(any(target_os = "redox", target_os = "wasi")))] +#[test] +fn test_unix_msg_with_scm_rights() { + use rustix::fd::AsFd; + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{ + recvmsg, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, + SendAncillaryBuffer, SendAncillaryMessage, SendFlags, + }; + use rustix::pipe::pipe; + use std::string::ToString; + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let server = { + let path = path.clone(); + + let connection_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + + let name = SocketAddrUnix::new(&path).unwrap(); + bind_unix(&connection_socket, &name).unwrap(); + listen(&connection_socket, 1).unwrap(); + + move || { + let mut pipe_end = None; + + let mut buffer = vec![0; BUFFER_SIZE]; + let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))]; + + 'exit: loop { + let data_socket = accept(&connection_socket).unwrap(); + let mut sum = 0; + loop { + let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space); + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut cmsg_buffer, + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + // Read out the pipe if we got it. + if let Some(end) = cmsg_buffer + .drain() + .filter_map(|msg| match msg { + RecvAncillaryMessage::ScmRights(rights) => Some(rights), + _ => None, + }) + .flatten() + .next() + { + pipe_end = Some(end); + } + + if &buffer[..nread] == b"exit" { + break 'exit; + } + if &buffer[..nread] == b"sum" { + break; + } + + sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(); + } + + let data = sum.to_string(); + sendmsg( + &data_socket, + &[IoSlice::new(data.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + + unlinkat(CWD, path, AtFlags::empty()).unwrap(); + + // Once we're done, send a message along the pipe. + let pipe = pipe_end.unwrap(); + write(&pipe, b"pipe message!").unwrap(); + } + }; + + let client = move || { + let addr = SocketAddrUnix::new(path).unwrap(); + let (read_end, write_end) = pipe().unwrap(); + let mut buffer = vec![0; BUFFER_SIZE]; + let runs: &[(&[&str], i32)] = &[ + (&["1", "2"], 3), + (&["4", "77", "103"], 184), + (&["5", "78", "104"], 187), + (&[], 0), + ]; + + for (args, sum) in runs { + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + connect_unix(&data_socket, &addr).unwrap(); + + for arg in *args { + sendmsg( + &data_socket, + &[IoSlice::new(arg.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + sendmsg( + &data_socket, + &[IoSlice::new(b"sum")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + assert_eq!( + i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(), + *sum + ); + } + + let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap(); + + // Format the CMSG. + let we = [write_end.as_fd()]; + let msg = SendAncillaryMessage::ScmRights(&we); + let mut space = vec![0; msg.size()]; + let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space); + assert!(cmsg_buffer.push(msg)); + + connect_unix(&data_socket, &addr).unwrap(); + sendmsg( + &data_socket, + &[IoSlice::new(b"exit")], + &mut cmsg_buffer, + SendFlags::empty(), + ) + .unwrap(); + + // Read a value from the pipe. + let mut buffer = [0u8; 13]; + read(&read_end, &mut buffer).unwrap(); + assert_eq!(&buffer, b"pipe message!".as_ref()); + }; + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} + +#[cfg(all(feature = "process", linux_kernel))] +#[test] +fn test_unix_peercred() { + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{ + recvmsg, sendmsg, sockopt, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, + SendAncillaryBuffer, SendAncillaryMessage, SendFlags, SocketFlags, + }; + use rustix::process::{getgid, getpid, getuid}; + + let (send_sock, recv_sock) = rustix::net::socketpair( + AddressFamily::UNIX, + SocketType::STREAM, + SocketFlags::CLOEXEC, + None, + ) + .unwrap(); + + sockopt::set_socket_passcred(&recv_sock, true).unwrap(); + + let ucred = sockopt::get_socket_peercred(&send_sock).unwrap(); + assert_eq!(ucred.pid, getpid()); + assert_eq!(ucred.uid, getuid()); + assert_eq!(ucred.gid, getgid()); + + let msg = SendAncillaryMessage::ScmCredentials(ucred); + let mut space = vec![0; msg.size()]; + let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space); + assert!(cmsg_buffer.push(msg)); + + sendmsg( + &send_sock, + &[IoSlice::new(b"cred")], + &mut cmsg_buffer, + SendFlags::empty(), + ) + .unwrap(); + + let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))]; + let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space); + + let mut buffer = vec![0; BUFFER_SIZE]; + recvmsg( + &recv_sock, + &mut [IoSliceMut::new(&mut buffer)], + &mut cmsg_buffer, + RecvFlags::empty(), + ) + .unwrap(); + + match cmsg_buffer.drain().next().unwrap() { + RecvAncillaryMessage::ScmCredentials(ucred2) => assert_eq!(ucred2, ucred), + _ => panic!("Unexpected ancilliary message"), + }; +}