Skip to content

Commit

Permalink
Fix sendmsg_unix's address encoding. (#885)
Browse files Browse the repository at this point in the history
When encoding the address for `sendmsg_unix`, use the `unix` field of
`SocketAddrUnix`, since the `unix` field is the `sockaddr_un` that the
OS will read.

Fixes #884.
  • Loading branch information
sunfishcode authored Oct 19, 2023
1 parent e35481c commit a59a191
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/backend/libc/net/msghdr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub(crate) fn with_unix_msghdr<R>(
) -> R {
f({
let mut h = zero_msghdr();
h.msg_name = as_ptr(addr) as _;
h.msg_name = as_ptr(&addr.unix) as _;
h.msg_namelen = addr.addr_len();
h.msg_iov = iov.as_ptr() as _;
h.msg_iovlen = msg_iov_len(iov.len());
Expand Down
2 changes: 1 addition & 1 deletion src/backend/linux_raw/net/msghdr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub(crate) fn with_unix_msghdr<R>(
f: impl FnOnce(c::msghdr) -> R,
) -> R {
f(c::msghdr {
msg_name: as_ptr(addr) as _,
msg_name: as_ptr(&addr.unix) as _,
msg_namelen: addr.addr_len() as _,
msg_iov: iov.as_ptr() as _,
msg_iovlen: msg_iov_len(iov.len()),
Expand Down
130 changes: 130 additions & 0 deletions tests/net/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,110 @@ fn do_test_unix_msg(addr: SocketAddrUnix) {
server.join().unwrap();
}

/// Similar to `do_test_unix_msg` but uses an unconnected socket and
/// `sendmsg_unix` instead of `sendmsg`.
#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))]
fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) {
use rustix::io::{IoSlice, IoSliceMut};
use rustix::net::{recvmsg, sendmsg_unix, RecvFlags, SendFlags};

let server = {
let runs: &[i32] = &[3, 184, 187, 0];
let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
bind_unix(&data_socket, &addr).unwrap();

move || {
let mut buffer = vec![0; BUFFER_SIZE];
for expected_sum in runs {
let mut sum = 0;
loop {
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
&mut Default::default(),
RecvFlags::empty(),
)
.unwrap()
.bytes;

assert_ne!(&buffer[..nread], b"exit");
if &buffer[..nread] == b"sum" {
break;
}

sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap();
}

assert_eq!(sum, *expected_sum);
}
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
&mut Default::default(),
RecvFlags::empty(),
)
.unwrap()
.bytes;

assert_eq!(&buffer[..nread], b"exit");
}
};

let client = move || {
let runs: &[&[&str]] = &[&["1", "2"], &["4", "77", "103"], &["5", "78", "104"], &[]];

for args in runs {
let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();

for arg in *args {
sendmsg_unix(
&data_socket,
&addr,
&[IoSlice::new(arg.as_bytes())],
&mut Default::default(),
SendFlags::empty(),
)
.unwrap();
}
sendmsg_unix(
&data_socket,
&addr,
&[IoSlice::new(b"sum")],
&mut Default::default(),
SendFlags::empty(),
)
.unwrap();
}

let data_socket = socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
sendmsg_unix(
&data_socket,
&addr,
&[IoSlice::new(b"exit")],
&mut Default::default(),
SendFlags::empty(),
)
.unwrap();
};

let server = thread::Builder::new()
.name("server".to_string())
.spawn(move || {
server();
})
.unwrap();

let client = thread::Builder::new()
.name("client".to_string())
.spawn(move || {
client();
})
.unwrap();

client.join().unwrap();
server.join().unwrap();
}

#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))]
#[test]
fn test_unix_msg() {
Expand All @@ -265,6 +369,19 @@ fn test_unix_msg() {
unlinkat(CWD, path, AtFlags::empty()).unwrap();
}

/// Like `test_unix_msg` but tests `do_test_unix_msg_unconnected`.
#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))]
#[test]
fn test_unix_msg_unconnected() {
let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("scp_4804");

let name = SocketAddrUnix::new(&path).unwrap();
do_test_unix_msg_unconnected(name);

unlinkat(CWD, path, AtFlags::empty()).unwrap();
}

#[cfg(linux_kernel)]
#[test]
fn test_abstract_unix_msg() {
Expand All @@ -277,6 +394,19 @@ fn test_abstract_unix_msg() {
do_test_unix_msg(name);
}

/// Like `test_abstract_unix_msg` but tests `do_test_unix_msg_unconnected`.
#[cfg(linux_kernel)]
#[test]
fn test_abstract_unix_msg_unconnected() {
use std::os::unix::ffi::OsStrExt;

let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("scp_4804");

let name = SocketAddrUnix::new_abstract_name(path.as_os_str().as_bytes()).unwrap();
do_test_unix_msg_unconnected(name);
}

#[cfg(not(any(target_os = "redox", target_os = "wasi")))]
#[test]
fn test_unix_msg_with_scm_rights() {
Expand Down

0 comments on commit a59a191

Please sign in to comment.